Coverage for dantinox / generator.py: 45%

159 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 11:22 +0200

1from __future__ import annotations 

2 

3import logging 

4import os 

5from collections.abc import Iterator 

6 

7import jax 

8import jax.numpy as jnp 

9import msgpack 

10import yaml 

11from flax import nnx 

12from flax.serialization import _msgpack_ext_unpack 

13 

14from core.config import Config 

15from core.generation import generate as _generate 

16from core.model import Transformer 

17from dantinox.exceptions import CheckpointError 

18from utils.tokenizer import Tokenizer, get_tokenizer, load_tokenizer_from_file 

19 

20log = logging.getLogger(__name__) 

21 

22_BPE_REPLACEMENTS = [ 

23 (" ", ""), 

24 ("Ġ", " "), 

25 ("âĢĻ", "'"), 

26 ("ù", "ù"), 

27 ("ì", "ì"), 

28 ("é", "é"), 

29 ("è", "è"), 

30 ("ò", "ò"), 

31 ("Ã", "à"), 

32] 

33 

34 

35# ── JIT-compiled streaming step functions ──────────────────────────────────── 

36 

37@nnx.jit 

38def _stream_prefill(model: nnx.Module, x: jnp.ndarray, kv_cache: tuple) -> tuple: 

39 """Full prompt forward pass. Returns (logits [B,T,V], filled_kv_cache).""" 

40 logits, new_cache, _ = model(x, True, kv_cache, 0, deterministic=True) 

41 return logits, new_cache 

42 

43 

44@nnx.jit 

45def _stream_decode(model: nnx.Module, tok: jnp.ndarray, kv_cache: tuple, pos: jax.Array) -> tuple: 

46 """Single-token decode step. Returns (logits [B,1,V], new_kv_cache).""" 

47 logits, new_cache, _ = model(tok, True, kv_cache, pos, deterministic=True) 

48 return logits, new_cache 

49 

50 

51def _sample_logit( 

52 logits: jnp.ndarray, 

53 key: jax.Array, 

54 greedy: bool, 

55 temperature: float, 

56 top_k: int | None, 

57 top_p: float | None, 

58) -> tuple[int, jax.Array]: 

59 """Sample one token id from a [V] logit vector. Returns (token_id, new_key).""" 

60 log_probs = jax.nn.log_softmax(logits[0].astype(jnp.float32) / temperature) 

61 

62 if greedy: 

63 return int(jnp.argmax(log_probs)), key 

64 

65 if top_k is not None: 

66 top_k_vals, top_k_idx = jax.lax.top_k(jnp.exp(log_probs), top_k) 

67 filtered = jnp.full_like(log_probs, -jnp.inf) 

68 filtered = filtered.at[top_k_idx].set( 

69 jnp.log(top_k_vals / top_k_vals.sum() + 1e-10) 

70 ) 

71 log_probs = filtered 

72 

73 if top_p is not None: 

74 probs = jnp.exp(log_probs) 

75 sorted_idx = jnp.argsort(probs)[::-1] 

76 sorted_probs = probs[sorted_idx] 

77 cum = jnp.cumsum(sorted_probs) 

78 mask = (cum - sorted_probs) >= top_p 

79 filtered_p = jnp.where(mask, 0.0, sorted_probs) 

80 filtered_p = filtered_p / (filtered_p.sum() + 1e-10) 

81 filtered_lp = jnp.full_like(log_probs, -jnp.inf) 

82 filtered_lp = filtered_lp.at[sorted_idx].set(jnp.log(filtered_p + 1e-10)) 

83 log_probs = filtered_lp 

84 

85 new_key, subkey = jax.random.split(key) 

86 tok_id = int(jax.random.categorical(subkey, log_probs)) 

87 return tok_id, new_key 

88 

89 

90# ── Checkpoint loader ───────────────────────────────────────────────────────── 

91 

92def _load_checkpoint(run_dir: str, seed: int) -> tuple[Config, Transformer, Tokenizer]: 

93 """Return (config, model, tokenizer) loaded from a local run directory.""" 

94 config_path = os.path.join(run_dir, "config.yaml") 

95 weights_path = os.path.join(run_dir, "model_weights.msgpack") 

96 

97 if not os.path.isdir(run_dir): 

98 raise CheckpointError(f"Run directory not found: {run_dir}") 

99 if not os.path.exists(config_path): 

100 raise CheckpointError(f"Config file not found: {config_path}") 

101 

102 with open(config_path) as f: 

103 raw = yaml.safe_load(f) 

104 

105 flat: dict = {} 

106 for section in raw.values(): 

107 if isinstance(section, dict): 

108 flat.update(section) 

109 if not flat: 

110 flat = raw 

111 

112 config = Config.from_dict(flat) 

113 

114 if config.mla: 

115 config.inference = True 

116 

117 tok_path = os.path.join(run_dir, "tokenizer.json") 

118 if os.path.exists(tok_path): 

119 tokenizer = load_tokenizer_from_file(tok_path) 

120 log.info("Loaded tokenizer from %s", tok_path) 

121 else: 

122 log.warning( 

123 "tokenizer.json not found in %r — rebuilding from original corpus " 

124 "(this only happens once; the file will be saved for future calls).", 

125 run_dir, 

126 ) 

127 if config.dataset_source == "huggingface": 

128 import logging as _logging 

129 # Silence the noisy httpx / datasets HTTP logs during the one-time download. 

130 for _noisy in ("httpx", "datasets", "huggingface_hub"): 

131 _logging.getLogger(_noisy).setLevel(_logging.WARNING) 

132 from datasets import load_dataset 

133 raw_dataset = load_dataset(config.dataset_name, split="train") 

134 text = " ".join(raw_dataset["text"]) 

135 else: 

136 if not os.path.exists(config.dataset_name): 

137 raise CheckpointError( 

138 f"tokenizer.json not found and dataset file {config.dataset_name!r} " 

139 "is also missing. Cannot rebuild the tokenizer vocabulary." 

140 ) 

141 with open(config.dataset_name, encoding="utf-8") as f: 

142 text = f.read() 

143 lines = [line.rstrip() for line in text.split("\n") if line.strip()] 

144 blocks = ["\n".join(lines[i : i + 3]) for i in range(0, len(lines), 3)] 

145 text = "\n\n".join(blocks) + "\n" 

146 tokenizer = get_tokenizer(config.tokenizer_type) 

147 if config.tokenizer_type == "char": 

148 tokenizer.train_from_text(text) 

149 elif config.tokenizer_type == "bpe": 

150 tokenizer.train_from_text(text, vocab_size=config.vocab_size) 

151 # Persist so the next call loads instantly without touching the corpus. 

152 tokenizer.save(tok_path) 

153 log.warning("Saved tokenizer to %s — subsequent calls will skip the download.", tok_path) 

154 

155 config.vocab_size = tokenizer.vocab_size 

156 

157 rngs = nnx.Rngs(seed) 

158 model = Transformer(config, rngs=rngs) 

159 

160 if os.path.exists(weights_path): 

161 log.info("Loading weights from %s", weights_path) 

162 with open(weights_path, "rb") as f: 

163 state_dict = msgpack.unpackb( 

164 f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False 

165 ) 

166 nnx.update(model, state_dict) 

167 else: 

168 log.warning("No weights file found at %s — using random initialisation", weights_path) 

169 

170 return config, model, tokenizer 

171 

172 

173# ── Generator ───────────────────────────────────────────────────────────────── 

174 

175class Generator: 

176 """ 

177 Loads a trained DantinoX checkpoint and generates text. 

178 

179 Accepts either a **local run directory** or a **HuggingFace Hub repo ID** 

180 — the checkpoint is downloaded automatically when needed. 

181 

182 Parameters 

183 ---------- 

184 run_dir : str 

185 Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such 

186 as ``"my-org/dantinox-dante"``. 

187 seed : int 

188 RNG seed used for sampling (default 42). 

189 token : str, optional 

190 HuggingFace access token for private repositories. 

191 revision : str, optional 

192 Branch, tag, or commit SHA to download from the Hub. 

193 

194 Raises 

195 ------ 

196 CheckpointError 

197 If the checkpoint cannot be found locally or downloaded from the Hub. 

198 

199 Examples 

200 -------- 

201 >>> gen = Generator("runs/run_20260101_120000") # local 

202 >>> gen = Generator("my-org/dantinox-dante") # HF Hub 

203 >>> gen = Generator("my-org/private-model", token="hf_…") # private Hub 

204 >>> text = gen.generate("Nel mezzo del cammin ") 

205 >>> print(text) 

206 """ 

207 

208 def __init__( 

209 self, 

210 run_dir: str, 

211 *, 

212 seed: int = 42, 

213 token: str | None = None, 

214 revision: str | None = None, 

215 ) -> None: 

216 from dantinox.hub import resolve_checkpoint 

217 

218 self.seed = seed 

219 # Resolve once: download from Hub if needed, then use the local path 

220 local_dir = resolve_checkpoint(run_dir, token=token, revision=revision) 

221 self.run_dir = local_dir 

222 self.config, self.model, self.tokenizer = _load_checkpoint(local_dir, seed) 

223 

224 def __repr__(self) -> str: 

225 attn = "MLA" if self.config.mla else ("GQA" if self.config.kv_heads < self.config.n_heads else "MHA") 

226 return f"Generator(run_dir={self.run_dir!r}, attn={attn}, seed={self.seed})" 

227 

228 def _bpe_fix(self, text: str) -> str: 

229 if self.config.tokenizer_type == "bpe": 

230 for old, new in _BPE_REPLACEMENTS: 

231 text = text.replace(old, new) 

232 return text 

233 

234 # ── Single-prompt generation ────────────────────────────────────────────── 

235 

236 def generate( 

237 self, 

238 prompt: str, 

239 *, 

240 max_new_tokens: int = 150, 

241 greedy: bool = False, 

242 top_k: int | None = None, 

243 top_p: float | None = None, 

244 temperature: float = 1.0, 

245 use_cache: bool = True, 

246 ) -> str: 

247 """ 

248 Generate text continuing from ``prompt``. 

249 

250 Parameters 

251 ---------- 

252 prompt : str 

253 The input prefix. 

254 max_new_tokens : int 

255 Number of tokens to generate (default 150). 

256 greedy : bool 

257 Use greedy decoding instead of sampling (default False). 

258 top_k : int, optional 

259 Keep only the top-k logits before sampling. 

260 top_p : float, optional 

261 Nucleus sampling threshold. 

262 temperature : float 

263 Softmax temperature (default 1.0). 

264 use_cache : bool 

265 Enable KV-cache for faster generation (default True). 

266 

267 Returns 

268 ------- 

269 str 

270 The full generated string (prompt + continuation). 

271 """ 

272 tokens = self.tokenizer.encode(prompt) 

273 x = jnp.array([tokens], dtype=jnp.int32) 

274 

275 log.debug( 

276 "Generating %d tokens from prompt of %d tokens (greedy=%s, cache=%s)", 

277 max_new_tokens, len(tokens), greedy, use_cache, 

278 ) 

279 

280 output = _generate( 

281 model=self.model, 

282 x=x, 

283 max_generations=max_new_tokens, 

284 greedy=greedy, 

285 seed=self.seed, 

286 use_cache=use_cache, 

287 top_k=top_k, 

288 top_p=top_p, 

289 temperature=temperature, 

290 ) 

291 output.block_until_ready() 

292 

293 return self._bpe_fix(self.tokenizer.decode(output[0].tolist())) 

294 

295 # ── Batched generation ──────────────────────────────────────────────────── 

296 

297 def generate_batch( 

298 self, 

299 prompts: list[str], 

300 *, 

301 max_new_tokens: int = 150, 

302 greedy: bool = False, 

303 top_k: int | None = None, 

304 top_p: float | None = None, 

305 temperature: float = 1.0, 

306 use_cache: bool = True, 

307 ) -> list[str]: 

308 """ 

309 Generate text for multiple prompts in a single batched forward pass. 

310 

311 Shorter prompts are left-padded with zeros so all share the same 

312 sequence length. This runs a true batch through the model, so 

313 throughput scales with GPU parallelism. 

314 

315 Parameters 

316 ---------- 

317 prompts : list[str] 

318 Input prefixes to generate from. 

319 max_new_tokens : int 

320 Tokens to generate per prompt (default 150). 

321 greedy : bool 

322 Greedy decoding (default False). 

323 top_k : int, optional 

324 Top-k filtering before sampling. 

325 top_p : float, optional 

326 Nucleus sampling threshold. 

327 temperature : float 

328 Softmax temperature (default 1.0). 

329 use_cache : bool 

330 Enable KV-cache (default True). 

331 

332 Returns 

333 ------- 

334 list[str] 

335 Generated strings (prompt + continuation) in the same order as 

336 ``prompts``. 

337 """ 

338 if not prompts: 

339 return [] 

340 

341 encoded = [self.tokenizer.encode(p) for p in prompts] 

342 max_len = max(len(e) for e in encoded) 

343 

344 # Left-pad shorter prompts with zeros so all share the same start position. 

345 padded = [([0] * (max_len - len(e))) + e for e in encoded] 

346 x = jnp.array(padded, dtype=jnp.int32) # [B, max_len] 

347 

348 log.debug("Batch generating: B=%d max_prompt_len=%d max_new=%d", len(prompts), max_len, max_new_tokens) 

349 

350 output = _generate( 

351 model=self.model, 

352 x=x, 

353 max_generations=max_new_tokens, 

354 greedy=greedy, 

355 seed=self.seed, 

356 use_cache=use_cache, 

357 top_k=top_k, 

358 top_p=top_p, 

359 temperature=temperature, 

360 ) 

361 output.block_until_ready() 

362 

363 results = [] 

364 for i, enc in enumerate(encoded): 

365 # Strip the left-padding: prompt starts at (max_len - len(enc)) 

366 start = max_len - len(enc) 

367 tokens_out = output[i, start:].tolist() 

368 results.append(self._bpe_fix(self.tokenizer.decode(tokens_out))) 

369 return results 

370 

371 # ── Streaming generation ────────────────────────────────────────────────── 

372 

373 def stream( 

374 self, 

375 prompt: str, 

376 *, 

377 max_new_tokens: int = 150, 

378 greedy: bool = False, 

379 top_k: int | None = None, 

380 top_p: float | None = None, 

381 temperature: float = 1.0, 

382 ) -> Iterator[str]: 

383 """ 

384 Stream generated tokens one at a time as they are produced. 

385 

386 Uses the KV-cache path: the prompt is prefilled in one forward pass, 

387 then each new token is decoded individually. Each ``yield`` returns 

388 the string for one generated token (may be a character or a BPE 

389 subword). 

390 

391 Parameters 

392 ---------- 

393 prompt : str 

394 The input prefix. 

395 max_new_tokens : int 

396 Maximum number of tokens to generate (default 150). 

397 greedy : bool 

398 Greedy decoding (default False). 

399 top_k : int, optional 

400 Top-k filtering. 

401 top_p : float, optional 

402 Nucleus sampling threshold. 

403 temperature : float 

404 Softmax temperature (default 1.0). 

405 

406 Yields 

407 ------ 

408 str 

409 Decoded string for each generated token. 

410 

411 Examples 

412 -------- 

413 >>> gen = Generator("runs/my_run") 

414 >>> for chunk in gen.stream("Nel mezzo", max_new_tokens=50): 

415 ... print(chunk, end="", flush=True) 

416 """ 

417 tokens = self.tokenizer.encode(prompt) 

418 T = len(tokens) 

419 max_ctx = self.config.max_context # type: ignore[attr-defined] 

420 num_blocks = self.config.num_blocks # type: ignore[attr-defined] 

421 

422 # Build full-context input with prompt at the start. 

423 x = jnp.zeros((1, max_ctx), dtype=jnp.int32) 

424 x = x.at[0, :T].set(jnp.array(tokens, dtype=jnp.int32)) 

425 

426 init_kv_cache = tuple((None, None) for _ in range(num_blocks)) 

427 key = jax.random.key(self.seed) 

428 

429 # Prefill: one pass over the entire prompt, populate KV cache. 

430 logits, kv_cache = _stream_prefill(self.model, x, init_kv_cache) 

431 

432 # Sample the first generated token from the last prompt position. 

433 tok_id, key = _sample_logit(logits[:, T - 1, :], key, greedy, temperature, top_k, top_p) 

434 yield self._bpe_fix(self.tokenizer.decode([tok_id])) 

435 

436 # Autoregressive decode loop. 

437 for pos in range(T, T + max_new_tokens - 1): 

438 if pos >= max_ctx: 

439 break 

440 tok = jnp.array([[tok_id]], dtype=jnp.int32) 

441 logits, kv_cache = _stream_decode(self.model, tok, kv_cache, jnp.array(pos)) 

442 tok_id, key = _sample_logit(logits[:, 0, :], key, greedy, temperature, top_k, top_p) 

443 yield self._bpe_fix(self.tokenizer.decode([tok_id]))