Coverage for core / mlp.py: 90%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 21:16 +0200

1import flax.nnx as nnx 

2import jax 

3import jax.numpy as jnp 

4 

5from .config import Config 

6from .lora import LoRALinear 

7 

8 

9class Activation(nnx.Module): 

10 def __init__(self, activation_name: str): 

11 self.activation_name = activation_name 

12 

13 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 

14 act_fn = getattr(jax.nn, self.activation_name.lower(), jax.nn.gelu) 

15 return act_fn(x) 

16 

17class Swiglu(nnx.Module): 

18 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 

19 gate, data = jnp.split(x, 2, axis=-1) 

20 return jax.nn.silu(gate) * data 

21 

22class MLP(nnx.Module): 

23 def __init__(self, config: Config, rngs: nnx.Rngs): 

24 intermediate_dim = config.dim * config.expansion 

25 up_proj_dim = intermediate_dim * 2 if config.use_swiglu else intermediate_dim 

26 

27 _use_lora_mlp = getattr(config, "use_lora", False) and getattr(config, "lora_targets", "attention") in ("mlp", "all") 

28 _lora_kw: dict = dict(rank=getattr(config, "lora_rank", 8), alpha=getattr(config, "lora_alpha", 16.0), dropout_rate=getattr(config, "lora_dropout", 0.0), rngs=rngs) 

29 

30 self.up_proj: nnx.Linear | LoRALinear = ( 

31 LoRALinear(config.dim, up_proj_dim, **_lora_kw) if _use_lora_mlp 

32 else nnx.Linear(config.dim, up_proj_dim, rngs=rngs) 

33 ) 

34 self.down_proj: nnx.Linear | LoRALinear = ( 

35 LoRALinear(intermediate_dim, config.dim, **_lora_kw) if _use_lora_mlp 

36 else nnx.Linear(intermediate_dim, config.dim, rngs=rngs) 

37 ) 

38 self.activation = Swiglu() if config.use_swiglu else Activation(config.activation) 

39 self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

40 self.mlp_loss = 0 

41 

42 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> tuple[jnp.ndarray, float]: 

43 x = self.up_proj(x) 

44 x = self.activation(x) 

45 x = self.down_proj(x) 

46 return self.dropout(x, deterministic=deterministic), self.mlp_loss