Coverage for core / output.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 16:17 +0200

1from __future__ import annotations 

2 

3from typing import NamedTuple 

4 

5import jax.numpy as jnp 

6 

7 

8class ModelOutput(NamedTuple): 

9 """ 

10 Named return type for ``Transformer.__call__``. 

11 

12 Supports both attribute access and positional unpacking so existing 

13 code that destructures the tuple continues to work unchanged:: 

14 

15 # Named (preferred) 

16 out = model(x, ...) 

17 loss = cross_entropy(out.logits, targets) + cfg.alpha * out.aux_loss 

18 

19 # Positional (backward-compatible) 

20 logits, kv_caches, aux_loss = model(x, ...) 

21 """ 

22 

23 logits: jnp.ndarray 

24 """Token logits with shape ``[batch, seq_len, vocab_size]``.""" 

25 

26 kv_caches: tuple 

27 """Per-layer KV (or compressed-latent) caches; ``None`` entries when ``use_cache=False``.""" 

28 

29 aux_loss: float 

30 """MoE load-balancing auxiliary loss (``0.0`` for dense models)."""