Coverage for dantinox / generator.py: 45%
159 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
1from __future__ import annotations
3import logging
4import os
5from collections.abc import Iterator
7import jax
8import jax.numpy as jnp
9import msgpack
10import yaml
11from flax import nnx
12from flax.serialization import _msgpack_ext_unpack
14from core.config import Config
15from core.generation import generate as _generate
16from core.model import Transformer
17from dantinox.exceptions import CheckpointError
18from utils.tokenizer import Tokenizer, get_tokenizer, load_tokenizer_from_file
20log = logging.getLogger(__name__)
22_BPE_REPLACEMENTS = [
23 (" ", ""),
24 ("Ġ", " "),
25 ("âĢĻ", "'"),
26 ("ù", "ù"),
27 ("ì", "ì"),
28 ("é", "é"),
29 ("è", "è"),
30 ("ò", "ò"),
31 ("Ã", "à"),
32]
35# ── JIT-compiled streaming step functions ────────────────────────────────────
37@nnx.jit
38def _stream_prefill(model: nnx.Module, x: jnp.ndarray, kv_cache: tuple) -> tuple:
39 """Full prompt forward pass. Returns (logits [B,T,V], filled_kv_cache)."""
40 logits, new_cache, _ = model(x, True, kv_cache, 0, deterministic=True)
41 return logits, new_cache
44@nnx.jit
45def _stream_decode(model: nnx.Module, tok: jnp.ndarray, kv_cache: tuple, pos: jax.Array) -> tuple:
46 """Single-token decode step. Returns (logits [B,1,V], new_kv_cache)."""
47 logits, new_cache, _ = model(tok, True, kv_cache, pos, deterministic=True)
48 return logits, new_cache
51def _sample_logit(
52 logits: jnp.ndarray,
53 key: jax.Array,
54 greedy: bool,
55 temperature: float,
56 top_k: int | None,
57 top_p: float | None,
58) -> tuple[int, jax.Array]:
59 """Sample one token id from a [V] logit vector. Returns (token_id, new_key)."""
60 log_probs = jax.nn.log_softmax(logits[0].astype(jnp.float32) / temperature)
62 if greedy:
63 return int(jnp.argmax(log_probs)), key
65 if top_k is not None:
66 top_k_vals, top_k_idx = jax.lax.top_k(jnp.exp(log_probs), top_k)
67 filtered = jnp.full_like(log_probs, -jnp.inf)
68 filtered = filtered.at[top_k_idx].set(
69 jnp.log(top_k_vals / top_k_vals.sum() + 1e-10)
70 )
71 log_probs = filtered
73 if top_p is not None:
74 probs = jnp.exp(log_probs)
75 sorted_idx = jnp.argsort(probs)[::-1]
76 sorted_probs = probs[sorted_idx]
77 cum = jnp.cumsum(sorted_probs)
78 mask = (cum - sorted_probs) >= top_p
79 filtered_p = jnp.where(mask, 0.0, sorted_probs)
80 filtered_p = filtered_p / (filtered_p.sum() + 1e-10)
81 filtered_lp = jnp.full_like(log_probs, -jnp.inf)
82 filtered_lp = filtered_lp.at[sorted_idx].set(jnp.log(filtered_p + 1e-10))
83 log_probs = filtered_lp
85 new_key, subkey = jax.random.split(key)
86 tok_id = int(jax.random.categorical(subkey, log_probs))
87 return tok_id, new_key
90# ── Checkpoint loader ─────────────────────────────────────────────────────────
92def _load_checkpoint(run_dir: str, seed: int) -> tuple[Config, Transformer, Tokenizer]:
93 """Return (config, model, tokenizer) loaded from a local run directory."""
94 config_path = os.path.join(run_dir, "config.yaml")
95 weights_path = os.path.join(run_dir, "model_weights.msgpack")
97 if not os.path.isdir(run_dir):
98 raise CheckpointError(f"Run directory not found: {run_dir}")
99 if not os.path.exists(config_path):
100 raise CheckpointError(f"Config file not found: {config_path}")
102 with open(config_path) as f:
103 raw = yaml.safe_load(f)
105 flat: dict = {}
106 for section in raw.values():
107 if isinstance(section, dict):
108 flat.update(section)
109 if not flat:
110 flat = raw
112 config = Config.from_dict(flat)
114 if config.mla:
115 config.inference = True
117 tok_path = os.path.join(run_dir, "tokenizer.json")
118 if os.path.exists(tok_path):
119 tokenizer = load_tokenizer_from_file(tok_path)
120 log.info("Loaded tokenizer from %s", tok_path)
121 else:
122 log.warning(
123 "tokenizer.json not found in %r — rebuilding from original corpus "
124 "(this only happens once; the file will be saved for future calls).",
125 run_dir,
126 )
127 if config.dataset_source == "huggingface":
128 import logging as _logging
129 # Silence the noisy httpx / datasets HTTP logs during the one-time download.
130 for _noisy in ("httpx", "datasets", "huggingface_hub"):
131 _logging.getLogger(_noisy).setLevel(_logging.WARNING)
132 from datasets import load_dataset
133 raw_dataset = load_dataset(config.dataset_name, split="train")
134 text = " ".join(raw_dataset["text"])
135 else:
136 if not os.path.exists(config.dataset_name):
137 raise CheckpointError(
138 f"tokenizer.json not found and dataset file {config.dataset_name!r} "
139 "is also missing. Cannot rebuild the tokenizer vocabulary."
140 )
141 with open(config.dataset_name, encoding="utf-8") as f:
142 text = f.read()
143 lines = [line.rstrip() for line in text.split("\n") if line.strip()]
144 blocks = ["\n".join(lines[i : i + 3]) for i in range(0, len(lines), 3)]
145 text = "\n\n".join(blocks) + "\n"
146 tokenizer = get_tokenizer(config.tokenizer_type)
147 if config.tokenizer_type == "char":
148 tokenizer.train_from_text(text)
149 elif config.tokenizer_type == "bpe":
150 tokenizer.train_from_text(text, vocab_size=config.vocab_size)
151 # Persist so the next call loads instantly without touching the corpus.
152 tokenizer.save(tok_path)
153 log.warning("Saved tokenizer to %s — subsequent calls will skip the download.", tok_path)
155 config.vocab_size = tokenizer.vocab_size
157 rngs = nnx.Rngs(seed)
158 model = Transformer(config, rngs=rngs)
160 if os.path.exists(weights_path):
161 log.info("Loading weights from %s", weights_path)
162 with open(weights_path, "rb") as f:
163 state_dict = msgpack.unpackb(
164 f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False
165 )
166 nnx.update(model, state_dict)
167 else:
168 log.warning("No weights file found at %s — using random initialisation", weights_path)
170 return config, model, tokenizer
173# ── Generator ─────────────────────────────────────────────────────────────────
175class Generator:
176 """
177 Loads a trained DantinoX checkpoint and generates text.
179 Accepts either a **local run directory** or a **HuggingFace Hub repo ID**
180 — the checkpoint is downloaded automatically when needed.
182 Parameters
183 ----------
184 run_dir : str
185 Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such
186 as ``"my-org/dantinox-dante"``.
187 seed : int
188 RNG seed used for sampling (default 42).
189 token : str, optional
190 HuggingFace access token for private repositories.
191 revision : str, optional
192 Branch, tag, or commit SHA to download from the Hub.
194 Raises
195 ------
196 CheckpointError
197 If the checkpoint cannot be found locally or downloaded from the Hub.
199 Examples
200 --------
201 >>> gen = Generator("runs/run_20260101_120000") # local
202 >>> gen = Generator("my-org/dantinox-dante") # HF Hub
203 >>> gen = Generator("my-org/private-model", token="hf_…") # private Hub
204 >>> text = gen.generate("Nel mezzo del cammin ")
205 >>> print(text)
206 """
208 def __init__(
209 self,
210 run_dir: str,
211 *,
212 seed: int = 42,
213 token: str | None = None,
214 revision: str | None = None,
215 ) -> None:
216 from dantinox.hub import resolve_checkpoint
218 self.seed = seed
219 # Resolve once: download from Hub if needed, then use the local path
220 local_dir = resolve_checkpoint(run_dir, token=token, revision=revision)
221 self.run_dir = local_dir
222 self.config, self.model, self.tokenizer = _load_checkpoint(local_dir, seed)
224 def __repr__(self) -> str:
225 attn = "MLA" if self.config.mla else ("GQA" if self.config.kv_heads < self.config.n_heads else "MHA")
226 return f"Generator(run_dir={self.run_dir!r}, attn={attn}, seed={self.seed})"
228 def _bpe_fix(self, text: str) -> str:
229 if self.config.tokenizer_type == "bpe":
230 for old, new in _BPE_REPLACEMENTS:
231 text = text.replace(old, new)
232 return text
234 # ── Single-prompt generation ──────────────────────────────────────────────
236 def generate(
237 self,
238 prompt: str,
239 *,
240 max_new_tokens: int = 150,
241 greedy: bool = False,
242 top_k: int | None = None,
243 top_p: float | None = None,
244 temperature: float = 1.0,
245 use_cache: bool = True,
246 ) -> str:
247 """
248 Generate text continuing from ``prompt``.
250 Parameters
251 ----------
252 prompt : str
253 The input prefix.
254 max_new_tokens : int
255 Number of tokens to generate (default 150).
256 greedy : bool
257 Use greedy decoding instead of sampling (default False).
258 top_k : int, optional
259 Keep only the top-k logits before sampling.
260 top_p : float, optional
261 Nucleus sampling threshold.
262 temperature : float
263 Softmax temperature (default 1.0).
264 use_cache : bool
265 Enable KV-cache for faster generation (default True).
267 Returns
268 -------
269 str
270 The full generated string (prompt + continuation).
271 """
272 tokens = self.tokenizer.encode(prompt)
273 x = jnp.array([tokens], dtype=jnp.int32)
275 log.debug(
276 "Generating %d tokens from prompt of %d tokens (greedy=%s, cache=%s)",
277 max_new_tokens, len(tokens), greedy, use_cache,
278 )
280 output = _generate(
281 model=self.model,
282 x=x,
283 max_generations=max_new_tokens,
284 greedy=greedy,
285 seed=self.seed,
286 use_cache=use_cache,
287 top_k=top_k,
288 top_p=top_p,
289 temperature=temperature,
290 )
291 output.block_until_ready()
293 return self._bpe_fix(self.tokenizer.decode(output[0].tolist()))
295 # ── Batched generation ────────────────────────────────────────────────────
297 def generate_batch(
298 self,
299 prompts: list[str],
300 *,
301 max_new_tokens: int = 150,
302 greedy: bool = False,
303 top_k: int | None = None,
304 top_p: float | None = None,
305 temperature: float = 1.0,
306 use_cache: bool = True,
307 ) -> list[str]:
308 """
309 Generate text for multiple prompts in a single batched forward pass.
311 Shorter prompts are left-padded with zeros so all share the same
312 sequence length. This runs a true batch through the model, so
313 throughput scales with GPU parallelism.
315 Parameters
316 ----------
317 prompts : list[str]
318 Input prefixes to generate from.
319 max_new_tokens : int
320 Tokens to generate per prompt (default 150).
321 greedy : bool
322 Greedy decoding (default False).
323 top_k : int, optional
324 Top-k filtering before sampling.
325 top_p : float, optional
326 Nucleus sampling threshold.
327 temperature : float
328 Softmax temperature (default 1.0).
329 use_cache : bool
330 Enable KV-cache (default True).
332 Returns
333 -------
334 list[str]
335 Generated strings (prompt + continuation) in the same order as
336 ``prompts``.
337 """
338 if not prompts:
339 return []
341 encoded = [self.tokenizer.encode(p) for p in prompts]
342 max_len = max(len(e) for e in encoded)
344 # Left-pad shorter prompts with zeros so all share the same start position.
345 padded = [([0] * (max_len - len(e))) + e for e in encoded]
346 x = jnp.array(padded, dtype=jnp.int32) # [B, max_len]
348 log.debug("Batch generating: B=%d max_prompt_len=%d max_new=%d", len(prompts), max_len, max_new_tokens)
350 output = _generate(
351 model=self.model,
352 x=x,
353 max_generations=max_new_tokens,
354 greedy=greedy,
355 seed=self.seed,
356 use_cache=use_cache,
357 top_k=top_k,
358 top_p=top_p,
359 temperature=temperature,
360 )
361 output.block_until_ready()
363 results = []
364 for i, enc in enumerate(encoded):
365 # Strip the left-padding: prompt starts at (max_len - len(enc))
366 start = max_len - len(enc)
367 tokens_out = output[i, start:].tolist()
368 results.append(self._bpe_fix(self.tokenizer.decode(tokens_out)))
369 return results
371 # ── Streaming generation ──────────────────────────────────────────────────
373 def stream(
374 self,
375 prompt: str,
376 *,
377 max_new_tokens: int = 150,
378 greedy: bool = False,
379 top_k: int | None = None,
380 top_p: float | None = None,
381 temperature: float = 1.0,
382 ) -> Iterator[str]:
383 """
384 Stream generated tokens one at a time as they are produced.
386 Uses the KV-cache path: the prompt is prefilled in one forward pass,
387 then each new token is decoded individually. Each ``yield`` returns
388 the string for one generated token (may be a character or a BPE
389 subword).
391 Parameters
392 ----------
393 prompt : str
394 The input prefix.
395 max_new_tokens : int
396 Maximum number of tokens to generate (default 150).
397 greedy : bool
398 Greedy decoding (default False).
399 top_k : int, optional
400 Top-k filtering.
401 top_p : float, optional
402 Nucleus sampling threshold.
403 temperature : float
404 Softmax temperature (default 1.0).
406 Yields
407 ------
408 str
409 Decoded string for each generated token.
411 Examples
412 --------
413 >>> gen = Generator("runs/my_run")
414 >>> for chunk in gen.stream("Nel mezzo", max_new_tokens=50):
415 ... print(chunk, end="", flush=True)
416 """
417 tokens = self.tokenizer.encode(prompt)
418 T = len(tokens)
419 max_ctx = self.config.max_context # type: ignore[attr-defined]
420 num_blocks = self.config.num_blocks # type: ignore[attr-defined]
422 # Build full-context input with prompt at the start.
423 x = jnp.zeros((1, max_ctx), dtype=jnp.int32)
424 x = x.at[0, :T].set(jnp.array(tokens, dtype=jnp.int32))
426 init_kv_cache = tuple((None, None) for _ in range(num_blocks))
427 key = jax.random.key(self.seed)
429 # Prefill: one pass over the entire prompt, populate KV cache.
430 logits, kv_cache = _stream_prefill(self.model, x, init_kv_cache)
432 # Sample the first generated token from the last prompt position.
433 tok_id, key = _sample_logit(logits[:, T - 1, :], key, greedy, temperature, top_k, top_p)
434 yield self._bpe_fix(self.tokenizer.decode([tok_id]))
436 # Autoregressive decode loop.
437 for pos in range(T, T + max_new_tokens - 1):
438 if pos >= max_ctx:
439 break
440 tok = jnp.array([[tok_id]], dtype=jnp.int32)
441 logits, kv_cache = _stream_decode(self.model, tok, kv_cache, jnp.array(pos))
442 tok_id, key = _sample_logit(logits[:, 0, :], key, greedy, temperature, top_k, top_p)
443 yield self._bpe_fix(self.tokenizer.decode([tok_id]))