Coverage for dantinox / trainer.py: 67%
330 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-03 15:24 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-03 15:24 +0200
1from __future__ import annotations
3import csv
4import datetime
5import json
6import logging
7import os
8import time
10import flax.serialization
11import jax
12import jax.numpy as jnp
13import optax
14from flax import nnx
15from flax.nnx.transforms.autodiff import DiffState
16from tqdm import tqdm
18from core.config import Config
19from core.lora import LoRAParam
20from core.model import Transformer
21from core.sharding import make_mesh, num_devices, replicate, shard_batch
22from dantinox.exceptions import ConfigError
23from utils.helpers import compute_loss, get_batch
24from utils.tokenizer import get_tokenizer, load_tokenizer_from_file
26log = logging.getLogger(__name__)
29# ── Helpers ──────────────────────────────────────────────────────────────────
31def _build_schedule(config: Config, total_steps: int) -> optax.Schedule:
32 """Return an optax schedule for the requested ``config.lr_schedule``."""
33 warmup_steps = min(
34 getattr(config, "warmup_steps", int(total_steps * 0.1)),
35 int(total_steps * 0.3),
36 )
37 safe_total = max(total_steps, warmup_steps + 1)
38 peak = config.lr
39 end = peak * 0.01
41 kind = getattr(config, "lr_schedule", "cosine")
43 if kind == "cosine":
44 return optax.warmup_cosine_decay_schedule(
45 init_value=0.0,
46 peak_value=peak,
47 warmup_steps=warmup_steps,
48 decay_steps=safe_total,
49 end_value=end,
50 )
52 if kind == "linear":
53 warmup = optax.linear_schedule(init_value=0.0, end_value=peak, transition_steps=warmup_steps)
54 decay = optax.linear_schedule(init_value=peak, end_value=end,
55 transition_steps=safe_total - warmup_steps)
56 return optax.join_schedules([warmup, decay], boundaries=[warmup_steps])
58 if kind == "constant":
59 warmup = optax.linear_schedule(init_value=0.0, end_value=peak, transition_steps=warmup_steps)
60 constant = optax.constant_schedule(peak)
61 return optax.join_schedules([warmup, constant], boundaries=[warmup_steps])
63 # wsd: warmup → stable (40 % of budget) → cosine decay to end
64 stable_steps = int(safe_total * 0.4)
65 decay_steps = safe_total - warmup_steps - stable_steps
66 warmup = optax.linear_schedule(init_value=0.0, end_value=peak, transition_steps=warmup_steps)
67 stable = optax.constant_schedule(peak)
68 decay = optax.cosine_decay_schedule(init_value=peak, decay_steps=max(decay_steps, 1), alpha=end / peak)
69 return optax.join_schedules(
70 [warmup, stable, decay],
71 boundaries=[warmup_steps, warmup_steps + stable_steps],
72 )
75def _build_optimizer(config: Config, total_steps: int) -> optax.GradientTransformation:
76 schedule = _build_schedule(config, total_steps)
78 name = config.optimizer.lower()
79 if name == "adamw":
80 base_opt: optax.GradientTransformation = optax.adamw(learning_rate=schedule)
81 elif name == "adafactor":
82 base_opt = optax.adafactor(learning_rate=schedule)
83 elif name == "lion":
84 base_opt = optax.lion(learning_rate=schedule)
85 else:
86 base_opt = optax.adam(learning_rate=schedule)
88 grad_clip = getattr(config, "grad_clip", 0.0)
89 if grad_clip > 0:
90 return optax.chain(optax.clip_by_global_norm(grad_clip), base_opt)
91 return base_opt
94def _load_text(config: Config, data_path: str | None) -> str:
95 if config.dataset_source == "huggingface":
96 from datasets import load_dataset # type: ignore[import]
98 name = config.dataset_name
99 subset = getattr(config, "dataset_config", "") or None
100 split = getattr(config, "dataset_split", "train") or "train"
101 field = getattr(config, "dataset_text_field", "text") or "text"
103 load_kw: dict = {"split": split, "streaming": config.streaming}
104 if subset:
105 load_kw["name"] = subset
107 raw = load_dataset(name, **load_kw)
109 if config.streaming:
110 # IterableDataset — materialise to a single string
111 parts = [ex[field] for ex in raw if field in ex and ex[field]]
112 else:
113 parts = [t for t in raw[field] if t]
115 if not parts:
116 raise ConfigError(
117 f"No text found in column '{field}' of '{name}'. "
118 f"Set dataset_text_field to the correct column name."
119 )
120 return " ".join(parts)
122 path = data_path or config.dataset_name
123 if not path:
124 raise ConfigError(
125 "No data_path provided and config.dataset_name is empty. "
126 "Pass data_path to Trainer.fit() or set dataset_name in the config."
127 )
128 if not os.path.exists(path):
129 raise ConfigError(f"Data file not found: {path}")
130 with open(path, encoding="utf-8") as f:
131 return f.read()
134def _format_text(text: str) -> str:
135 lines = [line.rstrip() for line in text.split("\n") if line.strip()]
136 blocks = ["\n".join(lines[i : i + 3]) for i in range(0, len(lines), 3)]
137 return "\n\n".join(blocks) + "\n"
140def _model_summary(model: Transformer, config: Config, optimizer: nnx.Optimizer) -> dict:
141 params = nnx.state(model, nnx.Param)
142 total = sum(x.size for x in jax.tree_util.tree_leaves(params))
143 opt_state = nnx.state(optimizer)
144 opt_params = sum(
145 x.size for x in jax.tree_util.tree_leaves(opt_state) if isinstance(x, jax.Array)
146 )
147 bpp = 2 if getattr(config, "use_bf16", False) else 4
148 act = config.batch_size * config.max_context * config.dim * config.num_blocks * 8 * bpp
149 return {
150 "total_params_M": round(total / 1e6, 2),
151 "dtype": "bfloat16" if bpp == 2 else "float32",
152 "weights_mem_MB": round(total * bpp / 1e6, 2),
153 "optimizer_mem_MB": round(opt_params * bpp / 1e6, 2),
154 "est_activations_MB": round(act / 1e6, 2),
155 }
158def _cast_params(model: Transformer, dtype: jnp.dtype) -> None:
159 params = nnx.state(model, nnx.Param)
160 nnx.update(
161 model,
162 jax.tree_util.tree_map(
163 lambda x: x.astype(dtype) if jnp.issubdtype(x.dtype, jnp.floating) else x,
164 params,
165 ),
166 )
169def _save_weights(model: Transformer, path: str) -> None:
170 state_dict = nnx.state(model, nnx.Param).to_pure_dict()
171 with open(path, "wb") as f:
172 f.write(flax.serialization.msgpack_serialize(state_dict))
175# ── Trainer ───────────────────────────────────────────────────────────────────
177class Trainer:
178 """
179 High-level training interface for DantinoX.
181 Parameters
182 ----------
183 config : Config
184 Model and training configuration.
186 Examples
187 --------
188 >>> trainer = Trainer(Config.from_yaml("configs/default_config.yaml"))
189 >>> run_dir = trainer.fit("data/corpus.txt")
190 """
192 def __init__(self, config: Config) -> None:
193 self.config = config
195 def __repr__(self) -> str:
196 return f"Trainer(config={self.config!r})"
198 def fit(
199 self,
200 data_path: str | None = None,
201 *,
202 run_dir: str | None = None,
203 wandb_project: str | None = None,
204 resume: bool = False,
205 ) -> str:
206 """
207 Train a model and save the checkpoint.
209 Parameters
210 ----------
211 data_path : str, optional
212 Path to the training corpus. Falls back to ``config.dataset_name``.
213 run_dir : str, optional
214 Directory to write the checkpoint and logs. Auto-generated if omitted.
215 wandb_project : str, optional
216 If provided, metrics are logged to Weights & Biases.
217 resume : bool
218 If ``True`` and a previous checkpoint exists in ``run_dir``, training
219 resumes from the saved step. Optimizer state is not preserved.
221 Returns
222 -------
223 str
224 Path to the run directory containing the saved checkpoint.
226 Raises
227 ------
228 ConfigError
229 If the data path is missing or the file cannot be found.
230 """
231 config = self.config
233 if run_dir is None:
234 run_dir = os.path.join(
235 "runs", datetime.datetime.now().strftime("run_%Y%m%d_%H%M%S")
236 )
237 os.makedirs(run_dir, exist_ok=True)
238 config.save_yaml(os.path.join(run_dir, "config.yaml"))
240 log.info("Run directory: %s", run_dir)
242 text = _format_text(_load_text(config, data_path))
244 tok_path = os.path.join(run_dir, "tokenizer.json")
245 if resume and os.path.exists(tok_path):
246 tokenizer = load_tokenizer_from_file(tok_path)
247 log.info("Resumed tokenizer from %s", tok_path)
248 else:
249 tokenizer = get_tokenizer(config.tokenizer_type)
250 if config.tokenizer_type == "char":
251 tokenizer.train_from_text(text)
252 elif config.tokenizer_type == "bpe":
253 tokenizer.train_from_text(text, vocab_size=config.vocab_size)
254 tokenizer.save(tok_path)
255 log.info("Tokenizer saved to %s", tok_path)
257 config.vocab_size = tokenizer.vocab_size
258 full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32)
259 n = int(0.9 * len(full_data))
260 train_data, val_data = full_data[:n], full_data[n:]
262 tokens_per_step = config.batch_size * config.max_context
263 steps_per_epoch = max(1, len(train_data) // tokens_per_step)
264 total_steps = steps_per_epoch * config.epochs
266 tx = _build_optimizer(config, total_steps)
267 rngs = nnx.Rngs(config.seed)
268 model = Transformer(config, rngs=rngs)
269 if getattr(config, "use_bf16", False):
270 _cast_params(model, jnp.bfloat16)
271 log.info("Model cast to bfloat16")
273 # LoRA: only train adapter params; base weights are frozen nnx.Param
274 wrt_type = LoRAParam if getattr(config, "use_lora", False) else nnx.Param
275 optimizer = nnx.Optimizer(model, tx, wrt=wrt_type)
277 # Multi-GPU: build a data-parallel mesh when more than one device is requested
278 n_dev_cfg = getattr(config, "n_devices", 0)
279 import jax as _jax
280 n_local = len(_jax.local_devices())
281 use_multi_gpu = (n_dev_cfg != 1) and n_local > 1
282 mesh = make_mesh(n_dev_cfg) if use_multi_gpu else None
283 if use_multi_gpu:
284 n_dev = num_devices(mesh) # type: ignore[arg-type]
285 if config.batch_size % n_dev != 0:
286 raise ConfigError(
287 f"batch_size ({config.batch_size}) must be divisible by "
288 f"n_devices ({n_dev}) for data-parallel training."
289 )
290 log.info("Data-parallel training on %d devices", n_dev)
292 start_step = 0
293 cursor_path = os.path.join(run_dir, "training_cursor.json")
294 resume_weights = os.path.join(run_dir, "model_weights.msgpack")
295 if resume and os.path.exists(cursor_path) and os.path.exists(resume_weights):
296 with open(cursor_path) as cursor_f:
297 cursor = json.load(cursor_f)
298 start_step = int(cursor.get("step", 0)) + 1
299 import msgpack
300 from flax.serialization import _msgpack_ext_unpack
301 with open(resume_weights, "rb") as weights_f:
302 state_dict = msgpack.unpackb(
303 weights_f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False
304 )
305 nnx.update(model, state_dict)
306 log.info("Resumed training from step %d", start_step)
308 summary = _model_summary(model, config, optimizer)
309 with open(os.path.join(run_dir, "model_summary.json"), "w") as f:
310 json.dump(summary, f, indent=4)
311 log.info(
312 "Model: %sM params | Est. VRAM: %sMB",
313 summary["total_params_M"],
314 summary["weights_mem_MB"] + summary["optimizer_mem_MB"] + summary["est_activations_MB"],
315 )
317 if wandb_project is not None:
318 import wandb
319 wandb.init(project=wandb_project, config=config.to_dict()) # type: ignore[attr-defined]
321 micro_bs = config.batch_size // config.grad_accum
322 _wrt = wrt_type # captured in closure for JIT
324 @nnx.jit
325 def train_step(model, opt, metrics, full_x, full_y):
326 xs = full_x.reshape(config.grad_accum, micro_bs, -1)
327 ys = full_y.reshape(config.grad_accum, micro_bs, -1)
329 def _loss(model, x, y):
330 logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0)
331 loss = compute_loss(logits, y)
332 if getattr(model, "use_moe", False):
333 loss = loss + model.alpha_balance * bal
334 return loss, bal
336 grad_fn = nnx.value_and_grad(_loss, argnums=DiffState(0, _wrt), has_aux=True)
337 acc = jax.tree_util.tree_map(jnp.zeros_like, nnx.state(model, _wrt))
338 total_loss = jnp.array(0.0)
339 total_bal = jnp.array(0.0)
340 for i in range(config.grad_accum):
341 (loss, bal), grads = grad_fn(model, xs[i], ys[i])
342 acc = jax.tree_util.tree_map(
343 lambda a, g: a + g / config.grad_accum, acc, grads
344 )
345 total_loss += loss / config.grad_accum
346 total_bal += bal / config.grad_accum
347 opt.update(model, acc)
348 metrics.update(loss=total_loss)
349 return total_loss, total_bal
351 @nnx.jit
352 def eval_step(model, x, y):
353 logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0)
354 loss = compute_loss(logits, y)
355 if getattr(model, "use_moe", False):
356 loss = loss + model.alpha_balance * bal
357 return loss, bal
359 def estimate_loss(key):
360 out: dict[str, float] = {}
361 for split, d in [("train", train_data), ("val", val_data)]:
362 losses, bals = [], []
363 for _ in range(config.eval_iters):
364 key, sub = jax.random.split(key)
365 x, y = get_batch(d, 1, config.max_context, sub)
366 loss_val, b = eval_step(model, x, y)
367 losses.append(float(loss_val))
368 bals.append(float(b))
369 out[split] = sum(losses) / len(losses)
370 out[f"{split}_bal"] = sum(bals) / len(bals)
371 return out, key
373 log_path = os.path.join(run_dir, "training_log.csv")
374 weights_path = os.path.join(run_dir, "model_weights.msgpack")
375 best_weights_path = os.path.join(run_dir, "best_model_weights.msgpack")
377 key = jax.random.PRNGKey(config.seed)
378 metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss"))
380 # Multi-GPU: replicate model/optimizer/metrics to all devices once
381 if use_multi_gpu:
382 assert mesh is not None
383 state = replicate(nnx.state((model, optimizer, metrics)), mesh)
384 nnx.update((model, optimizer, metrics), state)
386 pbar = tqdm(
387 range(start_step, total_steps),
388 desc="Training",
389 unit="step",
390 dynamic_ncols=True,
391 initial=start_step,
392 total=total_steps,
393 )
394 t0 = time.time()
396 patience = getattr(config, "patience", 0)
397 best_val_loss = float("inf")
398 no_improve = 0
400 with open(log_path, "a", newline="") as log_f:
401 log_w = csv.writer(log_f)
402 if os.path.getsize(log_path) == 0:
403 log_w.writerow(
404 ["step", "train_loss", "val_loss", "train_bal", "val_bal", "ms_per_step"]
405 )
406 try:
407 for step in pbar:
408 key, sub = jax.random.split(key)
409 x, y = get_batch(train_data, config.batch_size, config.max_context, sub)
410 if use_multi_gpu:
411 assert mesh is not None
412 x = shard_batch(x, mesh)
413 y = shard_batch(y, mesh)
414 train_step(model, optimizer, metrics, x, y)
416 if step % 50 == 0:
417 t1 = time.time()
418 dt = (t1 - t0) * 1000 / 50
419 t0 = t1
420 losses, key = estimate_loss(key)
421 val_loss = losses["val"]
422 pbar.set_postfix(
423 train=f"{losses['train']:.4f}",
424 val=f"{val_loss:.4f}",
425 )
426 log.info(
427 "step %d/%d | train=%.4f val=%.4f bal=%.4f",
428 step, total_steps,
429 losses["train"], val_loss, losses["train_bal"],
430 )
431 log_w.writerow(
432 [
433 step,
434 float(losses["train"]),
435 float(val_loss),
436 float(losses["train_bal"]),
437 float(losses["val_bal"]),
438 round(dt, 2),
439 ]
440 )
441 log_f.flush()
443 # Periodic checkpoint for resume
444 _save_weights(model, weights_path)
445 with open(cursor_path, "w") as cf:
446 json.dump({"step": step}, cf)
448 # Best checkpoint tracking
449 if val_loss < best_val_loss:
450 best_val_loss = val_loss
451 no_improve = 0
452 _save_weights(model, best_weights_path)
453 log.info("New best val loss %.4f — saved best checkpoint", best_val_loss)
454 else:
455 no_improve += 1
456 if patience > 0 and no_improve >= patience:
457 log.info(
458 "Early stopping at step %d (no improvement for %d evals)",
459 step, patience,
460 )
461 break
463 if wandb_project is not None:
464 import wandb
465 wandb.log({"train_loss": losses["train"], "val_loss": val_loss, "step": step}) # type: ignore[attr-defined]
466 finally:
467 pbar.close()
468 if wandb_project is not None:
469 import wandb
470 wandb.finish() # type: ignore[attr-defined]
472 _save_weights(model, weights_path)
473 log.info("Checkpoint saved: %s", weights_path)
474 return run_dir
476 def find_lr(
477 self,
478 data_path: str | None = None,
479 *,
480 min_lr: float = 1e-7,
481 max_lr: float = 1.0,
482 num_steps: int = 100,
483 smoothing: float = 0.9,
484 ) -> tuple[float, list[float], list[float]]:
485 """
486 LR range test (Smith 2015).
488 Trains for ``num_steps`` steps while exponentially increasing the
489 learning rate from ``min_lr`` to ``max_lr``. Returns a tuple of
490 ``(suggested_lr, lr_history, loss_history)``.
492 Parameters
493 ----------
494 data_path : str, optional
495 Path to the training corpus.
496 min_lr : float
497 Starting learning rate (default 1e-7).
498 max_lr : float
499 Maximum learning rate (default 1.0).
500 num_steps : int
501 Number of steps in the sweep (default 100).
502 smoothing : float
503 Exponential smoothing factor for the loss curve (default 0.9).
505 Returns
506 -------
507 tuple[float, list[float], list[float]]
508 ``(suggested_lr, lr_history, loss_history)``
509 """
510 import math
512 config = self.config
513 text = _format_text(_load_text(config, data_path))
515 tokenizer = get_tokenizer(config.tokenizer_type)
516 if config.tokenizer_type == "char":
517 tokenizer.train_from_text(text)
518 elif config.tokenizer_type == "bpe":
519 tokenizer.train_from_text(text, vocab_size=config.vocab_size)
521 config.vocab_size = tokenizer.vocab_size
522 full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32)
523 train_data = full_data[: int(0.9 * len(full_data))]
525 rngs = nnx.Rngs(config.seed)
526 model = Transformer(config, rngs=rngs)
527 if getattr(config, "use_bf16", False):
528 _cast_params(model, jnp.bfloat16)
530 log_multiplier = math.log(max_lr / min_lr) / max(1, num_steps - 1)
532 def _lr_fn(step: jnp.ndarray) -> jnp.ndarray:
533 return jnp.array(min_lr, jnp.float32) * jnp.exp(
534 step.astype(jnp.float32) * jnp.array(log_multiplier, jnp.float32)
535 )
537 tx = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=_lr_fn))
538 optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
540 @nnx.jit
541 def _step(model, opt, x, y):
542 def loss_fn(m):
543 logits, _, _ = m(x, use_cache=False, kv_caches=None, cache_index=0)
544 return compute_loss(logits, y)
546 loss, grads = nnx.value_and_grad(loss_fn)(model)
547 opt.update(model, grads)
548 return loss
550 key = jax.random.PRNGKey(config.seed)
551 lr_history: list[float] = []
552 loss_history: list[float] = []
553 smooth_loss = 0.0
554 best_loss = float("inf")
556 pbar = tqdm(range(num_steps), desc="LR finder", unit="step", dynamic_ncols=True)
557 for step in pbar:
558 key, sub = jax.random.split(key)
559 x, y = get_batch(train_data, config.batch_size, config.max_context, sub)
560 loss_val = float(_step(model, optimizer, x, y))
562 smooth_loss = (
563 loss_val if step == 0
564 else smoothing * smooth_loss + (1 - smoothing) * loss_val
565 )
566 debiased = smooth_loss / (1 - smoothing ** (step + 1))
568 current_lr = float(_lr_fn(jnp.array(step)))
569 lr_history.append(current_lr)
570 loss_history.append(debiased)
572 pbar.set_postfix(lr=f"{current_lr:.2e}", loss=f"{debiased:.4f}")
574 if debiased < best_loss:
575 best_loss = debiased
576 if debiased > 4 * best_loss:
577 log.info("Loss diverging at step %d — stopping sweep early", step)
578 break
580 pbar.close()
582 if len(loss_history) > 2:
583 slopes = [loss_history[i + 1] - loss_history[i] for i in range(len(loss_history) - 1)]
584 suggested_lr = lr_history[min(range(len(slopes)), key=lambda i: slopes[i])]
585 else:
586 suggested_lr = min_lr
588 log.info(
589 "LR finder done — suggested lr=%.2e (sweep range [%.2e, %.2e])",
590 suggested_lr, min_lr, max_lr,
591 )
592 return suggested_lr, lr_history, loss_history