Coverage for core / block.py: 100%
32 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 16:17 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 16:17 +0200
1import flax.nnx as nnx
2import jax.numpy as jnp
4from .attention import Attention
5from .config import Config
6from .mlp import MLP
7from .moe import MoE
9# ── Normalisation ─────────────────────────────────────────────────────────────
11class RMSNorm(nnx.Module):
12 """
13 Root Mean Square Layer Normalisation (Zhang & Sennrich, 2019).
15 Faster than LayerNorm — no mean subtraction, no bias — with identical
16 empirical performance on modern LLMs (LLaMA, Mistral, Gemma, …).
17 """
19 def __init__(self, dim: int, *, eps: float = 1e-6, rngs: nnx.Rngs) -> None:
20 self.scale = nnx.Param(jnp.ones(dim))
21 self.eps = eps
23 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
24 rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps)
25 return (x / rms) * self.scale[...]
28def _build_norm(config: Config, dim: int, rngs: nnx.Rngs) -> nnx.Module:
29 """Return a RMSNorm or LayerNorm depending on ``config.norm_type``."""
30 if config.norm_type == "rmsnorm":
31 return RMSNorm(dim, rngs=rngs)
32 return nnx.LayerNorm(dim, rngs=rngs)
35# ── Transformer block ─────────────────────────────────────────────────────────
37class Block(nnx.Module):
38 def __init__(self, config: Config, rngs: nnx.Rngs):
39 self.attention: Attention = Attention(config, rngs)
40 self.ln1: nnx.Module = _build_norm(config, config.dim, rngs)
41 self.ln2: nnx.Module = _build_norm(config, config.dim, rngs)
42 self.use_moe: bool = config.use_moe
43 if self.use_moe:
44 self.moe = MoE(config, rngs)
45 else:
46 self.mlp = MLP(config, rngs)
48 def __call__(self, x: jnp.ndarray,
49 use_cache: bool,
50 kv_cache: tuple,
51 cache_index: int,
52 deterministic: bool = False) -> tuple[jnp.ndarray, tuple, jnp.ndarray | float]:
54 x_attn, kv_cache = self.attention(self.ln1(x),
55 use_cache=use_cache,
56 kv_cache=kv_cache,
57 cache_index=cache_index,
58 deterministic=deterministic)
59 x = x + x_attn
60 ff, balancing_loss = (
61 self.moe(self.ln2(x), deterministic=deterministic)
62 if self.use_moe
63 else self.mlp(self.ln2(x), deterministic=deterministic)
64 )
65 x = x + ff
66 return x, kv_cache, balancing_loss