Coverage for dantinox / trainer.py: 67%

330 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-03 15:24 +0200

1from __future__ import annotations 

2 

3import csv 

4import datetime 

5import json 

6import logging 

7import os 

8import time 

9 

10import flax.serialization 

11import jax 

12import jax.numpy as jnp 

13import optax 

14from flax import nnx 

15from flax.nnx.transforms.autodiff import DiffState 

16from tqdm import tqdm 

17 

18from core.config import Config 

19from core.lora import LoRAParam 

20from core.model import Transformer 

21from core.sharding import make_mesh, num_devices, replicate, shard_batch 

22from dantinox.exceptions import ConfigError 

23from utils.helpers import compute_loss, get_batch 

24from utils.tokenizer import get_tokenizer, load_tokenizer_from_file 

25 

26log = logging.getLogger(__name__) 

27 

28 

29# ── Helpers ────────────────────────────────────────────────────────────────── 

30 

31def _build_schedule(config: Config, total_steps: int) -> optax.Schedule: 

32 """Return an optax schedule for the requested ``config.lr_schedule``.""" 

33 warmup_steps = min( 

34 getattr(config, "warmup_steps", int(total_steps * 0.1)), 

35 int(total_steps * 0.3), 

36 ) 

37 safe_total = max(total_steps, warmup_steps + 1) 

38 peak = config.lr 

39 end = peak * 0.01 

40 

41 kind = getattr(config, "lr_schedule", "cosine") 

42 

43 if kind == "cosine": 

44 return optax.warmup_cosine_decay_schedule( 

45 init_value=0.0, 

46 peak_value=peak, 

47 warmup_steps=warmup_steps, 

48 decay_steps=safe_total, 

49 end_value=end, 

50 ) 

51 

52 if kind == "linear": 

53 warmup = optax.linear_schedule(init_value=0.0, end_value=peak, transition_steps=warmup_steps) 

54 decay = optax.linear_schedule(init_value=peak, end_value=end, 

55 transition_steps=safe_total - warmup_steps) 

56 return optax.join_schedules([warmup, decay], boundaries=[warmup_steps]) 

57 

58 if kind == "constant": 

59 warmup = optax.linear_schedule(init_value=0.0, end_value=peak, transition_steps=warmup_steps) 

60 constant = optax.constant_schedule(peak) 

61 return optax.join_schedules([warmup, constant], boundaries=[warmup_steps]) 

62 

63 # wsd: warmup → stable (40 % of budget) → cosine decay to end 

64 stable_steps = int(safe_total * 0.4) 

65 decay_steps = safe_total - warmup_steps - stable_steps 

66 warmup = optax.linear_schedule(init_value=0.0, end_value=peak, transition_steps=warmup_steps) 

67 stable = optax.constant_schedule(peak) 

68 decay = optax.cosine_decay_schedule(init_value=peak, decay_steps=max(decay_steps, 1), alpha=end / peak) 

69 return optax.join_schedules( 

70 [warmup, stable, decay], 

71 boundaries=[warmup_steps, warmup_steps + stable_steps], 

72 ) 

73 

74 

75def _build_optimizer(config: Config, total_steps: int) -> optax.GradientTransformation: 

76 schedule = _build_schedule(config, total_steps) 

77 

78 name = config.optimizer.lower() 

79 if name == "adamw": 

80 base_opt: optax.GradientTransformation = optax.adamw(learning_rate=schedule) 

81 elif name == "adafactor": 

82 base_opt = optax.adafactor(learning_rate=schedule) 

83 elif name == "lion": 

84 base_opt = optax.lion(learning_rate=schedule) 

85 else: 

86 base_opt = optax.adam(learning_rate=schedule) 

87 

88 grad_clip = getattr(config, "grad_clip", 0.0) 

89 if grad_clip > 0: 

90 return optax.chain(optax.clip_by_global_norm(grad_clip), base_opt) 

91 return base_opt 

92 

93 

94def _load_text(config: Config, data_path: str | None) -> str: 

95 if config.dataset_source == "huggingface": 

96 from datasets import load_dataset # type: ignore[import] 

97 

98 name = config.dataset_name 

99 subset = getattr(config, "dataset_config", "") or None 

100 split = getattr(config, "dataset_split", "train") or "train" 

101 field = getattr(config, "dataset_text_field", "text") or "text" 

102 

103 load_kw: dict = {"split": split, "streaming": config.streaming} 

104 if subset: 

105 load_kw["name"] = subset 

106 

107 raw = load_dataset(name, **load_kw) 

108 

109 if config.streaming: 

110 # IterableDataset — materialise to a single string 

111 parts = [ex[field] for ex in raw if field in ex and ex[field]] 

112 else: 

113 parts = [t for t in raw[field] if t] 

114 

115 if not parts: 

116 raise ConfigError( 

117 f"No text found in column '{field}' of '{name}'. " 

118 f"Set dataset_text_field to the correct column name." 

119 ) 

120 return " ".join(parts) 

121 

122 path = data_path or config.dataset_name 

123 if not path: 

124 raise ConfigError( 

125 "No data_path provided and config.dataset_name is empty. " 

126 "Pass data_path to Trainer.fit() or set dataset_name in the config." 

127 ) 

128 if not os.path.exists(path): 

129 raise ConfigError(f"Data file not found: {path}") 

130 with open(path, encoding="utf-8") as f: 

131 return f.read() 

132 

133 

134def _format_text(text: str) -> str: 

135 lines = [line.rstrip() for line in text.split("\n") if line.strip()] 

136 blocks = ["\n".join(lines[i : i + 3]) for i in range(0, len(lines), 3)] 

137 return "\n\n".join(blocks) + "\n" 

138 

139 

140def _model_summary(model: Transformer, config: Config, optimizer: nnx.Optimizer) -> dict: 

141 params = nnx.state(model, nnx.Param) 

142 total = sum(x.size for x in jax.tree_util.tree_leaves(params)) 

143 opt_state = nnx.state(optimizer) 

144 opt_params = sum( 

145 x.size for x in jax.tree_util.tree_leaves(opt_state) if isinstance(x, jax.Array) 

146 ) 

147 bpp = 2 if getattr(config, "use_bf16", False) else 4 

148 act = config.batch_size * config.max_context * config.dim * config.num_blocks * 8 * bpp 

149 return { 

150 "total_params_M": round(total / 1e6, 2), 

151 "dtype": "bfloat16" if bpp == 2 else "float32", 

152 "weights_mem_MB": round(total * bpp / 1e6, 2), 

153 "optimizer_mem_MB": round(opt_params * bpp / 1e6, 2), 

154 "est_activations_MB": round(act / 1e6, 2), 

155 } 

156 

157 

158def _cast_params(model: Transformer, dtype: jnp.dtype) -> None: 

159 params = nnx.state(model, nnx.Param) 

160 nnx.update( 

161 model, 

162 jax.tree_util.tree_map( 

163 lambda x: x.astype(dtype) if jnp.issubdtype(x.dtype, jnp.floating) else x, 

164 params, 

165 ), 

166 ) 

167 

168 

169def _save_weights(model: Transformer, path: str) -> None: 

170 state_dict = nnx.state(model, nnx.Param).to_pure_dict() 

171 with open(path, "wb") as f: 

172 f.write(flax.serialization.msgpack_serialize(state_dict)) 

173 

174 

175# ── Trainer ─────────────────────────────────────────────────────────────────── 

176 

177class Trainer: 

178 """ 

179 High-level training interface for DantinoX. 

180 

181 Parameters 

182 ---------- 

183 config : Config 

184 Model and training configuration. 

185 

186 Examples 

187 -------- 

188 >>> trainer = Trainer(Config.from_yaml("configs/default_config.yaml")) 

189 >>> run_dir = trainer.fit("data/corpus.txt") 

190 """ 

191 

192 def __init__(self, config: Config) -> None: 

193 self.config = config 

194 

195 def __repr__(self) -> str: 

196 return f"Trainer(config={self.config!r})" 

197 

198 def fit( 

199 self, 

200 data_path: str | None = None, 

201 *, 

202 run_dir: str | None = None, 

203 wandb_project: str | None = None, 

204 resume: bool = False, 

205 ) -> str: 

206 """ 

207 Train a model and save the checkpoint. 

208 

209 Parameters 

210 ---------- 

211 data_path : str, optional 

212 Path to the training corpus. Falls back to ``config.dataset_name``. 

213 run_dir : str, optional 

214 Directory to write the checkpoint and logs. Auto-generated if omitted. 

215 wandb_project : str, optional 

216 If provided, metrics are logged to Weights & Biases. 

217 resume : bool 

218 If ``True`` and a previous checkpoint exists in ``run_dir``, training 

219 resumes from the saved step. Optimizer state is not preserved. 

220 

221 Returns 

222 ------- 

223 str 

224 Path to the run directory containing the saved checkpoint. 

225 

226 Raises 

227 ------ 

228 ConfigError 

229 If the data path is missing or the file cannot be found. 

230 """ 

231 config = self.config 

232 

233 if run_dir is None: 

234 run_dir = os.path.join( 

235 "runs", datetime.datetime.now().strftime("run_%Y%m%d_%H%M%S") 

236 ) 

237 os.makedirs(run_dir, exist_ok=True) 

238 config.save_yaml(os.path.join(run_dir, "config.yaml")) 

239 

240 log.info("Run directory: %s", run_dir) 

241 

242 text = _format_text(_load_text(config, data_path)) 

243 

244 tok_path = os.path.join(run_dir, "tokenizer.json") 

245 if resume and os.path.exists(tok_path): 

246 tokenizer = load_tokenizer_from_file(tok_path) 

247 log.info("Resumed tokenizer from %s", tok_path) 

248 else: 

249 tokenizer = get_tokenizer(config.tokenizer_type) 

250 if config.tokenizer_type == "char": 

251 tokenizer.train_from_text(text) 

252 elif config.tokenizer_type == "bpe": 

253 tokenizer.train_from_text(text, vocab_size=config.vocab_size) 

254 tokenizer.save(tok_path) 

255 log.info("Tokenizer saved to %s", tok_path) 

256 

257 config.vocab_size = tokenizer.vocab_size 

258 full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32) 

259 n = int(0.9 * len(full_data)) 

260 train_data, val_data = full_data[:n], full_data[n:] 

261 

262 tokens_per_step = config.batch_size * config.max_context 

263 steps_per_epoch = max(1, len(train_data) // tokens_per_step) 

264 total_steps = steps_per_epoch * config.epochs 

265 

266 tx = _build_optimizer(config, total_steps) 

267 rngs = nnx.Rngs(config.seed) 

268 model = Transformer(config, rngs=rngs) 

269 if getattr(config, "use_bf16", False): 

270 _cast_params(model, jnp.bfloat16) 

271 log.info("Model cast to bfloat16") 

272 

273 # LoRA: only train adapter params; base weights are frozen nnx.Param 

274 wrt_type = LoRAParam if getattr(config, "use_lora", False) else nnx.Param 

275 optimizer = nnx.Optimizer(model, tx, wrt=wrt_type) 

276 

277 # Multi-GPU: build a data-parallel mesh when more than one device is requested 

278 n_dev_cfg = getattr(config, "n_devices", 0) 

279 import jax as _jax 

280 n_local = len(_jax.local_devices()) 

281 use_multi_gpu = (n_dev_cfg != 1) and n_local > 1 

282 mesh = make_mesh(n_dev_cfg) if use_multi_gpu else None 

283 if use_multi_gpu: 

284 n_dev = num_devices(mesh) # type: ignore[arg-type] 

285 if config.batch_size % n_dev != 0: 

286 raise ConfigError( 

287 f"batch_size ({config.batch_size}) must be divisible by " 

288 f"n_devices ({n_dev}) for data-parallel training." 

289 ) 

290 log.info("Data-parallel training on %d devices", n_dev) 

291 

292 start_step = 0 

293 cursor_path = os.path.join(run_dir, "training_cursor.json") 

294 resume_weights = os.path.join(run_dir, "model_weights.msgpack") 

295 if resume and os.path.exists(cursor_path) and os.path.exists(resume_weights): 

296 with open(cursor_path) as cursor_f: 

297 cursor = json.load(cursor_f) 

298 start_step = int(cursor.get("step", 0)) + 1 

299 import msgpack 

300 from flax.serialization import _msgpack_ext_unpack 

301 with open(resume_weights, "rb") as weights_f: 

302 state_dict = msgpack.unpackb( 

303 weights_f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False 

304 ) 

305 nnx.update(model, state_dict) 

306 log.info("Resumed training from step %d", start_step) 

307 

308 summary = _model_summary(model, config, optimizer) 

309 with open(os.path.join(run_dir, "model_summary.json"), "w") as f: 

310 json.dump(summary, f, indent=4) 

311 log.info( 

312 "Model: %sM params | Est. VRAM: %sMB", 

313 summary["total_params_M"], 

314 summary["weights_mem_MB"] + summary["optimizer_mem_MB"] + summary["est_activations_MB"], 

315 ) 

316 

317 if wandb_project is not None: 

318 import wandb 

319 wandb.init(project=wandb_project, config=config.to_dict()) # type: ignore[attr-defined] 

320 

321 micro_bs = config.batch_size // config.grad_accum 

322 _wrt = wrt_type # captured in closure for JIT 

323 

324 @nnx.jit 

325 def train_step(model, opt, metrics, full_x, full_y): 

326 xs = full_x.reshape(config.grad_accum, micro_bs, -1) 

327 ys = full_y.reshape(config.grad_accum, micro_bs, -1) 

328 

329 def _loss(model, x, y): 

330 logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0) 

331 loss = compute_loss(logits, y) 

332 if getattr(model, "use_moe", False): 

333 loss = loss + model.alpha_balance * bal 

334 return loss, bal 

335 

336 grad_fn = nnx.value_and_grad(_loss, argnums=DiffState(0, _wrt), has_aux=True) 

337 acc = jax.tree_util.tree_map(jnp.zeros_like, nnx.state(model, _wrt)) 

338 total_loss = jnp.array(0.0) 

339 total_bal = jnp.array(0.0) 

340 for i in range(config.grad_accum): 

341 (loss, bal), grads = grad_fn(model, xs[i], ys[i]) 

342 acc = jax.tree_util.tree_map( 

343 lambda a, g: a + g / config.grad_accum, acc, grads 

344 ) 

345 total_loss += loss / config.grad_accum 

346 total_bal += bal / config.grad_accum 

347 opt.update(model, acc) 

348 metrics.update(loss=total_loss) 

349 return total_loss, total_bal 

350 

351 @nnx.jit 

352 def eval_step(model, x, y): 

353 logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0) 

354 loss = compute_loss(logits, y) 

355 if getattr(model, "use_moe", False): 

356 loss = loss + model.alpha_balance * bal 

357 return loss, bal 

358 

359 def estimate_loss(key): 

360 out: dict[str, float] = {} 

361 for split, d in [("train", train_data), ("val", val_data)]: 

362 losses, bals = [], [] 

363 for _ in range(config.eval_iters): 

364 key, sub = jax.random.split(key) 

365 x, y = get_batch(d, 1, config.max_context, sub) 

366 loss_val, b = eval_step(model, x, y) 

367 losses.append(float(loss_val)) 

368 bals.append(float(b)) 

369 out[split] = sum(losses) / len(losses) 

370 out[f"{split}_bal"] = sum(bals) / len(bals) 

371 return out, key 

372 

373 log_path = os.path.join(run_dir, "training_log.csv") 

374 weights_path = os.path.join(run_dir, "model_weights.msgpack") 

375 best_weights_path = os.path.join(run_dir, "best_model_weights.msgpack") 

376 

377 key = jax.random.PRNGKey(config.seed) 

378 metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss")) 

379 

380 # Multi-GPU: replicate model/optimizer/metrics to all devices once 

381 if use_multi_gpu: 

382 assert mesh is not None 

383 state = replicate(nnx.state((model, optimizer, metrics)), mesh) 

384 nnx.update((model, optimizer, metrics), state) 

385 

386 pbar = tqdm( 

387 range(start_step, total_steps), 

388 desc="Training", 

389 unit="step", 

390 dynamic_ncols=True, 

391 initial=start_step, 

392 total=total_steps, 

393 ) 

394 t0 = time.time() 

395 

396 patience = getattr(config, "patience", 0) 

397 best_val_loss = float("inf") 

398 no_improve = 0 

399 

400 with open(log_path, "a", newline="") as log_f: 

401 log_w = csv.writer(log_f) 

402 if os.path.getsize(log_path) == 0: 

403 log_w.writerow( 

404 ["step", "train_loss", "val_loss", "train_bal", "val_bal", "ms_per_step"] 

405 ) 

406 try: 

407 for step in pbar: 

408 key, sub = jax.random.split(key) 

409 x, y = get_batch(train_data, config.batch_size, config.max_context, sub) 

410 if use_multi_gpu: 

411 assert mesh is not None 

412 x = shard_batch(x, mesh) 

413 y = shard_batch(y, mesh) 

414 train_step(model, optimizer, metrics, x, y) 

415 

416 if step % 50 == 0: 

417 t1 = time.time() 

418 dt = (t1 - t0) * 1000 / 50 

419 t0 = t1 

420 losses, key = estimate_loss(key) 

421 val_loss = losses["val"] 

422 pbar.set_postfix( 

423 train=f"{losses['train']:.4f}", 

424 val=f"{val_loss:.4f}", 

425 ) 

426 log.info( 

427 "step %d/%d | train=%.4f val=%.4f bal=%.4f", 

428 step, total_steps, 

429 losses["train"], val_loss, losses["train_bal"], 

430 ) 

431 log_w.writerow( 

432 [ 

433 step, 

434 float(losses["train"]), 

435 float(val_loss), 

436 float(losses["train_bal"]), 

437 float(losses["val_bal"]), 

438 round(dt, 2), 

439 ] 

440 ) 

441 log_f.flush() 

442 

443 # Periodic checkpoint for resume 

444 _save_weights(model, weights_path) 

445 with open(cursor_path, "w") as cf: 

446 json.dump({"step": step}, cf) 

447 

448 # Best checkpoint tracking 

449 if val_loss < best_val_loss: 

450 best_val_loss = val_loss 

451 no_improve = 0 

452 _save_weights(model, best_weights_path) 

453 log.info("New best val loss %.4f — saved best checkpoint", best_val_loss) 

454 else: 

455 no_improve += 1 

456 if patience > 0 and no_improve >= patience: 

457 log.info( 

458 "Early stopping at step %d (no improvement for %d evals)", 

459 step, patience, 

460 ) 

461 break 

462 

463 if wandb_project is not None: 

464 import wandb 

465 wandb.log({"train_loss": losses["train"], "val_loss": val_loss, "step": step}) # type: ignore[attr-defined] 

466 finally: 

467 pbar.close() 

468 if wandb_project is not None: 

469 import wandb 

470 wandb.finish() # type: ignore[attr-defined] 

471 

472 _save_weights(model, weights_path) 

473 log.info("Checkpoint saved: %s", weights_path) 

474 return run_dir 

475 

476 def find_lr( 

477 self, 

478 data_path: str | None = None, 

479 *, 

480 min_lr: float = 1e-7, 

481 max_lr: float = 1.0, 

482 num_steps: int = 100, 

483 smoothing: float = 0.9, 

484 ) -> tuple[float, list[float], list[float]]: 

485 """ 

486 LR range test (Smith 2015). 

487 

488 Trains for ``num_steps`` steps while exponentially increasing the 

489 learning rate from ``min_lr`` to ``max_lr``. Returns a tuple of 

490 ``(suggested_lr, lr_history, loss_history)``. 

491 

492 Parameters 

493 ---------- 

494 data_path : str, optional 

495 Path to the training corpus. 

496 min_lr : float 

497 Starting learning rate (default 1e-7). 

498 max_lr : float 

499 Maximum learning rate (default 1.0). 

500 num_steps : int 

501 Number of steps in the sweep (default 100). 

502 smoothing : float 

503 Exponential smoothing factor for the loss curve (default 0.9). 

504 

505 Returns 

506 ------- 

507 tuple[float, list[float], list[float]] 

508 ``(suggested_lr, lr_history, loss_history)`` 

509 """ 

510 import math 

511 

512 config = self.config 

513 text = _format_text(_load_text(config, data_path)) 

514 

515 tokenizer = get_tokenizer(config.tokenizer_type) 

516 if config.tokenizer_type == "char": 

517 tokenizer.train_from_text(text) 

518 elif config.tokenizer_type == "bpe": 

519 tokenizer.train_from_text(text, vocab_size=config.vocab_size) 

520 

521 config.vocab_size = tokenizer.vocab_size 

522 full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32) 

523 train_data = full_data[: int(0.9 * len(full_data))] 

524 

525 rngs = nnx.Rngs(config.seed) 

526 model = Transformer(config, rngs=rngs) 

527 if getattr(config, "use_bf16", False): 

528 _cast_params(model, jnp.bfloat16) 

529 

530 log_multiplier = math.log(max_lr / min_lr) / max(1, num_steps - 1) 

531 

532 def _lr_fn(step: jnp.ndarray) -> jnp.ndarray: 

533 return jnp.array(min_lr, jnp.float32) * jnp.exp( 

534 step.astype(jnp.float32) * jnp.array(log_multiplier, jnp.float32) 

535 ) 

536 

537 tx = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=_lr_fn)) 

538 optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) 

539 

540 @nnx.jit 

541 def _step(model, opt, x, y): 

542 def loss_fn(m): 

543 logits, _, _ = m(x, use_cache=False, kv_caches=None, cache_index=0) 

544 return compute_loss(logits, y) 

545 

546 loss, grads = nnx.value_and_grad(loss_fn)(model) 

547 opt.update(model, grads) 

548 return loss 

549 

550 key = jax.random.PRNGKey(config.seed) 

551 lr_history: list[float] = [] 

552 loss_history: list[float] = [] 

553 smooth_loss = 0.0 

554 best_loss = float("inf") 

555 

556 pbar = tqdm(range(num_steps), desc="LR finder", unit="step", dynamic_ncols=True) 

557 for step in pbar: 

558 key, sub = jax.random.split(key) 

559 x, y = get_batch(train_data, config.batch_size, config.max_context, sub) 

560 loss_val = float(_step(model, optimizer, x, y)) 

561 

562 smooth_loss = ( 

563 loss_val if step == 0 

564 else smoothing * smooth_loss + (1 - smoothing) * loss_val 

565 ) 

566 debiased = smooth_loss / (1 - smoothing ** (step + 1)) 

567 

568 current_lr = float(_lr_fn(jnp.array(step))) 

569 lr_history.append(current_lr) 

570 loss_history.append(debiased) 

571 

572 pbar.set_postfix(lr=f"{current_lr:.2e}", loss=f"{debiased:.4f}") 

573 

574 if debiased < best_loss: 

575 best_loss = debiased 

576 if debiased > 4 * best_loss: 

577 log.info("Loss diverging at step %d — stopping sweep early", step) 

578 break 

579 

580 pbar.close() 

581 

582 if len(loss_history) > 2: 

583 slopes = [loss_history[i + 1] - loss_history[i] for i in range(len(loss_history) - 1)] 

584 suggested_lr = lr_history[min(range(len(slopes)), key=lambda i: slopes[i])] 

585 else: 

586 suggested_lr = min_lr 

587 

588 log.info( 

589 "LR finder done — suggested lr=%.2e (sweep range [%.2e, %.2e])", 

590 suggested_lr, min_lr, max_lr, 

591 ) 

592 return suggested_lr, lr_history, loss_history