Coverage for core / model.py: 60%
87 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
1from __future__ import annotations
3import flax.nnx as nnx
4import jax
5import jax.numpy as jnp
7from .block import Block, _build_norm
8from .config import Config
9from .output import ModelOutput
12class Transformer(nnx.Module, pytree=False):
13 def __init__(self, config: Config, rngs: nnx.Rngs):
14 self.num_blocks: int = config.num_blocks
15 self.blocks: list = [Block(config, rngs=rngs) for _ in range(self.num_blocks)]
16 self.wte: nnx.Embed = nnx.Embed(config.vocab_size, config.dim, rngs=rngs)
17 self.weight_tying: bool = config.weight_tying
18 self.trainable_pos: bool = config.trainable_pos
19 self.absolute_pos: bool = config.absolute_pos
20 self.max_context: int = config.max_context
21 self.gradient_checkpointing: bool = config.gradient_checkpointing
22 self.ln_f: nnx.Module = _build_norm(config, config.dim, rngs)
23 self.emb_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
24 self.use_moe: bool = config.use_moe
25 self.alpha_balance: float = config.alpha_balance
27 if config.weight_tying:
28 # Share the embedding Param so lm_head.kernel stays a tracked nnx.Variable.
29 # Assigning embedding.T (a raw array) would silently drop it from NNX's
30 # state graph and cause DynamicJaxprTracer errors inside @nnx.jit.
31 self.lm_head: nnx.Linear | None = None
32 else:
33 self.lm_head = nnx.Linear(config.dim, config.vocab_size, rngs=rngs)
34 if self.trainable_pos:
35 self.wpe: nnx.Embed = nnx.Embed(config.max_context, config.dim, rngs=rngs)
36 elif self.absolute_pos:
37 def _build_compute_absolute_pos(T: int, C: int) -> jnp.ndarray:
38 pos = jnp.zeros((T, C))
39 row = jnp.arange(T)
40 col = jnp.arange(0, C, 2)
41 k = 1.0 / (10000 ** (col / C))
42 ratio = jnp.einsum('i,j->ij', row, k)
43 pos = pos.at[:, 0::2].set(jnp.sin(ratio))
44 pos = pos.at[:, 1::2].set(jnp.cos(ratio))
45 return jnp.expand_dims(pos, axis=0)
47 self.wpe: jnp.ndarray = _build_compute_absolute_pos(config.max_context, config.dim) # type: ignore[assignment, no-redef]
49 def __call__(self,
50 x: jnp.ndarray,
51 use_cache: bool,
52 kv_caches: tuple | None,
53 cache_index: int | None,
54 deterministic: bool = False) -> ModelOutput:
56 B, T = x.shape
57 x = self.wte(x)
58 if kv_caches is None:
59 kv_caches = tuple((None, None) for _ in range(self.num_blocks))
60 if self.absolute_pos:
61 wpe_slice = jax.lax.dynamic_slice_in_dim(
62 self.wpe, # type: ignore[arg-type]
63 start_index=cache_index, # type: ignore[arg-type]
64 slice_size=T,
65 axis=1
66 )
67 x = x + wpe_slice
68 elif self.trainable_pos:
69 x = x + self.wpe(jnp.arange(T, dtype=x.dtype))
71 x = self.emb_dropout(x, deterministic=deterministic)
73 def block_fn(block_module: object, hidden_state: jnp.ndarray, kv_c: object, det: bool) -> tuple:
74 return block_module( # type: ignore[call-arg, operator]
75 hidden_state,
76 use_cache=use_cache,
77 kv_cache=kv_c,
78 cache_index=cache_index,
79 deterministic=det
80 )
82 def _apply_block(bm: object, hs: jnp.ndarray, kvc: object) -> tuple:
83 return block_fn(bm, hs, kvc, deterministic)
85 if self.gradient_checkpointing and not use_cache:
86 checkpointed_block = nnx.remat(_apply_block)
87 else:
88 checkpointed_block = _apply_block
90 new_kv_caches = []
91 balancing_loss_total = 0.0
92 for i, block in enumerate(self.blocks):
93 x, new_kv, balancing_loss = checkpointed_block(block, x, kv_caches[i] if kv_caches else None)
94 new_kv_caches.append(new_kv)
95 balancing_loss_total += balancing_loss
97 x = self.ln_f(x)
99 logits = x @ self.wte.embedding[...].T if self.weight_tying else self.lm_head(x) # type: ignore[union-attr, misc]
102 return ModelOutput(
103 logits=logits,
104 kv_caches=tuple(new_kv_caches),
105 aux_loss=balancing_loss_total,
106 )
108 @classmethod
109 def from_pretrained(
110 cls,
111 path_or_repo: str,
112 rngs: nnx.Rngs | None = None,
113 *,
114 best: bool = True,
115 token: str | None = None,
116 revision: str | None = None,
117 ) -> Transformer:
118 """Load a trained Transformer from a local directory or HuggingFace Hub.
120 Parameters
121 ----------
122 path_or_repo:
123 Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such
124 as ``"my-org/dantinox-dante"``. The checkpoint is downloaded
125 automatically when a Hub ID is given.
126 rngs:
127 PRNG state for initialisation. Defaults to ``nnx.Rngs(0)``.
128 best:
129 When ``True`` (default), loads ``best_model_weights.msgpack``
130 if it exists, otherwise falls back to ``model_weights.msgpack``.
131 token:
132 HuggingFace access token for private repositories.
133 revision:
134 Branch, tag, or commit SHA to download from the Hub.
135 """
136 import contextlib
137 import os
139 import msgpack
141 # Lazy import to avoid circular dependency (core ← dantinox)
142 from dantinox.hub import resolve_checkpoint # type: ignore[import]
144 run_dir = resolve_checkpoint(path_or_repo, token=token, revision=revision)
146 if rngs is None:
147 rngs = nnx.Rngs(0)
149 config = Config.from_yaml(os.path.join(run_dir, "config.yaml"))
150 model = cls(config, rngs=rngs)
152 weights_path = os.path.join(run_dir, "best_model_weights.msgpack")
153 if not best or not os.path.exists(weights_path):
154 weights_path = os.path.join(run_dir, "model_weights.msgpack")
156 with open(weights_path, "rb") as f:
157 raw = f.read()
159 # Use the same private hook that trainer.py uses for consistency.
160 _ext_hook: object = None
161 with contextlib.suppress(ImportError):
162 from flax.serialization import _msgpack_ext_unpack # type: ignore[attr-defined]
163 _ext_hook = _msgpack_ext_unpack
165 state_dict = msgpack.unpackb(raw, ext_hook=_ext_hook, strict_map_key=False)
166 nnx.update(model, state_dict)
167 return model