Coverage for core / generation.py: 61%
96 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 11:22 +0200
1from collections.abc import Callable
3import jax
4import jax.numpy as jnp
5from flax import nnx
7DecodeFunc = Callable[[jnp.ndarray, jax.Array | None], jnp.ndarray]
10def _greedy_decode(v, key=None):
11 return jnp.argmax(v, axis=-1, keepdims=True)
13def _sampling_decode(v, key):
14 return jax.random.categorical(key, jnp.log(v + 1e-10), axis=-1)
16def decode(
17 probs: jnp.ndarray,
18 decoding_func: DecodeFunc,
19 key: jax.Array | None
20 ) -> jnp.ndarray:
22 tok = decoding_func(probs, key)
24 if tok.ndim == 1:
25 tok = tok[:, None]
27 return tok
29@nnx.jit(static_argnames=['decoding_func', 'use_cache', 'top_k', 'top_p', 'temperature'])
30def _generate_toks(
31 model: nnx.Module,
32 x: jnp.ndarray,
33 key: jax.Array | None,
34 start_pos: int | jax.Array,
35 max_generations: int | jax.Array,
36 decoding_func: DecodeFunc,
37 use_cache: bool = False,
38 top_k: int | None = None,
39 top_p: float | None = None,
40 temperature: float = 1.0
41 ) -> jnp.ndarray:
43 def __apply_top_k(probs, decoding_func, key, top_k):
44 top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k, axis=-1)
45 top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True)
47 new_key, subkey = jax.random.split(key)
48 batch_keys = jax.random.split(subkey, probs.shape[0])
50 def sample_from_top_k(p, k, i):
51 sample = decode(probs=p, decoding_func=decoding_func, key=k)
52 return i[sample]
54 toks = jax.vmap(sample_from_top_k)(top_k_probs, batch_keys, top_k_indices)
55 return toks, new_key
57 def __apply_top_p(probs, decoding_func, key, top_p):
58 sorted_indices = jnp.argsort(probs, axis=-1)[:, ::-1]
59 sorted_probs = jnp.take_along_axis(probs, sorted_indices, axis=-1)
61 new_key, subkey = jax.random.split(key)
62 batch_keys = jax.random.split(subkey, probs.shape[0])
64 def sample_from_top_p(p_sorted, k, idx_sorted, top_p_val):
65 cumulative_probs = jnp.cumsum(p_sorted, axis=-1)
66 mask = (cumulative_probs - p_sorted) < top_p_val
67 masked_probs = jnp.where(mask, p_sorted, 0.0)
68 masked_probs = masked_probs / jnp.sum(masked_probs)
70 sample_idx = decode(probs=masked_probs, decoding_func=decoding_func, key=k)
71 return idx_sorted[sample_idx]
73 toks = jax.vmap(sample_from_top_p, in_axes=(0, 0, 0, None))(
74 sorted_probs, batch_keys, sorted_indices, top_p
75 )
76 return toks, new_key
78 def generate_with_kv_cache(i, val):
79 x, tok, kv_cache, k = val
80 last_logits, new_kv_cache, _ = model(tok, use_cache, kv_cache, i-1, deterministic=True)
81 x, k, next_tok_id = _get_tok_id(i, x, k, last_logits[:, -1, :])
82 return x, next_tok_id, new_kv_cache, k
84 def prefill_or_no_cache(i, val):
85 x, kv_cache, _, k = val
86 logits, new_kv_cache, _ = model(x, use_cache, kv_cache, 0, deterministic=True)
87 x, k, tok = _get_tok_id(i, x, k, logits[:, i-1, :])
88 return x, new_kv_cache, tok, k
90 def _get_tok_id(i, x, k, last_logits):
91 last_logits = last_logits / temperature
92 probs = jax.nn.softmax(last_logits, axis=-1)
94 if k is None:
95 tok = decode(probs=probs, decoding_func=decoding_func, key=k)
96 elif top_k is not None:
97 tok, k = __apply_top_k(probs=probs, decoding_func=decoding_func, key=k, top_k=top_k)
98 elif top_p is not None:
99 tok, k = __apply_top_p(probs=probs, decoding_func=decoding_func, key=k, top_p=top_p)
100 else:
101 new_key, subkey = jax.random.split(k)
102 batch_keys = jax.random.split(subkey, probs.shape[0])
104 def sample_base(p, ky):
105 return decode(probs=p, decoding_func=decoding_func, key=ky)
107 tok = jax.vmap(sample_base)(probs, batch_keys)
108 k = new_key
109 tok = tok.reshape(-1, 1)
110 x = x.at[:, i].set(tok[:, 0])
111 return x, k, tok
113 num_blocks: int = model.num_blocks # type: ignore[attr-defined]
114 init_kv_cache = tuple((None, None) for _ in range(num_blocks))
115 dummy_tok = jnp.zeros((x.shape[0], 1), dtype=jnp.int32)
117 if use_cache is False:
118 # kv_cache is never updated in this path, so keep it out of the carry
119 # to avoid passing Python None values through jax.lax.fori_loop.
120 def prefill_no_cache(i, val):
121 _x, _k = val
122 _cache = tuple((None, None) for _ in range(num_blocks))
123 logits, _, _ = model(_x, False, _cache, 0, deterministic=True)
124 _x, _k, _ = _get_tok_id(i, _x, _k, logits[:, i - 1, :])
125 return _x, _k
127 x, _ = jax.lax.fori_loop(
128 lower=start_pos,
129 upper=start_pos + max_generations,
130 body_fun=prefill_no_cache,
131 init_val=(x, key),
132 )
133 else:
134 x, kv_cache, tok, key = prefill_or_no_cache(start_pos,
135 (x, init_kv_cache, dummy_tok, key))
136 x, _, _, _ = jax.lax.fori_loop(lower=start_pos + 1,
137 upper=start_pos + max_generations,
138 body_fun=generate_with_kv_cache,
139 init_val=(x, tok, kv_cache, key))
140 return x
143def generate(
144 model: nnx.Module,
145 x: jnp.ndarray,
146 max_generations: int,
147 greedy: bool = False,
148 seed: int = 42,
149 use_cache: bool = True,
150 top_p: float | None = None,
151 top_k: int | None = None,
152 temperature: float = 1.0) -> jnp.ndarray:
154 B, T = x.shape
155 to_generate = min(model.max_context, T + max_generations) - T # type: ignore[attr-defined]
157 if to_generate <= 0:
158 return x
160 x_padded = jnp.zeros((B, model.max_context), dtype=x.dtype) # type: ignore[attr-defined]
161 x_padded = x_padded.at[:, :T].set(x)
163 decoding_func: DecodeFunc
164 if greedy:
165 key = None
166 decoding_func = _greedy_decode
167 else:
168 key = jax.random.key(seed)
169 decoding_func = _sampling_decode
171 # Pass start_pos and max_generations as JAX arrays (not Python ints) so
172 # @nnx.jit treats them as dynamic traced values. A Python int is static
173 # from JAX's perspective, which would trigger a separate compilation for
174 # every distinct (start_pos, max_generations) pair — e.g. the warmup call
175 # with max_new_tokens=1 would compile a different kernel than the real call
176 # with max_new_tokens=200, blowing up the apparent tok/s.
177 x = _generate_toks(model,
178 x_padded,
179 key=key,
180 start_pos=jnp.array(T, dtype=jnp.int32),
181 max_generations=jnp.array(to_generate, dtype=jnp.int32),
182 decoding_func=decoding_func,
183 use_cache=use_cache,
184 top_p=top_p,
185 top_k=top_k,
186 temperature=temperature)
188 return x[:, :T + to_generate]