Coverage for dantinox / bench.py: 19%

204 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 12:25 +0200

1from __future__ import annotations 

2 

3import dataclasses 

4import logging 

5import os 

6import time 

7import traceback 

8from collections.abc import Sequence 

9from typing import Any 

10 

11import jax 

12import jax.numpy as jnp 

13import msgpack 

14import numpy as np 

15from flax import nnx 

16from flax.serialization import _msgpack_ext_unpack 

17 

18from core.config import Config 

19from core.model import Transformer 

20from dantinox.exceptions import BenchmarkError 

21 

22log = logging.getLogger(__name__) 

23 

24SEQ_LENS = [64, 128, 256, 512] 

25BATCH_SIZES = [1, 4, 16, 64, 128, 256] 

26FIXED_SEQ = 256 

27N_WARMUP = 3 

28N_MEASURE = 20 

29 

30 

31@nnx.jit 

32def _decode_step(model: Transformer, tok: jnp.ndarray, cache: tuple | None, idx: int) -> tuple: 

33 return model(tok, use_cache=True, kv_caches=cache, cache_index=idx) 

34 

35 

36@nnx.jit 

37def _prefill_step(model: Transformer, prompt: jnp.ndarray) -> tuple: 

38 return model(prompt, use_cache=False, kv_caches=None, cache_index=0) 

39 

40 

41def _load_config(run_path: str) -> Config: 

42 import yaml 

43 cfg_path = os.path.join(run_path, "config.yaml") 

44 if not os.path.exists(cfg_path): 

45 raise BenchmarkError(f"Config not found: {cfg_path}") 

46 with open(cfg_path) as f: 

47 raw = yaml.safe_load(f) 

48 flat: dict = {} 

49 for section in raw.values(): 

50 if isinstance(section, dict): 

51 flat.update(section) 

52 if not flat: 

53 flat = raw 

54 return Config.from_dict(flat) 

55 

56 

57def _detect_vocab(state_dict: dict, dim: int) -> int | None: 

58 def _get(d: object, key: str) -> object: 

59 if not isinstance(d, dict): 

60 return None 

61 v = d.get(key) 

62 if v is None: 

63 v = d.get(key.encode() if isinstance(key, str) else key) 

64 return v 

65 

66 def _unwrap(obj: object) -> Any: 

67 if isinstance(obj, dict): 

68 for k in ("value", "raw_value", b"value", b"raw_value"): 

69 if k in obj: 

70 return obj[k] 

71 return obj 

72 

73 wte = _get(state_dict, "wte") 

74 if wte is None: 

75 return None 

76 emb = _unwrap(_get(wte, "embedding")) 

77 if emb is None or not hasattr(emb, "shape") or emb.ndim != 2: 

78 return None 

79 return int(emb.shape[0]) if emb.shape[1] == dim else ( 

80 int(emb.shape[1]) if emb.shape[0] == dim else None 

81 ) 

82 

83 

84def _load_model(run_path: str, config: Config) -> Transformer: 

85 weights_path = os.path.join(run_path, "model_weights.msgpack") 

86 if not os.path.exists(weights_path): 

87 raise BenchmarkError(f"Weights not found: {weights_path}") 

88 with open(weights_path, "rb") as f: 

89 state_dict = msgpack.unpackb( 

90 f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False 

91 ) 

92 actual_vocab = _detect_vocab(state_dict, config.dim) 

93 if actual_vocab is not None and actual_vocab != config.vocab_size: 

94 config = dataclasses.replace(config, vocab_size=actual_vocab) 

95 model = Transformer(config, rngs=nnx.Rngs(42)) 

96 nnx.update(model, state_dict) 

97 return model 

98 

99 

100def _attn_type(config: Config) -> str: 

101 if getattr(config, "mla", False): 

102 return "MLA" 

103 if getattr(config, "kv_heads", config.n_heads) < config.n_heads: 

104 return "GQA" 

105 return "MHA" 

106 

107 

108def _theoretical_cache_mb(config: Config) -> float: 

109 S = config.max_context 

110 if getattr(config, "mla", False): 

111 per_layer = S * (getattr(config, "down_dim_kv", 0) + getattr(config, "rope_dim", 0)) * 4 

112 else: 

113 per_layer = 2 * S * getattr(config, "kv_heads", config.n_heads) * (config.dim // config.n_heads) * 4 

114 return per_layer * config.num_blocks / 1e6 

115 

116 

117def _val_loss(run_path: str) -> float | None: 

118 import pandas as pd 

119 log_csv = os.path.join(run_path, "training_log.csv") 

120 if not os.path.exists(log_csv): 

121 return None 

122 try: 

123 df = pd.read_csv(log_csv) 

124 return float(df["val_loss"].dropna().iloc[-1]) 

125 except Exception: 

126 return None 

127 

128 

129def _xla_costs(fn: Any, *args: object) -> tuple[float, float]: 

130 try: 

131 costs = fn.lower(*args).cost_analysis() 

132 if isinstance(costs, list): 

133 flops = sum(c.get("flops", 0) for c in costs) 

134 mem = sum(c.get("bytes accessed", 0) for c in costs) 

135 else: 

136 flops = costs.get("flops", float("nan")) 

137 mem = costs.get("bytes accessed", float("nan")) 

138 return float(flops), float(mem) 

139 except Exception: 

140 return float("nan"), float("nan") 

141 

142 

143def benchmark_run(run_path: str) -> dict: 

144 """Benchmark a single run directory, returning a metrics dict.""" 

145 config = _load_config(run_path) 

146 model = _load_model(run_path, config) 

147 

148 prompt_len = min(config.max_context, 256) 

149 prompt = jnp.ones((1, prompt_len), dtype=jnp.int32) 

150 tok = jnp.ones((1, 1), dtype=jnp.int32) 

151 

152 # Prefill latency 

153 jax.block_until_ready(_prefill_step(model, prompt)) 

154 t0 = time.perf_counter() 

155 jax.block_until_ready(_prefill_step(model, prompt)) 

156 prefill_ms = (time.perf_counter() - t0) * 1000 

157 

158 # Decode throughput — sequence length scaling @ BS=1 

159 tps_by_seqlen: dict[int, float] = {} 

160 for seq in SEQ_LENS: 

161 if seq > config.max_context: 

162 tps_by_seqlen[seq] = float("nan") 

163 continue 

164 try: 

165 for _ in range(N_WARMUP): 

166 _decode_step(model, tok, None, seq) 

167 jax.block_until_ready(_decode_step(model, tok, None, seq)) 

168 t0 = time.perf_counter() 

169 for _ in range(N_MEASURE): 

170 _decode_step(model, tok, None, seq) 

171 jax.block_until_ready(_decode_step(model, tok, None, seq)) 

172 tps_by_seqlen[seq] = N_MEASURE / (time.perf_counter() - t0) 

173 log.debug("seq=%d: %.1f tok/s", seq, tps_by_seqlen[seq]) 

174 except Exception as exc: 

175 log.warning("seq=%d failed: %s", seq, exc) 

176 tps_by_seqlen[seq] = float("nan") 

177 

178 # Decode throughput — batch scaling @ FIXED_SEQ 

179 tps_by_batch: dict[int, float] = {} 

180 max_batch = 0 

181 for bs in BATCH_SIZES: 

182 if config.max_context < FIXED_SEQ: 

183 tps_by_batch[bs] = float("nan") 

184 continue 

185 tok_b = jnp.ones((bs, 1), dtype=jnp.int32) 

186 try: 

187 for _ in range(N_WARMUP): 

188 _decode_step(model, tok_b, None, FIXED_SEQ) 

189 jax.block_until_ready(_decode_step(model, tok_b, None, FIXED_SEQ)) 

190 t0 = time.perf_counter() 

191 for _ in range(N_MEASURE): 

192 _decode_step(model, tok_b, None, FIXED_SEQ) 

193 jax.block_until_ready(_decode_step(model, tok_b, None, FIXED_SEQ)) 

194 tps_by_batch[bs] = N_MEASURE * bs / (time.perf_counter() - t0) 

195 max_batch = bs 

196 log.debug("bs=%d: %.1f tok/s", bs, tps_by_batch[bs]) 

197 except Exception as exc: 

198 log.warning("bs=%d OOM/failed: %s", bs, exc) 

199 tps_by_batch[bs] = float("nan") 

200 break 

201 for bs in BATCH_SIZES: 

202 tps_by_batch.setdefault(bs, float("nan")) 

203 

204 # FLOPs via XLA cost analysis 

205 mid_idx = min(config.max_context // 2, config.max_context - 1) 

206 decode_flops, decode_bytes = _xla_costs(_decode_step, model, tok, None, mid_idx) 

207 prefill_flops, prefill_bytes = _xla_costs(_prefill_step, model, prompt) 

208 

209 def _safe_div(a: float, b: float) -> float: 

210 return round(a / max(b, 1), 4) if not (np.isnan(a) or np.isnan(b)) else float("nan") 

211 

212 def _safe_round(v: float, n: float) -> float: 

213 return round(v / n, 4) if not np.isnan(v) else float("nan") 

214 

215 best_tps = tps_by_seqlen.get(max(s for s in SEQ_LENS if s <= config.max_context), float("nan")) 

216 decode_gflops = _safe_round(decode_flops, 1e9) 

217 prefill_gflops = _safe_round(prefill_flops, 1e9) 

218 

219 _, model_state = nnx.split(model) 

220 params_m = sum( 

221 x.size for x in jax.tree_util.tree_leaves(model_state) if hasattr(x, "size") 

222 ) / 1e6 

223 

224 return { 

225 "run": os.path.basename(run_path), 

226 "type": _attn_type(config), 

227 "params_m": params_m, 

228 "moe": getattr(config, "use_moe", False), 

229 "num_blocks": config.num_blocks, 

230 "dim": config.dim, 

231 "n_heads": config.n_heads, 

232 "kv_heads": getattr(config, "kv_heads", config.n_heads), 

233 "max_context": config.max_context, 

234 "down_dim_kv": getattr(config, "down_dim_kv", None), 

235 "theoretical_cache_mb": round(_theoretical_cache_mb(config), 2), 

236 "prefill_ms": prefill_ms, 

237 "val_loss": _val_loss(run_path), 

238 "decode_gflops": decode_gflops, 

239 "prefill_gflops": prefill_gflops, 

240 "decode_arith_int": _safe_div(decode_flops, decode_bytes), 

241 "prefill_arith_int": _safe_div(prefill_flops, prefill_bytes), 

242 "decode_tflops_s": ( 

243 round(decode_gflops * best_tps / 1e3, 4) 

244 if not (np.isnan(decode_gflops) or np.isnan(best_tps)) 

245 else float("nan") 

246 ), 

247 "max_batch_survived": max_batch, 

248 **{f"tps_{s}": tps_by_seqlen.get(s, float("nan")) for s in SEQ_LENS}, 

249 **{f"tps_bs{b}": tps_by_batch.get(b, float("nan")) for b in BATCH_SIZES}, 

250 } 

251 

252 

253class BenchmarkRunner: 

254 """ 

255 Benchmarks one or more DantinoX run directories. 

256 

257 Parameters 

258 ---------- 

259 runs_dir : str 

260 Directory containing run sub-directories (default ``"runs"``). 

261 seq_lens : list[int], optional 

262 Sequence lengths to test for throughput scaling. 

263 batch_sizes : list[int], optional 

264 Batch sizes to test for memory/throughput scaling. 

265 

266 Examples 

267 -------- 

268 >>> runner = BenchmarkRunner("runs") 

269 >>> df = runner.run() 

270 >>> df.to_csv("results.csv", index=False) 

271 """ 

272 

273 def __init__( 

274 self, 

275 runs_dir: str = "runs", 

276 *, 

277 seq_lens: Sequence[int] | None = None, 

278 batch_sizes: Sequence[int] | None = None, 

279 ) -> None: 

280 self.runs_dir = runs_dir 

281 if seq_lens is not None: 

282 global SEQ_LENS 

283 SEQ_LENS = list(seq_lens) 

284 if batch_sizes is not None: 

285 global BATCH_SIZES 

286 BATCH_SIZES = list(batch_sizes) 

287 

288 def __repr__(self) -> str: 

289 return f"BenchmarkRunner(runs_dir={self.runs_dir!r})" 

290 

291 def run( 

292 self, 

293 run_names: Sequence[str] | None = None, 

294 *, 

295 out_csv: str | None = None, 

296 ) -> Any: 

297 """ 

298 Run benchmarks and return a DataFrame. 

299 

300 Parameters 

301 ---------- 

302 run_names : list[str], optional 

303 Subset of run names to evaluate. Benchmarks all runs if omitted. 

304 out_csv : str, optional 

305 Write results to this CSV path. 

306 

307 Returns 

308 ------- 

309 pandas.DataFrame 

310 

311 Raises 

312 ------ 

313 BenchmarkError 

314 If the runs directory does not exist. 

315 """ 

316 import pandas as pd 

317 

318 if not os.path.isdir(self.runs_dir): 

319 raise BenchmarkError(f"Runs directory not found: {self.runs_dir}") 

320 

321 if run_names is None: 

322 run_names = [ 

323 d for d in os.listdir(self.runs_dir) 

324 if os.path.isdir(os.path.join(self.runs_dir, d)) 

325 ] 

326 

327 results = [] 

328 for name in run_names: 

329 path = os.path.join(self.runs_dir, name) 

330 log.info("Benchmarking: %s", name) 

331 try: 

332 results.append(benchmark_run(path)) 

333 except BenchmarkError as exc: 

334 log.error(" Skipped %s: %s", name, exc) 

335 except Exception as exc: 

336 log.error(" Unexpected error for %s: %s\n%s", name, exc, traceback.format_exc()) 

337 

338 df = pd.DataFrame(results) 

339 if out_csv: 

340 df.to_csv(out_csv, index=False) 

341 log.info("Saved benchmark results to %s", out_csv) 

342 return df