Coverage for dantinox / cli.py: 0%

251 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 11:22 +0200

1""" 

2dantinox CLI 

3============ 

4Entry point registered as the ``dantinox`` command by pyproject.toml. 

5 

6Subcommands 

7----------- 

8 train Train a model from a config file and a text corpus. 

9 generate Generate text from a trained checkpoint. 

10 sweep Run a W&B Bayesian hyperparameter sweep. 

11 benchmark Benchmark all (or selected) runs in a directory. 

12 

13Examples 

14-------- 

15 dantinox train --config configs/default_config.yaml --data_path data/corpus.txt 

16 dantinox generate --run_dir runs/run_20260101 --prompt "Nel mezzo del cammin " 

17 dantinox sweep --config configs/sweep.yaml --data_path data/corpus.txt 

18 dantinox benchmark --runs_dir runs --out_csv results.csv 

19""" 

20 

21from __future__ import annotations 

22 

23import argparse 

24import dataclasses 

25import logging 

26import sys 

27from pathlib import Path 

28 

29from core.config import Config 

30from dantinox import __version__ 

31 

32 

33# Persistent XLA compilation cache — compiled GPU kernels are saved to disk so 

34# subsequent calls with the same model architecture skip recompilation entirely. 

35# Must be set before any JAX operation (lazy import keeps this safe). 

36def _init_jax_cache() -> None: 

37 import jax 

38 _cache = Path.home() / ".cache" / "jax_xla" / "dantinox" 

39 _cache.mkdir(parents=True, exist_ok=True) 

40 jax.config.update("jax_compilation_cache_dir", str(_cache)) 

41 

42# ─── helpers ──────────────────────────────────────────────────────────────── 

43 

44def _add_config_overrides(parser: argparse.ArgumentParser) -> None: 

45 """Add one --<field> flag for every Config field.""" 

46 for field in dataclasses.fields(Config): 

47 flag = f"--{field.name}" 

48 if flag not in parser._option_string_actions: 

49 parser.add_argument(flag, type=type(field.default) if field.default is not dataclasses.MISSING else str, default=None) 

50 

51 

52def _apply_overrides(config: Config, args: argparse.Namespace) -> Config: 

53 """Write any non-None CLI overrides onto the config object.""" 

54 for field in dataclasses.fields(Config): 

55 val = getattr(args, field.name, None) 

56 if val is not None: 

57 setattr(config, field.name, val) 

58 return config 

59 

60 

61# ─── subcommand handlers ──────────────────────────────────────────────────── 

62 

63def _cmd_train(args: argparse.Namespace) -> None: 

64 config = Config.from_yaml(args.config) 

65 config = _apply_overrides(config, args) 

66 

67 from dantinox.trainer import Trainer 

68 trainer = Trainer(config) 

69 run_dir = trainer.fit( 

70 args.data_path, 

71 run_dir=getattr(args, "run_dir", None), 

72 wandb_project=getattr(args, "wandb_project", None), 

73 resume=getattr(args, "resume", False), 

74 ) 

75 print(f"\nRun saved to: {run_dir}") 

76 

77 

78def _cmd_generate(args: argparse.Namespace) -> None: 

79 _init_jax_cache() 

80 import time 

81 

82 from dantinox.generator import Generator 

83 

84 gen = Generator(args.run_dir, seed=args.seed) 

85 

86 print(f"\nRun: {args.run_dir}") 

87 print(f"Prompt: {args.prompt}") 

88 print("-" * 40) 

89 

90 sampling = dict( 

91 greedy=args.greedy, 

92 top_k=args.top_k, 

93 top_p=args.top_p, 

94 temperature=args.temperature, 

95 ) 

96 

97 if args.stream: 

98 # Streaming: print the prompt first, then yield tokens as they arrive. 

99 print(args.prompt, end="", flush=True) 

100 t0 = time.time() 

101 n = 0 

102 for chunk in gen.stream(args.prompt, max_new_tokens=args.max_new_tokens, **sampling): 

103 print(chunk, end="", flush=True) 

104 n += 1 

105 elapsed = time.time() - t0 

106 print(f"\n{'-' * 40}") 

107 print(f"Generated {n} tokens in {elapsed:.2f}s ({n / elapsed:.1f} tok/s)") 

108 else: 

109 # Batch mode: warmup then timed generate. 

110 gen.generate(args.prompt, max_new_tokens=1) 

111 t0 = time.time() 

112 text = gen.generate( 

113 args.prompt, 

114 max_new_tokens=args.max_new_tokens, 

115 use_cache=not args.no_cache, 

116 **sampling, 

117 ) 

118 elapsed = time.time() - t0 

119 prompt_tokens = len(gen.tokenizer.encode(args.prompt)) 

120 new_tokens = len(gen.tokenizer.encode(text)) - prompt_tokens 

121 print(text) 

122 print("-" * 40) 

123 print(f"Generated {new_tokens} tokens in {elapsed:.2f}s " 

124 f"({new_tokens / elapsed:.1f} tok/s)") 

125 

126 

127def _cmd_sweep(args: argparse.Namespace) -> None: 

128 """Launch a W&B sweep agent using the existing train_sweep entry point.""" 

129 try: 

130 import wandb 

131 except ImportError: 

132 print("wandb is not installed. Install it with: pip install wandb", file=sys.stderr) 

133 sys.exit(1) 

134 

135 import yaml 

136 with open(args.sweep_config) as f: 

137 sweep_cfg = yaml.safe_load(f) 

138 

139 project = getattr(args, "wandb_project", None) or "DantinoX" 

140 sweep_id = wandb.sweep(sweep_cfg, project=project) # type: ignore[attr-defined] 

141 print(f"Sweep ID: {sweep_id} (project: {project})") 

142 

143 def _agent_fn() -> None: 

144 from dantinox.trainer import Trainer 

145 

146 run = wandb.init() # type: ignore[attr-defined] 

147 wc = dict(run.config) 

148 

149 base = Config.from_yaml(args.config) if args.config else Config() 

150 for k, v in wc.items(): 

151 if hasattr(base, k): 

152 setattr(base, k, v) 

153 

154 trainer = Trainer(base) 

155 trainer.fit(args.data_path, wandb_project=None) 

156 wandb.finish() # type: ignore[attr-defined] 

157 

158 wandb.agent(sweep_id, function=_agent_fn, count=getattr(args, "count", None)) # type: ignore[attr-defined] 

159 

160 

161def _cmd_find_lr(args: argparse.Namespace) -> None: 

162 config = Config.from_yaml(args.config) 

163 config = _apply_overrides(config, args) 

164 

165 from dantinox.trainer import Trainer 

166 trainer = Trainer(config) 

167 suggested_lr, lr_hist, loss_hist = trainer.find_lr( 

168 args.data_path, 

169 min_lr=args.min_lr, 

170 max_lr=args.max_lr, 

171 num_steps=args.num_steps, 

172 ) 

173 print(f"\nSuggested learning rate: {suggested_lr:.2e}") 

174 if args.plot: 

175 try: 

176 import matplotlib.pyplot as plt 

177 fig, ax = plt.subplots(figsize=(8, 4)) 

178 ax.plot(lr_hist, loss_hist) 

179 ax.axvline(suggested_lr, color="red", linestyle="--", label=f"suggested: {suggested_lr:.2e}") 

180 ax.set_xscale("log") 

181 ax.set_xlabel("Learning rate") 

182 ax.set_ylabel("Smoothed loss") 

183 ax.set_title("LR Range Test") 

184 ax.legend() 

185 plt.tight_layout() 

186 out = args.plot_out or "lr_finder.png" 

187 fig.savefig(out, dpi=150) 

188 print(f"Plot saved to: {out}") 

189 except ImportError: 

190 print("matplotlib not installed — skipping plot (pip install matplotlib)") 

191 

192 

193def _cmd_push(args: argparse.Namespace) -> None: 

194 from dantinox.hub import push 

195 url = push( 

196 args.run_dir, 

197 args.repo, 

198 private=args.private, 

199 token=args.token, 

200 commit_message=args.message, 

201 ) 

202 print(f"Uploaded to: {url}") 

203 

204 

205def _cmd_pull(args: argparse.Namespace) -> None: 

206 from dantinox.hub import pull 

207 run_dir = pull( 

208 args.repo, 

209 local_dir=args.local_dir, 

210 token=args.token, 

211 revision=args.revision, 

212 ) 

213 print(f"Downloaded to: {run_dir}") 

214 

215 

216def _cmd_plot(args: argparse.Namespace) -> None: 

217 from dantinox.plotting import Plotter 

218 groups = args.groups if args.groups else None 

219 plotter = Plotter( 

220 in_csv=args.in_csv, 

221 out_dir=args.out_dir, 

222 ) 

223 results = plotter.run(groups=groups) 

224 total = sum(len(v) for v in results.values()) 

225 print(f"\nDone — {total} figures written to {args.out_dir}/") 

226 

227 

228def _cmd_benchmark(args: argparse.Namespace) -> None: 

229 from dantinox.bench import BenchmarkRunner 

230 runner = BenchmarkRunner(args.runs_dir) 

231 run_names = args.runs if args.runs else None 

232 df = runner.run(run_names, out_csv=args.out_csv) 

233 

234 if not df.empty: 

235 cols = ["run", "type", "params_m", "theoretical_cache_mb", "prefill_ms"] 

236 cols = [c for c in cols if c in df.columns] 

237 print("\n" + df[cols].to_string(index=False)) 

238 

239 

240def _cmd_infbench(args: argparse.Namespace) -> None: 

241 """Delegate to benchmarks/run_all.py (subprocess keeps JAX state isolated).""" 

242 import subprocess 

243 from pathlib import Path 

244 

245 run_all = Path(__file__).resolve().parent.parent / "benchmarks" / "run_all.py" 

246 if not run_all.exists(): 

247 print(f"Error: {run_all} not found — is the repo intact?", file=sys.stderr) 

248 sys.exit(1) 

249 

250 cmd = [sys.executable, str(run_all), 

251 "--out-csv", args.out_csv, 

252 "--out-dir", args.out_dir, 

253 "--n-warmup", str(args.n_warmup), 

254 "--n-trials", str(args.n_trials)] 

255 if args.groups: 

256 cmd += ["--groups"] + args.groups 

257 if getattr(args, "device", None): 

258 cmd += ["--device", args.device] 

259 if args.sweep_only: 

260 cmd += ["--sweep-only"] 

261 if args.plot_only: 

262 cmd += ["--plot-only"] 

263 if args.verbose: 

264 cmd += ["--verbose"] 

265 if getattr(args, "trained", False): 

266 cmd += ["--trained"] 

267 if getattr(args, "inference_off", False): 

268 cmd += ["--inference-off"] 

269 if getattr(args, "runs_dir", None): 

270 cmd += ["--runs-dir", args.runs_dir] 

271 if getattr(args, "trained_csv", None): 

272 cmd += ["--trained-csv", args.trained_csv] 

273 if getattr(args, "trained_plot", None): 

274 cmd += ["--trained-plot", args.trained_plot] 

275 if getattr(args, "batch_csv", None): 

276 cmd += ["--batch-csv", args.batch_csv] 

277 if getattr(args, "batch_sizes", None): 

278 cmd += ["--batch-sizes"] + [str(b) for b in args.batch_sizes] 

279 if getattr(args, "batch_seq_len", None): 

280 cmd += ["--batch-seq-len", str(args.batch_seq_len)] 

281 

282 result = subprocess.run(cmd) 

283 sys.exit(result.returncode) 

284 

285 

286# ─── argument parser ──────────────────────────────────────────────────────── 

287 

288def _build_parser() -> argparse.ArgumentParser: 

289 parser = argparse.ArgumentParser( 

290 prog="dantinox", 

291 description="DantinoX — JAX/Flax Transformer library CLI", 

292 ) 

293 parser.add_argument( 

294 "--version", action="version", version=f"%(prog)s {__version__}" 

295 ) 

296 sub = parser.add_subparsers(dest="command", required=True) 

297 

298 # ── train ────────────────────────────────────────────────────────────── 

299 p_train = sub.add_parser("train", help="Train a model") 

300 p_train.add_argument("--config", default="configs/default_config.yaml", 

301 help="Path to a YAML config file") 

302 p_train.add_argument("--data_path", help="Path to the training corpus") 

303 p_train.add_argument("--run_dir", help="Output run directory (auto-generated if omitted)") 

304 p_train.add_argument("--wandb_project", help="W&B project name for logging") 

305 p_train.add_argument("--resume", action="store_true", 

306 help="Resume training from the last checkpoint in --run_dir") 

307 _add_config_overrides(p_train) 

308 

309 # ── generate ─────────────────────────────────────────────────────────── 

310 p_gen = sub.add_parser("generate", help="Generate text from a checkpoint") 

311 p_gen.add_argument("--run_dir", required=True, help="Run directory with config + weights") 

312 p_gen.add_argument("--prompt", default="Nel mezzo del cammin ", help="Input prompt") 

313 p_gen.add_argument("--max_new_tokens", type=int, default=150) 

314 p_gen.add_argument("--greedy", action="store_true") 

315 p_gen.add_argument("--top_k", type=int, default=None) 

316 p_gen.add_argument("--top_p", type=float, default=None) 

317 p_gen.add_argument("--temperature", type=float, default=1.0) 

318 p_gen.add_argument("--no_cache", action="store_true", help="Disable KV cache") 

319 p_gen.add_argument("--stream", action="store_true", 

320 help="Stream tokens to stdout as they are produced") 

321 p_gen.add_argument("--seed", type=int, default=42) 

322 

323 # ── sweep ────────────────────────────────────────────────────────────── 

324 p_sweep = sub.add_parser("sweep", help="Run a W&B hyperparameter sweep") 

325 p_sweep.add_argument("--sweep_config", default="configs/sweep.yaml", 

326 help="W&B sweep YAML configuration") 

327 p_sweep.add_argument("--config", default="configs/default_config.yaml", 

328 help="Base model config (overridden by sweep params)") 

329 p_sweep.add_argument("--data_path", required=True, help="Path to the training corpus") 

330 p_sweep.add_argument("--wandb_project", default="DantinoX") 

331 p_sweep.add_argument("--count", type=int, default=None, 

332 help="Maximum number of sweep runs (default: unlimited)") 

333 

334 # ── benchmark ────────────────────────────────────────────────────────── 

335 p_bench = sub.add_parser("benchmark", help="Benchmark run directories") 

336 p_bench.add_argument("--runs_dir", default="runs", help="Directory containing run sub-dirs") 

337 p_bench.add_argument("--runs", nargs="*", help="Specific run names to benchmark (default: all)") 

338 p_bench.add_argument("--out_csv", default=None, help="Write results to this CSV file") 

339 

340 # ── find-lr ──────────────────────────────────────────────────────────────── 

341 p_flr = sub.add_parser("find-lr", help="Run the LR range test and suggest a learning rate") 

342 p_flr.add_argument("--config", default="configs/default_config.yaml", 

343 help="Path to a YAML config file") 

344 p_flr.add_argument("--data_path", required=True, help="Path to the training corpus") 

345 p_flr.add_argument("--min_lr", type=float, default=1e-7, help="Start LR (default 1e-7)") 

346 p_flr.add_argument("--max_lr", type=float, default=1.0, help="End LR (default 1.0)") 

347 p_flr.add_argument("--num_steps", type=int, default=100, help="Sweep steps (default 100)") 

348 p_flr.add_argument("--plot", action="store_true", help="Save a loss-vs-LR PNG") 

349 p_flr.add_argument("--plot_out", default=None, help="Output PNG path (default: lr_finder.png)") 

350 _add_config_overrides(p_flr) 

351 

352 # ── push ─────────────────────────────────────────────────────────────────── 

353 p_push = sub.add_parser("push", help="Upload a checkpoint to HuggingFace Hub") 

354 p_push.add_argument("--run_dir", required=True, help="Local run directory to upload") 

355 p_push.add_argument("--repo", required=True, help="Hub repo id (e.g. my-org/my-model)") 

356 p_push.add_argument("--private", action="store_true", help="Create a private repository") 

357 p_push.add_argument("--token", default=None, help="HuggingFace access token") 

358 p_push.add_argument("--message", default=None, help="Commit message") 

359 

360 # ── pull ─────────────────────────────────────────────────────────────────── 

361 p_pull = sub.add_parser("pull", help="Download a checkpoint from HuggingFace Hub") 

362 p_pull.add_argument("--repo", required=True, help="Hub repo id (e.g. my-org/my-model)") 

363 p_pull.add_argument("--local_dir", default=None, help="Where to save the files") 

364 p_pull.add_argument("--token", default=None, help="HuggingFace access token") 

365 p_pull.add_argument("--revision", default=None, help="Branch, tag, or commit SHA") 

366 

367 # ── infbench ─────────────────────────────────────────────────────────────── 

368 p_ib = sub.add_parser( 

369 "infbench", 

370 help="Run the full benchmark suite: inference sweep + optional trained-model pipeline", 

371 formatter_class=argparse.RawDescriptionHelpFormatter, 

372 description=( 

373 "Benchmark pipeline:\n" 

374 " Stage 1 benchmarks/inference_sweep.py → CSV (random-model sweep)\n" 

375 " Stage 2 benchmarks/plot_inference.py → 21 PNG plots\n" 

376 " Stage 3 benchmarks/trained_analysis.py → CSV (real trained runs)\n" 

377 " Stage 4 benchmarks/trained_batch_sweep.py→ CSV (tps vs batch size)\n\n" 

378 "Stages 3-4 only run when --trained is passed.\n\n" 

379 "Examples:\n" 

380 " dantinox infbench\n" 

381 " dantinox infbench --trained\n" 

382 " dantinox infbench --trained --inference-off\n" 

383 " dantinox infbench --groups attention_type scale --n-trials 5\n" 

384 " dantinox infbench --plot-only --out-csv results/inference_sweep.csv\n" 

385 " dantinox infbench --device 1" 

386 ), 

387 ) 

388 p_ib.add_argument("--out-csv", default="results/inference_sweep.csv", metavar="PATH", 

389 help="CSV output path (default: results/inference_sweep.csv)") 

390 p_ib.add_argument("--out-dir", default="results/plots/", metavar="DIR", 

391 help="Directory for plot PNGs (default: results/plots/)") 

392 p_ib.add_argument("--groups", nargs="+", metavar="GROUP", 

393 help="Restrict sweep to these groups (default: all 13)") 

394 p_ib.add_argument("--n-warmup", type=int, default=3, metavar="N", 

395 help="Warm-up reps per experiment (default: 3)") 

396 p_ib.add_argument("--n-trials", type=int, default=10, metavar="N", 

397 help="Measured reps per experiment (default: 10)") 

398 p_ib.add_argument("--device", default=None, metavar="N", 

399 help="CUDA device index for CUDA_VISIBLE_DEVICES (default: env)") 

400 p_ib.add_argument("--sweep-only", action="store_true", 

401 help="Run sweep only, skip plotting") 

402 p_ib.add_argument("--plot-only", action="store_true", 

403 help="Skip sweep, re-plot existing --out-csv") 

404 p_ib.add_argument("--verbose", action="store_true", 

405 help="Print per-experiment metrics during the sweep") 

406 p_ib.add_argument("--trained", action="store_true", 

407 help="Also run the trained-model pipeline (stages 3 and 4)") 

408 p_ib.add_argument("--inference-off", action="store_true", 

409 help="Skip inference pipeline; requires --trained") 

410 p_ib.add_argument("--runs-dir", default="runs", metavar="DIR", 

411 help="Directory of trained run subdirs (default: runs)") 

412 p_ib.add_argument("--trained-csv", default="results/benchmark_results.csv", metavar="PATH", 

413 help="Output CSV for trained-model analysis (default: results/benchmark_results.csv)") 

414 p_ib.add_argument("--trained-plot", default="results/plots/trained_analysis.png", metavar="PATH", 

415 help="Output PNG for trained-model analysis") 

416 p_ib.add_argument("--batch-csv", default="results/batch_sweep_results.csv", metavar="PATH", 

417 help="Output CSV for batch sweep (default: results/batch_sweep_results.csv)") 

418 p_ib.add_argument("--batch-sizes", nargs="+", type=int, metavar="N", 

419 help="Batch sizes for the batch sweep (default: 1 2 4 8 16 32 64)") 

420 p_ib.add_argument("--batch-seq-len", type=int, default=512, metavar="N", 

421 help="Fixed sequence length for the batch sweep (default: 512)") 

422 

423 # ── plot ──────────────────────────────────────────────────────────────── 

424 p_plot = sub.add_parser("plot", help="Generate benchmark plots from a results CSV") 

425 p_plot.add_argument("--in_csv", default="benchmark_results.csv", 

426 help="Benchmark results CSV (output of 'dantinox benchmark')") 

427 p_plot.add_argument("--out_dir", default="plots", 

428 help="Output directory for PNG files (default: plots/)") 

429 p_plot.add_argument("--batch_csv", default=None, 

430 help="Optional batch sweep CSV for the batch throughput figure") 

431 p_plot.add_argument("--groups", nargs="*", 

432 metavar="GROUP", 

433 help="Plot groups to generate: insights perf 3d 3d_dkv (default: all)") 

434 

435 return parser 

436 

437 

438def main(argv: list[str] | None = None) -> None: 

439 logging.basicConfig( 

440 level=logging.INFO, 

441 format="%(asctime)s %(levelname)-8s %(name)s — %(message)s", 

442 datefmt="%H:%M:%S", 

443 ) 

444 parser = _build_parser() 

445 args = parser.parse_args(argv) 

446 

447 dispatch = { 

448 "train": _cmd_train, 

449 "generate": _cmd_generate, 

450 "sweep": _cmd_sweep, 

451 "benchmark": _cmd_benchmark, 

452 "infbench": _cmd_infbench, 

453 "plot": _cmd_plot, 

454 "find-lr": _cmd_find_lr, 

455 "push": _cmd_push, 

456 "pull": _cmd_pull, 

457 } 

458 dispatch[args.command](args) 

459 

460 

461if __name__ == "__main__": 

462 main()