Coverage for core / moe.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 16:48 +0200

1 

2import flax.nnx as nnx 

3import jax 

4import jax.numpy as jnp 

5 

6from .config import Config 

7from .mlp import MLP 

8 

9 

10class MoE(nnx.Module): 

11 def __init__(self, config: Config, rngs: nnx.Rngs): 

12 self.n_experts: int = config.n_experts 

13 self.experts: nnx.List = nnx.List([MLP(config, rngs) for _ in range(self.n_experts)]) 

14 self.router: nnx.Linear = nnx.Linear(config.dim, self.n_experts, use_bias=False, rngs=rngs) 

15 self.top_k_mlp: int = config.top_k_mlp 

16 

17 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, jnp.ndarray]: 

18 B, T, _ = x.shape 

19 x_routed = self.router(x) 

20 probs = jax.nn.softmax(x_routed) 

21 values, indices = jax.lax.top_k(probs, self.top_k_mlp) 

22 values = values / jnp.sum(values, axis=-1, keepdims=True) 

23 y = jnp.zeros_like(x) 

24 

25 expert_mean_prob = jnp.mean(jnp.reshape(probs, (B*T, self.n_experts)), axis=0) 

26 freq = jnp.mean(jnp.sum(jax.nn.one_hot(indices, self.n_experts), axis=2), axis=(0, 1)) 

27 moe_loss = jnp.sum(freq*expert_mean_prob) * self.n_experts 

28 

29 for i in range(self.n_experts): 

30 mask = (indices == i) 

31 expert_weight = jnp.sum(jnp.where(mask, values, 0), axis=-1, keepdims=True) 

32 expert_out, _ = self.experts[i](x, deterministic=deterministic) 

33 y = y + (expert_weight * expert_out) 

34 return y, moe_loss