Coverage for core / lora.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 21:16 +0200

1from __future__ import annotations 

2 

3import math 

4 

5import flax.nnx as nnx 

6import jax 

7import jax.numpy as jnp 

8 

9 

10class LoRAParam(nnx.Variable): 

11 """Trainable LoRA variable — distinct type so base nnx.Param weights stay frozen.""" 

12 pass 

13 

14 

15class LoRALinear(nnx.Module): 

16 """Drop-in replacement for nnx.Linear with frozen base weight and trainable low-rank delta. 

17 

18 The effective weight is W_eff = W_base + (alpha / rank) * A @ B, 

19 where A is initialised with scaled Gaussian noise and B with zeros, 

20 so the adapter contributes nothing at initialisation. 

21 """ 

22 

23 def __init__( 

24 self, 

25 in_features: int, 

26 out_features: int, 

27 *, 

28 rank: int = 8, 

29 alpha: float = 16.0, 

30 dropout_rate: float = 0.0, 

31 use_bias: bool = False, 

32 rngs: nnx.Rngs, 

33 ) -> None: 

34 self.base = nnx.Linear(in_features, out_features, use_bias=use_bias, rngs=rngs) 

35 self.scale = alpha / rank 

36 

37 key = rngs.params() 

38 k_a, k_b = jax.random.split(key) 

39 self.lora_A = LoRAParam( 

40 jax.random.normal(k_a, (in_features, rank)) / math.sqrt(in_features) 

41 ) 

42 self.lora_B = LoRAParam(jnp.zeros((rank, out_features))) 

43 

44 self.dropout: nnx.Dropout | None = ( 

45 nnx.Dropout(dropout_rate, rngs=rngs) if dropout_rate > 0.0 else None 

46 ) 

47 

48 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray: 

49 out = self.base(x) 

50 delta = x @ self.lora_A[...] 

51 if self.dropout is not None: 

52 delta = self.dropout(delta, deterministic=deterministic) 

53 return out + (delta @ self.lora_B[...]) * self.scale 

54 

55 def merge_weights(self) -> jnp.ndarray: 

56 """Return fused kernel W + (alpha/r) * A @ B for export or deployment.""" 

57 return self.base.kernel[...] + self.scale * (self.lora_A[...] @ self.lora_B[...])