Coverage for core / block.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 16:17 +0200

1import flax.nnx as nnx 

2import jax.numpy as jnp 

3 

4from .attention import Attention 

5from .config import Config 

6from .mlp import MLP 

7from .moe import MoE 

8 

9# ── Normalisation ───────────────────────────────────────────────────────────── 

10 

11class RMSNorm(nnx.Module): 

12 """ 

13 Root Mean Square Layer Normalisation (Zhang & Sennrich, 2019). 

14 

15 Faster than LayerNorm — no mean subtraction, no bias — with identical 

16 empirical performance on modern LLMs (LLaMA, Mistral, Gemma, …). 

17 """ 

18 

19 def __init__(self, dim: int, *, eps: float = 1e-6, rngs: nnx.Rngs) -> None: 

20 self.scale = nnx.Param(jnp.ones(dim)) 

21 self.eps = eps 

22 

23 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 

24 rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps) 

25 return (x / rms) * self.scale[...] 

26 

27 

28def _build_norm(config: Config, dim: int, rngs: nnx.Rngs) -> nnx.Module: 

29 """Return a RMSNorm or LayerNorm depending on ``config.norm_type``.""" 

30 if config.norm_type == "rmsnorm": 

31 return RMSNorm(dim, rngs=rngs) 

32 return nnx.LayerNorm(dim, rngs=rngs) 

33 

34 

35# ── Transformer block ───────────────────────────────────────────────────────── 

36 

37class Block(nnx.Module): 

38 def __init__(self, config: Config, rngs: nnx.Rngs): 

39 self.attention: Attention = Attention(config, rngs) 

40 self.ln1: nnx.Module = _build_norm(config, config.dim, rngs) 

41 self.ln2: nnx.Module = _build_norm(config, config.dim, rngs) 

42 self.use_moe: bool = config.use_moe 

43 if self.use_moe: 

44 self.moe = MoE(config, rngs) 

45 else: 

46 self.mlp = MLP(config, rngs) 

47 

48 def __call__(self, x: jnp.ndarray, 

49 use_cache: bool, 

50 kv_cache: tuple, 

51 cache_index: int, 

52 deterministic: bool = False) -> tuple[jnp.ndarray, tuple, jnp.ndarray | float]: 

53 

54 x_attn, kv_cache = self.attention(self.ln1(x), 

55 use_cache=use_cache, 

56 kv_cache=kv_cache, 

57 cache_index=cache_index, 

58 deterministic=deterministic) 

59 x = x + x_attn 

60 ff, balancing_loss = ( 

61 self.moe(self.ln2(x), deterministic=deterministic) 

62 if self.use_moe 

63 else self.mlp(self.ln2(x), deterministic=deterministic) 

64 ) 

65 x = x + ff 

66 return x, kv_cache, balancing_loss