Coverage for core / mlp.py: 90%
31 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 21:16 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 21:16 +0200
1import flax.nnx as nnx
2import jax
3import jax.numpy as jnp
5from .config import Config
6from .lora import LoRALinear
9class Activation(nnx.Module):
10 def __init__(self, activation_name: str):
11 self.activation_name = activation_name
13 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
14 act_fn = getattr(jax.nn, self.activation_name.lower(), jax.nn.gelu)
15 return act_fn(x)
17class Swiglu(nnx.Module):
18 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
19 gate, data = jnp.split(x, 2, axis=-1)
20 return jax.nn.silu(gate) * data
22class MLP(nnx.Module):
23 def __init__(self, config: Config, rngs: nnx.Rngs):
24 intermediate_dim = config.dim * config.expansion
25 up_proj_dim = intermediate_dim * 2 if config.use_swiglu else intermediate_dim
27 _use_lora_mlp = getattr(config, "use_lora", False) and getattr(config, "lora_targets", "attention") in ("mlp", "all")
28 _lora_kw: dict = dict(rank=getattr(config, "lora_rank", 8), alpha=getattr(config, "lora_alpha", 16.0), dropout_rate=getattr(config, "lora_dropout", 0.0), rngs=rngs)
30 self.up_proj: nnx.Linear | LoRALinear = (
31 LoRALinear(config.dim, up_proj_dim, **_lora_kw) if _use_lora_mlp
32 else nnx.Linear(config.dim, up_proj_dim, rngs=rngs)
33 )
34 self.down_proj: nnx.Linear | LoRALinear = (
35 LoRALinear(intermediate_dim, config.dim, **_lora_kw) if _use_lora_mlp
36 else nnx.Linear(intermediate_dim, config.dim, rngs=rngs)
37 )
38 self.activation = Swiglu() if config.use_swiglu else Activation(config.activation)
39 self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
40 self.mlp_loss = 0
42 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, float]:
43 x = self.up_proj(x)
44 x = self.activation(x)
45 x = self.down_proj(x)
46 return self.dropout(x, deterministic=deterministic), self.mlp_loss