Coverage for core / attention.py: 93%
198 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 15:35 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 15:35 +0200
1import math
3import flax.nnx as nnx
4import jax
5import jax.numpy as jnp
7from .config import Config
8from .lora import LoRALinear
11class Attention(nnx.Module):
12 def __init__(self, config: Config, rngs: nnx.Rngs):
13 self.max_context: int = config.max_context
14 self.head_size: int = config.head_size
15 self.n_heads: int = config.n_heads
16 self.dim: int = config.dim
17 self.kv_heads: int = config.kv_heads if config.kv_heads is not None else self.n_heads
19 _use_lora_attn = getattr(config, "use_lora", False) and getattr(config, "lora_targets", "attention") in ("attention", "all")
20 _lora_kw: dict = dict(rank=getattr(config, "lora_rank", 8), alpha=getattr(config, "lora_alpha", 16.0), dropout_rate=getattr(config, "lora_dropout", 0.0), rngs=rngs)
22 qkv_out = self.dim + 2 * self.kv_heads * self.head_size
23 self.qkv: nnx.Linear | LoRALinear = (
24 LoRALinear(self.dim, qkv_out, use_bias=False, **_lora_kw)
25 if _use_lora_attn
26 else nnx.Linear(self.dim, qkv_out, use_bias=False, rngs=rngs)
27 )
28 self.tril: jnp.ndarray = jnp.tril(
29 jnp.ones((self.max_context, self.max_context), dtype=bool)
30 )
31 self.o_proj: nnx.Linear | LoRALinear = (
32 LoRALinear(self.dim, self.dim, **_lora_kw)
33 if _use_lora_attn
34 else nnx.Linear(self.dim, self.dim, rngs=rngs)
35 )
36 self.no_sink: bool = config.no_sink
37 self.W: nnx.Linear = nnx.Linear(self.dim, self.dim, rngs=rngs)
39 self.sliding_window: bool = config.sliding_window
40 if self.sliding_window:
41 table = jnp.arange(self.max_context)[:, None] - jnp.arange(self.max_context)[None, :]
42 mask = (table <= config.context_window) & (table >= 0)
43 self.window = jnp.where(mask, 0, -1e9)
45 self.use_rotary: bool = config.use_rotary_pos
46 self.use_flash_attention: bool = config.use_flash_attention
48 self.attn_dropout: nnx.Dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
49 self.resid_dropout: nnx.Dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
51 # MLA projections (always allocated; only used when config.mla=True)
52 self.down_q: nnx.Linear = nnx.Linear(config.dim, config.down_dim_q, rngs=rngs)
53 self.down_kv: nnx.Linear = nnx.Linear(config.dim, config.down_dim_kv, rngs=rngs)
54 self.up_q: nnx.Linear = nnx.Linear(config.down_dim_q, config.head_size * config.n_heads, rngs=rngs)
55 self.up_k: nnx.Linear = nnx.Linear(config.down_dim_kv, config.head_size * config.kv_heads, rngs=rngs)
56 self.up_v: nnx.Linear = nnx.Linear(config.down_dim_kv, config.head_size * config.kv_heads, rngs=rngs)
58 self.down_dim_q: int = config.down_dim_q
59 self.down_dim_kv: int = config.down_dim_kv
60 self.mla: bool = config.mla
61 self.inference: bool = config.inference
62 self.rope_dim: int = config.rope_dim if hasattr(config, "rope_dim") else self.head_size // 2
64 if self.mla:
65 self.q_pe: nnx.Linear = nnx.Linear(config.dim, self.rope_dim, rngs=rngs)
66 self.k_pe: nnx.Linear = nnx.Linear(config.dim, self.rope_dim, rngs=rngs)
67 self.norm_q = nnx.LayerNorm(config.down_dim_q, rngs=rngs)
68 self.norm_kv = nnx.LayerNorm(config.down_dim_kv, rngs=rngs)
70 # RoPE frequency table — scaled by rope_scale_factor for NTK-aware extension
71 rope_dim_size = self.head_size if not self.mla else self.rope_dim
72 self._rope_scale = getattr(config, "rope_scale_factor", 1.0)
73 self.angle: jnp.ndarray = self._compute_angle(self.max_context, rope_dim_size)
75 # ── RoPE helpers ──────────────────────────────────────────────────────────
77 def _compute_angle(self, T: int, C: int) -> jnp.ndarray:
78 """Precompute RoPE angles [1, 1, 1, T, C//2].
80 When ``rope_scale_factor > 1`` the base frequency is compressed
81 (NTK-aware scaling), allowing the model to generalise to contexts
82 longer than ``max_context`` without fine-tuning.
83 """
84 P = jnp.arange(T, dtype=jnp.float32)
85 base = 10000.0 * self._rope_scale
86 inv_freq = 1.0 / (base ** (jnp.arange(0, C, 2, dtype=jnp.float32) / C))
87 degree = jnp.einsum("i,j->ij", P, inv_freq)
88 return degree[None, None, None, :, :] # [1, 1, 1, T, C//2]
90 def _apply_rope_grouped(self, x: jnp.ndarray, cache_index: int) -> jnp.ndarray:
91 """Apply RoPE to a tensor in grouped [B, H, G, T, D] layout (cache path)."""
92 T = x.shape[3]
93 angle = jax.lax.dynamic_slice_in_dim(self.angle, cache_index, T, axis=3)
94 cos_a, sin_a = jax.lax.cos(angle), jax.lax.sin(angle)
95 out = jnp.empty_like(x)
96 out = out.at[..., 0::2].set(x[..., 0::2] * cos_a - x[..., 1::2] * sin_a)
97 out = out.at[..., 1::2].set(x[..., 0::2] * sin_a + x[..., 1::2] * cos_a)
98 return out
100 # kept as an alias so MLA code that calls __apply_rotation still works
101 __apply_rotation = _apply_rope_grouped
103 def _apply_rope_thd(self, x: jnp.ndarray, cache_index: int) -> jnp.ndarray:
104 """Apply RoPE to a tensor in [B, T, H, D] layout (Flash Attention path)."""
105 T = x.shape[1] # static under JIT
106 angle = jax.lax.dynamic_slice_in_dim(
107 self.angle[0, 0, 0], # [T_max, D//2]
108 start_index=cache_index, slice_size=T, axis=0,
109 ) # [T, D//2]
110 angle = angle[None, :, None, :] # [1, T, 1, D//2]
111 cos_a, sin_a = jnp.cos(angle), jnp.sin(angle)
112 out = jnp.empty_like(x)
113 out = out.at[..., 0::2].set(x[..., 0::2] * cos_a - x[..., 1::2] * sin_a)
114 out = out.at[..., 1::2].set(x[..., 0::2] * sin_a + x[..., 1::2] * cos_a)
115 return out
117 # ── Cache helpers ─────────────────────────────────────────────────────────
119 def _compute_cache(self, kv_cache, cache_index, B, T, k=None, v=None, c_kv=None, k_rope=None):
120 if self.mla is False:
121 if kv_cache[0] is None:
122 k_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=k.dtype)
123 v_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=v.dtype)
124 k_cache = k_cache.at[:, :, :, :T, :].set(k)
125 v_cache = v_cache.at[:, :, :, :T, :].set(v)
126 else:
127 k_cache, v_cache = map(
128 lambda x, y, index: jax.lax.dynamic_update_slice(x, y, (0, 0, 0, index, 0)),
129 (kv_cache[0], kv_cache[1]), (k, v), (cache_index, cache_index)
130 )
131 kv_cache = (k_cache, v_cache)
132 k, v = k_cache, v_cache
133 k_rope_cache = None
134 else:
135 if kv_cache[0] is None:
136 c_cache = jnp.zeros((B, self.max_context, self.down_dim_kv), dtype=c_kv.dtype)
137 c_cache = c_cache.at[:, :T, :].set(c_kv)
138 k_rope_cache = jnp.zeros((B, 1, 1, self.max_context, self.rope_dim), dtype=c_kv.dtype)
139 k_rope_cache = k_rope_cache.at[:, :, :, :T, :].set(k_rope)
140 else:
141 c_cache = jax.lax.dynamic_update_slice(kv_cache[0], c_kv, (0, cache_index, 0))
142 k_rope_cache = jax.lax.dynamic_update_slice(kv_cache[1], k_rope, (0, 0, 0, cache_index, 0))
143 kv_cache = (c_cache, k_rope_cache)
144 k, v = c_cache, c_cache
145 return kv_cache, k_rope_cache, k, v
147 # ── Head reshape (grouped layout for cache path) ──────────────────────────
149 def reshape_head(self, B, T, q, k, v):
150 def _reshaping(x, n_heads):
151 return jnp.reshape(x, (B, T, n_heads, self.head_size))
153 def _permute(x):
154 return jnp.transpose(x, (0, 2, 3, 1, 4))
156 q = _reshaping(q, self.n_heads).reshape(
157 B, T, self.kv_heads, self.n_heads // self.kv_heads, self.head_size
158 )
159 k, v = map(_reshaping, (k, v), (self.kv_heads, self.kv_heads))
160 k, v = map(lambda x: jnp.expand_dims(x, axis=3), (k, v))
161 q, k, v = map(_permute, (q, k, v))
162 return q, k, v
164 def _compute_attention(self, use_cache, kv_cache, cache_index, B, T, q, k, v):
165 """Manual (non-Flash) attention for the cache/sliding-window path."""
166 q, k, v = self.reshape_head(B, T, q, k, v)
167 if self.use_rotary:
168 q, k = map(self._apply_rope_grouped, (q, k), (cache_index, cache_index))
169 if use_cache:
170 kv_cache, _, k, v = self._compute_cache(kv_cache, cache_index, B, T, k, v)
171 k = jnp.swapaxes(k, -2, -1)
172 attn = q @ k / math.sqrt(self.head_size)
173 return kv_cache, q, k, v, attn
175 # ── Forward ───────────────────────────────────────────────────────────────
177 def __call__(
178 self,
179 x: jnp.ndarray,
180 use_cache: bool,
181 kv_cache: tuple,
182 cache_index: int,
183 deterministic: bool = False,
184 ) -> tuple[jnp.ndarray, tuple]:
186 B, T, _ = x.shape
187 assert self.max_context >= T, "Sequence too long"
189 # ── Flash Attention fast path (MHA/GQA, training only) ────────────────
190 # Conditions: not MLA, no KV-cache (training), no sliding window.
191 # Uses jax.nn.dot_product_attention for fused scaled-dot-product with
192 # a causal mask. Requires JAX >= 0.4.25.
193 if self.mla is False and not use_cache and not self.sliding_window and self.use_flash_attention:
194 q_size = self.dim
195 kv_size = self.kv_heads * self.head_size
196 qkv = self.qkv(x)
197 q, k, v = jax.lax.split(qkv, (q_size, kv_size, kv_size), axis=-1)
199 # [B, T, H, D] layout expected by dot_product_attention
200 q_fa = q.reshape(B, T, self.n_heads, self.head_size)
201 k_fa = k.reshape(B, T, self.kv_heads, self.head_size)
202 v_fa = v.reshape(B, T, self.kv_heads, self.head_size)
204 if self.use_rotary:
205 q_fa = self._apply_rope_thd(q_fa, cache_index)
206 k_fa = self._apply_rope_thd(k_fa, cache_index)
208 # GQA: broadcast k/v to full head count for JAX < 0.4.31 compat
209 if self.kv_heads < self.n_heads:
210 g = self.n_heads // self.kv_heads
211 k_fa = jnp.repeat(k_fa, g, axis=2)
212 v_fa = jnp.repeat(v_fa, g, axis=2)
214 y = jax.nn.dot_product_attention(q_fa, k_fa, v_fa, is_causal=True)
215 # y: [B, T, n_heads, head_size] → [B, T, dim]
216 y = y.reshape(B, T, self.dim)
218 if self.no_sink:
219 y = y * jax.nn.sigmoid(self.W(x))
220 out = self.o_proj(y)
221 out = self.resid_dropout(out, deterministic=deterministic)
222 return out, kv_cache
224 # ── MLA path ──────────────────────────────────────────────────────────
225 if self.mla is True:
226 q = self.down_q(x)
227 c_kv = self.down_kv(x)
228 q = self.norm_q(q)
229 c_kv = self.norm_kv(c_kv)
231 if self.use_rotary:
232 q_rope = self.q_pe(x)[:, None, None, :, :]
233 k_rope = self.k_pe(x)[:, None, None, :, :]
234 q_rope, k_rope = map(
235 self._apply_rope_grouped, (q_rope, k_rope), (cache_index, cache_index)
236 )
238 if self.inference is False:
239 q = self.up_q(q)
240 k, v = map(lambda f, vec: f(vec), (self.up_k, self.up_v), (c_kv, c_kv))
241 q, k, v = self.reshape_head(B, T, q, k, v)
243 q_rope_ext, k_rope_ext = map(
244 lambda x, m: jnp.broadcast_to(x, m.shape[:-1] + (self.rope_dim,)),
245 (q_rope, k_rope), (q, k)
246 )
247 q_full = jnp.concatenate([q, q_rope_ext], axis=-1)
248 k_full = jnp.concatenate([k, k_rope_ext], axis=-1)
249 k_full = jnp.swapaxes(k_full, -2, -1)
250 attn = q_full @ k_full / math.sqrt(self.head_size + self.rope_dim)
251 else:
252 if use_cache:
253 kv_cache, k_rope, k, v = self._compute_cache(
254 kv_cache, cache_index, B, T, c_kv=c_kv, k_rope=k_rope
255 )
256 else:
257 k = c_kv
258 v = c_kv
260 q_proj = self.up_q.kernel.reshape(
261 self.down_dim_q, self.kv_heads, self.n_heads // self.kv_heads, self.head_size
262 )
263 k_proj = self.up_k.kernel.reshape(self.down_dim_kv, self.kv_heads, self.head_size)
264 attn_proj = jnp.einsum("qngh, knh -> ngqk", q_proj, k_proj)
265 attn_proj = jnp.einsum("btq, ngqk -> btngk", q, attn_proj)
266 attn = jnp.einsum("btngk, bsk -> bngts", attn_proj, k)
267 attn_rope = q_rope @ jnp.swapaxes(k_rope, -2, -1)
268 attn = (attn + attn_rope) / math.sqrt(self.head_size + self.rope_dim)
270 # ── Manual MHA/GQA path (cache or sliding window) ─────────────────────
271 elif self.mla is False:
272 q_size = self.dim
273 kv_size = self.kv_heads * self.head_size
274 qkv = self.qkv(x)
275 q, k, v = jax.lax.split(qkv, (q_size, kv_size, kv_size), axis=-1)
276 kv_cache, q, k, v, attn = self._compute_attention(
277 use_cache, kv_cache, cache_index, B, T, q, k, v
278 )
280 # ── Causal mask + softmax ─────────────────────────────────────────────
281 S = attn.shape[-1]
282 mask = jax.lax.dynamic_slice(self.tril, (cache_index, 0), (T, S))
283 attn = attn + jnp.where(mask, 0.0, -1e9)
285 if self.sliding_window:
286 window_mask = jax.lax.dynamic_slice(self.window, (cache_index, 0), (T, S))
287 attn = attn + window_mask
289 causal_attn = jax.nn.softmax(attn)
290 causal_attn = self.attn_dropout(causal_attn, deterministic=deterministic)
292 # ── Output projection ─────────────────────────────────────────────────
293 if self.inference is True and self.mla is True:
294 L = jnp.einsum("bngts, bsd -> bngtd", causal_attn, v)
295 W_v = self.up_v.kernel.reshape(self.down_dim_kv, self.kv_heads, self.head_size)
296 if self.no_sink:
297 y_heads = jnp.einsum("bngtd, dnh -> bngth", L, W_v)
298 y = jnp.transpose(y_heads, (0, 3, 1, 2, 4)).reshape(B, T, self.dim)
299 y = y * jax.nn.sigmoid(self.W(x))
300 out = self.o_proj(y)
301 else:
302 W_o = self.o_proj.kernel.reshape( # type: ignore[union-attr]
303 self.kv_heads, self.n_heads // self.kv_heads, self.head_size, self.dim
304 )
305 W_vo = jnp.einsum("dnh, nghc -> dngc", W_v, W_o)
306 out = jnp.einsum("bngtd, dngc -> btc", L, W_vo)
307 else:
308 y = causal_attn @ v
309 y = jnp.transpose(y, (0, 3, 1, 2, 4)).reshape(B, T, self.dim)
310 if self.no_sink:
311 y = y * jax.nn.sigmoid(self.W(x))
312 out = self.o_proj(y)
314 out = self.resid_dropout(out, deterministic=deterministic)
315 return out, kv_cache