Coverage for dantinox / cli.py: 0%
251 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
1"""
2dantinox CLI
3============
4Entry point registered as the ``dantinox`` command by pyproject.toml.
6Subcommands
7-----------
8 train Train a model from a config file and a text corpus.
9 generate Generate text from a trained checkpoint.
10 sweep Run a W&B Bayesian hyperparameter sweep.
11 benchmark Benchmark all (or selected) runs in a directory.
13Examples
14--------
15 dantinox train --config configs/default_config.yaml --data_path data/corpus.txt
16 dantinox generate --run_dir runs/run_20260101 --prompt "Nel mezzo del cammin "
17 dantinox sweep --config configs/sweep.yaml --data_path data/corpus.txt
18 dantinox benchmark --runs_dir runs --out_csv results.csv
19"""
21from __future__ import annotations
23import argparse
24import dataclasses
25import logging
26import sys
27from pathlib import Path
29from core.config import Config
30from dantinox import __version__
33# Persistent XLA compilation cache — compiled GPU kernels are saved to disk so
34# subsequent calls with the same model architecture skip recompilation entirely.
35# Must be set before any JAX operation (lazy import keeps this safe).
36def _init_jax_cache() -> None:
37 import jax
38 _cache = Path.home() / ".cache" / "jax_xla" / "dantinox"
39 _cache.mkdir(parents=True, exist_ok=True)
40 jax.config.update("jax_compilation_cache_dir", str(_cache))
42# ─── helpers ────────────────────────────────────────────────────────────────
44def _add_config_overrides(parser: argparse.ArgumentParser) -> None:
45 """Add one --<field> flag for every Config field."""
46 for field in dataclasses.fields(Config):
47 flag = f"--{field.name}"
48 if flag not in parser._option_string_actions:
49 parser.add_argument(flag, type=type(field.default) if field.default is not dataclasses.MISSING else str, default=None)
52def _apply_overrides(config: Config, args: argparse.Namespace) -> Config:
53 """Write any non-None CLI overrides onto the config object."""
54 for field in dataclasses.fields(Config):
55 val = getattr(args, field.name, None)
56 if val is not None:
57 setattr(config, field.name, val)
58 return config
61# ─── subcommand handlers ────────────────────────────────────────────────────
63def _cmd_train(args: argparse.Namespace) -> None:
64 config = Config.from_yaml(args.config)
65 config = _apply_overrides(config, args)
67 from dantinox.trainer import Trainer
68 trainer = Trainer(config)
69 run_dir = trainer.fit(
70 args.data_path,
71 run_dir=getattr(args, "run_dir", None),
72 wandb_project=getattr(args, "wandb_project", None),
73 resume=getattr(args, "resume", False),
74 )
75 print(f"\nRun saved to: {run_dir}")
78def _cmd_generate(args: argparse.Namespace) -> None:
79 _init_jax_cache()
80 import time
82 from dantinox.generator import Generator
84 gen = Generator(args.run_dir, seed=args.seed)
86 print(f"\nRun: {args.run_dir}")
87 print(f"Prompt: {args.prompt}")
88 print("-" * 40)
90 sampling = dict(
91 greedy=args.greedy,
92 top_k=args.top_k,
93 top_p=args.top_p,
94 temperature=args.temperature,
95 )
97 if args.stream:
98 # Streaming: print the prompt first, then yield tokens as they arrive.
99 print(args.prompt, end="", flush=True)
100 t0 = time.time()
101 n = 0
102 for chunk in gen.stream(args.prompt, max_new_tokens=args.max_new_tokens, **sampling):
103 print(chunk, end="", flush=True)
104 n += 1
105 elapsed = time.time() - t0
106 print(f"\n{'-' * 40}")
107 print(f"Generated {n} tokens in {elapsed:.2f}s ({n / elapsed:.1f} tok/s)")
108 else:
109 # Batch mode: warmup then timed generate.
110 gen.generate(args.prompt, max_new_tokens=1)
111 t0 = time.time()
112 text = gen.generate(
113 args.prompt,
114 max_new_tokens=args.max_new_tokens,
115 use_cache=not args.no_cache,
116 **sampling,
117 )
118 elapsed = time.time() - t0
119 prompt_tokens = len(gen.tokenizer.encode(args.prompt))
120 new_tokens = len(gen.tokenizer.encode(text)) - prompt_tokens
121 print(text)
122 print("-" * 40)
123 print(f"Generated {new_tokens} tokens in {elapsed:.2f}s "
124 f"({new_tokens / elapsed:.1f} tok/s)")
127def _cmd_sweep(args: argparse.Namespace) -> None:
128 """Launch a W&B sweep agent using the existing train_sweep entry point."""
129 try:
130 import wandb
131 except ImportError:
132 print("wandb is not installed. Install it with: pip install wandb", file=sys.stderr)
133 sys.exit(1)
135 import yaml
136 with open(args.sweep_config) as f:
137 sweep_cfg = yaml.safe_load(f)
139 project = getattr(args, "wandb_project", None) or "DantinoX"
140 sweep_id = wandb.sweep(sweep_cfg, project=project) # type: ignore[attr-defined]
141 print(f"Sweep ID: {sweep_id} (project: {project})")
143 def _agent_fn() -> None:
144 from dantinox.trainer import Trainer
146 run = wandb.init() # type: ignore[attr-defined]
147 wc = dict(run.config)
149 base = Config.from_yaml(args.config) if args.config else Config()
150 for k, v in wc.items():
151 if hasattr(base, k):
152 setattr(base, k, v)
154 trainer = Trainer(base)
155 trainer.fit(args.data_path, wandb_project=None)
156 wandb.finish() # type: ignore[attr-defined]
158 wandb.agent(sweep_id, function=_agent_fn, count=getattr(args, "count", None)) # type: ignore[attr-defined]
161def _cmd_find_lr(args: argparse.Namespace) -> None:
162 config = Config.from_yaml(args.config)
163 config = _apply_overrides(config, args)
165 from dantinox.trainer import Trainer
166 trainer = Trainer(config)
167 suggested_lr, lr_hist, loss_hist = trainer.find_lr(
168 args.data_path,
169 min_lr=args.min_lr,
170 max_lr=args.max_lr,
171 num_steps=args.num_steps,
172 )
173 print(f"\nSuggested learning rate: {suggested_lr:.2e}")
174 if args.plot:
175 try:
176 import matplotlib.pyplot as plt
177 fig, ax = plt.subplots(figsize=(8, 4))
178 ax.plot(lr_hist, loss_hist)
179 ax.axvline(suggested_lr, color="red", linestyle="--", label=f"suggested: {suggested_lr:.2e}")
180 ax.set_xscale("log")
181 ax.set_xlabel("Learning rate")
182 ax.set_ylabel("Smoothed loss")
183 ax.set_title("LR Range Test")
184 ax.legend()
185 plt.tight_layout()
186 out = args.plot_out or "lr_finder.png"
187 fig.savefig(out, dpi=150)
188 print(f"Plot saved to: {out}")
189 except ImportError:
190 print("matplotlib not installed — skipping plot (pip install matplotlib)")
193def _cmd_push(args: argparse.Namespace) -> None:
194 from dantinox.hub import push
195 url = push(
196 args.run_dir,
197 args.repo,
198 private=args.private,
199 token=args.token,
200 commit_message=args.message,
201 )
202 print(f"Uploaded to: {url}")
205def _cmd_pull(args: argparse.Namespace) -> None:
206 from dantinox.hub import pull
207 run_dir = pull(
208 args.repo,
209 local_dir=args.local_dir,
210 token=args.token,
211 revision=args.revision,
212 )
213 print(f"Downloaded to: {run_dir}")
216def _cmd_plot(args: argparse.Namespace) -> None:
217 from dantinox.plotting import Plotter
218 groups = args.groups if args.groups else None
219 plotter = Plotter(
220 in_csv=args.in_csv,
221 out_dir=args.out_dir,
222 )
223 results = plotter.run(groups=groups)
224 total = sum(len(v) for v in results.values())
225 print(f"\nDone — {total} figures written to {args.out_dir}/")
228def _cmd_benchmark(args: argparse.Namespace) -> None:
229 from dantinox.bench import BenchmarkRunner
230 runner = BenchmarkRunner(args.runs_dir)
231 run_names = args.runs if args.runs else None
232 df = runner.run(run_names, out_csv=args.out_csv)
234 if not df.empty:
235 cols = ["run", "type", "params_m", "theoretical_cache_mb", "prefill_ms"]
236 cols = [c for c in cols if c in df.columns]
237 print("\n" + df[cols].to_string(index=False))
240def _cmd_infbench(args: argparse.Namespace) -> None:
241 """Delegate to benchmarks/run_all.py (subprocess keeps JAX state isolated)."""
242 import subprocess
243 from pathlib import Path
245 run_all = Path(__file__).resolve().parent.parent / "benchmarks" / "run_all.py"
246 if not run_all.exists():
247 print(f"Error: {run_all} not found — is the repo intact?", file=sys.stderr)
248 sys.exit(1)
250 cmd = [sys.executable, str(run_all),
251 "--out-csv", args.out_csv,
252 "--out-dir", args.out_dir,
253 "--n-warmup", str(args.n_warmup),
254 "--n-trials", str(args.n_trials)]
255 if args.groups:
256 cmd += ["--groups"] + args.groups
257 if getattr(args, "device", None):
258 cmd += ["--device", args.device]
259 if args.sweep_only:
260 cmd += ["--sweep-only"]
261 if args.plot_only:
262 cmd += ["--plot-only"]
263 if args.verbose:
264 cmd += ["--verbose"]
265 if getattr(args, "trained", False):
266 cmd += ["--trained"]
267 if getattr(args, "inference_off", False):
268 cmd += ["--inference-off"]
269 if getattr(args, "runs_dir", None):
270 cmd += ["--runs-dir", args.runs_dir]
271 if getattr(args, "trained_csv", None):
272 cmd += ["--trained-csv", args.trained_csv]
273 if getattr(args, "trained_plot", None):
274 cmd += ["--trained-plot", args.trained_plot]
275 if getattr(args, "batch_csv", None):
276 cmd += ["--batch-csv", args.batch_csv]
277 if getattr(args, "batch_sizes", None):
278 cmd += ["--batch-sizes"] + [str(b) for b in args.batch_sizes]
279 if getattr(args, "batch_seq_len", None):
280 cmd += ["--batch-seq-len", str(args.batch_seq_len)]
282 result = subprocess.run(cmd)
283 sys.exit(result.returncode)
286# ─── argument parser ────────────────────────────────────────────────────────
288def _build_parser() -> argparse.ArgumentParser:
289 parser = argparse.ArgumentParser(
290 prog="dantinox",
291 description="DantinoX — JAX/Flax Transformer library CLI",
292 )
293 parser.add_argument(
294 "--version", action="version", version=f"%(prog)s {__version__}"
295 )
296 sub = parser.add_subparsers(dest="command", required=True)
298 # ── train ──────────────────────────────────────────────────────────────
299 p_train = sub.add_parser("train", help="Train a model")
300 p_train.add_argument("--config", default="configs/default_config.yaml",
301 help="Path to a YAML config file")
302 p_train.add_argument("--data_path", help="Path to the training corpus")
303 p_train.add_argument("--run_dir", help="Output run directory (auto-generated if omitted)")
304 p_train.add_argument("--wandb_project", help="W&B project name for logging")
305 p_train.add_argument("--resume", action="store_true",
306 help="Resume training from the last checkpoint in --run_dir")
307 _add_config_overrides(p_train)
309 # ── generate ───────────────────────────────────────────────────────────
310 p_gen = sub.add_parser("generate", help="Generate text from a checkpoint")
311 p_gen.add_argument("--run_dir", required=True, help="Run directory with config + weights")
312 p_gen.add_argument("--prompt", default="Nel mezzo del cammin ", help="Input prompt")
313 p_gen.add_argument("--max_new_tokens", type=int, default=150)
314 p_gen.add_argument("--greedy", action="store_true")
315 p_gen.add_argument("--top_k", type=int, default=None)
316 p_gen.add_argument("--top_p", type=float, default=None)
317 p_gen.add_argument("--temperature", type=float, default=1.0)
318 p_gen.add_argument("--no_cache", action="store_true", help="Disable KV cache")
319 p_gen.add_argument("--stream", action="store_true",
320 help="Stream tokens to stdout as they are produced")
321 p_gen.add_argument("--seed", type=int, default=42)
323 # ── sweep ──────────────────────────────────────────────────────────────
324 p_sweep = sub.add_parser("sweep", help="Run a W&B hyperparameter sweep")
325 p_sweep.add_argument("--sweep_config", default="configs/sweep.yaml",
326 help="W&B sweep YAML configuration")
327 p_sweep.add_argument("--config", default="configs/default_config.yaml",
328 help="Base model config (overridden by sweep params)")
329 p_sweep.add_argument("--data_path", required=True, help="Path to the training corpus")
330 p_sweep.add_argument("--wandb_project", default="DantinoX")
331 p_sweep.add_argument("--count", type=int, default=None,
332 help="Maximum number of sweep runs (default: unlimited)")
334 # ── benchmark ──────────────────────────────────────────────────────────
335 p_bench = sub.add_parser("benchmark", help="Benchmark run directories")
336 p_bench.add_argument("--runs_dir", default="runs", help="Directory containing run sub-dirs")
337 p_bench.add_argument("--runs", nargs="*", help="Specific run names to benchmark (default: all)")
338 p_bench.add_argument("--out_csv", default=None, help="Write results to this CSV file")
340 # ── find-lr ────────────────────────────────────────────────────────────────
341 p_flr = sub.add_parser("find-lr", help="Run the LR range test and suggest a learning rate")
342 p_flr.add_argument("--config", default="configs/default_config.yaml",
343 help="Path to a YAML config file")
344 p_flr.add_argument("--data_path", required=True, help="Path to the training corpus")
345 p_flr.add_argument("--min_lr", type=float, default=1e-7, help="Start LR (default 1e-7)")
346 p_flr.add_argument("--max_lr", type=float, default=1.0, help="End LR (default 1.0)")
347 p_flr.add_argument("--num_steps", type=int, default=100, help="Sweep steps (default 100)")
348 p_flr.add_argument("--plot", action="store_true", help="Save a loss-vs-LR PNG")
349 p_flr.add_argument("--plot_out", default=None, help="Output PNG path (default: lr_finder.png)")
350 _add_config_overrides(p_flr)
352 # ── push ───────────────────────────────────────────────────────────────────
353 p_push = sub.add_parser("push", help="Upload a checkpoint to HuggingFace Hub")
354 p_push.add_argument("--run_dir", required=True, help="Local run directory to upload")
355 p_push.add_argument("--repo", required=True, help="Hub repo id (e.g. my-org/my-model)")
356 p_push.add_argument("--private", action="store_true", help="Create a private repository")
357 p_push.add_argument("--token", default=None, help="HuggingFace access token")
358 p_push.add_argument("--message", default=None, help="Commit message")
360 # ── pull ───────────────────────────────────────────────────────────────────
361 p_pull = sub.add_parser("pull", help="Download a checkpoint from HuggingFace Hub")
362 p_pull.add_argument("--repo", required=True, help="Hub repo id (e.g. my-org/my-model)")
363 p_pull.add_argument("--local_dir", default=None, help="Where to save the files")
364 p_pull.add_argument("--token", default=None, help="HuggingFace access token")
365 p_pull.add_argument("--revision", default=None, help="Branch, tag, or commit SHA")
367 # ── infbench ───────────────────────────────────────────────────────────────
368 p_ib = sub.add_parser(
369 "infbench",
370 help="Run the full benchmark suite: inference sweep + optional trained-model pipeline",
371 formatter_class=argparse.RawDescriptionHelpFormatter,
372 description=(
373 "Benchmark pipeline:\n"
374 " Stage 1 benchmarks/inference_sweep.py → CSV (random-model sweep)\n"
375 " Stage 2 benchmarks/plot_inference.py → 21 PNG plots\n"
376 " Stage 3 benchmarks/trained_analysis.py → CSV (real trained runs)\n"
377 " Stage 4 benchmarks/trained_batch_sweep.py→ CSV (tps vs batch size)\n\n"
378 "Stages 3-4 only run when --trained is passed.\n\n"
379 "Examples:\n"
380 " dantinox infbench\n"
381 " dantinox infbench --trained\n"
382 " dantinox infbench --trained --inference-off\n"
383 " dantinox infbench --groups attention_type scale --n-trials 5\n"
384 " dantinox infbench --plot-only --out-csv results/inference_sweep.csv\n"
385 " dantinox infbench --device 1"
386 ),
387 )
388 p_ib.add_argument("--out-csv", default="results/inference_sweep.csv", metavar="PATH",
389 help="CSV output path (default: results/inference_sweep.csv)")
390 p_ib.add_argument("--out-dir", default="results/plots/", metavar="DIR",
391 help="Directory for plot PNGs (default: results/plots/)")
392 p_ib.add_argument("--groups", nargs="+", metavar="GROUP",
393 help="Restrict sweep to these groups (default: all 13)")
394 p_ib.add_argument("--n-warmup", type=int, default=3, metavar="N",
395 help="Warm-up reps per experiment (default: 3)")
396 p_ib.add_argument("--n-trials", type=int, default=10, metavar="N",
397 help="Measured reps per experiment (default: 10)")
398 p_ib.add_argument("--device", default=None, metavar="N",
399 help="CUDA device index for CUDA_VISIBLE_DEVICES (default: env)")
400 p_ib.add_argument("--sweep-only", action="store_true",
401 help="Run sweep only, skip plotting")
402 p_ib.add_argument("--plot-only", action="store_true",
403 help="Skip sweep, re-plot existing --out-csv")
404 p_ib.add_argument("--verbose", action="store_true",
405 help="Print per-experiment metrics during the sweep")
406 p_ib.add_argument("--trained", action="store_true",
407 help="Also run the trained-model pipeline (stages 3 and 4)")
408 p_ib.add_argument("--inference-off", action="store_true",
409 help="Skip inference pipeline; requires --trained")
410 p_ib.add_argument("--runs-dir", default="runs", metavar="DIR",
411 help="Directory of trained run subdirs (default: runs)")
412 p_ib.add_argument("--trained-csv", default="results/benchmark_results.csv", metavar="PATH",
413 help="Output CSV for trained-model analysis (default: results/benchmark_results.csv)")
414 p_ib.add_argument("--trained-plot", default="results/plots/trained_analysis.png", metavar="PATH",
415 help="Output PNG for trained-model analysis")
416 p_ib.add_argument("--batch-csv", default="results/batch_sweep_results.csv", metavar="PATH",
417 help="Output CSV for batch sweep (default: results/batch_sweep_results.csv)")
418 p_ib.add_argument("--batch-sizes", nargs="+", type=int, metavar="N",
419 help="Batch sizes for the batch sweep (default: 1 2 4 8 16 32 64)")
420 p_ib.add_argument("--batch-seq-len", type=int, default=512, metavar="N",
421 help="Fixed sequence length for the batch sweep (default: 512)")
423 # ── plot ────────────────────────────────────────────────────────────────
424 p_plot = sub.add_parser("plot", help="Generate benchmark plots from a results CSV")
425 p_plot.add_argument("--in_csv", default="benchmark_results.csv",
426 help="Benchmark results CSV (output of 'dantinox benchmark')")
427 p_plot.add_argument("--out_dir", default="plots",
428 help="Output directory for PNG files (default: plots/)")
429 p_plot.add_argument("--batch_csv", default=None,
430 help="Optional batch sweep CSV for the batch throughput figure")
431 p_plot.add_argument("--groups", nargs="*",
432 metavar="GROUP",
433 help="Plot groups to generate: insights perf 3d 3d_dkv (default: all)")
435 return parser
438def main(argv: list[str] | None = None) -> None:
439 logging.basicConfig(
440 level=logging.INFO,
441 format="%(asctime)s %(levelname)-8s %(name)s — %(message)s",
442 datefmt="%H:%M:%S",
443 )
444 parser = _build_parser()
445 args = parser.parse_args(argv)
447 dispatch = {
448 "train": _cmd_train,
449 "generate": _cmd_generate,
450 "sweep": _cmd_sweep,
451 "benchmark": _cmd_benchmark,
452 "infbench": _cmd_infbench,
453 "plot": _cmd_plot,
454 "find-lr": _cmd_find_lr,
455 "push": _cmd_push,
456 "pull": _cmd_pull,
457 }
458 dispatch[args.command](args)
461if __name__ == "__main__":
462 main()