Skip to content

API Reference

Auto-generated from source docstrings via mkdocstrings.


High-level API

The dantinox package exposes five classes and two functions that cover the full lifecycle — training, generation, benchmarking, plotting, and Hub sharing — without touching internal modules.

Trainer

dantinox.trainer.Trainer

High-level training interface for DantinoX.

Parameters

config : Config Model and training configuration.

Examples

trainer = Trainer(Config.from_yaml("configs/default_config.yaml")) run_dir = trainer.fit("data/corpus.txt")

Source code in dantinox/trainer.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
class Trainer:
    """
    High-level training interface for DantinoX.

    Parameters
    ----------
    config : Config
        Model and training configuration.

    Examples
    --------
    >>> trainer = Trainer(Config.from_yaml("configs/default_config.yaml"))
    >>> run_dir = trainer.fit("data/corpus.txt")
    """

    def __init__(self, config: Config) -> None:
        self.config = config

    def __repr__(self) -> str:
        return f"Trainer(config={self.config!r})"

    def fit(
        self,
        data_path: str | None = None,
        *,
        run_dir: str | None = None,
        wandb_project: str | None = None,
        resume: bool = False,
    ) -> str:
        """
        Train a model and save the checkpoint.

        Parameters
        ----------
        data_path : str, optional
            Path to the training corpus. Falls back to ``config.dataset_name``.
        run_dir : str, optional
            Directory to write the checkpoint and logs. Auto-generated if omitted.
        wandb_project : str, optional
            If provided, metrics are logged to Weights & Biases.
        resume : bool
            If ``True`` and a previous checkpoint exists in ``run_dir``, training
            resumes from the saved step. Optimizer state is not preserved.

        Returns
        -------
        str
            Path to the run directory containing the saved checkpoint.

        Raises
        ------
        ConfigError
            If the data path is missing or the file cannot be found.
        """
        config = self.config

        if run_dir is None:
            run_dir = os.path.join(
                "runs", datetime.datetime.now().strftime("run_%Y%m%d_%H%M%S")
            )
        os.makedirs(run_dir, exist_ok=True)
        config.save_yaml(os.path.join(run_dir, "config.yaml"))

        log.info("Run directory: %s", run_dir)

        text = _format_text(_load_text(config, data_path))

        tok_path = os.path.join(run_dir, "tokenizer.json")
        if resume and os.path.exists(tok_path):
            tokenizer = load_tokenizer_from_file(tok_path)
            log.info("Resumed tokenizer from %s", tok_path)
        else:
            tokenizer = get_tokenizer(config.tokenizer_type)
            if config.tokenizer_type == "char":
                tokenizer.train_from_text(text)
            elif config.tokenizer_type == "bpe":
                tokenizer.train_from_text(text, vocab_size=config.vocab_size)
            tokenizer.save(tok_path)
            log.info("Tokenizer saved to %s", tok_path)

        config.vocab_size = tokenizer.vocab_size
        full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32)
        n = int(0.9 * len(full_data))
        train_data, val_data = full_data[:n], full_data[n:]

        tokens_per_step = config.batch_size * config.max_context
        steps_per_epoch = max(1, len(train_data) // tokens_per_step)
        total_steps = steps_per_epoch * config.epochs

        tx = _build_optimizer(config, total_steps)
        rngs = nnx.Rngs(config.seed)
        model = Transformer(config, rngs=rngs)
        if getattr(config, "use_bf16", False):
            _cast_params(model, jnp.bfloat16)
            log.info("Model cast to bfloat16")

        # LoRA: only train adapter params; base weights are frozen nnx.Param
        wrt_type = LoRAParam if getattr(config, "use_lora", False) else nnx.Param
        optimizer = nnx.Optimizer(model, tx, wrt=wrt_type)

        # Multi-GPU: build a data-parallel mesh when more than one device is requested
        n_dev_cfg = getattr(config, "n_devices", 0)
        import jax as _jax
        n_local = len(_jax.local_devices())
        use_multi_gpu = (n_dev_cfg != 1) and n_local > 1
        mesh = make_mesh(n_dev_cfg) if use_multi_gpu else None
        if use_multi_gpu:
            n_dev = num_devices(mesh)  # type: ignore[arg-type]
            if config.batch_size % n_dev != 0:
                raise ConfigError(
                    f"batch_size ({config.batch_size}) must be divisible by "
                    f"n_devices ({n_dev}) for data-parallel training."
                )
            log.info("Data-parallel training on %d devices", n_dev)

        start_step = 0
        cursor_path = os.path.join(run_dir, "training_cursor.json")
        resume_weights = os.path.join(run_dir, "model_weights.msgpack")
        if resume and os.path.exists(cursor_path) and os.path.exists(resume_weights):
            with open(cursor_path) as cursor_f:
                cursor = json.load(cursor_f)
            start_step = int(cursor.get("step", 0)) + 1
            import msgpack
            from flax.serialization import _msgpack_ext_unpack
            with open(resume_weights, "rb") as weights_f:
                state_dict = msgpack.unpackb(
                    weights_f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False
                )
            nnx.update(model, state_dict)
            log.info("Resumed training from step %d", start_step)

        summary = _model_summary(model, config, optimizer)
        with open(os.path.join(run_dir, "model_summary.json"), "w") as f:
            json.dump(summary, f, indent=4)
        log.info(
            "Model: %sM params | Est. VRAM: %sMB",
            summary["total_params_M"],
            summary["weights_mem_MB"] + summary["optimizer_mem_MB"] + summary["est_activations_MB"],
        )

        if wandb_project is not None:
            import wandb
            wandb.init(project=wandb_project, config=config.to_dict())  # type: ignore[attr-defined]

        micro_bs = config.batch_size // config.grad_accum
        _wrt = wrt_type  # captured in closure for JIT

        @nnx.jit
        def train_step(model, opt, metrics, full_x, full_y):
            xs = full_x.reshape(config.grad_accum, micro_bs, -1)
            ys = full_y.reshape(config.grad_accum, micro_bs, -1)

            def _loss(model, x, y):
                logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0)
                loss = compute_loss(logits, y)
                if getattr(model, "use_moe", False):
                    loss = loss + model.alpha_balance * bal
                return loss, bal

            grad_fn = nnx.value_and_grad(_loss, argnums=DiffState(0, _wrt), has_aux=True)
            acc = jax.tree_util.tree_map(jnp.zeros_like, nnx.state(model, _wrt))
            total_loss = jnp.array(0.0)
            total_bal = jnp.array(0.0)
            for i in range(config.grad_accum):
                (loss, bal), grads = grad_fn(model, xs[i], ys[i])
                acc = jax.tree_util.tree_map(
                    lambda a, g: a + g / config.grad_accum, acc, grads
                )
                total_loss += loss / config.grad_accum
                total_bal += bal / config.grad_accum
            opt.update(model, acc)
            metrics.update(loss=total_loss)
            return total_loss, total_bal

        @nnx.jit
        def eval_step(model, x, y):
            logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0)
            loss = compute_loss(logits, y)
            if getattr(model, "use_moe", False):
                loss = loss + model.alpha_balance * bal
            return loss, bal

        def estimate_loss(key):
            out: dict[str, float] = {}
            for split, d in [("train", train_data), ("val", val_data)]:
                losses, bals = [], []
                for _ in range(config.eval_iters):
                    key, sub = jax.random.split(key)
                    x, y = get_batch(d, 1, config.max_context, sub)
                    loss_val, b = eval_step(model, x, y)
                    losses.append(float(loss_val))
                    bals.append(float(b))
                out[split] = sum(losses) / len(losses)
                out[f"{split}_bal"] = sum(bals) / len(bals)
            return out, key

        log_path = os.path.join(run_dir, "training_log.csv")
        weights_path = os.path.join(run_dir, "model_weights.msgpack")
        best_weights_path = os.path.join(run_dir, "best_model_weights.msgpack")

        key = jax.random.PRNGKey(config.seed)
        metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss"))

        # Multi-GPU: replicate model/optimizer/metrics to all devices once
        if use_multi_gpu:
            assert mesh is not None
            state = replicate(nnx.state((model, optimizer, metrics)), mesh)
            nnx.update((model, optimizer, metrics), state)

        pbar = tqdm(
            range(start_step, total_steps),
            desc="Training",
            unit="step",
            dynamic_ncols=True,
            initial=start_step,
            total=total_steps,
        )
        t0 = time.time()

        patience = getattr(config, "patience", 0)
        best_val_loss = float("inf")
        no_improve = 0

        with open(log_path, "a", newline="") as log_f:
            log_w = csv.writer(log_f)
            if os.path.getsize(log_path) == 0:
                log_w.writerow(
                    ["step", "train_loss", "val_loss", "train_bal", "val_bal", "ms_per_step"]
                )
            try:
                for step in pbar:
                    key, sub = jax.random.split(key)
                    x, y = get_batch(train_data, config.batch_size, config.max_context, sub)
                    if use_multi_gpu:
                        assert mesh is not None
                        x = shard_batch(x, mesh)
                        y = shard_batch(y, mesh)
                    train_step(model, optimizer, metrics, x, y)

                    if step % 50 == 0:
                        t1 = time.time()
                        dt = (t1 - t0) * 1000 / 50
                        t0 = t1
                        losses, key = estimate_loss(key)
                        val_loss = losses["val"]
                        pbar.set_postfix(
                            train=f"{losses['train']:.4f}",
                            val=f"{val_loss:.4f}",
                        )
                        log.info(
                            "step %d/%d | train=%.4f val=%.4f bal=%.4f",
                            step, total_steps,
                            losses["train"], val_loss, losses["train_bal"],
                        )
                        log_w.writerow(
                            [
                                step,
                                float(losses["train"]),
                                float(val_loss),
                                float(losses["train_bal"]),
                                float(losses["val_bal"]),
                                round(dt, 2),
                            ]
                        )
                        log_f.flush()

                        # Periodic checkpoint for resume
                        _save_weights(model, weights_path)
                        with open(cursor_path, "w") as cf:
                            json.dump({"step": step}, cf)

                        # Best checkpoint tracking
                        if val_loss < best_val_loss:
                            best_val_loss = val_loss
                            no_improve = 0
                            _save_weights(model, best_weights_path)
                            log.info("New best val loss %.4f — saved best checkpoint", best_val_loss)
                        else:
                            no_improve += 1
                            if patience > 0 and no_improve >= patience:
                                log.info(
                                    "Early stopping at step %d (no improvement for %d evals)",
                                    step, patience,
                                )
                                break

                        if wandb_project is not None:
                            import wandb
                            wandb.log({"train_loss": losses["train"], "val_loss": val_loss, "step": step})  # type: ignore[attr-defined]
            finally:
                pbar.close()
                if wandb_project is not None:
                    import wandb
                    wandb.finish()  # type: ignore[attr-defined]

        _save_weights(model, weights_path)
        log.info("Checkpoint saved: %s", weights_path)
        return run_dir

    def find_lr(
        self,
        data_path: str | None = None,
        *,
        min_lr: float = 1e-7,
        max_lr: float = 1.0,
        num_steps: int = 100,
        smoothing: float = 0.9,
    ) -> tuple[float, list[float], list[float]]:
        """
        LR range test (Smith 2015).

        Trains for ``num_steps`` steps while exponentially increasing the
        learning rate from ``min_lr`` to ``max_lr``.  Returns a tuple of
        ``(suggested_lr, lr_history, loss_history)``.

        Parameters
        ----------
        data_path : str, optional
            Path to the training corpus.
        min_lr : float
            Starting learning rate (default 1e-7).
        max_lr : float
            Maximum learning rate (default 1.0).
        num_steps : int
            Number of steps in the sweep (default 100).
        smoothing : float
            Exponential smoothing factor for the loss curve (default 0.9).

        Returns
        -------
        tuple[float, list[float], list[float]]
            ``(suggested_lr, lr_history, loss_history)``
        """
        import math

        config = self.config
        text = _format_text(_load_text(config, data_path))

        tokenizer = get_tokenizer(config.tokenizer_type)
        if config.tokenizer_type == "char":
            tokenizer.train_from_text(text)
        elif config.tokenizer_type == "bpe":
            tokenizer.train_from_text(text, vocab_size=config.vocab_size)

        config.vocab_size = tokenizer.vocab_size
        full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32)
        train_data = full_data[: int(0.9 * len(full_data))]

        rngs = nnx.Rngs(config.seed)
        model = Transformer(config, rngs=rngs)
        if getattr(config, "use_bf16", False):
            _cast_params(model, jnp.bfloat16)

        log_multiplier = math.log(max_lr / min_lr) / max(1, num_steps - 1)

        def _lr_fn(step: jnp.ndarray) -> jnp.ndarray:
            return jnp.array(min_lr, jnp.float32) * jnp.exp(
                step.astype(jnp.float32) * jnp.array(log_multiplier, jnp.float32)
            )

        tx = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=_lr_fn))
        optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

        @nnx.jit
        def _step(model, opt, x, y):
            def loss_fn(m):
                logits, _, _ = m(x, use_cache=False, kv_caches=None, cache_index=0)
                return compute_loss(logits, y)

            loss, grads = nnx.value_and_grad(loss_fn)(model)
            opt.update(model, grads)
            return loss

        key = jax.random.PRNGKey(config.seed)
        lr_history: list[float] = []
        loss_history: list[float] = []
        smooth_loss = 0.0
        best_loss = float("inf")

        pbar = tqdm(range(num_steps), desc="LR finder", unit="step", dynamic_ncols=True)
        for step in pbar:
            key, sub = jax.random.split(key)
            x, y = get_batch(train_data, config.batch_size, config.max_context, sub)
            loss_val = float(_step(model, optimizer, x, y))

            smooth_loss = (
                loss_val if step == 0
                else smoothing * smooth_loss + (1 - smoothing) * loss_val
            )
            debiased = smooth_loss / (1 - smoothing ** (step + 1))

            current_lr = float(_lr_fn(jnp.array(step)))
            lr_history.append(current_lr)
            loss_history.append(debiased)

            pbar.set_postfix(lr=f"{current_lr:.2e}", loss=f"{debiased:.4f}")

            if debiased < best_loss:
                best_loss = debiased
            if debiased > 4 * best_loss:
                log.info("Loss diverging at step %d — stopping sweep early", step)
                break

        pbar.close()

        if len(loss_history) > 2:
            slopes = [loss_history[i + 1] - loss_history[i] for i in range(len(loss_history) - 1)]
            suggested_lr = lr_history[min(range(len(slopes)), key=lambda i: slopes[i])]
        else:
            suggested_lr = min_lr

        log.info(
            "LR finder done — suggested lr=%.2e (sweep range [%.2e, %.2e])",
            suggested_lr, min_lr, max_lr,
        )
        return suggested_lr, lr_history, loss_history

Functions

__init__

__init__(config: Config) -> None
Source code in dantinox/trainer.py
def __init__(self, config: Config) -> None:
    self.config = config

fit

fit(
    data_path: str | None = None,
    *,
    run_dir: str | None = None,
    wandb_project: str | None = None,
    resume: bool = False,
) -> str

Train a model and save the checkpoint.

Parameters

data_path : str, optional Path to the training corpus. Falls back to config.dataset_name. run_dir : str, optional Directory to write the checkpoint and logs. Auto-generated if omitted. wandb_project : str, optional If provided, metrics are logged to Weights & Biases. resume : bool If True and a previous checkpoint exists in run_dir, training resumes from the saved step. Optimizer state is not preserved.

Returns

str Path to the run directory containing the saved checkpoint.

Raises

ConfigError If the data path is missing or the file cannot be found.

Source code in dantinox/trainer.py
def fit(
    self,
    data_path: str | None = None,
    *,
    run_dir: str | None = None,
    wandb_project: str | None = None,
    resume: bool = False,
) -> str:
    """
    Train a model and save the checkpoint.

    Parameters
    ----------
    data_path : str, optional
        Path to the training corpus. Falls back to ``config.dataset_name``.
    run_dir : str, optional
        Directory to write the checkpoint and logs. Auto-generated if omitted.
    wandb_project : str, optional
        If provided, metrics are logged to Weights & Biases.
    resume : bool
        If ``True`` and a previous checkpoint exists in ``run_dir``, training
        resumes from the saved step. Optimizer state is not preserved.

    Returns
    -------
    str
        Path to the run directory containing the saved checkpoint.

    Raises
    ------
    ConfigError
        If the data path is missing or the file cannot be found.
    """
    config = self.config

    if run_dir is None:
        run_dir = os.path.join(
            "runs", datetime.datetime.now().strftime("run_%Y%m%d_%H%M%S")
        )
    os.makedirs(run_dir, exist_ok=True)
    config.save_yaml(os.path.join(run_dir, "config.yaml"))

    log.info("Run directory: %s", run_dir)

    text = _format_text(_load_text(config, data_path))

    tok_path = os.path.join(run_dir, "tokenizer.json")
    if resume and os.path.exists(tok_path):
        tokenizer = load_tokenizer_from_file(tok_path)
        log.info("Resumed tokenizer from %s", tok_path)
    else:
        tokenizer = get_tokenizer(config.tokenizer_type)
        if config.tokenizer_type == "char":
            tokenizer.train_from_text(text)
        elif config.tokenizer_type == "bpe":
            tokenizer.train_from_text(text, vocab_size=config.vocab_size)
        tokenizer.save(tok_path)
        log.info("Tokenizer saved to %s", tok_path)

    config.vocab_size = tokenizer.vocab_size
    full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32)
    n = int(0.9 * len(full_data))
    train_data, val_data = full_data[:n], full_data[n:]

    tokens_per_step = config.batch_size * config.max_context
    steps_per_epoch = max(1, len(train_data) // tokens_per_step)
    total_steps = steps_per_epoch * config.epochs

    tx = _build_optimizer(config, total_steps)
    rngs = nnx.Rngs(config.seed)
    model = Transformer(config, rngs=rngs)
    if getattr(config, "use_bf16", False):
        _cast_params(model, jnp.bfloat16)
        log.info("Model cast to bfloat16")

    # LoRA: only train adapter params; base weights are frozen nnx.Param
    wrt_type = LoRAParam if getattr(config, "use_lora", False) else nnx.Param
    optimizer = nnx.Optimizer(model, tx, wrt=wrt_type)

    # Multi-GPU: build a data-parallel mesh when more than one device is requested
    n_dev_cfg = getattr(config, "n_devices", 0)
    import jax as _jax
    n_local = len(_jax.local_devices())
    use_multi_gpu = (n_dev_cfg != 1) and n_local > 1
    mesh = make_mesh(n_dev_cfg) if use_multi_gpu else None
    if use_multi_gpu:
        n_dev = num_devices(mesh)  # type: ignore[arg-type]
        if config.batch_size % n_dev != 0:
            raise ConfigError(
                f"batch_size ({config.batch_size}) must be divisible by "
                f"n_devices ({n_dev}) for data-parallel training."
            )
        log.info("Data-parallel training on %d devices", n_dev)

    start_step = 0
    cursor_path = os.path.join(run_dir, "training_cursor.json")
    resume_weights = os.path.join(run_dir, "model_weights.msgpack")
    if resume and os.path.exists(cursor_path) and os.path.exists(resume_weights):
        with open(cursor_path) as cursor_f:
            cursor = json.load(cursor_f)
        start_step = int(cursor.get("step", 0)) + 1
        import msgpack
        from flax.serialization import _msgpack_ext_unpack
        with open(resume_weights, "rb") as weights_f:
            state_dict = msgpack.unpackb(
                weights_f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False
            )
        nnx.update(model, state_dict)
        log.info("Resumed training from step %d", start_step)

    summary = _model_summary(model, config, optimizer)
    with open(os.path.join(run_dir, "model_summary.json"), "w") as f:
        json.dump(summary, f, indent=4)
    log.info(
        "Model: %sM params | Est. VRAM: %sMB",
        summary["total_params_M"],
        summary["weights_mem_MB"] + summary["optimizer_mem_MB"] + summary["est_activations_MB"],
    )

    if wandb_project is not None:
        import wandb
        wandb.init(project=wandb_project, config=config.to_dict())  # type: ignore[attr-defined]

    micro_bs = config.batch_size // config.grad_accum
    _wrt = wrt_type  # captured in closure for JIT

    @nnx.jit
    def train_step(model, opt, metrics, full_x, full_y):
        xs = full_x.reshape(config.grad_accum, micro_bs, -1)
        ys = full_y.reshape(config.grad_accum, micro_bs, -1)

        def _loss(model, x, y):
            logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0)
            loss = compute_loss(logits, y)
            if getattr(model, "use_moe", False):
                loss = loss + model.alpha_balance * bal
            return loss, bal

        grad_fn = nnx.value_and_grad(_loss, argnums=DiffState(0, _wrt), has_aux=True)
        acc = jax.tree_util.tree_map(jnp.zeros_like, nnx.state(model, _wrt))
        total_loss = jnp.array(0.0)
        total_bal = jnp.array(0.0)
        for i in range(config.grad_accum):
            (loss, bal), grads = grad_fn(model, xs[i], ys[i])
            acc = jax.tree_util.tree_map(
                lambda a, g: a + g / config.grad_accum, acc, grads
            )
            total_loss += loss / config.grad_accum
            total_bal += bal / config.grad_accum
        opt.update(model, acc)
        metrics.update(loss=total_loss)
        return total_loss, total_bal

    @nnx.jit
    def eval_step(model, x, y):
        logits, _, bal = model(x, use_cache=False, kv_caches=None, cache_index=0)
        loss = compute_loss(logits, y)
        if getattr(model, "use_moe", False):
            loss = loss + model.alpha_balance * bal
        return loss, bal

    def estimate_loss(key):
        out: dict[str, float] = {}
        for split, d in [("train", train_data), ("val", val_data)]:
            losses, bals = [], []
            for _ in range(config.eval_iters):
                key, sub = jax.random.split(key)
                x, y = get_batch(d, 1, config.max_context, sub)
                loss_val, b = eval_step(model, x, y)
                losses.append(float(loss_val))
                bals.append(float(b))
            out[split] = sum(losses) / len(losses)
            out[f"{split}_bal"] = sum(bals) / len(bals)
        return out, key

    log_path = os.path.join(run_dir, "training_log.csv")
    weights_path = os.path.join(run_dir, "model_weights.msgpack")
    best_weights_path = os.path.join(run_dir, "best_model_weights.msgpack")

    key = jax.random.PRNGKey(config.seed)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss"))

    # Multi-GPU: replicate model/optimizer/metrics to all devices once
    if use_multi_gpu:
        assert mesh is not None
        state = replicate(nnx.state((model, optimizer, metrics)), mesh)
        nnx.update((model, optimizer, metrics), state)

    pbar = tqdm(
        range(start_step, total_steps),
        desc="Training",
        unit="step",
        dynamic_ncols=True,
        initial=start_step,
        total=total_steps,
    )
    t0 = time.time()

    patience = getattr(config, "patience", 0)
    best_val_loss = float("inf")
    no_improve = 0

    with open(log_path, "a", newline="") as log_f:
        log_w = csv.writer(log_f)
        if os.path.getsize(log_path) == 0:
            log_w.writerow(
                ["step", "train_loss", "val_loss", "train_bal", "val_bal", "ms_per_step"]
            )
        try:
            for step in pbar:
                key, sub = jax.random.split(key)
                x, y = get_batch(train_data, config.batch_size, config.max_context, sub)
                if use_multi_gpu:
                    assert mesh is not None
                    x = shard_batch(x, mesh)
                    y = shard_batch(y, mesh)
                train_step(model, optimizer, metrics, x, y)

                if step % 50 == 0:
                    t1 = time.time()
                    dt = (t1 - t0) * 1000 / 50
                    t0 = t1
                    losses, key = estimate_loss(key)
                    val_loss = losses["val"]
                    pbar.set_postfix(
                        train=f"{losses['train']:.4f}",
                        val=f"{val_loss:.4f}",
                    )
                    log.info(
                        "step %d/%d | train=%.4f val=%.4f bal=%.4f",
                        step, total_steps,
                        losses["train"], val_loss, losses["train_bal"],
                    )
                    log_w.writerow(
                        [
                            step,
                            float(losses["train"]),
                            float(val_loss),
                            float(losses["train_bal"]),
                            float(losses["val_bal"]),
                            round(dt, 2),
                        ]
                    )
                    log_f.flush()

                    # Periodic checkpoint for resume
                    _save_weights(model, weights_path)
                    with open(cursor_path, "w") as cf:
                        json.dump({"step": step}, cf)

                    # Best checkpoint tracking
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        no_improve = 0
                        _save_weights(model, best_weights_path)
                        log.info("New best val loss %.4f — saved best checkpoint", best_val_loss)
                    else:
                        no_improve += 1
                        if patience > 0 and no_improve >= patience:
                            log.info(
                                "Early stopping at step %d (no improvement for %d evals)",
                                step, patience,
                            )
                            break

                    if wandb_project is not None:
                        import wandb
                        wandb.log({"train_loss": losses["train"], "val_loss": val_loss, "step": step})  # type: ignore[attr-defined]
        finally:
            pbar.close()
            if wandb_project is not None:
                import wandb
                wandb.finish()  # type: ignore[attr-defined]

    _save_weights(model, weights_path)
    log.info("Checkpoint saved: %s", weights_path)
    return run_dir

find_lr

find_lr(
    data_path: str | None = None,
    *,
    min_lr: float = 1e-07,
    max_lr: float = 1.0,
    num_steps: int = 100,
    smoothing: float = 0.9,
) -> tuple[float, list[float], list[float]]

LR range test (Smith 2015).

Trains for num_steps steps while exponentially increasing the learning rate from min_lr to max_lr. Returns a tuple of (suggested_lr, lr_history, loss_history).

Parameters

data_path : str, optional Path to the training corpus. min_lr : float Starting learning rate (default 1e-7). max_lr : float Maximum learning rate (default 1.0). num_steps : int Number of steps in the sweep (default 100). smoothing : float Exponential smoothing factor for the loss curve (default 0.9).

Returns

tuple[float, list[float], list[float]] (suggested_lr, lr_history, loss_history)

Source code in dantinox/trainer.py
def find_lr(
    self,
    data_path: str | None = None,
    *,
    min_lr: float = 1e-7,
    max_lr: float = 1.0,
    num_steps: int = 100,
    smoothing: float = 0.9,
) -> tuple[float, list[float], list[float]]:
    """
    LR range test (Smith 2015).

    Trains for ``num_steps`` steps while exponentially increasing the
    learning rate from ``min_lr`` to ``max_lr``.  Returns a tuple of
    ``(suggested_lr, lr_history, loss_history)``.

    Parameters
    ----------
    data_path : str, optional
        Path to the training corpus.
    min_lr : float
        Starting learning rate (default 1e-7).
    max_lr : float
        Maximum learning rate (default 1.0).
    num_steps : int
        Number of steps in the sweep (default 100).
    smoothing : float
        Exponential smoothing factor for the loss curve (default 0.9).

    Returns
    -------
    tuple[float, list[float], list[float]]
        ``(suggested_lr, lr_history, loss_history)``
    """
    import math

    config = self.config
    text = _format_text(_load_text(config, data_path))

    tokenizer = get_tokenizer(config.tokenizer_type)
    if config.tokenizer_type == "char":
        tokenizer.train_from_text(text)
    elif config.tokenizer_type == "bpe":
        tokenizer.train_from_text(text, vocab_size=config.vocab_size)

    config.vocab_size = tokenizer.vocab_size
    full_data = jnp.array(tokenizer.encode(text), dtype=jnp.int32)
    train_data = full_data[: int(0.9 * len(full_data))]

    rngs = nnx.Rngs(config.seed)
    model = Transformer(config, rngs=rngs)
    if getattr(config, "use_bf16", False):
        _cast_params(model, jnp.bfloat16)

    log_multiplier = math.log(max_lr / min_lr) / max(1, num_steps - 1)

    def _lr_fn(step: jnp.ndarray) -> jnp.ndarray:
        return jnp.array(min_lr, jnp.float32) * jnp.exp(
            step.astype(jnp.float32) * jnp.array(log_multiplier, jnp.float32)
        )

    tx = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=_lr_fn))
    optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

    @nnx.jit
    def _step(model, opt, x, y):
        def loss_fn(m):
            logits, _, _ = m(x, use_cache=False, kv_caches=None, cache_index=0)
            return compute_loss(logits, y)

        loss, grads = nnx.value_and_grad(loss_fn)(model)
        opt.update(model, grads)
        return loss

    key = jax.random.PRNGKey(config.seed)
    lr_history: list[float] = []
    loss_history: list[float] = []
    smooth_loss = 0.0
    best_loss = float("inf")

    pbar = tqdm(range(num_steps), desc="LR finder", unit="step", dynamic_ncols=True)
    for step in pbar:
        key, sub = jax.random.split(key)
        x, y = get_batch(train_data, config.batch_size, config.max_context, sub)
        loss_val = float(_step(model, optimizer, x, y))

        smooth_loss = (
            loss_val if step == 0
            else smoothing * smooth_loss + (1 - smoothing) * loss_val
        )
        debiased = smooth_loss / (1 - smoothing ** (step + 1))

        current_lr = float(_lr_fn(jnp.array(step)))
        lr_history.append(current_lr)
        loss_history.append(debiased)

        pbar.set_postfix(lr=f"{current_lr:.2e}", loss=f"{debiased:.4f}")

        if debiased < best_loss:
            best_loss = debiased
        if debiased > 4 * best_loss:
            log.info("Loss diverging at step %d — stopping sweep early", step)
            break

    pbar.close()

    if len(loss_history) > 2:
        slopes = [loss_history[i + 1] - loss_history[i] for i in range(len(loss_history) - 1)]
        suggested_lr = lr_history[min(range(len(slopes)), key=lambda i: slopes[i])]
    else:
        suggested_lr = min_lr

    log.info(
        "LR finder done — suggested lr=%.2e (sweep range [%.2e, %.2e])",
        suggested_lr, min_lr, max_lr,
    )
    return suggested_lr, lr_history, loss_history

Generator

dantinox.generator.Generator

Loads a trained DantinoX checkpoint and generates text.

Accepts either a local run directory or a HuggingFace Hub repo ID — the checkpoint is downloaded automatically when needed.

Parameters

run_dir : str Local path produced by Trainer.fit() or a Hub repo ID such as "my-org/dantinox-dante". seed : int RNG seed used for sampling (default 42). token : str, optional HuggingFace access token for private repositories. revision : str, optional Branch, tag, or commit SHA to download from the Hub.

Raises

CheckpointError If the checkpoint cannot be found locally or downloaded from the Hub.

Examples

gen = Generator("runs/run_20260101_120000") # local gen = Generator("my-org/dantinox-dante") # HF Hub gen = Generator("my-org/private-model", token="hf_…") # private Hub text = gen.generate("Nel mezzo del cammin ") print(text)

Source code in dantinox/generator.py
class Generator:
    """
    Loads a trained DantinoX checkpoint and generates text.

    Accepts either a **local run directory** or a **HuggingFace Hub repo ID**
    — the checkpoint is downloaded automatically when needed.

    Parameters
    ----------
    run_dir : str
        Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such
        as ``"my-org/dantinox-dante"``.
    seed : int
        RNG seed used for sampling (default 42).
    token : str, optional
        HuggingFace access token for private repositories.
    revision : str, optional
        Branch, tag, or commit SHA to download from the Hub.

    Raises
    ------
    CheckpointError
        If the checkpoint cannot be found locally or downloaded from the Hub.

    Examples
    --------
    >>> gen = Generator("runs/run_20260101_120000")          # local
    >>> gen = Generator("my-org/dantinox-dante")             # HF Hub
    >>> gen = Generator("my-org/private-model", token="hf_…")  # private Hub
    >>> text = gen.generate("Nel mezzo del cammin ")
    >>> print(text)
    """

    def __init__(
        self,
        run_dir: str,
        *,
        seed: int = 42,
        token: str | None = None,
        revision: str | None = None,
    ) -> None:
        from dantinox.hub import resolve_checkpoint

        self.seed = seed
        # Resolve once: download from Hub if needed, then use the local path
        local_dir = resolve_checkpoint(run_dir, token=token, revision=revision)
        self.run_dir = local_dir
        self.config, self.model, self.tokenizer = _load_checkpoint(local_dir, seed)

    def __repr__(self) -> str:
        attn = "MLA" if self.config.mla else ("GQA" if self.config.kv_heads < self.config.n_heads else "MHA")
        return f"Generator(run_dir={self.run_dir!r}, attn={attn}, seed={self.seed})"

    def _bpe_fix(self, text: str) -> str:
        if self.config.tokenizer_type == "bpe":
            for old, new in _BPE_REPLACEMENTS:
                text = text.replace(old, new)
        return text

    # ── Single-prompt generation ──────────────────────────────────────────────

    def generate(
        self,
        prompt: str,
        *,
        max_new_tokens: int = 150,
        greedy: bool = False,
        top_k: int | None = None,
        top_p: float | None = None,
        temperature: float = 1.0,
        use_cache: bool = True,
    ) -> str:
        """
        Generate text continuing from ``prompt``.

        Parameters
        ----------
        prompt : str
            The input prefix.
        max_new_tokens : int
            Number of tokens to generate (default 150).
        greedy : bool
            Use greedy decoding instead of sampling (default False).
        top_k : int, optional
            Keep only the top-k logits before sampling.
        top_p : float, optional
            Nucleus sampling threshold.
        temperature : float
            Softmax temperature (default 1.0).
        use_cache : bool
            Enable KV-cache for faster generation (default True).

        Returns
        -------
        str
            The full generated string (prompt + continuation).
        """
        tokens = self.tokenizer.encode(prompt)
        x = jnp.array([tokens], dtype=jnp.int32)

        log.debug(
            "Generating %d tokens from prompt of %d tokens (greedy=%s, cache=%s)",
            max_new_tokens, len(tokens), greedy, use_cache,
        )

        output = _generate(
            model=self.model,
            x=x,
            max_generations=max_new_tokens,
            greedy=greedy,
            seed=self.seed,
            use_cache=use_cache,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
        )
        output.block_until_ready()

        return self._bpe_fix(self.tokenizer.decode(output[0].tolist()))

    # ── Batched generation ────────────────────────────────────────────────────

    def generate_batch(
        self,
        prompts: list[str],
        *,
        max_new_tokens: int = 150,
        greedy: bool = False,
        top_k: int | None = None,
        top_p: float | None = None,
        temperature: float = 1.0,
        use_cache: bool = True,
    ) -> list[str]:
        """
        Generate text for multiple prompts in a single batched forward pass.

        Shorter prompts are left-padded with zeros so all share the same
        sequence length.  This runs a true batch through the model, so
        throughput scales with GPU parallelism.

        Parameters
        ----------
        prompts : list[str]
            Input prefixes to generate from.
        max_new_tokens : int
            Tokens to generate per prompt (default 150).
        greedy : bool
            Greedy decoding (default False).
        top_k : int, optional
            Top-k filtering before sampling.
        top_p : float, optional
            Nucleus sampling threshold.
        temperature : float
            Softmax temperature (default 1.0).
        use_cache : bool
            Enable KV-cache (default True).

        Returns
        -------
        list[str]
            Generated strings (prompt + continuation) in the same order as
            ``prompts``.
        """
        if not prompts:
            return []

        encoded = [self.tokenizer.encode(p) for p in prompts]
        max_len = max(len(e) for e in encoded)

        # Left-pad shorter prompts with zeros so all share the same start position.
        padded = [([0] * (max_len - len(e))) + e for e in encoded]
        x = jnp.array(padded, dtype=jnp.int32)  # [B, max_len]

        log.debug("Batch generating: B=%d max_prompt_len=%d max_new=%d", len(prompts), max_len, max_new_tokens)

        output = _generate(
            model=self.model,
            x=x,
            max_generations=max_new_tokens,
            greedy=greedy,
            seed=self.seed,
            use_cache=use_cache,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
        )
        output.block_until_ready()

        results = []
        for i, enc in enumerate(encoded):
            # Strip the left-padding: prompt starts at (max_len - len(enc))
            start = max_len - len(enc)
            tokens_out = output[i, start:].tolist()
            results.append(self._bpe_fix(self.tokenizer.decode(tokens_out)))
        return results

    # ── Streaming generation ──────────────────────────────────────────────────

    def stream(
        self,
        prompt: str,
        *,
        max_new_tokens: int = 150,
        greedy: bool = False,
        top_k: int | None = None,
        top_p: float | None = None,
        temperature: float = 1.0,
    ) -> Iterator[str]:
        """
        Stream generated tokens one at a time as they are produced.

        Uses the KV-cache path: the prompt is prefilled in one forward pass,
        then each new token is decoded individually.  Each ``yield`` returns
        the string for one generated token (may be a character or a BPE
        subword).

        Parameters
        ----------
        prompt : str
            The input prefix.
        max_new_tokens : int
            Maximum number of tokens to generate (default 150).
        greedy : bool
            Greedy decoding (default False).
        top_k : int, optional
            Top-k filtering.
        top_p : float, optional
            Nucleus sampling threshold.
        temperature : float
            Softmax temperature (default 1.0).

        Yields
        ------
        str
            Decoded string for each generated token.

        Examples
        --------
        >>> gen = Generator("runs/my_run")
        >>> for chunk in gen.stream("Nel mezzo", max_new_tokens=50):
        ...     print(chunk, end="", flush=True)
        """
        tokens = self.tokenizer.encode(prompt)
        T = len(tokens)
        max_ctx = self.config.max_context  # type: ignore[attr-defined]
        num_blocks = self.config.num_blocks  # type: ignore[attr-defined]

        # Build full-context input with prompt at the start.
        x = jnp.zeros((1, max_ctx), dtype=jnp.int32)
        x = x.at[0, :T].set(jnp.array(tokens, dtype=jnp.int32))

        init_kv_cache = tuple((None, None) for _ in range(num_blocks))
        key = jax.random.key(self.seed)

        # Prefill: one pass over the entire prompt, populate KV cache.
        logits, kv_cache = _stream_prefill(self.model, x, init_kv_cache)

        # Sample the first generated token from the last prompt position.
        tok_id, key = _sample_logit(logits[:, T - 1, :], key, greedy, temperature, top_k, top_p)
        yield self._bpe_fix(self.tokenizer.decode([tok_id]))

        # Autoregressive decode loop.
        for pos in range(T, T + max_new_tokens - 1):
            if pos >= max_ctx:
                break
            tok = jnp.array([[tok_id]], dtype=jnp.int32)
            logits, kv_cache = _stream_decode(self.model, tok, kv_cache, jnp.array(pos))
            tok_id, key = _sample_logit(logits[:, 0, :], key, greedy, temperature, top_k, top_p)
            yield self._bpe_fix(self.tokenizer.decode([tok_id]))

Functions

__init__

__init__(
    run_dir: str,
    *,
    seed: int = 42,
    token: str | None = None,
    revision: str | None = None,
) -> None
Source code in dantinox/generator.py
def __init__(
    self,
    run_dir: str,
    *,
    seed: int = 42,
    token: str | None = None,
    revision: str | None = None,
) -> None:
    from dantinox.hub import resolve_checkpoint

    self.seed = seed
    # Resolve once: download from Hub if needed, then use the local path
    local_dir = resolve_checkpoint(run_dir, token=token, revision=revision)
    self.run_dir = local_dir
    self.config, self.model, self.tokenizer = _load_checkpoint(local_dir, seed)

generate

generate(
    prompt: str,
    *,
    max_new_tokens: int = 150,
    greedy: bool = False,
    top_k: int | None = None,
    top_p: float | None = None,
    temperature: float = 1.0,
    use_cache: bool = True,
) -> str

Generate text continuing from prompt.

Parameters

prompt : str The input prefix. max_new_tokens : int Number of tokens to generate (default 150). greedy : bool Use greedy decoding instead of sampling (default False). top_k : int, optional Keep only the top-k logits before sampling. top_p : float, optional Nucleus sampling threshold. temperature : float Softmax temperature (default 1.0). use_cache : bool Enable KV-cache for faster generation (default True).

Returns

str The full generated string (prompt + continuation).

Source code in dantinox/generator.py
def generate(
    self,
    prompt: str,
    *,
    max_new_tokens: int = 150,
    greedy: bool = False,
    top_k: int | None = None,
    top_p: float | None = None,
    temperature: float = 1.0,
    use_cache: bool = True,
) -> str:
    """
    Generate text continuing from ``prompt``.

    Parameters
    ----------
    prompt : str
        The input prefix.
    max_new_tokens : int
        Number of tokens to generate (default 150).
    greedy : bool
        Use greedy decoding instead of sampling (default False).
    top_k : int, optional
        Keep only the top-k logits before sampling.
    top_p : float, optional
        Nucleus sampling threshold.
    temperature : float
        Softmax temperature (default 1.0).
    use_cache : bool
        Enable KV-cache for faster generation (default True).

    Returns
    -------
    str
        The full generated string (prompt + continuation).
    """
    tokens = self.tokenizer.encode(prompt)
    x = jnp.array([tokens], dtype=jnp.int32)

    log.debug(
        "Generating %d tokens from prompt of %d tokens (greedy=%s, cache=%s)",
        max_new_tokens, len(tokens), greedy, use_cache,
    )

    output = _generate(
        model=self.model,
        x=x,
        max_generations=max_new_tokens,
        greedy=greedy,
        seed=self.seed,
        use_cache=use_cache,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
    )
    output.block_until_ready()

    return self._bpe_fix(self.tokenizer.decode(output[0].tolist()))

generate_batch

generate_batch(
    prompts: list[str],
    *,
    max_new_tokens: int = 150,
    greedy: bool = False,
    top_k: int | None = None,
    top_p: float | None = None,
    temperature: float = 1.0,
    use_cache: bool = True,
) -> list[str]

Generate text for multiple prompts in a single batched forward pass.

Shorter prompts are left-padded with zeros so all share the same sequence length. This runs a true batch through the model, so throughput scales with GPU parallelism.

Parameters

prompts : list[str] Input prefixes to generate from. max_new_tokens : int Tokens to generate per prompt (default 150). greedy : bool Greedy decoding (default False). top_k : int, optional Top-k filtering before sampling. top_p : float, optional Nucleus sampling threshold. temperature : float Softmax temperature (default 1.0). use_cache : bool Enable KV-cache (default True).

Returns

list[str] Generated strings (prompt + continuation) in the same order as prompts.

Source code in dantinox/generator.py
def generate_batch(
    self,
    prompts: list[str],
    *,
    max_new_tokens: int = 150,
    greedy: bool = False,
    top_k: int | None = None,
    top_p: float | None = None,
    temperature: float = 1.0,
    use_cache: bool = True,
) -> list[str]:
    """
    Generate text for multiple prompts in a single batched forward pass.

    Shorter prompts are left-padded with zeros so all share the same
    sequence length.  This runs a true batch through the model, so
    throughput scales with GPU parallelism.

    Parameters
    ----------
    prompts : list[str]
        Input prefixes to generate from.
    max_new_tokens : int
        Tokens to generate per prompt (default 150).
    greedy : bool
        Greedy decoding (default False).
    top_k : int, optional
        Top-k filtering before sampling.
    top_p : float, optional
        Nucleus sampling threshold.
    temperature : float
        Softmax temperature (default 1.0).
    use_cache : bool
        Enable KV-cache (default True).

    Returns
    -------
    list[str]
        Generated strings (prompt + continuation) in the same order as
        ``prompts``.
    """
    if not prompts:
        return []

    encoded = [self.tokenizer.encode(p) for p in prompts]
    max_len = max(len(e) for e in encoded)

    # Left-pad shorter prompts with zeros so all share the same start position.
    padded = [([0] * (max_len - len(e))) + e for e in encoded]
    x = jnp.array(padded, dtype=jnp.int32)  # [B, max_len]

    log.debug("Batch generating: B=%d max_prompt_len=%d max_new=%d", len(prompts), max_len, max_new_tokens)

    output = _generate(
        model=self.model,
        x=x,
        max_generations=max_new_tokens,
        greedy=greedy,
        seed=self.seed,
        use_cache=use_cache,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
    )
    output.block_until_ready()

    results = []
    for i, enc in enumerate(encoded):
        # Strip the left-padding: prompt starts at (max_len - len(enc))
        start = max_len - len(enc)
        tokens_out = output[i, start:].tolist()
        results.append(self._bpe_fix(self.tokenizer.decode(tokens_out)))
    return results

stream

stream(
    prompt: str,
    *,
    max_new_tokens: int = 150,
    greedy: bool = False,
    top_k: int | None = None,
    top_p: float | None = None,
    temperature: float = 1.0,
) -> Iterator[str]

Stream generated tokens one at a time as they are produced.

Uses the KV-cache path: the prompt is prefilled in one forward pass, then each new token is decoded individually. Each yield returns the string for one generated token (may be a character or a BPE subword).

Parameters

prompt : str The input prefix. max_new_tokens : int Maximum number of tokens to generate (default 150). greedy : bool Greedy decoding (default False). top_k : int, optional Top-k filtering. top_p : float, optional Nucleus sampling threshold. temperature : float Softmax temperature (default 1.0).

Yields

str Decoded string for each generated token.

Examples

gen = Generator("runs/my_run") for chunk in gen.stream("Nel mezzo", max_new_tokens=50): ... print(chunk, end="", flush=True)

Source code in dantinox/generator.py
def stream(
    self,
    prompt: str,
    *,
    max_new_tokens: int = 150,
    greedy: bool = False,
    top_k: int | None = None,
    top_p: float | None = None,
    temperature: float = 1.0,
) -> Iterator[str]:
    """
    Stream generated tokens one at a time as they are produced.

    Uses the KV-cache path: the prompt is prefilled in one forward pass,
    then each new token is decoded individually.  Each ``yield`` returns
    the string for one generated token (may be a character or a BPE
    subword).

    Parameters
    ----------
    prompt : str
        The input prefix.
    max_new_tokens : int
        Maximum number of tokens to generate (default 150).
    greedy : bool
        Greedy decoding (default False).
    top_k : int, optional
        Top-k filtering.
    top_p : float, optional
        Nucleus sampling threshold.
    temperature : float
        Softmax temperature (default 1.0).

    Yields
    ------
    str
        Decoded string for each generated token.

    Examples
    --------
    >>> gen = Generator("runs/my_run")
    >>> for chunk in gen.stream("Nel mezzo", max_new_tokens=50):
    ...     print(chunk, end="", flush=True)
    """
    tokens = self.tokenizer.encode(prompt)
    T = len(tokens)
    max_ctx = self.config.max_context  # type: ignore[attr-defined]
    num_blocks = self.config.num_blocks  # type: ignore[attr-defined]

    # Build full-context input with prompt at the start.
    x = jnp.zeros((1, max_ctx), dtype=jnp.int32)
    x = x.at[0, :T].set(jnp.array(tokens, dtype=jnp.int32))

    init_kv_cache = tuple((None, None) for _ in range(num_blocks))
    key = jax.random.key(self.seed)

    # Prefill: one pass over the entire prompt, populate KV cache.
    logits, kv_cache = _stream_prefill(self.model, x, init_kv_cache)

    # Sample the first generated token from the last prompt position.
    tok_id, key = _sample_logit(logits[:, T - 1, :], key, greedy, temperature, top_k, top_p)
    yield self._bpe_fix(self.tokenizer.decode([tok_id]))

    # Autoregressive decode loop.
    for pos in range(T, T + max_new_tokens - 1):
        if pos >= max_ctx:
            break
        tok = jnp.array([[tok_id]], dtype=jnp.int32)
        logits, kv_cache = _stream_decode(self.model, tok, kv_cache, jnp.array(pos))
        tok_id, key = _sample_logit(logits[:, 0, :], key, greedy, temperature, top_k, top_p)
        yield self._bpe_fix(self.tokenizer.decode([tok_id]))

BenchmarkRunner

dantinox.bench.BenchmarkRunner

Benchmarks one or more DantinoX run directories.

Parameters

runs_dir : str Directory containing run sub-directories (default "runs"). seq_lens : list[int], optional Sequence lengths to test for throughput scaling. batch_sizes : list[int], optional Batch sizes to test for memory/throughput scaling.

Examples

runner = BenchmarkRunner("runs") df = runner.run() df.to_csv("results.csv", index=False)

Source code in dantinox/bench.py
class BenchmarkRunner:
    """
    Benchmarks one or more DantinoX run directories.

    Parameters
    ----------
    runs_dir : str
        Directory containing run sub-directories (default ``"runs"``).
    seq_lens : list[int], optional
        Sequence lengths to test for throughput scaling.
    batch_sizes : list[int], optional
        Batch sizes to test for memory/throughput scaling.

    Examples
    --------
    >>> runner = BenchmarkRunner("runs")
    >>> df = runner.run()
    >>> df.to_csv("results.csv", index=False)
    """

    def __init__(
        self,
        runs_dir: str = "runs",
        *,
        seq_lens: Sequence[int] | None = None,
        batch_sizes: Sequence[int] | None = None,
    ) -> None:
        self.runs_dir = runs_dir
        if seq_lens is not None:
            global SEQ_LENS
            SEQ_LENS = list(seq_lens)
        if batch_sizes is not None:
            global BATCH_SIZES
            BATCH_SIZES = list(batch_sizes)

    def __repr__(self) -> str:
        return f"BenchmarkRunner(runs_dir={self.runs_dir!r})"

    def run(
        self,
        run_names: Sequence[str] | None = None,
        *,
        out_csv: str | None = None,
    ) -> Any:
        """
        Run benchmarks and return a DataFrame.

        Parameters
        ----------
        run_names : list[str], optional
            Subset of run names to evaluate. Benchmarks all runs if omitted.
        out_csv : str, optional
            Write results to this CSV path.

        Returns
        -------
        pandas.DataFrame

        Raises
        ------
        BenchmarkError
            If the runs directory does not exist.
        """
        import pandas as pd

        if not os.path.isdir(self.runs_dir):
            raise BenchmarkError(f"Runs directory not found: {self.runs_dir}")

        if run_names is None:
            run_names = [
                d for d in os.listdir(self.runs_dir)
                if os.path.isdir(os.path.join(self.runs_dir, d))
            ]

        results = []
        for name in run_names:
            path = os.path.join(self.runs_dir, name)
            log.info("Benchmarking: %s", name)
            try:
                results.append(benchmark_run(path))
            except BenchmarkError as exc:
                log.error("  Skipped %s: %s", name, exc)
            except Exception as exc:
                log.error("  Unexpected error for %s: %s\n%s", name, exc, traceback.format_exc())

        df = pd.DataFrame(results)
        if out_csv:
            df.to_csv(out_csv, index=False)
            log.info("Saved benchmark results to %s", out_csv)
        return df

Functions

__init__

__init__(
    runs_dir: str = "runs",
    *,
    seq_lens: Sequence[int] | None = None,
    batch_sizes: Sequence[int] | None = None,
) -> None
Source code in dantinox/bench.py
def __init__(
    self,
    runs_dir: str = "runs",
    *,
    seq_lens: Sequence[int] | None = None,
    batch_sizes: Sequence[int] | None = None,
) -> None:
    self.runs_dir = runs_dir
    if seq_lens is not None:
        global SEQ_LENS
        SEQ_LENS = list(seq_lens)
    if batch_sizes is not None:
        global BATCH_SIZES
        BATCH_SIZES = list(batch_sizes)

run

run(
    run_names: Sequence[str] | None = None,
    *,
    out_csv: str | None = None,
) -> Any

Run benchmarks and return a DataFrame.

Parameters

run_names : list[str], optional Subset of run names to evaluate. Benchmarks all runs if omitted. out_csv : str, optional Write results to this CSV path.

Returns

pandas.DataFrame

Raises

BenchmarkError If the runs directory does not exist.

Source code in dantinox/bench.py
def run(
    self,
    run_names: Sequence[str] | None = None,
    *,
    out_csv: str | None = None,
) -> Any:
    """
    Run benchmarks and return a DataFrame.

    Parameters
    ----------
    run_names : list[str], optional
        Subset of run names to evaluate. Benchmarks all runs if omitted.
    out_csv : str, optional
        Write results to this CSV path.

    Returns
    -------
    pandas.DataFrame

    Raises
    ------
    BenchmarkError
        If the runs directory does not exist.
    """
    import pandas as pd

    if not os.path.isdir(self.runs_dir):
        raise BenchmarkError(f"Runs directory not found: {self.runs_dir}")

    if run_names is None:
        run_names = [
            d for d in os.listdir(self.runs_dir)
            if os.path.isdir(os.path.join(self.runs_dir, d))
        ]

    results = []
    for name in run_names:
        path = os.path.join(self.runs_dir, name)
        log.info("Benchmarking: %s", name)
        try:
            results.append(benchmark_run(path))
        except BenchmarkError as exc:
            log.error("  Skipped %s: %s", name, exc)
        except Exception as exc:
            log.error("  Unexpected error for %s: %s\n%s", name, exc, traceback.format_exc())

    df = pd.DataFrame(results)
    if out_csv:
        df.to_csv(out_csv, index=False)
        log.info("Saved benchmark results to %s", out_csv)
    return df

Plotter

dantinox.plotting.Plotter

Generates all DantinoX benchmark plots from a CSV produced by :meth:~dantinox.BenchmarkRunner.run.

Runs the four bundled plot modules (perf, insights, 3d, 3d_dkv) and writes 16 PNG files to out_dir.

Parameters

in_csv : str Path to benchmark_results.csv. out_dir : str Directory where PNGs are written (created if absent). batch_csv : str, optional Path to batch_sweep_results.csv for the batch-throughput plot. If omitted, that figure is replaced with a placeholder.

Raises

PlotError If the CSV is missing or a group name is invalid.

Examples

from dantinox import BenchmarkRunner, Plotter BenchmarkRunner("runs").run(out_csv="benchmark_results.csv") Plotter("benchmark_results.csv").run()

Source code in dantinox/plotting.py
class Plotter:
    """
    Generates all DantinoX benchmark plots from a CSV produced by
    :meth:`~dantinox.BenchmarkRunner.run`.

    Runs the four bundled plot modules (``perf``, ``insights``, ``3d``,
    ``3d_dkv``) and writes 16 PNG files to *out_dir*.

    Parameters
    ----------
    in_csv : str
        Path to ``benchmark_results.csv``.
    out_dir : str
        Directory where PNGs are written (created if absent).
    batch_csv : str, optional
        Path to ``batch_sweep_results.csv`` for the batch-throughput plot.
        If omitted, that figure is replaced with a placeholder.

    Raises
    ------
    PlotError
        If the CSV is missing or a group name is invalid.

    Examples
    --------
    >>> from dantinox import BenchmarkRunner, Plotter
    >>> BenchmarkRunner("runs").run(out_csv="benchmark_results.csv")
    >>> Plotter("benchmark_results.csv").run()
    """

    def __init__(
        self,
        in_csv: str = "benchmark_results.csv",
        out_dir: str = "plots",
        *,
        batch_csv: str | None = None,
    ) -> None:
        self.in_csv    = in_csv
        self.out_dir   = out_dir
        self.batch_csv = batch_csv

    def __repr__(self) -> str:
        return f"Plotter(in_csv={self.in_csv!r}, out_dir={self.out_dir!r})"

    def run(self, groups: list[str] | None = None) -> dict[str, list[str]]:
        """
        Generate plots and save them as PNGs.

        Parameters
        ----------
        groups : list[str], optional
            Subset of ``["perf", "insights", "3d", "3d_dkv"]``.
            Generates all four if omitted.

        Returns
        -------
        dict[str, list[str]]
            Mapping of group name → list of figure function names that ran.

        Raises
        ------
        PlotError
            If the benchmark CSV is not found or a group name is invalid.
        """
        if not os.path.exists(self.in_csv):
            raise PlotError(
                f"Benchmark CSV not found: {self.in_csv}\n"
                "Run BenchmarkRunner.run(out_csv='benchmark_results.csv') first."
            )

        os.makedirs(self.out_dir, exist_ok=True)
        selected = list(groups) if groups else ALL_GROUPS
        unknown  = [g for g in selected if g not in _PLOT_GROUPS]
        if unknown:
            raise PlotError(
                f"Unknown plot group(s): {unknown}. Valid groups: {ALL_GROUPS}"
            )

        results: dict[str, list[str]] = {}
        for group in selected:
            log.info("[%s] generating plots…", group)
            try:
                done = _run_group(group, self.in_csv, self.out_dir, self.batch_csv)
                results[group] = done
                log.info("  %d figures written to %s/", len(done), self.out_dir)
            except PlotError:
                raise
            except Exception as exc:
                log.error("  group '%s' failed: %s", group, exc)
                results[group] = []

        return results

Functions

__init__

__init__(
    in_csv: str = "benchmark_results.csv",
    out_dir: str = "plots",
    *,
    batch_csv: str | None = None,
) -> None
Source code in dantinox/plotting.py
def __init__(
    self,
    in_csv: str = "benchmark_results.csv",
    out_dir: str = "plots",
    *,
    batch_csv: str | None = None,
) -> None:
    self.in_csv    = in_csv
    self.out_dir   = out_dir
    self.batch_csv = batch_csv

run

run(
    groups: list[str] | None = None,
) -> dict[str, list[str]]

Generate plots and save them as PNGs.

Parameters

groups : list[str], optional Subset of ["perf", "insights", "3d", "3d_dkv"]. Generates all four if omitted.

Returns

dict[str, list[str]] Mapping of group name → list of figure function names that ran.

Raises

PlotError If the benchmark CSV is not found or a group name is invalid.

Source code in dantinox/plotting.py
def run(self, groups: list[str] | None = None) -> dict[str, list[str]]:
    """
    Generate plots and save them as PNGs.

    Parameters
    ----------
    groups : list[str], optional
        Subset of ``["perf", "insights", "3d", "3d_dkv"]``.
        Generates all four if omitted.

    Returns
    -------
    dict[str, list[str]]
        Mapping of group name → list of figure function names that ran.

    Raises
    ------
    PlotError
        If the benchmark CSV is not found or a group name is invalid.
    """
    if not os.path.exists(self.in_csv):
        raise PlotError(
            f"Benchmark CSV not found: {self.in_csv}\n"
            "Run BenchmarkRunner.run(out_csv='benchmark_results.csv') first."
        )

    os.makedirs(self.out_dir, exist_ok=True)
    selected = list(groups) if groups else ALL_GROUPS
    unknown  = [g for g in selected if g not in _PLOT_GROUPS]
    if unknown:
        raise PlotError(
            f"Unknown plot group(s): {unknown}. Valid groups: {ALL_GROUPS}"
        )

    results: dict[str, list[str]] = {}
    for group in selected:
        log.info("[%s] generating plots…", group)
        try:
            done = _run_group(group, self.in_csv, self.out_dir, self.batch_csv)
            results[group] = done
            log.info("  %d figures written to %s/", len(done), self.out_dir)
        except PlotError:
            raise
        except Exception as exc:
            log.error("  group '%s' failed: %s", group, exc)
            results[group] = []

    return results

Hub

Push, pull, and directly load checkpoints from HuggingFace Hub.

Optional dependency

Install with pip install "dantinox[hub]" or pip install huggingface-hub.

Direct loading — no pull step needed

from dantinox import Generator
from core import Transformer

gen   = Generator("my-org/dantinox-dante")                    # downloads + loads
model = Transformer.from_pretrained("my-org/dantinox-dante")  # same, no tokenizer

dantinox.hub.resolve_checkpoint

resolve_checkpoint(
    path_or_repo: str,
    *,
    token: str | None = None,
    revision: str | None = None,
) -> str

Return a local directory path for path_or_repo.

If path_or_repo is an existing local directory it is returned unchanged. Otherwise it is treated as a HuggingFace Hub repo ID (e.g. "my-org/dantinox-dante") and the checkpoint is downloaded via :func:pull before returning the local cache path.

Parameters

path_or_repo: Local run directory or HuggingFace Hub repo ID. token: HuggingFace access token for private repositories. revision: Branch, tag, or commit SHA to download.

Returns

str Absolute path to a local directory suitable for passing to Generator(), Transformer.from_pretrained(), etc.

Source code in dantinox/hub.py
def resolve_checkpoint(
    path_or_repo: str,
    *,
    token: str | None = None,
    revision: str | None = None,
) -> str:
    """Return a local directory path for *path_or_repo*.

    If *path_or_repo* is an existing local directory it is returned unchanged.
    Otherwise it is treated as a HuggingFace Hub repo ID (e.g.
    ``"my-org/dantinox-dante"``) and the checkpoint is downloaded via
    :func:`pull` before returning the local cache path.

    Parameters
    ----------
    path_or_repo:
        Local run directory **or** HuggingFace Hub repo ID.
    token:
        HuggingFace access token for private repositories.
    revision:
        Branch, tag, or commit SHA to download.

    Returns
    -------
    str
        Absolute path to a local directory suitable for passing to
        ``Generator()``, ``Transformer.from_pretrained()``, etc.
    """
    if os.path.isdir(path_or_repo):
        return path_or_repo
    return pull(path_or_repo, token=token, revision=revision)

dantinox.hub.push

push(
    run_dir: str,
    repo_id: str,
    *,
    private: bool = False,
    token: str | None = None,
    commit_message: str | None = None,
) -> str

Upload a run directory to a HuggingFace Hub model repository.

Creates the repository if it does not exist. Only the core checkpoint files are uploaded (config.yaml, tokenizer.json, model_weights.msgpack, best_model_weights.msgpack, model_summary.json). Log files are excluded.

Parameters

run_dir : str Local path to a DantinoX run directory. repo_id : str Hub repository in the form "owner/repo-name". private : bool Create the repository as private (default False). token : str, optional HuggingFace access token. Falls back to the HF_TOKEN environment variable or the cached login token. commit_message : str, optional Commit message for the upload (auto-generated if omitted).

Returns

str URL of the Hub repository after the upload.

Raises

ImportError If huggingface_hub is not installed.

Source code in dantinox/hub.py
def push(
    run_dir: str,
    repo_id: str,
    *,
    private: bool = False,
    token: str | None = None,
    commit_message: str | None = None,
) -> str:
    """
    Upload a run directory to a HuggingFace Hub model repository.

    Creates the repository if it does not exist.  Only the core checkpoint
    files are uploaded (``config.yaml``, ``tokenizer.json``,
    ``model_weights.msgpack``, ``best_model_weights.msgpack``,
    ``model_summary.json``).  Log files are excluded.

    Parameters
    ----------
    run_dir : str
        Local path to a DantinoX run directory.
    repo_id : str
        Hub repository in the form ``"owner/repo-name"``.
    private : bool
        Create the repository as private (default False).
    token : str, optional
        HuggingFace access token.  Falls back to the ``HF_TOKEN``
        environment variable or the cached login token.
    commit_message : str, optional
        Commit message for the upload (auto-generated if omitted).

    Returns
    -------
    str
        URL of the Hub repository after the upload.

    Raises
    ------
    ImportError
        If ``huggingface_hub`` is not installed.
    """
    try:
        from huggingface_hub import HfApi
    except ImportError as exc:
        raise ImportError(
            "huggingface_hub is required for Hub integration: "
            "pip install huggingface-hub"
        ) from exc

    import os
    msg = commit_message or f"Upload DantinoX checkpoint from {os.path.basename(run_dir)}"

    api = HfApi(token=token)
    api.create_repo(repo_id, repo_type="model", private=private, exist_ok=True)

    url = api.upload_folder(
        folder_path=run_dir,
        repo_id=repo_id,
        repo_type="model",
        commit_message=msg,
        ignore_patterns=_UPLOAD_IGNORE,
    )

    log.info("Pushed %s%s", run_dir, url)
    return str(url)

dantinox.hub.pull

pull(
    repo_id: str,
    *,
    local_dir: str | None = None,
    token: str | None = None,
    revision: str | None = None,
) -> str

Download a DantinoX checkpoint from HuggingFace Hub.

Parameters

repo_id : str Hub repository in the form "owner/repo-name". local_dir : str, optional Where to store the downloaded files. Defaults to the HuggingFace cache directory (~/.cache/huggingface/hub/...). token : str, optional HuggingFace access token for private repositories. revision : str, optional Git revision (branch, tag, or commit SHA) to download.

Returns

str Path to the local directory containing the checkpoint. Pass this directly to Generator(run_dir).

Raises

ImportError If huggingface_hub is not installed.

Source code in dantinox/hub.py
def pull(
    repo_id: str,
    *,
    local_dir: str | None = None,
    token: str | None = None,
    revision: str | None = None,
) -> str:
    """
    Download a DantinoX checkpoint from HuggingFace Hub.

    Parameters
    ----------
    repo_id : str
        Hub repository in the form ``"owner/repo-name"``.
    local_dir : str, optional
        Where to store the downloaded files.  Defaults to the HuggingFace
        cache directory (``~/.cache/huggingface/hub/...``).
    token : str, optional
        HuggingFace access token for private repositories.
    revision : str, optional
        Git revision (branch, tag, or commit SHA) to download.

    Returns
    -------
    str
        Path to the local directory containing the checkpoint.  Pass this
        directly to ``Generator(run_dir)``.

    Raises
    ------
    ImportError
        If ``huggingface_hub`` is not installed.
    """
    try:
        from huggingface_hub import snapshot_download
    except ImportError as exc:
        raise ImportError(
            "huggingface_hub is required for Hub integration: "
            "pip install huggingface-hub"
        ) from exc

    run_dir: str = snapshot_download(
        repo_id=repo_id,
        repo_type="model",
        local_dir=local_dir,
        token=token,
        revision=revision,
    )

    log.info("Pulled %s%s", repo_id, run_dir)
    return run_dir

Core Modules

Internal implementation. Import directly when you need low-level access.

Model Architecture

Core Transformer components — Transformer, Block, Attention (MHA/GQA/MLA), MoE, and MLP.

core.model

Classes

Transformer

Bases: Module

Source code in core/model.py
class Transformer(nnx.Module, pytree=False):
    def __init__(self, config: Config, rngs: nnx.Rngs):
        self.num_blocks: int     = config.num_blocks
        self.blocks: list        = [Block(config, rngs=rngs) for _ in range(self.num_blocks)]
        self.wte: nnx.Embed      = nnx.Embed(config.vocab_size, config.dim, rngs=rngs)
        self.weight_tying: bool  = config.weight_tying
        self.trainable_pos: bool = config.trainable_pos
        self.absolute_pos: bool  = config.absolute_pos
        self.max_context: int    = config.max_context
        self.gradient_checkpointing: bool = config.gradient_checkpointing
        self.ln_f: nnx.Module    = _build_norm(config, config.dim, rngs)
        self.emb_dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
        self.use_moe: bool        = config.use_moe
        self.alpha_balance: float = config.alpha_balance

        if config.weight_tying:
            # Share the embedding Param so lm_head.kernel stays a tracked nnx.Variable.
            # Assigning embedding.T (a raw array) would silently drop it from NNX's
            # state graph and cause DynamicJaxprTracer errors inside @nnx.jit.
            self.lm_head: nnx.Linear | None = None
        else:
            self.lm_head = nnx.Linear(config.dim, config.vocab_size, rngs=rngs)
        if self.trainable_pos:
            self.wpe: nnx.Embed  = nnx.Embed(config.max_context, config.dim, rngs=rngs)
        elif self.absolute_pos:
            def _build_compute_absolute_pos(T: int, C: int) -> jnp.ndarray:
                pos = jnp.zeros((T, C))
                row = jnp.arange(T)
                col = jnp.arange(0, C, 2)
                k = 1.0 / (10000 ** (col / C))
                ratio = jnp.einsum('i,j->ij', row, k)
                pos = pos.at[:, 0::2].set(jnp.sin(ratio))
                pos = pos.at[:, 1::2].set(jnp.cos(ratio))
                return jnp.expand_dims(pos, axis=0)

            self.wpe: jnp.ndarray = _build_compute_absolute_pos(config.max_context, config.dim)  # type: ignore[assignment, no-redef]

    def __call__(self,
                 x: jnp.ndarray,
                 use_cache: bool,
                 kv_caches: tuple | None,
                 cache_index: int | None,
                 deterministic: bool = False) -> ModelOutput:

        B, T = x.shape
        x = self.wte(x)
        if kv_caches is None:
            kv_caches = tuple((None, None) for _ in range(self.num_blocks))
        if self.absolute_pos:
            wpe_slice = jax.lax.dynamic_slice_in_dim(
                self.wpe,  # type: ignore[arg-type]
                start_index=cache_index,  # type: ignore[arg-type]
                slice_size=T,
                axis=1
            )
            x = x + wpe_slice
        elif self.trainable_pos:
            x = x + self.wpe(jnp.arange(T, dtype=x.dtype))

        x = self.emb_dropout(x, deterministic=deterministic)

        def block_fn(block_module: object, hidden_state: jnp.ndarray, kv_c: object, det: bool) -> tuple:
            return block_module(  # type: ignore[call-arg, operator]
                hidden_state,
                use_cache=use_cache,
                kv_cache=kv_c,
                cache_index=cache_index,
                deterministic=det
            )

        def _apply_block(bm: object, hs: jnp.ndarray, kvc: object) -> tuple:
            return block_fn(bm, hs, kvc, deterministic)

        if self.gradient_checkpointing and not use_cache:
            checkpointed_block = nnx.remat(_apply_block)
        else:
            checkpointed_block = _apply_block

        new_kv_caches = []
        balancing_loss_total = 0.0
        for i, block in enumerate(self.blocks):
            x, new_kv, balancing_loss = checkpointed_block(block, x, kv_caches[i] if kv_caches else None)
            new_kv_caches.append(new_kv)
            balancing_loss_total += balancing_loss

        x = self.ln_f(x)

        logits = x @ self.wte.embedding[...].T if self.weight_tying else self.lm_head(x)  # type: ignore[union-attr, misc]


        return ModelOutput(
            logits=logits,
            kv_caches=tuple(new_kv_caches),
            aux_loss=balancing_loss_total,
        )

    @classmethod
    def from_pretrained(
        cls,
        path_or_repo: str,
        rngs: nnx.Rngs | None = None,
        *,
        best: bool = True,
        token: str | None = None,
        revision: str | None = None,
    ) -> Transformer:
        """Load a trained Transformer from a local directory or HuggingFace Hub.

        Parameters
        ----------
        path_or_repo:
            Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such
            as ``"my-org/dantinox-dante"``.  The checkpoint is downloaded
            automatically when a Hub ID is given.
        rngs:
            PRNG state for initialisation. Defaults to ``nnx.Rngs(0)``.
        best:
            When ``True`` (default), loads ``best_model_weights.msgpack``
            if it exists, otherwise falls back to ``model_weights.msgpack``.
        token:
            HuggingFace access token for private repositories.
        revision:
            Branch, tag, or commit SHA to download from the Hub.
        """
        import contextlib
        import os

        import msgpack

        # Lazy import to avoid circular dependency (core ← dantinox)
        from dantinox.hub import resolve_checkpoint  # type: ignore[import]

        run_dir = resolve_checkpoint(path_or_repo, token=token, revision=revision)

        if rngs is None:
            rngs = nnx.Rngs(0)

        config = Config.from_yaml(os.path.join(run_dir, "config.yaml"))
        model = cls(config, rngs=rngs)

        weights_path = os.path.join(run_dir, "best_model_weights.msgpack")
        if not best or not os.path.exists(weights_path):
            weights_path = os.path.join(run_dir, "model_weights.msgpack")

        with open(weights_path, "rb") as f:
            raw = f.read()

        # Use the same private hook that trainer.py uses for consistency.
        _ext_hook: object = None
        with contextlib.suppress(ImportError):
            from flax.serialization import _msgpack_ext_unpack  # type: ignore[attr-defined]
            _ext_hook = _msgpack_ext_unpack

        state_dict = msgpack.unpackb(raw, ext_hook=_ext_hook, strict_map_key=False)
        nnx.update(model, state_dict)
        return model
Functions
from_pretrained classmethod
from_pretrained(
    path_or_repo: str,
    rngs: Rngs | None = None,
    *,
    best: bool = True,
    token: str | None = None,
    revision: str | None = None,
) -> Transformer

Load a trained Transformer from a local directory or HuggingFace Hub.

Parameters

path_or_repo: Local path produced by Trainer.fit() or a Hub repo ID such as "my-org/dantinox-dante". The checkpoint is downloaded automatically when a Hub ID is given. rngs: PRNG state for initialisation. Defaults to nnx.Rngs(0). best: When True (default), loads best_model_weights.msgpack if it exists, otherwise falls back to model_weights.msgpack. token: HuggingFace access token for private repositories. revision: Branch, tag, or commit SHA to download from the Hub.

Source code in core/model.py
@classmethod
def from_pretrained(
    cls,
    path_or_repo: str,
    rngs: nnx.Rngs | None = None,
    *,
    best: bool = True,
    token: str | None = None,
    revision: str | None = None,
) -> Transformer:
    """Load a trained Transformer from a local directory or HuggingFace Hub.

    Parameters
    ----------
    path_or_repo:
        Local path produced by ``Trainer.fit()`` **or** a Hub repo ID such
        as ``"my-org/dantinox-dante"``.  The checkpoint is downloaded
        automatically when a Hub ID is given.
    rngs:
        PRNG state for initialisation. Defaults to ``nnx.Rngs(0)``.
    best:
        When ``True`` (default), loads ``best_model_weights.msgpack``
        if it exists, otherwise falls back to ``model_weights.msgpack``.
    token:
        HuggingFace access token for private repositories.
    revision:
        Branch, tag, or commit SHA to download from the Hub.
    """
    import contextlib
    import os

    import msgpack

    # Lazy import to avoid circular dependency (core ← dantinox)
    from dantinox.hub import resolve_checkpoint  # type: ignore[import]

    run_dir = resolve_checkpoint(path_or_repo, token=token, revision=revision)

    if rngs is None:
        rngs = nnx.Rngs(0)

    config = Config.from_yaml(os.path.join(run_dir, "config.yaml"))
    model = cls(config, rngs=rngs)

    weights_path = os.path.join(run_dir, "best_model_weights.msgpack")
    if not best or not os.path.exists(weights_path):
        weights_path = os.path.join(run_dir, "model_weights.msgpack")

    with open(weights_path, "rb") as f:
        raw = f.read()

    # Use the same private hook that trainer.py uses for consistency.
    _ext_hook: object = None
    with contextlib.suppress(ImportError):
        from flax.serialization import _msgpack_ext_unpack  # type: ignore[attr-defined]
        _ext_hook = _msgpack_ext_unpack

    state_dict = msgpack.unpackb(raw, ext_hook=_ext_hook, strict_map_key=False)
    nnx.update(model, state_dict)
    return model

Normalisation

RMSNorm is the alternative to nnx.LayerNorm selected when norm_type = "rmsnorm".

core.block.RMSNorm

Bases: Module

Root Mean Square Layer Normalisation (Zhang & Sennrich, 2019).

Faster than LayerNorm — no mean subtraction, no bias — with identical empirical performance on modern LLMs (LLaMA, Mistral, Gemma, …).

Source code in core/block.py
class RMSNorm(nnx.Module):
    """
    Root Mean Square Layer Normalisation (Zhang & Sennrich, 2019).

    Faster than LayerNorm — no mean subtraction, no bias — with identical
    empirical performance on modern LLMs (LLaMA, Mistral, Gemma, …).
    """

    def __init__(self, dim: int, *, eps: float = 1e-6, rngs: nnx.Rngs) -> None:
        self.scale = nnx.Param(jnp.ones(dim))
        self.eps = eps

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps)
        return (x / rms) * self.scale[...]

Model Output

Transformer.__call__ returns a ModelOutput NamedTuple — supports both attribute access and positional unpacking.

core.output.ModelOutput

Bases: NamedTuple

Named return type for Transformer.__call__.

Supports both attribute access and positional unpacking so existing code that destructures the tuple continues to work unchanged::

# Named (preferred)
out = model(x, ...)
loss = cross_entropy(out.logits, targets) + cfg.alpha * out.aux_loss

# Positional (backward-compatible)
logits, kv_caches, aux_loss = model(x, ...)
Source code in core/output.py
class ModelOutput(NamedTuple):
    """
    Named return type for ``Transformer.__call__``.

    Supports both attribute access and positional unpacking so existing
    code that destructures the tuple continues to work unchanged::

        # Named (preferred)
        out = model(x, ...)
        loss = cross_entropy(out.logits, targets) + cfg.alpha * out.aux_loss

        # Positional (backward-compatible)
        logits, kv_caches, aux_loss = model(x, ...)
    """

    logits: jnp.ndarray
    """Token logits with shape ``[batch, seq_len, vocab_size]``."""

    kv_caches: tuple
    """Per-layer KV (or compressed-latent) caches; ``None`` entries when ``use_cache=False``."""

    aux_loss: float
    """MoE load-balancing auxiliary loss (``0.0`` for dense models)."""

Attributes

aux_loss instance-attribute

aux_loss: float

MoE load-balancing auxiliary loss (0.0 for dense models).

kv_caches instance-attribute

kv_caches: tuple

Per-layer KV (or compressed-latent) caches; None entries when use_cache=False.

logits instance-attribute

logits: ndarray

Token logits with shape [batch, seq_len, vocab_size].


LoRA Adapters

LoRAParam is a distinct NNX variable type that freezes base weights at the type level. LoRALinear is a drop-in replacement for nnx.Linear with a trainable low-rank delta.

core.lora.LoRAParam

Bases: Variable

Trainable LoRA variable — distinct type so base nnx.Param weights stay frozen.

Source code in core/lora.py
class LoRAParam(nnx.Variable):
    """Trainable LoRA variable — distinct type so base nnx.Param weights stay frozen."""
    pass

core.lora.LoRALinear

Bases: Module

Drop-in replacement for nnx.Linear with frozen base weight and trainable low-rank delta.

The effective weight is W_eff = W_base + (alpha / rank) * A @ B, where A is initialised with scaled Gaussian noise and B with zeros, so the adapter contributes nothing at initialisation.

Source code in core/lora.py
class LoRALinear(nnx.Module):
    """Drop-in replacement for nnx.Linear with frozen base weight and trainable low-rank delta.

    The effective weight is  W_eff = W_base + (alpha / rank) * A @ B,
    where A is initialised with scaled Gaussian noise and B with zeros,
    so the adapter contributes nothing at initialisation.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        rank: int = 8,
        alpha: float = 16.0,
        dropout_rate: float = 0.0,
        use_bias: bool = False,
        rngs: nnx.Rngs,
    ) -> None:
        self.base = nnx.Linear(in_features, out_features, use_bias=use_bias, rngs=rngs)
        self.scale = alpha / rank

        key = rngs.params()
        k_a, k_b = jax.random.split(key)
        self.lora_A = LoRAParam(
            jax.random.normal(k_a, (in_features, rank)) / math.sqrt(in_features)
        )
        self.lora_B = LoRAParam(jnp.zeros((rank, out_features)))

        self.dropout: nnx.Dropout | None = (
            nnx.Dropout(dropout_rate, rngs=rngs) if dropout_rate > 0.0 else None
        )

    def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray:
        out = self.base(x)
        delta = x @ self.lora_A[...]
        if self.dropout is not None:
            delta = self.dropout(delta, deterministic=deterministic)
        return out + (delta @ self.lora_B[...]) * self.scale

    def merge_weights(self) -> jnp.ndarray:
        """Return fused kernel W + (alpha/r) * A @ B for export or deployment."""
        return self.base.kernel[...] + self.scale * (self.lora_A[...] @ self.lora_B[...])

Functions

__init__

__init__(
    in_features: int,
    out_features: int,
    *,
    rank: int = 8,
    alpha: float = 16.0,
    dropout_rate: float = 0.0,
    use_bias: bool = False,
    rngs: Rngs,
) -> None
Source code in core/lora.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    rank: int = 8,
    alpha: float = 16.0,
    dropout_rate: float = 0.0,
    use_bias: bool = False,
    rngs: nnx.Rngs,
) -> None:
    self.base = nnx.Linear(in_features, out_features, use_bias=use_bias, rngs=rngs)
    self.scale = alpha / rank

    key = rngs.params()
    k_a, k_b = jax.random.split(key)
    self.lora_A = LoRAParam(
        jax.random.normal(k_a, (in_features, rank)) / math.sqrt(in_features)
    )
    self.lora_B = LoRAParam(jnp.zeros((rank, out_features)))

    self.dropout: nnx.Dropout | None = (
        nnx.Dropout(dropout_rate, rngs=rngs) if dropout_rate > 0.0 else None
    )

__call__

__call__(
    x: ndarray, deterministic: bool = False
) -> jnp.ndarray
Source code in core/lora.py
def __call__(self, x: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray:
    out = self.base(x)
    delta = x @ self.lora_A[...]
    if self.dropout is not None:
        delta = self.dropout(delta, deterministic=deterministic)
    return out + (delta @ self.lora_B[...]) * self.scale

merge_weights

merge_weights() -> jnp.ndarray

Return fused kernel W + (alpha/r) * A @ B for export or deployment.

Source code in core/lora.py
def merge_weights(self) -> jnp.ndarray:
    """Return fused kernel W + (alpha/r) * A @ B for export or deployment."""
    return self.base.kernel[...] + self.scale * (self.lora_A[...] @ self.lora_B[...])

Sharding Utilities

SPMD data-parallel helpers built on jax.sharding. Pass n_devices in Config to activate automatically, or call these directly for custom sharding strategies.

core.sharding

Functions

make_mesh

make_mesh(n_devices: int = 0) -> Mesh

Create a 1-D data-parallel mesh.

Parameters

n_devices: Number of devices to use. 0 (default) means all available local devices.

Source code in core/sharding.py
def make_mesh(n_devices: int = 0) -> Mesh:
    """Create a 1-D data-parallel mesh.

    Parameters
    ----------
    n_devices:
        Number of devices to use.  0 (default) means all available local devices.
    """
    devices = jax.local_devices()
    if 0 < n_devices < len(devices):
        devices = devices[:n_devices]
    return Mesh(np.array(devices), axis_names=("data",))

replicate

replicate(pytree: _T, mesh: Mesh) -> _T

Copy pytree to every device in mesh (no sharding on any axis).

Source code in core/sharding.py
def replicate(pytree: _T, mesh: Mesh) -> _T:
    """Copy *pytree* to every device in *mesh* (no sharding on any axis)."""
    sharding = NamedSharding(mesh, P())
    return jax.device_put(pytree, sharding)

shard_batch

shard_batch(pytree: _T, mesh: Mesh) -> _T

Shard pytree along its leading (batch) axis across all devices in mesh.

Source code in core/sharding.py
def shard_batch(pytree: _T, mesh: Mesh) -> _T:
    """Shard *pytree* along its leading (batch) axis across all devices in *mesh*."""
    sharding = NamedSharding(mesh, P("data"))
    return jax.device_put(pytree, sharding)

num_devices

num_devices(mesh: Mesh) -> int

Return the total number of devices in mesh.

Source code in core/sharding.py
def num_devices(mesh: Mesh) -> int:
    """Return the total number of devices in *mesh*."""
    return mesh.size

Configuration

The Config dataclass is the single source of truth for all architectural and training hyperparameters.

core.config

Classes

Config dataclass

Source code in core/config.py
@dataclass
class Config:
    # ── Model Architecture ───────────────────────────────────────────────────
    dim: int = 512
    n_heads: int = 16
    head_size: int = 32
    num_blocks: int = 20
    vocab_size: int = 200
    max_context: int = 512
    kv_heads: int = 4
    weight_tying: bool = True
    activation: str = "gelu"
    gradient_checkpointing: bool = True
    dropout_rate: float = 0.15
    use_swiglu: bool = True

    # ── MoE ─────────────────────────────────────────────────────────────────
    use_moe: bool = False
    n_experts: int = 4
    top_k_mlp: int = 2
    expansion: int = 4
    alpha_balance: float = 0.1

    # ── Attention & Positional ───────────────────────────────────────────────
    use_rotary_pos: bool = True
    trainable_pos: bool = False
    absolute_pos: bool = False
    sliding_window: bool = False
    context_window: int = 4
    no_sink: bool = True
    use_flash_attention: bool = False

    # ── Multi-Head Latent Attention (MLA) ────────────────────────────────────
    mla: bool = False
    inference: bool = False
    down_dim_q: int = 256
    down_dim_kv: int = 256
    rope_dim: int = 32

    # ── Tokenizer ───────────────────────────────────────────────────────────
    tokenizer_type: str = "char"
    tokenizer_path: str | None = None

    # ── Dataset ─────────────────────────────────────────────────────────────
    dataset_source: str = "local"
    dataset_name: str = ""
    dataset_config: str = ""        # HF subset/config name (e.g. "en" for allenai/c4)
    dataset_text_field: str = "text"  # column that contains the raw text
    dataset_split: str = "train"    # HF split to load
    streaming: bool = False

    # ── Training & Optimisation ──────────────────────────────────────────────
    lr: float = 0.005
    batch_size: int = 128
    grad_accum: int = 16
    seed: int = 42
    optimizer: str = "adamw"
    epochs: int = 1000
    warmup_steps: int = 420
    grad_clip: float = 1.0
    patience: int = 0
    use_bf16: bool = False

    # ── Normalisation ────────────────────────────────────────────────────────
    norm_type: str = "layernorm"     # "layernorm" | "rmsnorm"

    # ── RoPE scaling ─────────────────────────────────────────────────────────
    rope_scale_factor: float = 1.0   # >1 compresses frequencies for long-ctx (NTK-aware)

    # ── LR schedule ──────────────────────────────────────────────────────────
    lr_schedule: str = "cosine"      # "cosine" | "linear" | "constant" | "wsd"

    # ── LoRA fine-tuning ──────────────────────────────────────────────────────
    use_lora: bool = False
    lora_rank: int = 8
    lora_alpha: float = 16.0
    lora_dropout: float = 0.0
    lora_targets: str = "attention"  # "attention" | "mlp" | "all"

    # ── Multi-GPU data-parallel ───────────────────────────────────────────────
    n_devices: int = 0               # 0 = use all available local devices

    # ── Logging & Metrics ───────────────────────────────────────────────────
    eval_iters: int = 20
    log_file: str = "training_log.csv"
    summary_file: str = "model_summary.json"

    def __post_init__(self) -> None:
        if self.kv_heads is None:
            self.kv_heads = self.n_heads // 4

        if self.dim != self.n_heads * self.head_size:
            raise ValueError(
                f"dim ({self.dim}) must equal n_heads * head_size "
                f"({self.n_heads} * {self.head_size} = {self.n_heads * self.head_size})"
            )
        if self.n_heads % self.kv_heads != 0:
            raise ValueError(
                f"n_heads ({self.n_heads}) must be divisible by kv_heads ({self.kv_heads})"
            )
        if self.mla and self.use_rotary_pos and self.rope_dim > self.head_size:
            raise ValueError(
                f"rope_dim ({self.rope_dim}) must be <= head_size ({self.head_size}) "
                "when using MLA with rotary positional encoding"
            )
        if self.norm_type not in ("layernorm", "rmsnorm"):
            raise ValueError(
                f"norm_type must be 'layernorm' or 'rmsnorm', got {self.norm_type!r}"
            )
        if self.lr_schedule not in ("cosine", "linear", "constant", "wsd"):
            raise ValueError(
                f"lr_schedule must be 'cosine', 'linear', 'constant', or 'wsd', "
                f"got {self.lr_schedule!r}"
            )
        if self.rope_scale_factor <= 0:
            raise ValueError(
                f"rope_scale_factor must be > 0, got {self.rope_scale_factor}"
            )
        if self.lora_targets not in ("attention", "mlp", "all"):
            raise ValueError(
                f"lora_targets must be 'attention', 'mlp', or 'all', got {self.lora_targets!r}"
            )
        if self.lora_rank < 1:
            raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}")
        if self.n_devices < 0:
            raise ValueError(f"n_devices must be >= 0, got {self.n_devices}")

    def __repr__(self) -> str:
        attn = "MLA" if self.mla else ("GQA" if self.kv_heads < self.n_heads else "MHA")
        moe = "+MoE" if self.use_moe else ""
        return (
            f"Config(dim={self.dim}, heads={self.n_heads}, blocks={self.num_blocks}, "
            f"ctx={self.max_context}, attn={attn}{moe})"
        )

    # ── Serialisation ────────────────────────────────────────────────────────

    def to_dict(self) -> dict[str, Any]:
        """Return a plain dict of all config fields."""
        return asdict(self)

    @classmethod
    def from_dict(cls, d: dict[str, Any]) -> Config:
        """Construct a Config from a plain dict, ignoring unknown keys."""
        valid = {f.name for f in fields(cls)}
        return cls(**{k: v for k, v in d.items() if k in valid})

    @classmethod
    def from_yaml(cls, path: str) -> Config:
        """Load a Config from a YAML file (flat or sectioned)."""
        with open(path) as f:
            raw = yaml.safe_load(f)

        flat: dict[str, Any] = {}
        for v in raw.values():
            if isinstance(v, dict):
                flat.update(v)
        if not flat:
            flat = raw

        return cls.from_dict(flat)

    def save_yaml(self, path: str) -> None:
        """Write the config to a YAML file."""
        with open(path, "w") as f:
            yaml.dump(self.to_dict(), f)
Functions
from_dict classmethod
from_dict(d: dict[str, Any]) -> Config

Construct a Config from a plain dict, ignoring unknown keys.

Source code in core/config.py
@classmethod
def from_dict(cls, d: dict[str, Any]) -> Config:
    """Construct a Config from a plain dict, ignoring unknown keys."""
    valid = {f.name for f in fields(cls)}
    return cls(**{k: v for k, v in d.items() if k in valid})
from_yaml classmethod
from_yaml(path: str) -> Config

Load a Config from a YAML file (flat or sectioned).

Source code in core/config.py
@classmethod
def from_yaml(cls, path: str) -> Config:
    """Load a Config from a YAML file (flat or sectioned)."""
    with open(path) as f:
        raw = yaml.safe_load(f)

    flat: dict[str, Any] = {}
    for v in raw.values():
        if isinstance(v, dict):
            flat.update(v)
    if not flat:
        flat = raw

    return cls.from_dict(flat)
save_yaml
save_yaml(path: str) -> None

Write the config to a YAML file.

Source code in core/config.py
def save_yaml(self, path: str) -> None:
    """Write the config to a YAML file."""
    with open(path, "w") as f:
        yaml.dump(self.to_dict(), f)
to_dict
to_dict() -> dict[str, Any]

Return a plain dict of all config fields.

Source code in core/config.py
def to_dict(self) -> dict[str, Any]:
    """Return a plain dict of all config fields."""
    return asdict(self)

Generation Engine

Autoregressive inference with static KV-cache management, jax.lax.fori_loop token loop, and sampling strategies (greedy, Top-K, Top-P).

core.generation


Tokenizers

Character-level and Byte-Level BPE tokenizers with save/load support.

utils.tokenizer

Classes

Tokenizer

Bases: Protocol

Source code in utils/tokenizer.py
class Tokenizer(Protocol):
    vocab_size: int

    def encode(self, s: str) -> list[int]: ...
    def decode(self, tokens: list[int]) -> str: ...
    def train_from_text(self, text: str, **kwargs: Any) -> None: ...
    def save(self, path: str) -> None: ...

CharTokenizer

Source code in utils/tokenizer.py
class CharTokenizer:
    def __init__(self) -> None:
        self.stoi: dict[str, int] = {}
        self.itos: dict[int, str] = {}
        self.vocab_size: int = 0

    def train_from_text(self, text: str, **kwargs: Any) -> None:
        chars = sorted(set(text))
        self.vocab_size = len(chars)
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}

    def encode(self, s: str) -> list[int]:
        return [self.stoi[c] for c in s]

    def decode(self, tokens: list[int]) -> str:
        return ''.join(self.itos[i] for i in tokens)

    def save(self, path: str) -> None:
        payload = {"type": "char", "vocab": self.stoi}
        with open(path, "w", encoding="utf-8") as f:
            json.dump(payload, f, ensure_ascii=False)

BPETokenizer

Source code in utils/tokenizer.py
class BPETokenizer:
    def __init__(self) -> None:
        from tokenizers import Tokenizer, models, pre_tokenizers
        self.tokenizer = Tokenizer(models.BPE())
        self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
        self.vocab_size: int = 0

    def train_from_text(self, text: str, vocab_size: int = 1000, **kwargs: Any) -> None:
        from tokenizers import trainers
        trainer = trainers.BpeTrainer(
            vocab_size=vocab_size,
            special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
        )
        self.tokenizer.train_from_iterator([text], trainer=trainer)
        self.vocab_size = self.tokenizer.get_vocab_size()

    def encode(self, s: str) -> list[int]:
        return self.tokenizer.encode(s).ids

    def decode(self, tokens: list[int]) -> str:
        return self.tokenizer.decode(tokens)

    def save(self, path: str) -> None:
        payload = {
            "type": "bpe",
            "vocab_size": self.vocab_size,
            "tokenizer": self.tokenizer.to_str(),
        }
        with open(path, "w", encoding="utf-8") as f:
            json.dump(payload, f, ensure_ascii=False)

Functions

get_tokenizer

get_tokenizer(tokenizer_type: str) -> Tokenizer
Source code in utils/tokenizer.py
def get_tokenizer(tokenizer_type: str) -> Tokenizer:
    if tokenizer_type == "char":
        return CharTokenizer()
    elif tokenizer_type == "bpe":
        return BPETokenizer()
    else:
        raise ValueError(f"Unknown tokenizer type: {tokenizer_type}")

load_tokenizer_from_file

load_tokenizer_from_file(path: str) -> Tokenizer

Load a tokenizer that was previously saved with tokenizer.save().

Source code in utils/tokenizer.py
def load_tokenizer_from_file(path: str) -> Tokenizer:
    """Load a tokenizer that was previously saved with ``tokenizer.save()``."""
    with open(path, encoding="utf-8") as f:
        payload = json.load(f)

    tok_type = payload["type"]
    if tok_type == "char":
        char_tok = CharTokenizer()
        char_tok.stoi = {k: int(v) for k, v in payload["vocab"].items()}
        char_tok.itos = {int(v): k for k, v in payload["vocab"].items()}
        char_tok.vocab_size = len(char_tok.stoi)
        return char_tok
    if tok_type == "bpe":
        from tokenizers import Tokenizer as HFTokenizer
        bpe_tok = BPETokenizer()
        bpe_tok.tokenizer = HFTokenizer.from_str(payload["tokenizer"])
        bpe_tok.vocab_size = payload["vocab_size"]
        return bpe_tok
    raise ValueError(f"Unknown tokenizer type in file: {tok_type!r}")

CLI Reference

The dantinox command provides eight subcommands:

Subcommand Description
train Train a model from a config and corpus
generate Generate text from a checkpoint
find-lr Run the LR range test and suggest a learning rate
push Upload a checkpoint to HuggingFace Hub
pull Download a checkpoint from HuggingFace Hub
sweep Run a W&B Bayesian hyperparameter sweep
benchmark Benchmark throughput and FLOPs for run directories
plot Generate figures from benchmark results
dantinox --version
dantinox --help
dantinox train --help
dantinox find-lr --help
dantinox push --help

train

dantinox train
  --config PATH          YAML config file (default: configs/default_config.yaml)
  --data_path PATH       Training corpus
  --run_dir PATH         Output directory (auto-generated if omitted)
  --wandb_project NAME   W&B project for logging
  --resume               Resume from last checkpoint in --run_dir
  --<field> VALUE        Override any Config field (e.g. --lr 3e-4 --use_bf16 True)

generate

dantinox generate
  --run_dir PATH         Run directory with config + weights (required)
  --prompt TEXT          Input prefix (default: "Nel mezzo del cammin ")
  --max_new_tokens N     Tokens to generate (default: 150)
  --greedy               Greedy decoding
  --temperature FLOAT    Softmax temperature (default: 1.0)
  --top_k INT            Top-k sampling
  --top_p FLOAT          Nucleus sampling threshold
  --no_cache             Disable KV cache
  --seed INT             RNG seed (default: 42)

find-lr

dantinox find-lr
  --config PATH          YAML config file
  --data_path PATH       Training corpus (required)
  --min_lr FLOAT         Start LR (default: 1e-7)
  --max_lr FLOAT         End LR (default: 1.0)
  --num_steps INT        Sweep steps (default: 100)
  --plot                 Save a lr_finder.png loss curve
  --plot_out PATH        Custom output path for the PNG
  --<field> VALUE        Override any Config field

push

dantinox push
  --run_dir PATH         Local run directory to upload (required)
  --repo NAME            Hub repo id, e.g. my-org/my-model (required)
  --private              Create a private repository
  --token TOKEN          HuggingFace access token
  --message TEXT         Commit message

pull

dantinox pull
  --repo NAME            Hub repo id (required)
  --local_dir PATH       Where to save the files
  --token TOKEN          HuggingFace access token
  --revision REF         Branch, tag, or commit SHA