DantinoX¶
"E quindi uscimmo a riveder le stelle."
A decoder-only Transformer library built from scratch in JAX and Flax NNX — MHA, GQA, MLA, MoE, Flash Attention, LoRA fine-tuning, multi-GPU sharding, and more, all from a single config.
Get Started Open in Colab View on GitHub
Python 3.12+ MLA · GQA · MHA XLA-Native Flash Attention bfloat16 RMSNorm · LayerNorm LoRA Fine-Tuning :material-gpu: Multi-GPU SPMD HF Hub pip install MIT
-
Three Attention Families
MHA, GQA, and Multi-Head Latent Attention (MLA) with decoupled RoPE and full weight absorption — all switchable via a single config flag. No code edits required.
-
JAX-Native, XLA-First
Static KV cache via
dynamic_update_slice,@jax.jittraining loop, and optional Flash Attention viajax.nn.dot_product_attention— zero dynamic shapes, zero recompilation. -
Fully Benchmarked
Bayesian sweeps over 20+ hyperparameters logged to W&B. Results visualised in 2D and 3D across throughput, FLOPs, KV cache size, and latency.
-
Production-Ready
Trainer,Generator,BenchmarkRunner— bfloat16, gradient clipping, early stopping,from_pretrained, 4 LR schedules, streaming generation, and HuggingFace Hub direct loading (Generator("owner/repo")). -
Research-Grade Math
MLA weight absorption, NTK-aware RoPE scaling, MoE load-balancing loss, and decoupled positional encoding — all implemented from first principles.
-
Zero-Code Configuration
Every component — attention type, normalisation, positional encoding, FFN — is a YAML field in a single
Configdataclass. Toggle RMSNorm, Flash Attention, or MoE without touching source. -
LoRA Fine-Tuning
Adapt any pre-trained checkpoint with
use_lora=True. A customLoRAParamvariable type keeps base weights frozen; only rank-decomposed adapters are trained — ~0.1–0.5 % of parameters. -
Multi-GPU SPMD
Data-parallel training across any number of GPUs via JAX's SPMD sharding. Model weights are replicated, batches are sharded; XLA fuses the AllReduce automatically. Set
n_devices=4and go.
Quickstart¶
git clone https://github.com/winstonsmith1897/DantinoX.git
cd DantinoX
conda create -n dantinox python=3.12 -y && conda activate dantinox
pip install -U "jax[cuda12]"
pip install -e ".[all]"
from dantinox import Config, Trainer, Generator
# 1. Train — bfloat16, RMSNorm, Flash Attention, WSD schedule
config = Config(
dim=512, n_heads=16, head_size=32, num_blocks=8,
lr=3e-4, grad_clip=1.0, use_bf16=True,
norm_type="rmsnorm", # RMSNorm instead of LayerNorm
use_flash_attention=True, # fused scaled-dot-product (JAX ≥ 0.4.25)
lr_schedule="wsd", # warmup → stable → cosine decay
rope_scale_factor=2.0, # NTK-aware: ~2× effective context window
patience=5, # stop if val loss stalls for 5 evals
)
run_dir = Trainer(config).fit("data/corpus.txt")
# 2. Single-prompt generation
gen = Generator(run_dir)
print(gen.generate("Nel mezzo del cammin ", max_new_tokens=200))
# 3. Batched generation — one forward pass for all prompts
texts = gen.generate_batch(
["Nel mezzo", "Lasciate ogni speranza", "Per me si va"],
max_new_tokens=100, temperature=0.8,
)
# 4. Streaming generation — yield tokens as they are produced
for chunk in gen.stream("Nel mezzo del cammin ", max_new_tokens=150):
print(chunk, end="", flush=True)
# 5. Find the right learning rate before a full run
lr, lrs, losses = Trainer(config).find_lr("data/corpus.txt", num_steps=100)
print(f"Suggested LR: {lr:.2e}")
# 6. Load model directly for custom inference / fine-tuning
from core import Transformer
model = Transformer.from_pretrained(run_dir) # local
model = Transformer.from_pretrained("my-org/dantinox-dante") # HF Hub — downloads automatically
# 7. Load directly from the Hub — no pull step needed
gen_hub = Generator("my-org/dantinox-dante") # public repo
gen_prv = Generator("my-org/private", token="hf_…") # private repo
print(gen_hub.generate("Nel mezzo del cammin "))
# 8. LoRA fine-tuning — only adapter params are trained (~0.2% of total)
ft_config = Config.from_yaml(f"{run_dir}/config.yaml")
ft_config.use_lora = True; ft_config.lora_rank = 8; ft_config.lora_targets = "attention"
ft_run = Trainer(ft_config).fit("data/finetune.txt")
# 9. Multi-GPU data-parallel — set n_devices, everything else is automatic
config_4gpu = Config(dim=512, n_heads=16, head_size=32, num_blocks=8,
batch_size=256, n_devices=4)
Trainer(config_4gpu).fit("data/corpus.txt")
# 10. Push to HuggingFace Hub
from dantinox import push
push(run_dir, "my-org/dantinox-dante", private=False)
# Train with bfloat16 and gradient clipping
dantinox train \
--config configs/default_config.yaml \
--data_path data/corpus.txt \
--use_bf16 True --grad_clip 1.0 --patience 5
# Resume an interrupted run
dantinox train --config configs/default_config.yaml \
--data_path data/corpus.txt \
--run_dir runs/run_20260101_120000 --resume
# Find the best learning rate before committing to a long run
dantinox find-lr \
--config configs/default_config.yaml \
--data_path data/corpus.txt \
--min_lr 1e-6 --max_lr 1e-2 --num_steps 100 --plot
# Generate text
dantinox generate \
--run_dir runs/run_20260101_120000 \
--prompt "Nel mezzo del cammin " \
--max_new_tokens 200 --temperature 0.8 --top_k 40
# Sweep (W&B Bayesian)
dantinox sweep --sweep_config configs/sweep.yaml --data_path data/corpus.txt
# Benchmark all runs, then plot
dantinox benchmark --runs_dir runs --out_csv benchmark_results.csv
dantinox plot --in_csv benchmark_results.csv --out_dir plots/
# Share your checkpoint on HuggingFace Hub
dantinox push --run_dir runs/run_20260101_120000 --repo my-org/dantinox-dante
dantinox pull --repo my-org/dantinox-dante --local_dir runs/pulled
Why DantinoX?¶
Most "from-scratch" Transformer implementations stop at the forward pass. DantinoX goes further:
- Correct XLA semantics — static KV cache, no dynamic shapes, no recompilation at decode time.
- Real research features — MLA weight absorption, NTK-aware RoPE scaling, MoE load balancing, Flash Attention — not demos, fully tested.
- A library, not a script —
Trainer,Generator,BenchmarkRunner, and a CLI.pip installand go. - Production-ready fine-tuning — LoRA adapters with a custom
LoRAParamvariable type: base weights frozen at the type level, not by manual filtering. Merge and export with one call. - Multi-GPU out of the box — JAX SPMD sharding with
n_devices=Nin config. No pmap, no manual AllReduce — XLA handles it. - Auditable — 86 tests, mypy clean, ruff clean, coverage report in the docs.
Documentation¶
| Page | What you'll find | |
|---|---|---|
| Core Architecture | Attention types, math, LoRA, multi-GPU, full config reference | |
| Training & Sweeps | bfloat16, grad clipping, early stopping, resume, LR finder, LoRA fine-tuning, multi-GPU | |
| Inference & Generation | Single, batch & streaming generation, KV-cache pipeline, sampling strategies | |
| Benchmarks | MHA vs GQA vs MLA — throughput, cache size, FLOPs, 3D surfaces | |
| Ablation Studies | Optimizer, MoE, positional encoding, regularization | |
| API Reference | Trainer, Generator, LoRALinear, sharding utils, BenchmarkRunner, Hub |
Project Structure¶
DantinoX/
├── dantinox/ # Public library API
│ ├── __init__.py # Top-level imports and __version__
│ ├── trainer.py # Trainer — training, gradient clipping, early stopping, LR finder
│ ├── generator.py # Generator — single, batch & streaming generation
│ ├── hub.py # push() / pull() — HuggingFace Hub integration
│ ├── bench.py # BenchmarkRunner — throughput / FLOPs benchmarks
│ ├── plotting.py # Plotter — automated plot generation
│ ├── exceptions.py # DantinoXError hierarchy
│ └── cli.py # dantinox CLI (train, generate, find-lr, push, pull, ...)
│
├── core/ # Internal implementation
│ ├── config.py # Config dataclass — single source of truth
│ ├── model.py # Transformer (+ from_pretrained), Block, RMSNorm
│ ├── attention.py # Attention kernels, Flash Attention, KV-cache logic
│ ├── output.py # ModelOutput NamedTuple
│ └── generation.py # Autoregressive inference engine
│
├── utils/
│ ├── tokenizer.py # CharTokenizer, BPETokenizer, save/load
│ └── helpers.py # Loss, batching utilities
│
├── configs/
│ ├── default_config.yaml # Standard training setup
│ └── sweep.yaml # W&B Bayesian sweep configuration
│
├── tests/ # pytest integration + unit tests
├── examples/
│ ├── quickstart.py # Train → generate end-to-end demo
│ └── DantinoX_Colab.ipynb # Colab notebook (GPU, HF dataset, Hub, LoRA)
│
└── pyproject.toml # pip install -e ".[all]"