Coverage for core / old_model.py: 0%

195 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 16:48 +0200

1import math 

2 

3import flax.nnx as nnx 

4import jax 

5import jax.numpy as jnp 

6 

7from .config import Config 

8 

9 

10class Attention(nnx.Module): 

11 def __init__(self, config: Config, rngs: nnx.Rngs): 

12 self.max_context:int = config.max_context 

13 self.head_size:int = config.head_size 

14 self.n_heads: int = config.n_heads 

15 self.dim: int = config.dim 

16 self.kv_heads: int = config.kv_heads if config.kv_heads is not None else self.n_heads 

17 self.qkv: nnx.Linear = nnx.Linear(self.dim, 

18 self.dim + 2 * self.kv_heads*self.head_size, 

19 use_bias=False, 

20 rngs=rngs) 

21 self.tril: jnp.ndarray = jnp.tril( 

22 jnp.ones((self.max_context, self.max_context), dtype=bool) 

23 ) 

24 self.o_proj: nnx.Linear = nnx.Linear(self.dim, self.dim, rngs=rngs) 

25 self.no_sink: bool = config.no_sink 

26 

27 self.W: nnx.Linear = nnx.Linear(self.dim, self.dim, rngs=rngs) 

28 

29 self.sliding_window: bool = config.sliding_window 

30 

31 if self.sliding_window: 

32 table = jnp.arange(self.max_context)[:, None] - jnp.arange(self.max_context)[None, :] 

33 mask = (table <= config.context_window) & (table >= 0) 

34 self.window = jnp.where(mask, 0, -1e9) 

35 

36 self.use_rotary: bool = config.use_rotary_pos 

37 if self.use_rotary: 

38 def __compute_angle(T:int, C:int) -> jnp.ndarray: 

39 P = jnp.arange(T) 

40 W = 1 / (1000 ** (jnp.arange(C//2) / C)) 

41 degree = jnp.einsum('i,j->ij', P, W)[None, None, None, :, :] 

42 return degree 

43 

44 self.angle: jnp.ndarray = __compute_angle(self.max_context, self.head_size) 

45 

46 self.attn_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

47 self.resid_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

48 

49 

50 def __apply_rotation(self, x: jnp.ndarray, cache_index: int) -> jnp.ndarray: 

51 T = x.shape[3] 

52 odd = x[:, :, :, :, 0::2] 

53 even = x[:, :, :, :, 1::2] 

54 

55 angle = jax.lax.dynamic_slice_in_dim( 

56 self.angle, 

57 start_index=cache_index, 

58 slice_size=T, 

59 axis=3 

60 ) 

61 x_odd = jax.lax.cos(angle) * odd - jax.lax.sin(angle) * even 

62 x_even = jax.lax.sin(angle) * odd + jax.lax.cos(angle) * even 

63 

64 y = jnp.stack([x_even, x_odd], axis=-1).reshape(x.shape) 

65 return y 

66 

67 

68 

69 def __call__(self, x: jnp.ndarray, 

70 use_cache: bool, 

71 kv_cache: tuple, 

72 cache_index:int, 

73 deterministic: bool = False) -> jnp.ndarray: 

74 

75 B, T, _ = x.shape 

76 assert self.max_context >= T, "Sequence too Long" 

77 

78 qkv = self.qkv(x) 

79 

80 q_size = self.dim 

81 kv_size = self.kv_heads * self.head_size 

82 

83 q, k, v = jax.lax.split( 

84 qkv, 

85 (q_size, kv_size, kv_size), 

86 axis=-1 

87 ) 

88 

89 reshaping = lambda x, n_heads: jnp.reshape(x, (B, T, n_heads, self.head_size)) 

90 

91 q = reshaping(q, self.n_heads).reshape(B, T, self.kv_heads, 

92 self.n_heads // self.kv_heads, self.head_size) 

93 

94 k, v = map(reshaping, (k, v), (self.kv_heads, self.kv_heads)) 

95 

96 k, v = map(lambda x: jnp.expand_dims(x, axis=3), (k, v)) 

97 

98 permute = lambda x: jnp.transpose(x, (0, 2, 3, 1, 4)) 

99 

100 q, k, v = map(permute, (q, k, v)) 

101 

102 if self.use_rotary: 

103 q, k = map(self.__apply_rotation, (q, k), (cache_index, cache_index)) 

104 

105 if use_cache: 

106 if kv_cache == (None, None): 

107 k_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=k.dtype) 

108 v_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=v.dtype) 

109 k_cache, v_cache = k_cache.at[:, :, :, :T, :].set(k), v_cache.at[:, :, :, :T, :].set(v) 

110 else: 

111 k_cache, v_cache = map(lambda x, y, index: jax.lax.dynamic_update_slice(x, y, (0, 0, 0, index, 0)), 

112 (kv_cache[0], kv_cache[1]), (k, v), (cache_index, cache_index)) 

113 

114 kv_cache = (k_cache, v_cache) 

115 k, v = k_cache, v_cache 

116 

117 k = jnp.swapaxes(k, -2, -1) 

118 attn = q @ k / math.sqrt(self.head_size) 

119 # mask = jax.lax.dynamic_slice_in_dim(operand=self.tril, 

120 # start_index=cache_index, 

121 # slice_size=T, 

122 # axis=0) 

123 S = attn.shape[-1] 

124 mask = jax.lax.dynamic_slice( 

125 self.tril, 

126 start_indices=(cache_index, 0), 

127 slice_sizes=(T, S) 

128 ) 

129 trilled = jnp.where(mask, 0.0, -1e9) 

130 

131 attn = attn + trilled 

132 

133 if self.sliding_window: 

134 window_mask = jax.lax.dynamic_slice( 

135 self.window, 

136 start_indices=(cache_index, 0), 

137 slice_sizes=(T, S) 

138 ) 

139 attn = attn + window_mask 

140 

141 causal_attn = jax.nn.softmax(attn) 

142 causal_attn = self.attn_dropout(causal_attn, deterministic=deterministic) 

143 

144 y = causal_attn @ v 

145 

146 y = jnp.transpose(y, (0, 3, 1, 2, 4)).reshape(B, T, self.dim) 

147 

148 if self.no_sink: 

149 y = y * jax.nn.sigmoid(self.W(x)) 

150 

151 out = self.resid_dropout(self.o_proj(y), deterministic=deterministic) 

152 return out, kv_cache 

153 

154class Activation(nnx.Module): 

155 def __init__(self, activation_name: str): 

156 self.activation_name = activation_name 

157 

158 def __call__(self, x: jnp.ndarray): 

159 act_fn = getattr(jax.nn, self.activation_name.lower(), jax.nn.gelu) 

160 return act_fn(x) 

161 

162class Swiglu(nnx.Module): 

163 def __call__(self, x: jnp.ndarray): 

164 gate, data = jnp.split(x, 2, axis=-1) 

165 return jax.nn.silu(gate) * data 

166 

167 

168class MLP(nnx.Module): 

169 def __init__(self, config: Config, rngs: nnx.Rngs): 

170 intermediate_dim = config.dim * config.expansion 

171 up_proj_dim = intermediate_dim * 2 if config.use_swiglu else intermediate_dim 

172 self.up_proj = nnx.Linear(config.dim, up_proj_dim, rngs=rngs) 

173 self.down_proj = nnx.Linear(intermediate_dim, config.dim, rngs=rngs) 

174 self.activation = Swiglu() if config.use_swiglu else Activation(config.activation) 

175 self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

176 self.mlp_loss = 0 

177 

178 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, float]: 

179 x = self.up_proj(x) 

180 x = self.activation(x) 

181 x = self.down_proj(x) 

182 return self.dropout(x, deterministic=deterministic), self.mlp_loss 

183 

184class MoE(nnx.Module): 

185 def __init__(self, config: Config, rngs: nnx.Rngs): 

186 self.n_experts: int = config.n_experts 

187 self.experts: nnx.List = nnx.List([MLP(config, rngs) for _ in range(self.n_experts)]) 

188 self.router: nnx.Linear = nnx.Linear(config.dim, self.n_experts, use_bias=False, rngs=rngs) 

189 self.top_k_mlp: int = config.top_k_mlp 

190 

191 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, jnp.ndarray]: 

192 B, T, _ = x.shape 

193 x_routed = self.router(x) 

194 probs = jax.nn.softmax(x_routed) 

195 values, indices = jax.lax.top_k(probs, self.top_k_mlp) 

196 values = values / jnp.sum(values, axis=-1, keepdims=True) 

197 y = jnp.zeros_like(x) 

198 

199 expert_mean_prob = jnp.mean(jnp.reshape(probs, (B*T, self.n_experts)), axis=0) 

200 freq = jnp.mean(jnp.sum(jax.nn.one_hot(indices, self.n_experts), axis=2), axis=(0, 1)) 

201 moe_loss = jnp.sum(freq*expert_mean_prob) * self.n_experts 

202 

203 for i in range(self.n_experts): 

204 mask = (indices == i) 

205 expert_weight = jnp.sum(jnp.where(mask, values, 0), axis=-1, keepdims=True) 

206 expert_out, _ = self.experts[i](x, deterministic=deterministic) 

207 y = y + (expert_weight * expert_out) 

208 return y, moe_loss 

209 

210 

211class Block(nnx.Module): 

212 def __init__(self, config: Config, rngs: nnx.Rngs): 

213 self.attention: Attention = Attention(config, rngs) 

214 self.ln1: nnx.LayerNorm = nnx.LayerNorm(config.dim, rngs=rngs) 

215 self.ln2: nnx.LayerNorm = nnx.LayerNorm(config.dim, rngs=rngs) 

216 self.use_moe: bool = config.use_moe 

217 if self.use_moe: 

218 self.moe = MoE(config, rngs) 

219 else: 

220 self.mlp = MLP(config, rngs) 

221 

222 def __call__(self, x: jnp.ndarray, 

223 use_cache: bool, 

224 kv_cache: tuple, 

225 cache_index: int, 

226 deterministic: bool = False) -> tuple[jnp.ndarray, tuple, jnp.ndarray | float]: 

227 

228 x_attn, kv_cache = self.attention(self.ln1(x), 

229 use_cache=use_cache, 

230 kv_cache=kv_cache, 

231 cache_index=cache_index, 

232 deterministic=deterministic) 

233 x = x + x_attn 

234 ff, balancing_loss = self.moe(self.ln2(x), deterministic=deterministic) if self.use_moe else self.mlp(self.ln2(x), deterministic=deterministic) 

235 x = x + ff 

236 return x, kv_cache, balancing_loss 

237 

238class Transformer(nnx.Module): 

239 def __init__(self, config: Config, rngs: nnx.Rngs): 

240 self.num_blocks: int = config.num_blocks 

241 self.blocks: nnx.List = nnx.List([Block(config, rngs=rngs) for _ in range(self.num_blocks)]) 

242 self.lm_head: nnx.Linear = nnx.Linear(config.dim, config.vocab_size, rngs=rngs) 

243 self.wte: nnx.Embed = nnx.Embed(config.vocab_size, config.dim, rngs=rngs) 

244 self.trainable_pos: bool = config.trainable_pos 

245 self.absolute_pos: bool = config.absolute_pos 

246 self.max_context: bool = config.max_context 

247 self.gradient_checkpointing: bool = config.gradient_checkpointing 

248 self.ln_f: nnx.LayerNorm = nnx.LayerNorm(config.dim, rngs=rngs) 

249 self.emb_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

250 self.use_moe: bool = config.use_moe 

251 self.alpha_balance: float = config.alpha_balance 

252 

253 if config.weight_tying: 

254 self.lm_head.kernel = self.wte.embedding.T 

255 if self.trainable_pos: 

256 self.wpe: nnx.Embed = nnx.Embed(config.max_context, config.dim, rngs=rngs) 

257 elif self.absolute_pos: 

258 def _build_compute_absolute_pos(T: int, C: int) -> jnp.ndarray: 

259 pos = jnp.zeros((T, C)) 

260 row = jnp.arange(T) 

261 col = jnp.arange(0, C, 2) 

262 k = 1.0 / (10000 ** (col / C)) 

263 

264 ratio = jnp.einsum('i,j->ij', row, k) 

265 pos = pos.at[:, 0::2].set(jnp.sin(ratio)) 

266 pos = pos.at[:, 1::2].set(jnp.cos(ratio)) 

267 

268 return jnp.expand_dims(pos, axis=0) 

269 

270 self.wpe: jnp.ndarray = _build_compute_absolute_pos(config.max_context, config.dim) 

271 

272 def __call__(self, 

273 x: jnp.ndarray, 

274 use_cache:bool, 

275 kv_caches: tuple | None, 

276 cache_index: int | None, 

277 deterministic: bool = False) -> tuple[jnp.ndarray, tuple]: 

278 

279 B, T = x.shape 

280 x = self.wte(x) 

281 if kv_caches is None: 

282 kv_caches = tuple((None, None) for _ in range(self.num_blocks)) 

283 if self.absolute_pos: 

284 wpe_slice = jax.lax.dynamic_slice_in_dim( 

285 self.wpe, 

286 start_index=cache_index, 

287 slice_size=T, 

288 axis=1 

289 ) 

290 x = x + wpe_slice 

291 elif self.trainable_pos: 

292 x = x + self.wpe(jnp.arange(T, dtype=x.dtype)) 

293 

294 x = self.emb_dropout(x, deterministic=deterministic) 

295 

296 def block_fn(block_module, hidden_state, kv_c, det): 

297 return block_module( 

298 hidden_state, 

299 use_cache=use_cache, 

300 kv_cache=kv_c, 

301 cache_index=cache_index, 

302 deterministic=det 

303 ) 

304 

305 if self.gradient_checkpointing and not use_cache: 

306 checkpointed_block = nnx.remat(lambda bm, hs, kvc: block_fn(bm, hs, kvc, deterministic)) 

307 else: 

308 checkpointed_block = lambda bm, hs, kvc: block_fn(bm, hs, kvc, deterministic) 

309 

310 new_kv_caches = [] 

311 balancing_loss_total = 0 

312 for i, block in enumerate(self.blocks): 

313 x, new_kv, balancing_loss = checkpointed_block(block, x, kv_caches[i] if kv_caches else None) 

314 new_kv_caches.append(new_kv) 

315 balancing_loss_total += balancing_loss 

316 

317 x = self.ln_f(x) 

318 

319 return self.lm_head(x), tuple(new_kv_caches), balancing_loss 

320 

321 

322