Coverage for core / output.py: 100%
10 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 16:17 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 16:17 +0200
1from __future__ import annotations
3from typing import NamedTuple
5import jax.numpy as jnp
8class ModelOutput(NamedTuple):
9 """
10 Named return type for ``Transformer.__call__``.
12 Supports both attribute access and positional unpacking so existing
13 code that destructures the tuple continues to work unchanged::
15 # Named (preferred)
16 out = model(x, ...)
17 loss = cross_entropy(out.logits, targets) + cfg.alpha * out.aux_loss
19 # Positional (backward-compatible)
20 logits, kv_caches, aux_loss = model(x, ...)
21 """
23 logits: jnp.ndarray
24 """Token logits with shape ``[batch, seq_len, vocab_size]``."""
26 kv_caches: tuple
27 """Per-layer KV (or compressed-latent) caches; ``None`` entries when ``use_cache=False``."""
29 aux_loss: float
30 """MoE load-balancing auxiliary loss (``0.0`` for dense models)."""