Skip to content

π”‡π”žπ”«π”±π”¦π”«π”¬π”›

"Ah JAX, vituperio delle genti..."
(Ah JAX, the shame of the people...)

A Transformer so "nano" it barely rhymes, implemented in JAX and Flax NNX. Built with sweat and XLA errors.

JAX Flax NNX Python License


DantinoX Architecture


Overview: The DantinoX Project

"Nel mezzo del cammin di nostra vita mi ritrovai per una selva oscura, chΓ© la diritta via era smarrita."

DantinoX is a from-scratch implementation of a modern Large Language Model built natively in JAX and Flax NNX. The primary motivation behind this project is educational and exploratory: to understand the internal mechanics of current transformer architectures and to learn how to write efficient JAX code without constantly fighting XLA compilation errors.

To thoroughly understand these constraints, DantinoX implements standard modern Deep Learning components directly from the ground up:

  • Sparse Mixture of Experts (MoE) with Load Balancing Loss
  • Rotary Positional Embeddings (RoPE)
  • Grouped Query Attention (GQA)
  • Sliding Window & Attention Gating
  • Static KV Cache
  • Weight Tying
  • Gradient Checkpointing and Gradient Accumulation

Highly Customizable

Rather than a rigid production artifact, the codebase is designed to be highly customizable. The architecture is modular, allowing users to easily toggle between different configurationsβ€”such as switching between a standard Dense MLP and Sparse MoE routingβ€”to observe the direct impact on compute requirements and VRAM usage.

The final result is a functional, memory-efficient Transformer. It serves as a practical reference for resolving shape mismatches, managing GPU memory footprint, and successfully taming the XLA compiler.

"E quindi uscimmo a riveder le stelle."


Project Structure

DantinoX/
β”œβ”€β”€ core/                   # Core neural network logic
β”‚   β”œβ”€β”€ config.py           # Configuration parameters (Config Dataclass)
β”‚   β”œβ”€β”€ model.py            # Transformer architecture (Attention, MLP, MoE, Block)
β”‚   β”œβ”€β”€ generation.py       # Inference engine & static KV-Cache management
β”‚   └── __init__.py
β”‚
β”œβ”€β”€ configs/                # YAML configuration files
β”‚   β”œβ”€β”€ default_config.yaml # Standard training setup
β”‚   └── sweep.yaml          # Hyperparameter search config (W&B)
β”‚
β”œβ”€β”€ utils/                  # Utility functions
β”‚   β”œβ”€β”€ tokenizer.py        # Tokenizer management (Char-level & Byte-Level BPE)
β”‚   β”œβ”€β”€ helpers.py          # Loss functions, batching, sharding logic
β”‚   └── __init__.py
β”‚
β”œβ”€β”€ runs/                   # Training outputs (weights, logs, saved configs)
β”‚
β”œβ”€β”€ train.py                # Training script
β”œβ”€β”€ generate.py             # Text generation script
β”œβ”€β”€ requirements.txt        # Python dependencies
└── README.md               # Documentation