Coverage for core / moe.py: 100%
27 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 16:48 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 16:48 +0200
2import flax.nnx as nnx
3import jax
4import jax.numpy as jnp
6from .config import Config
7from .mlp import MLP
10class MoE(nnx.Module):
11 def __init__(self, config: Config, rngs: nnx.Rngs):
12 self.n_experts: int = config.n_experts
13 self.experts: nnx.List = nnx.List([MLP(config, rngs) for _ in range(self.n_experts)])
14 self.router: nnx.Linear = nnx.Linear(config.dim, self.n_experts, use_bias=False, rngs=rngs)
15 self.top_k_mlp: int = config.top_k_mlp
17 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, jnp.ndarray]:
18 B, T, _ = x.shape
19 x_routed = self.router(x)
20 probs = jax.nn.softmax(x_routed)
21 values, indices = jax.lax.top_k(probs, self.top_k_mlp)
22 values = values / jnp.sum(values, axis=-1, keepdims=True)
23 y = jnp.zeros_like(x)
25 expert_mean_prob = jnp.mean(jnp.reshape(probs, (B*T, self.n_experts)), axis=0)
26 freq = jnp.mean(jnp.sum(jax.nn.one_hot(indices, self.n_experts), axis=2), axis=(0, 1))
27 moe_loss = jnp.sum(freq*expert_mean_prob) * self.n_experts
29 for i in range(self.n_experts):
30 mask = (indices == i)
31 expert_weight = jnp.sum(jnp.where(mask, values, 0), axis=-1, keepdims=True)
32 expert_out, _ = self.experts[i](x, deterministic=deterministic)
33 y = y + (expert_weight * expert_out)
34 return y, moe_loss