Coverage for core / generation.py: 61%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 11:22 +0200

1from collections.abc import Callable 

2 

3import jax 

4import jax.numpy as jnp 

5from flax import nnx 

6 

7DecodeFunc = Callable[[jnp.ndarray, jax.Array | None], jnp.ndarray] 

8 

9 

10def _greedy_decode(v, key=None): 

11 return jnp.argmax(v, axis=-1, keepdims=True) 

12 

13def _sampling_decode(v, key): 

14 return jax.random.categorical(key, jnp.log(v + 1e-10), axis=-1) 

15 

16def decode( 

17 probs: jnp.ndarray, 

18 decoding_func: DecodeFunc, 

19 key: jax.Array | None 

20 ) -> jnp.ndarray: 

21 

22 tok = decoding_func(probs, key) 

23 

24 if tok.ndim == 1: 

25 tok = tok[:, None] 

26 

27 return tok 

28 

29@nnx.jit(static_argnames=['decoding_func', 'use_cache', 'top_k', 'top_p', 'temperature']) 

30def _generate_toks( 

31 model: nnx.Module, 

32 x: jnp.ndarray, 

33 key: jax.Array | None, 

34 start_pos: int | jax.Array, 

35 max_generations: int | jax.Array, 

36 decoding_func: DecodeFunc, 

37 use_cache: bool = False, 

38 top_k: int | None = None, 

39 top_p: float | None = None, 

40 temperature: float = 1.0 

41 ) -> jnp.ndarray: 

42 

43 def __apply_top_k(probs, decoding_func, key, top_k): 

44 top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k, axis=-1) 

45 top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True) 

46 

47 new_key, subkey = jax.random.split(key) 

48 batch_keys = jax.random.split(subkey, probs.shape[0]) 

49 

50 def sample_from_top_k(p, k, i): 

51 sample = decode(probs=p, decoding_func=decoding_func, key=k) 

52 return i[sample] 

53 

54 toks = jax.vmap(sample_from_top_k)(top_k_probs, batch_keys, top_k_indices) 

55 return toks, new_key 

56 

57 def __apply_top_p(probs, decoding_func, key, top_p): 

58 sorted_indices = jnp.argsort(probs, axis=-1)[:, ::-1] 

59 sorted_probs = jnp.take_along_axis(probs, sorted_indices, axis=-1) 

60 

61 new_key, subkey = jax.random.split(key) 

62 batch_keys = jax.random.split(subkey, probs.shape[0]) 

63 

64 def sample_from_top_p(p_sorted, k, idx_sorted, top_p_val): 

65 cumulative_probs = jnp.cumsum(p_sorted, axis=-1) 

66 mask = (cumulative_probs - p_sorted) < top_p_val 

67 masked_probs = jnp.where(mask, p_sorted, 0.0) 

68 masked_probs = masked_probs / jnp.sum(masked_probs) 

69 

70 sample_idx = decode(probs=masked_probs, decoding_func=decoding_func, key=k) 

71 return idx_sorted[sample_idx] 

72 

73 toks = jax.vmap(sample_from_top_p, in_axes=(0, 0, 0, None))( 

74 sorted_probs, batch_keys, sorted_indices, top_p 

75 ) 

76 return toks, new_key 

77 

78 def generate_with_kv_cache(i, val): 

79 x, tok, kv_cache, k = val 

80 last_logits, new_kv_cache, _ = model(tok, use_cache, kv_cache, i-1, deterministic=True) 

81 x, k, next_tok_id = _get_tok_id(i, x, k, last_logits[:, -1, :]) 

82 return x, next_tok_id, new_kv_cache, k 

83 

84 def prefill_or_no_cache(i, val): 

85 x, kv_cache, _, k = val 

86 logits, new_kv_cache, _ = model(x, use_cache, kv_cache, 0, deterministic=True) 

87 x, k, tok = _get_tok_id(i, x, k, logits[:, i-1, :]) 

88 return x, new_kv_cache, tok, k 

89 

90 def _get_tok_id(i, x, k, last_logits): 

91 last_logits = last_logits / temperature 

92 probs = jax.nn.softmax(last_logits, axis=-1) 

93 

94 if k is None: 

95 tok = decode(probs=probs, decoding_func=decoding_func, key=k) 

96 elif top_k is not None: 

97 tok, k = __apply_top_k(probs=probs, decoding_func=decoding_func, key=k, top_k=top_k) 

98 elif top_p is not None: 

99 tok, k = __apply_top_p(probs=probs, decoding_func=decoding_func, key=k, top_p=top_p) 

100 else: 

101 new_key, subkey = jax.random.split(k) 

102 batch_keys = jax.random.split(subkey, probs.shape[0]) 

103 

104 def sample_base(p, ky): 

105 return decode(probs=p, decoding_func=decoding_func, key=ky) 

106 

107 tok = jax.vmap(sample_base)(probs, batch_keys) 

108 k = new_key 

109 tok = tok.reshape(-1, 1) 

110 x = x.at[:, i].set(tok[:, 0]) 

111 return x, k, tok 

112 

113 num_blocks: int = model.num_blocks # type: ignore[attr-defined] 

114 init_kv_cache = tuple((None, None) for _ in range(num_blocks)) 

115 dummy_tok = jnp.zeros((x.shape[0], 1), dtype=jnp.int32) 

116 

117 if use_cache is False: 

118 # kv_cache is never updated in this path, so keep it out of the carry 

119 # to avoid passing Python None values through jax.lax.fori_loop. 

120 def prefill_no_cache(i, val): 

121 _x, _k = val 

122 _cache = tuple((None, None) for _ in range(num_blocks)) 

123 logits, _, _ = model(_x, False, _cache, 0, deterministic=True) 

124 _x, _k, _ = _get_tok_id(i, _x, _k, logits[:, i - 1, :]) 

125 return _x, _k 

126 

127 x, _ = jax.lax.fori_loop( 

128 lower=start_pos, 

129 upper=start_pos + max_generations, 

130 body_fun=prefill_no_cache, 

131 init_val=(x, key), 

132 ) 

133 else: 

134 x, kv_cache, tok, key = prefill_or_no_cache(start_pos, 

135 (x, init_kv_cache, dummy_tok, key)) 

136 x, _, _, _ = jax.lax.fori_loop(lower=start_pos + 1, 

137 upper=start_pos + max_generations, 

138 body_fun=generate_with_kv_cache, 

139 init_val=(x, tok, kv_cache, key)) 

140 return x 

141 

142 

143def generate( 

144 model: nnx.Module, 

145 x: jnp.ndarray, 

146 max_generations: int, 

147 greedy: bool = False, 

148 seed: int = 42, 

149 use_cache: bool = True, 

150 top_p: float | None = None, 

151 top_k: int | None = None, 

152 temperature: float = 1.0) -> jnp.ndarray: 

153 

154 B, T = x.shape 

155 to_generate = min(model.max_context, T + max_generations) - T # type: ignore[attr-defined] 

156 

157 if to_generate <= 0: 

158 return x 

159 

160 x_padded = jnp.zeros((B, model.max_context), dtype=x.dtype) # type: ignore[attr-defined] 

161 x_padded = x_padded.at[:, :T].set(x) 

162 

163 decoding_func: DecodeFunc 

164 if greedy: 

165 key = None 

166 decoding_func = _greedy_decode 

167 else: 

168 key = jax.random.key(seed) 

169 decoding_func = _sampling_decode 

170 

171 # Pass start_pos and max_generations as JAX arrays (not Python ints) so 

172 # @nnx.jit treats them as dynamic traced values. A Python int is static 

173 # from JAX's perspective, which would trigger a separate compilation for 

174 # every distinct (start_pos, max_generations) pair — e.g. the warmup call 

175 # with max_new_tokens=1 would compile a different kernel than the real call 

176 # with max_new_tokens=200, blowing up the apparent tok/s. 

177 x = _generate_toks(model, 

178 x_padded, 

179 key=key, 

180 start_pos=jnp.array(T, dtype=jnp.int32), 

181 max_generations=jnp.array(to_generate, dtype=jnp.int32), 

182 decoding_func=decoding_func, 

183 use_cache=use_cache, 

184 top_p=top_p, 

185 top_k=top_k, 

186 temperature=temperature) 

187 

188 return x[:, :T + to_generate]