Coverage for core / lora.py: 100%
24 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 21:16 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 21:16 +0200
1from __future__ import annotations
3import math
5import flax.nnx as nnx
6import jax
7import jax.numpy as jnp
10class LoRAParam(nnx.Variable):
11 """Trainable LoRA variable — distinct type so base nnx.Param weights stay frozen."""
12 pass
15class LoRALinear(nnx.Module):
16 """Drop-in replacement for nnx.Linear with frozen base weight and trainable low-rank delta.
18 The effective weight is W_eff = W_base + (alpha / rank) * A @ B,
19 where A is initialised with scaled Gaussian noise and B with zeros,
20 so the adapter contributes nothing at initialisation.
21 """
23 def __init__(
24 self,
25 in_features: int,
26 out_features: int,
27 *,
28 rank: int = 8,
29 alpha: float = 16.0,
30 dropout_rate: float = 0.0,
31 use_bias: bool = False,
32 rngs: nnx.Rngs,
33 ) -> None:
34 self.base = nnx.Linear(in_features, out_features, use_bias=use_bias, rngs=rngs)
35 self.scale = alpha / rank
37 key = rngs.params()
38 k_a, k_b = jax.random.split(key)
39 self.lora_A = LoRAParam(
40 jax.random.normal(k_a, (in_features, rank)) / math.sqrt(in_features)
41 )
42 self.lora_B = LoRAParam(jnp.zeros((rank, out_features)))
44 self.dropout: nnx.Dropout | None = (
45 nnx.Dropout(dropout_rate, rngs=rngs) if dropout_rate > 0.0 else None
46 )
48 def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray:
49 out = self.base(x)
50 delta = x @ self.lora_A[...]
51 if self.dropout is not None:
52 delta = self.dropout(delta, deterministic=deterministic)
53 return out + (delta @ self.lora_B[...]) * self.scale
55 def merge_weights(self) -> jnp.ndarray:
56 """Return fused kernel W + (alpha/r) * A @ B for export or deployment."""
57 return self.base.kernel[...] + self.scale * (self.lora_A[...] @ self.lora_B[...])