Coverage for core / attention.py: 93%

198 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 15:35 +0200

1import math 

2 

3import flax.nnx as nnx 

4import jax 

5import jax.numpy as jnp 

6 

7from .config import Config 

8from .lora import LoRALinear 

9 

10 

11class Attention(nnx.Module): 

12 def __init__(self, config: Config, rngs: nnx.Rngs): 

13 self.max_context: int = config.max_context 

14 self.head_size: int = config.head_size 

15 self.n_heads: int = config.n_heads 

16 self.dim: int = config.dim 

17 self.kv_heads: int = config.kv_heads if config.kv_heads is not None else self.n_heads 

18 

19 _use_lora_attn = getattr(config, "use_lora", False) and getattr(config, "lora_targets", "attention") in ("attention", "all") 

20 _lora_kw: dict = dict(rank=getattr(config, "lora_rank", 8), alpha=getattr(config, "lora_alpha", 16.0), dropout_rate=getattr(config, "lora_dropout", 0.0), rngs=rngs) 

21 

22 qkv_out = self.dim + 2 * self.kv_heads * self.head_size 

23 self.qkv: nnx.Linear | LoRALinear = ( 

24 LoRALinear(self.dim, qkv_out, use_bias=False, **_lora_kw) 

25 if _use_lora_attn 

26 else nnx.Linear(self.dim, qkv_out, use_bias=False, rngs=rngs) 

27 ) 

28 self.tril: jnp.ndarray = jnp.tril( 

29 jnp.ones((self.max_context, self.max_context), dtype=bool) 

30 ) 

31 self.o_proj: nnx.Linear | LoRALinear = ( 

32 LoRALinear(self.dim, self.dim, **_lora_kw) 

33 if _use_lora_attn 

34 else nnx.Linear(self.dim, self.dim, rngs=rngs) 

35 ) 

36 self.no_sink: bool = config.no_sink 

37 self.W: nnx.Linear = nnx.Linear(self.dim, self.dim, rngs=rngs) 

38 

39 self.sliding_window: bool = config.sliding_window 

40 if self.sliding_window: 

41 table = jnp.arange(self.max_context)[:, None] - jnp.arange(self.max_context)[None, :] 

42 mask = (table <= config.context_window) & (table >= 0) 

43 self.window = jnp.where(mask, 0, -1e9) 

44 

45 self.use_rotary: bool = config.use_rotary_pos 

46 self.use_flash_attention: bool = config.use_flash_attention 

47 

48 self.attn_dropout: nnx.Dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

49 self.resid_dropout: nnx.Dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

50 

51 # MLA projections (always allocated; only used when config.mla=True) 

52 self.down_q: nnx.Linear = nnx.Linear(config.dim, config.down_dim_q, rngs=rngs) 

53 self.down_kv: nnx.Linear = nnx.Linear(config.dim, config.down_dim_kv, rngs=rngs) 

54 self.up_q: nnx.Linear = nnx.Linear(config.down_dim_q, config.head_size * config.n_heads, rngs=rngs) 

55 self.up_k: nnx.Linear = nnx.Linear(config.down_dim_kv, config.head_size * config.kv_heads, rngs=rngs) 

56 self.up_v: nnx.Linear = nnx.Linear(config.down_dim_kv, config.head_size * config.kv_heads, rngs=rngs) 

57 

58 self.down_dim_q: int = config.down_dim_q 

59 self.down_dim_kv: int = config.down_dim_kv 

60 self.mla: bool = config.mla 

61 self.inference: bool = config.inference 

62 self.rope_dim: int = config.rope_dim if hasattr(config, "rope_dim") else self.head_size // 2 

63 

64 if self.mla: 

65 self.q_pe: nnx.Linear = nnx.Linear(config.dim, self.rope_dim, rngs=rngs) 

66 self.k_pe: nnx.Linear = nnx.Linear(config.dim, self.rope_dim, rngs=rngs) 

67 self.norm_q = nnx.LayerNorm(config.down_dim_q, rngs=rngs) 

68 self.norm_kv = nnx.LayerNorm(config.down_dim_kv, rngs=rngs) 

69 

70 # RoPE frequency table — scaled by rope_scale_factor for NTK-aware extension 

71 rope_dim_size = self.head_size if not self.mla else self.rope_dim 

72 self._rope_scale = getattr(config, "rope_scale_factor", 1.0) 

73 self.angle: jnp.ndarray = self._compute_angle(self.max_context, rope_dim_size) 

74 

75 # ── RoPE helpers ────────────────────────────────────────────────────────── 

76 

77 def _compute_angle(self, T: int, C: int) -> jnp.ndarray: 

78 """Precompute RoPE angles [1, 1, 1, T, C//2]. 

79 

80 When ``rope_scale_factor > 1`` the base frequency is compressed 

81 (NTK-aware scaling), allowing the model to generalise to contexts 

82 longer than ``max_context`` without fine-tuning. 

83 """ 

84 P = jnp.arange(T, dtype=jnp.float32) 

85 base = 10000.0 * self._rope_scale 

86 inv_freq = 1.0 / (base ** (jnp.arange(0, C, 2, dtype=jnp.float32) / C)) 

87 degree = jnp.einsum("i,j->ij", P, inv_freq) 

88 return degree[None, None, None, :, :] # [1, 1, 1, T, C//2] 

89 

90 def _apply_rope_grouped(self, x: jnp.ndarray, cache_index: int) -> jnp.ndarray: 

91 """Apply RoPE to a tensor in grouped [B, H, G, T, D] layout (cache path).""" 

92 T = x.shape[3] 

93 angle = jax.lax.dynamic_slice_in_dim(self.angle, cache_index, T, axis=3) 

94 cos_a, sin_a = jax.lax.cos(angle), jax.lax.sin(angle) 

95 out = jnp.empty_like(x) 

96 out = out.at[..., 0::2].set(x[..., 0::2] * cos_a - x[..., 1::2] * sin_a) 

97 out = out.at[..., 1::2].set(x[..., 0::2] * sin_a + x[..., 1::2] * cos_a) 

98 return out 

99 

100 # kept as an alias so MLA code that calls __apply_rotation still works 

101 __apply_rotation = _apply_rope_grouped 

102 

103 def _apply_rope_thd(self, x: jnp.ndarray, cache_index: int) -> jnp.ndarray: 

104 """Apply RoPE to a tensor in [B, T, H, D] layout (Flash Attention path).""" 

105 T = x.shape[1] # static under JIT 

106 angle = jax.lax.dynamic_slice_in_dim( 

107 self.angle[0, 0, 0], # [T_max, D//2] 

108 start_index=cache_index, slice_size=T, axis=0, 

109 ) # [T, D//2] 

110 angle = angle[None, :, None, :] # [1, T, 1, D//2] 

111 cos_a, sin_a = jnp.cos(angle), jnp.sin(angle) 

112 out = jnp.empty_like(x) 

113 out = out.at[..., 0::2].set(x[..., 0::2] * cos_a - x[..., 1::2] * sin_a) 

114 out = out.at[..., 1::2].set(x[..., 0::2] * sin_a + x[..., 1::2] * cos_a) 

115 return out 

116 

117 # ── Cache helpers ───────────────────────────────────────────────────────── 

118 

119 def _compute_cache(self, kv_cache, cache_index, B, T, k=None, v=None, c_kv=None, k_rope=None): 

120 if self.mla is False: 

121 if kv_cache[0] is None: 

122 k_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=k.dtype) 

123 v_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=v.dtype) 

124 k_cache = k_cache.at[:, :, :, :T, :].set(k) 

125 v_cache = v_cache.at[:, :, :, :T, :].set(v) 

126 else: 

127 k_cache, v_cache = map( 

128 lambda x, y, index: jax.lax.dynamic_update_slice(x, y, (0, 0, 0, index, 0)), 

129 (kv_cache[0], kv_cache[1]), (k, v), (cache_index, cache_index) 

130 ) 

131 kv_cache = (k_cache, v_cache) 

132 k, v = k_cache, v_cache 

133 k_rope_cache = None 

134 else: 

135 if kv_cache[0] is None: 

136 c_cache = jnp.zeros((B, self.max_context, self.down_dim_kv), dtype=c_kv.dtype) 

137 c_cache = c_cache.at[:, :T, :].set(c_kv) 

138 k_rope_cache = jnp.zeros((B, 1, 1, self.max_context, self.rope_dim), dtype=c_kv.dtype) 

139 k_rope_cache = k_rope_cache.at[:, :, :, :T, :].set(k_rope) 

140 else: 

141 c_cache = jax.lax.dynamic_update_slice(kv_cache[0], c_kv, (0, cache_index, 0)) 

142 k_rope_cache = jax.lax.dynamic_update_slice(kv_cache[1], k_rope, (0, 0, 0, cache_index, 0)) 

143 kv_cache = (c_cache, k_rope_cache) 

144 k, v = c_cache, c_cache 

145 return kv_cache, k_rope_cache, k, v 

146 

147 # ── Head reshape (grouped layout for cache path) ────────────────────────── 

148 

149 def reshape_head(self, B, T, q, k, v): 

150 def _reshaping(x, n_heads): 

151 return jnp.reshape(x, (B, T, n_heads, self.head_size)) 

152 

153 def _permute(x): 

154 return jnp.transpose(x, (0, 2, 3, 1, 4)) 

155 

156 q = _reshaping(q, self.n_heads).reshape( 

157 B, T, self.kv_heads, self.n_heads // self.kv_heads, self.head_size 

158 ) 

159 k, v = map(_reshaping, (k, v), (self.kv_heads, self.kv_heads)) 

160 k, v = map(lambda x: jnp.expand_dims(x, axis=3), (k, v)) 

161 q, k, v = map(_permute, (q, k, v)) 

162 return q, k, v 

163 

164 def _compute_attention(self, use_cache, kv_cache, cache_index, B, T, q, k, v): 

165 """Manual (non-Flash) attention for the cache/sliding-window path.""" 

166 q, k, v = self.reshape_head(B, T, q, k, v) 

167 if self.use_rotary: 

168 q, k = map(self._apply_rope_grouped, (q, k), (cache_index, cache_index)) 

169 if use_cache: 

170 kv_cache, _, k, v = self._compute_cache(kv_cache, cache_index, B, T, k, v) 

171 k = jnp.swapaxes(k, -2, -1) 

172 attn = q @ k / math.sqrt(self.head_size) 

173 return kv_cache, q, k, v, attn 

174 

175 # ── Forward ─────────────────────────────────────────────────────────────── 

176 

177 def __call__( 

178 self, 

179 x: jnp.ndarray, 

180 use_cache: bool, 

181 kv_cache: tuple, 

182 cache_index: int, 

183 deterministic: bool = False, 

184 ) -> tuple[jnp.ndarray, tuple]: 

185 

186 B, T, _ = x.shape 

187 assert self.max_context >= T, "Sequence too long" 

188 

189 # ── Flash Attention fast path (MHA/GQA, training only) ──────────────── 

190 # Conditions: not MLA, no KV-cache (training), no sliding window. 

191 # Uses jax.nn.dot_product_attention for fused scaled-dot-product with 

192 # a causal mask. Requires JAX >= 0.4.25. 

193 if self.mla is False and not use_cache and not self.sliding_window and self.use_flash_attention: 

194 q_size = self.dim 

195 kv_size = self.kv_heads * self.head_size 

196 qkv = self.qkv(x) 

197 q, k, v = jax.lax.split(qkv, (q_size, kv_size, kv_size), axis=-1) 

198 

199 # [B, T, H, D] layout expected by dot_product_attention 

200 q_fa = q.reshape(B, T, self.n_heads, self.head_size) 

201 k_fa = k.reshape(B, T, self.kv_heads, self.head_size) 

202 v_fa = v.reshape(B, T, self.kv_heads, self.head_size) 

203 

204 if self.use_rotary: 

205 q_fa = self._apply_rope_thd(q_fa, cache_index) 

206 k_fa = self._apply_rope_thd(k_fa, cache_index) 

207 

208 # GQA: broadcast k/v to full head count for JAX < 0.4.31 compat 

209 if self.kv_heads < self.n_heads: 

210 g = self.n_heads // self.kv_heads 

211 k_fa = jnp.repeat(k_fa, g, axis=2) 

212 v_fa = jnp.repeat(v_fa, g, axis=2) 

213 

214 y = jax.nn.dot_product_attention(q_fa, k_fa, v_fa, is_causal=True) 

215 # y: [B, T, n_heads, head_size] → [B, T, dim] 

216 y = y.reshape(B, T, self.dim) 

217 

218 if self.no_sink: 

219 y = y * jax.nn.sigmoid(self.W(x)) 

220 out = self.o_proj(y) 

221 out = self.resid_dropout(out, deterministic=deterministic) 

222 return out, kv_cache 

223 

224 # ── MLA path ────────────────────────────────────────────────────────── 

225 if self.mla is True: 

226 q = self.down_q(x) 

227 c_kv = self.down_kv(x) 

228 q = self.norm_q(q) 

229 c_kv = self.norm_kv(c_kv) 

230 

231 if self.use_rotary: 

232 q_rope = self.q_pe(x)[:, None, None, :, :] 

233 k_rope = self.k_pe(x)[:, None, None, :, :] 

234 q_rope, k_rope = map( 

235 self._apply_rope_grouped, (q_rope, k_rope), (cache_index, cache_index) 

236 ) 

237 

238 if self.inference is False: 

239 q = self.up_q(q) 

240 k, v = map(lambda f, vec: f(vec), (self.up_k, self.up_v), (c_kv, c_kv)) 

241 q, k, v = self.reshape_head(B, T, q, k, v) 

242 

243 q_rope_ext, k_rope_ext = map( 

244 lambda x, m: jnp.broadcast_to(x, m.shape[:-1] + (self.rope_dim,)), 

245 (q_rope, k_rope), (q, k) 

246 ) 

247 q_full = jnp.concatenate([q, q_rope_ext], axis=-1) 

248 k_full = jnp.concatenate([k, k_rope_ext], axis=-1) 

249 k_full = jnp.swapaxes(k_full, -2, -1) 

250 attn = q_full @ k_full / math.sqrt(self.head_size + self.rope_dim) 

251 else: 

252 if use_cache: 

253 kv_cache, k_rope, k, v = self._compute_cache( 

254 kv_cache, cache_index, B, T, c_kv=c_kv, k_rope=k_rope 

255 ) 

256 else: 

257 k = c_kv 

258 v = c_kv 

259 

260 q_proj = self.up_q.kernel.reshape( 

261 self.down_dim_q, self.kv_heads, self.n_heads // self.kv_heads, self.head_size 

262 ) 

263 k_proj = self.up_k.kernel.reshape(self.down_dim_kv, self.kv_heads, self.head_size) 

264 attn_proj = jnp.einsum("qngh, knh -> ngqk", q_proj, k_proj) 

265 attn_proj = jnp.einsum("btq, ngqk -> btngk", q, attn_proj) 

266 attn = jnp.einsum("btngk, bsk -> bngts", attn_proj, k) 

267 attn_rope = q_rope @ jnp.swapaxes(k_rope, -2, -1) 

268 attn = (attn + attn_rope) / math.sqrt(self.head_size + self.rope_dim) 

269 

270 # ── Manual MHA/GQA path (cache or sliding window) ───────────────────── 

271 elif self.mla is False: 

272 q_size = self.dim 

273 kv_size = self.kv_heads * self.head_size 

274 qkv = self.qkv(x) 

275 q, k, v = jax.lax.split(qkv, (q_size, kv_size, kv_size), axis=-1) 

276 kv_cache, q, k, v, attn = self._compute_attention( 

277 use_cache, kv_cache, cache_index, B, T, q, k, v 

278 ) 

279 

280 # ── Causal mask + softmax ───────────────────────────────────────────── 

281 S = attn.shape[-1] 

282 mask = jax.lax.dynamic_slice(self.tril, (cache_index, 0), (T, S)) 

283 attn = attn + jnp.where(mask, 0.0, -1e9) 

284 

285 if self.sliding_window: 

286 window_mask = jax.lax.dynamic_slice(self.window, (cache_index, 0), (T, S)) 

287 attn = attn + window_mask 

288 

289 causal_attn = jax.nn.softmax(attn) 

290 causal_attn = self.attn_dropout(causal_attn, deterministic=deterministic) 

291 

292 # ── Output projection ───────────────────────────────────────────────── 

293 if self.inference is True and self.mla is True: 

294 L = jnp.einsum("bngts, bsd -> bngtd", causal_attn, v) 

295 W_v = self.up_v.kernel.reshape(self.down_dim_kv, self.kv_heads, self.head_size) 

296 if self.no_sink: 

297 y_heads = jnp.einsum("bngtd, dnh -> bngth", L, W_v) 

298 y = jnp.transpose(y_heads, (0, 3, 1, 2, 4)).reshape(B, T, self.dim) 

299 y = y * jax.nn.sigmoid(self.W(x)) 

300 out = self.o_proj(y) 

301 else: 

302 W_o = self.o_proj.kernel.reshape( # type: ignore[union-attr] 

303 self.kv_heads, self.n_heads // self.kv_heads, self.head_size, self.dim 

304 ) 

305 W_vo = jnp.einsum("dnh, nghc -> dngc", W_v, W_o) 

306 out = jnp.einsum("bngtd, dngc -> btc", L, W_vo) 

307 else: 

308 y = causal_attn @ v 

309 y = jnp.transpose(y, (0, 3, 1, 2, 4)).reshape(B, T, self.dim) 

310 if self.no_sink: 

311 y = y * jax.nn.sigmoid(self.W(x)) 

312 out = self.o_proj(y) 

313 

314 out = self.resid_dropout(out, deterministic=deterministic) 

315 return out, kv_cache