Coverage for core / old_model.py: 0%
195 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 16:48 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 16:48 +0200
1import math
3import flax.nnx as nnx
4import jax
5import jax.numpy as jnp
7from .config import Config
10class Attention(nnx.Module):
11 def __init__(self, config: Config, rngs: nnx.Rngs):
12 self.max_context:int = config.max_context
13 self.head_size:int = config.head_size
14 self.n_heads: int = config.n_heads
15 self.dim: int = config.dim
16 self.kv_heads: int = config.kv_heads if config.kv_heads is not None else self.n_heads
17 self.qkv: nnx.Linear = nnx.Linear(self.dim,
18 self.dim + 2 * self.kv_heads*self.head_size,
19 use_bias=False,
20 rngs=rngs)
21 self.tril: jnp.ndarray = jnp.tril(
22 jnp.ones((self.max_context, self.max_context), dtype=bool)
23 )
24 self.o_proj: nnx.Linear = nnx.Linear(self.dim, self.dim, rngs=rngs)
25 self.no_sink: bool = config.no_sink
27 self.W: nnx.Linear = nnx.Linear(self.dim, self.dim, rngs=rngs)
29 self.sliding_window: bool = config.sliding_window
31 if self.sliding_window:
32 table = jnp.arange(self.max_context)[:, None] - jnp.arange(self.max_context)[None, :]
33 mask = (table <= config.context_window) & (table >= 0)
34 self.window = jnp.where(mask, 0, -1e9)
36 self.use_rotary: bool = config.use_rotary_pos
37 if self.use_rotary:
38 def __compute_angle(T:int, C:int) -> jnp.ndarray:
39 P = jnp.arange(T)
40 W = 1 / (1000 ** (jnp.arange(C//2) / C))
41 degree = jnp.einsum('i,j->ij', P, W)[None, None, None, :, :]
42 return degree
44 self.angle: jnp.ndarray = __compute_angle(self.max_context, self.head_size)
46 self.attn_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
47 self.resid_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
50 def __apply_rotation(self, x: jnp.ndarray, cache_index: int) -> jnp.ndarray:
51 T = x.shape[3]
52 odd = x[:, :, :, :, 0::2]
53 even = x[:, :, :, :, 1::2]
55 angle = jax.lax.dynamic_slice_in_dim(
56 self.angle,
57 start_index=cache_index,
58 slice_size=T,
59 axis=3
60 )
61 x_odd = jax.lax.cos(angle) * odd - jax.lax.sin(angle) * even
62 x_even = jax.lax.sin(angle) * odd + jax.lax.cos(angle) * even
64 y = jnp.stack([x_even, x_odd], axis=-1).reshape(x.shape)
65 return y
69 def __call__(self, x: jnp.ndarray,
70 use_cache: bool,
71 kv_cache: tuple,
72 cache_index:int,
73 deterministic: bool = False) -> jnp.ndarray:
75 B, T, _ = x.shape
76 assert self.max_context >= T, "Sequence too Long"
78 qkv = self.qkv(x)
80 q_size = self.dim
81 kv_size = self.kv_heads * self.head_size
83 q, k, v = jax.lax.split(
84 qkv,
85 (q_size, kv_size, kv_size),
86 axis=-1
87 )
89 reshaping = lambda x, n_heads: jnp.reshape(x, (B, T, n_heads, self.head_size))
91 q = reshaping(q, self.n_heads).reshape(B, T, self.kv_heads,
92 self.n_heads // self.kv_heads, self.head_size)
94 k, v = map(reshaping, (k, v), (self.kv_heads, self.kv_heads))
96 k, v = map(lambda x: jnp.expand_dims(x, axis=3), (k, v))
98 permute = lambda x: jnp.transpose(x, (0, 2, 3, 1, 4))
100 q, k, v = map(permute, (q, k, v))
102 if self.use_rotary:
103 q, k = map(self.__apply_rotation, (q, k), (cache_index, cache_index))
105 if use_cache:
106 if kv_cache == (None, None):
107 k_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=k.dtype)
108 v_cache = jnp.zeros((B, self.kv_heads, 1, self.max_context, self.head_size), dtype=v.dtype)
109 k_cache, v_cache = k_cache.at[:, :, :, :T, :].set(k), v_cache.at[:, :, :, :T, :].set(v)
110 else:
111 k_cache, v_cache = map(lambda x, y, index: jax.lax.dynamic_update_slice(x, y, (0, 0, 0, index, 0)),
112 (kv_cache[0], kv_cache[1]), (k, v), (cache_index, cache_index))
114 kv_cache = (k_cache, v_cache)
115 k, v = k_cache, v_cache
117 k = jnp.swapaxes(k, -2, -1)
118 attn = q @ k / math.sqrt(self.head_size)
119 # mask = jax.lax.dynamic_slice_in_dim(operand=self.tril,
120 # start_index=cache_index,
121 # slice_size=T,
122 # axis=0)
123 S = attn.shape[-1]
124 mask = jax.lax.dynamic_slice(
125 self.tril,
126 start_indices=(cache_index, 0),
127 slice_sizes=(T, S)
128 )
129 trilled = jnp.where(mask, 0.0, -1e9)
131 attn = attn + trilled
133 if self.sliding_window:
134 window_mask = jax.lax.dynamic_slice(
135 self.window,
136 start_indices=(cache_index, 0),
137 slice_sizes=(T, S)
138 )
139 attn = attn + window_mask
141 causal_attn = jax.nn.softmax(attn)
142 causal_attn = self.attn_dropout(causal_attn, deterministic=deterministic)
144 y = causal_attn @ v
146 y = jnp.transpose(y, (0, 3, 1, 2, 4)).reshape(B, T, self.dim)
148 if self.no_sink:
149 y = y * jax.nn.sigmoid(self.W(x))
151 out = self.resid_dropout(self.o_proj(y), deterministic=deterministic)
152 return out, kv_cache
154class Activation(nnx.Module):
155 def __init__(self, activation_name: str):
156 self.activation_name = activation_name
158 def __call__(self, x: jnp.ndarray):
159 act_fn = getattr(jax.nn, self.activation_name.lower(), jax.nn.gelu)
160 return act_fn(x)
162class Swiglu(nnx.Module):
163 def __call__(self, x: jnp.ndarray):
164 gate, data = jnp.split(x, 2, axis=-1)
165 return jax.nn.silu(gate) * data
168class MLP(nnx.Module):
169 def __init__(self, config: Config, rngs: nnx.Rngs):
170 intermediate_dim = config.dim * config.expansion
171 up_proj_dim = intermediate_dim * 2 if config.use_swiglu else intermediate_dim
172 self.up_proj = nnx.Linear(config.dim, up_proj_dim, rngs=rngs)
173 self.down_proj = nnx.Linear(intermediate_dim, config.dim, rngs=rngs)
174 self.activation = Swiglu() if config.use_swiglu else Activation(config.activation)
175 self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
176 self.mlp_loss = 0
178 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, float]:
179 x = self.up_proj(x)
180 x = self.activation(x)
181 x = self.down_proj(x)
182 return self.dropout(x, deterministic=deterministic), self.mlp_loss
184class MoE(nnx.Module):
185 def __init__(self, config: Config, rngs: nnx.Rngs):
186 self.n_experts: int = config.n_experts
187 self.experts: nnx.List = nnx.List([MLP(config, rngs) for _ in range(self.n_experts)])
188 self.router: nnx.Linear = nnx.Linear(config.dim, self.n_experts, use_bias=False, rngs=rngs)
189 self.top_k_mlp: int = config.top_k_mlp
191 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, jnp.ndarray]:
192 B, T, _ = x.shape
193 x_routed = self.router(x)
194 probs = jax.nn.softmax(x_routed)
195 values, indices = jax.lax.top_k(probs, self.top_k_mlp)
196 values = values / jnp.sum(values, axis=-1, keepdims=True)
197 y = jnp.zeros_like(x)
199 expert_mean_prob = jnp.mean(jnp.reshape(probs, (B*T, self.n_experts)), axis=0)
200 freq = jnp.mean(jnp.sum(jax.nn.one_hot(indices, self.n_experts), axis=2), axis=(0, 1))
201 moe_loss = jnp.sum(freq*expert_mean_prob) * self.n_experts
203 for i in range(self.n_experts):
204 mask = (indices == i)
205 expert_weight = jnp.sum(jnp.where(mask, values, 0), axis=-1, keepdims=True)
206 expert_out, _ = self.experts[i](x, deterministic=deterministic)
207 y = y + (expert_weight * expert_out)
208 return y, moe_loss
211class Block(nnx.Module):
212 def __init__(self, config: Config, rngs: nnx.Rngs):
213 self.attention: Attention = Attention(config, rngs)
214 self.ln1: nnx.LayerNorm = nnx.LayerNorm(config.dim, rngs=rngs)
215 self.ln2: nnx.LayerNorm = nnx.LayerNorm(config.dim, rngs=rngs)
216 self.use_moe: bool = config.use_moe
217 if self.use_moe:
218 self.moe = MoE(config, rngs)
219 else:
220 self.mlp = MLP(config, rngs)
222 def __call__(self, x: jnp.ndarray,
223 use_cache: bool,
224 kv_cache: tuple,
225 cache_index: int,
226 deterministic: bool = False) -> tuple[jnp.ndarray, tuple, jnp.ndarray | float]:
228 x_attn, kv_cache = self.attention(self.ln1(x),
229 use_cache=use_cache,
230 kv_cache=kv_cache,
231 cache_index=cache_index,
232 deterministic=deterministic)
233 x = x + x_attn
234 ff, balancing_loss = self.moe(self.ln2(x), deterministic=deterministic) if self.use_moe else self.mlp(self.ln2(x), deterministic=deterministic)
235 x = x + ff
236 return x, kv_cache, balancing_loss
238class Transformer(nnx.Module):
239 def __init__(self, config: Config, rngs: nnx.Rngs):
240 self.num_blocks: int = config.num_blocks
241 self.blocks: nnx.List = nnx.List([Block(config, rngs=rngs) for _ in range(self.num_blocks)])
242 self.lm_head: nnx.Linear = nnx.Linear(config.dim, config.vocab_size, rngs=rngs)
243 self.wte: nnx.Embed = nnx.Embed(config.vocab_size, config.dim, rngs=rngs)
244 self.trainable_pos: bool = config.trainable_pos
245 self.absolute_pos: bool = config.absolute_pos
246 self.max_context: bool = config.max_context
247 self.gradient_checkpointing: bool = config.gradient_checkpointing
248 self.ln_f: nnx.LayerNorm = nnx.LayerNorm(config.dim, rngs=rngs)
249 self.emb_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
250 self.use_moe: bool = config.use_moe
251 self.alpha_balance: float = config.alpha_balance
253 if config.weight_tying:
254 self.lm_head.kernel = self.wte.embedding.T
255 if self.trainable_pos:
256 self.wpe: nnx.Embed = nnx.Embed(config.max_context, config.dim, rngs=rngs)
257 elif self.absolute_pos:
258 def _build_compute_absolute_pos(T: int, C: int) -> jnp.ndarray:
259 pos = jnp.zeros((T, C))
260 row = jnp.arange(T)
261 col = jnp.arange(0, C, 2)
262 k = 1.0 / (10000 ** (col / C))
264 ratio = jnp.einsum('i,j->ij', row, k)
265 pos = pos.at[:, 0::2].set(jnp.sin(ratio))
266 pos = pos.at[:, 1::2].set(jnp.cos(ratio))
268 return jnp.expand_dims(pos, axis=0)
270 self.wpe: jnp.ndarray = _build_compute_absolute_pos(config.max_context, config.dim)
272 def __call__(self,
273 x: jnp.ndarray,
274 use_cache:bool,
275 kv_caches: tuple | None,
276 cache_index: int | None,
277 deterministic: bool = False) -> tuple[jnp.ndarray, tuple]:
279 B, T = x.shape
280 x = self.wte(x)
281 if kv_caches is None:
282 kv_caches = tuple((None, None) for _ in range(self.num_blocks))
283 if self.absolute_pos:
284 wpe_slice = jax.lax.dynamic_slice_in_dim(
285 self.wpe,
286 start_index=cache_index,
287 slice_size=T,
288 axis=1
289 )
290 x = x + wpe_slice
291 elif self.trainable_pos:
292 x = x + self.wpe(jnp.arange(T, dtype=x.dtype))
294 x = self.emb_dropout(x, deterministic=deterministic)
296 def block_fn(block_module, hidden_state, kv_c, det):
297 return block_module(
298 hidden_state,
299 use_cache=use_cache,
300 kv_cache=kv_c,
301 cache_index=cache_index,
302 deterministic=det
303 )
305 if self.gradient_checkpointing and not use_cache:
306 checkpointed_block = nnx.remat(lambda bm, hs, kvc: block_fn(bm, hs, kvc, deterministic))
307 else:
308 checkpointed_block = lambda bm, hs, kvc: block_fn(bm, hs, kvc, deterministic)
310 new_kv_caches = []
311 balancing_loss_total = 0
312 for i, block in enumerate(self.blocks):
313 x, new_kv, balancing_loss = checkpointed_block(block, x, kv_caches[i] if kv_caches else None)
314 new_kv_caches.append(new_kv)
315 balancing_loss_total += balancing_loss
317 x = self.ln_f(x)
319 return self.lm_head(x), tuple(new_kv_caches), balancing_loss