Coverage for core / model.py: 60%

87 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 11:22 +0200

1from __future__ import annotations 

2 

3import flax.nnx as nnx 

4import jax 

5import jax.numpy as jnp 

6 

7from .block import Block, _build_norm 

8from .config import Config 

9from .output import ModelOutput 

10 

11 

12class Transformer(nnx.Module, pytree=False): 

13 def __init__(self, config: Config, rngs: nnx.Rngs): 

14 self.num_blocks: int = config.num_blocks 

15 self.blocks: list = [Block(config, rngs=rngs) for _ in range(self.num_blocks)] 

16 self.wte: nnx.Embed = nnx.Embed(config.vocab_size, config.dim, rngs=rngs) 

17 self.weight_tying: bool = config.weight_tying 

18 self.trainable_pos: bool = config.trainable_pos 

19 self.absolute_pos: bool = config.absolute_pos 

20 self.max_context: int = config.max_context 

21 self.gradient_checkpointing: bool = config.gradient_checkpointing 

22 self.ln_f: nnx.Module = _build_norm(config, config.dim, rngs) 

23 self.emb_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) 

24 self.use_moe: bool = config.use_moe 

25 self.alpha_balance: float = config.alpha_balance 

26 

27 if config.weight_tying: 

28 # Share the embedding Param so lm_head.kernel stays a tracked nnx.Variable. 

29 # Assigning embedding.T (a raw array) would silently drop it from NNX's 

30 # state graph and cause DynamicJaxprTracer errors inside @nnx.jit. 

31 self.lm_head: nnx.Linear | None = None 

32 else: 

33 self.lm_head = nnx.Linear(config.dim, config.vocab_size, rngs=rngs) 

34 if self.trainable_pos: 

35 self.wpe: nnx.Embed = nnx.Embed(config.max_context, config.dim, rngs=rngs) 

36 elif self.absolute_pos: 

37 def _build_compute_absolute_pos(T: int, C: int) -> jnp.ndarray: 

38 pos = jnp.zeros((T, C)) 

39 row = jnp.arange(T) 

40 col = jnp.arange(0, C, 2) 

41 k = 1.0 / (10000 ** (col / C)) 

42 ratio = jnp.einsum('i,j->ij', row, k) 

43 pos = pos.at[:, 0::2].set(jnp.sin(ratio)) 

44 pos = pos.at[:, 1::2].set(jnp.cos(ratio)) 

45 return jnp.expand_dims(pos, axis=0) 

46 

47 self.wpe: jnp.ndarray = _build_compute_absolute_pos(config.max_context, config.dim) # type: ignore[assignment, no-redef] 

48 

49 def __call__(self, 

50 x: jnp.ndarray, 

51 use_cache: bool, 

52 kv_caches: tuple | None, 

53 cache_index: int | None, 

54 deterministic: bool = False) -> ModelOutput: 

55 

56 B, T = x.shape 

57 x = self.wte(x) 

58 if kv_caches is None: 

59 kv_caches = tuple((None, None) for _ in range(self.num_blocks)) 

60 if self.absolute_pos: 

61 wpe_slice = jax.lax.dynamic_slice_in_dim( 

62 self.wpe, # type: ignore[arg-type] 

63 start_index=cache_index, # type: ignore[arg-type] 

64 slice_size=T, 

65 axis=1 

66 ) 

67 x = x + wpe_slice 

68 elif self.trainable_pos: 

69 x = x + self.wpe(jnp.arange(T, dtype=x.dtype)) 

70 

71 x = self.emb_dropout(x, deterministic=deterministic) 

72 

73 def block_fn(block_module: object, hidden_state: jnp.ndarray, kv_c: object, det: bool) -> tuple: 

74 return block_module( # type: ignore[call-arg, operator] 

75 hidden_state, 

76 use_cache=use_cache, 

77 kv_cache=kv_c, 

78 cache_index=cache_index, 

79 deterministic=det 

80 ) 

81 

82 def _apply_block(bm: object, hs: jnp.ndarray, kvc: object) -> tuple: 

83 return block_fn(bm, hs, kvc, deterministic) 

84 

85 if self.gradient_checkpointing and not use_cache: 

86 checkpointed_block = nnx.remat(_apply_block) 

87 else: 

88 checkpointed_block = _apply_block 

89 

90 new_kv_caches = [] 

91 balancing_loss_total = 0.0 

92 for i, block in enumerate(self.blocks): 

93 x, new_kv, balancing_loss = checkpointed_block(block, x, kv_caches[i] if kv_caches else None) 

94 new_kv_caches.append(new_kv) 

95 balancing_loss_total += balancing_loss 

96 

97 x = self.ln_f(x) 

98 

99 logits = x @ self.wte.embedding[...].T if self.weight_tying else self.lm_head(x) # type: ignore[union-attr, misc] 

100 

101 

102 return ModelOutput( 

103 logits=logits, 

104 kv_caches=tuple(new_kv_caches), 

105 aux_loss=balancing_loss_total, 

106 ) 

107 

108 @classmethod 

109 def from_pretrained( 

110 cls, 

111 path_or_repo: str, 

112 rngs: nnx.Rngs | None = None, 

113 *, 

114 best: bool = True, 

115 token: str | None = None, 

116 revision: str | None = None, 

117 ) -> Transformer: 

118 """Load a trained Transformer from a local directory or HuggingFace Hub. 

119 

120 Parameters 

121 ---------- 

122 path_or_repo: 

123 Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such 

124 as ``"my-org/dantinox-dante"``. The checkpoint is downloaded 

125 automatically when a Hub ID is given. 

126 rngs: 

127 PRNG state for initialisation. Defaults to ``nnx.Rngs(0)``. 

128 best: 

129 When ``True`` (default), loads ``best_model_weights.msgpack`` 

130 if it exists, otherwise falls back to ``model_weights.msgpack``. 

131 token: 

132 HuggingFace access token for private repositories. 

133 revision: 

134 Branch, tag, or commit SHA to download from the Hub. 

135 """ 

136 import contextlib 

137 import os 

138 

139 import msgpack 

140 

141 # Lazy import to avoid circular dependency (core ← dantinox) 

142 from dantinox.hub import resolve_checkpoint # type: ignore[import] 

143 

144 run_dir = resolve_checkpoint(path_or_repo, token=token, revision=revision) 

145 

146 if rngs is None: 

147 rngs = nnx.Rngs(0) 

148 

149 config = Config.from_yaml(os.path.join(run_dir, "config.yaml")) 

150 model = cls(config, rngs=rngs) 

151 

152 weights_path = os.path.join(run_dir, "best_model_weights.msgpack") 

153 if not best or not os.path.exists(weights_path): 

154 weights_path = os.path.join(run_dir, "model_weights.msgpack") 

155 

156 with open(weights_path, "rb") as f: 

157 raw = f.read() 

158 

159 # Use the same private hook that trainer.py uses for consistency. 

160 _ext_hook: object = None 

161 with contextlib.suppress(ImportError): 

162 from flax.serialization import _msgpack_ext_unpack # type: ignore[attr-defined] 

163 _ext_hook = _msgpack_ext_unpack 

164 

165 state_dict = msgpack.unpackb(raw, ext_hook=_ext_hook, strict_map_key=False) 

166 nnx.update(model, state_dict) 

167 return model