Coverage for dantinox / bench.py: 19%
204 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 12:25 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 12:25 +0200
1from __future__ import annotations
3import dataclasses
4import logging
5import os
6import time
7import traceback
8from collections.abc import Sequence
9from typing import Any
11import jax
12import jax.numpy as jnp
13import msgpack
14import numpy as np
15from flax import nnx
16from flax.serialization import _msgpack_ext_unpack
18from core.config import Config
19from core.model import Transformer
20from dantinox.exceptions import BenchmarkError
22log = logging.getLogger(__name__)
24SEQ_LENS = [64, 128, 256, 512]
25BATCH_SIZES = [1, 4, 16, 64, 128, 256]
26FIXED_SEQ = 256
27N_WARMUP = 3
28N_MEASURE = 20
31@nnx.jit
32def _decode_step(model: Transformer, tok: jnp.ndarray, cache: tuple | None, idx: int) -> tuple:
33 return model(tok, use_cache=True, kv_caches=cache, cache_index=idx)
36@nnx.jit
37def _prefill_step(model: Transformer, prompt: jnp.ndarray) -> tuple:
38 return model(prompt, use_cache=False, kv_caches=None, cache_index=0)
41def _load_config(run_path: str) -> Config:
42 import yaml
43 cfg_path = os.path.join(run_path, "config.yaml")
44 if not os.path.exists(cfg_path):
45 raise BenchmarkError(f"Config not found: {cfg_path}")
46 with open(cfg_path) as f:
47 raw = yaml.safe_load(f)
48 flat: dict = {}
49 for section in raw.values():
50 if isinstance(section, dict):
51 flat.update(section)
52 if not flat:
53 flat = raw
54 return Config.from_dict(flat)
57def _detect_vocab(state_dict: dict, dim: int) -> int | None:
58 def _get(d: object, key: str) -> object:
59 if not isinstance(d, dict):
60 return None
61 v = d.get(key)
62 if v is None:
63 v = d.get(key.encode() if isinstance(key, str) else key)
64 return v
66 def _unwrap(obj: object) -> Any:
67 if isinstance(obj, dict):
68 for k in ("value", "raw_value", b"value", b"raw_value"):
69 if k in obj:
70 return obj[k]
71 return obj
73 wte = _get(state_dict, "wte")
74 if wte is None:
75 return None
76 emb = _unwrap(_get(wte, "embedding"))
77 if emb is None or not hasattr(emb, "shape") or emb.ndim != 2:
78 return None
79 return int(emb.shape[0]) if emb.shape[1] == dim else (
80 int(emb.shape[1]) if emb.shape[0] == dim else None
81 )
84def _load_model(run_path: str, config: Config) -> Transformer:
85 weights_path = os.path.join(run_path, "model_weights.msgpack")
86 if not os.path.exists(weights_path):
87 raise BenchmarkError(f"Weights not found: {weights_path}")
88 with open(weights_path, "rb") as f:
89 state_dict = msgpack.unpackb(
90 f.read(), ext_hook=_msgpack_ext_unpack, strict_map_key=False
91 )
92 actual_vocab = _detect_vocab(state_dict, config.dim)
93 if actual_vocab is not None and actual_vocab != config.vocab_size:
94 config = dataclasses.replace(config, vocab_size=actual_vocab)
95 model = Transformer(config, rngs=nnx.Rngs(42))
96 nnx.update(model, state_dict)
97 return model
100def _attn_type(config: Config) -> str:
101 if getattr(config, "mla", False):
102 return "MLA"
103 if getattr(config, "kv_heads", config.n_heads) < config.n_heads:
104 return "GQA"
105 return "MHA"
108def _theoretical_cache_mb(config: Config) -> float:
109 S = config.max_context
110 if getattr(config, "mla", False):
111 per_layer = S * (getattr(config, "down_dim_kv", 0) + getattr(config, "rope_dim", 0)) * 4
112 else:
113 per_layer = 2 * S * getattr(config, "kv_heads", config.n_heads) * (config.dim // config.n_heads) * 4
114 return per_layer * config.num_blocks / 1e6
117def _val_loss(run_path: str) -> float | None:
118 import pandas as pd
119 log_csv = os.path.join(run_path, "training_log.csv")
120 if not os.path.exists(log_csv):
121 return None
122 try:
123 df = pd.read_csv(log_csv)
124 return float(df["val_loss"].dropna().iloc[-1])
125 except Exception:
126 return None
129def _xla_costs(fn: Any, *args: object) -> tuple[float, float]:
130 try:
131 costs = fn.lower(*args).cost_analysis()
132 if isinstance(costs, list):
133 flops = sum(c.get("flops", 0) for c in costs)
134 mem = sum(c.get("bytes accessed", 0) for c in costs)
135 else:
136 flops = costs.get("flops", float("nan"))
137 mem = costs.get("bytes accessed", float("nan"))
138 return float(flops), float(mem)
139 except Exception:
140 return float("nan"), float("nan")
143def benchmark_run(run_path: str) -> dict:
144 """Benchmark a single run directory, returning a metrics dict."""
145 config = _load_config(run_path)
146 model = _load_model(run_path, config)
148 prompt_len = min(config.max_context, 256)
149 prompt = jnp.ones((1, prompt_len), dtype=jnp.int32)
150 tok = jnp.ones((1, 1), dtype=jnp.int32)
152 # Prefill latency
153 jax.block_until_ready(_prefill_step(model, prompt))
154 t0 = time.perf_counter()
155 jax.block_until_ready(_prefill_step(model, prompt))
156 prefill_ms = (time.perf_counter() - t0) * 1000
158 # Decode throughput — sequence length scaling @ BS=1
159 tps_by_seqlen: dict[int, float] = {}
160 for seq in SEQ_LENS:
161 if seq > config.max_context:
162 tps_by_seqlen[seq] = float("nan")
163 continue
164 try:
165 for _ in range(N_WARMUP):
166 _decode_step(model, tok, None, seq)
167 jax.block_until_ready(_decode_step(model, tok, None, seq))
168 t0 = time.perf_counter()
169 for _ in range(N_MEASURE):
170 _decode_step(model, tok, None, seq)
171 jax.block_until_ready(_decode_step(model, tok, None, seq))
172 tps_by_seqlen[seq] = N_MEASURE / (time.perf_counter() - t0)
173 log.debug("seq=%d: %.1f tok/s", seq, tps_by_seqlen[seq])
174 except Exception as exc:
175 log.warning("seq=%d failed: %s", seq, exc)
176 tps_by_seqlen[seq] = float("nan")
178 # Decode throughput — batch scaling @ FIXED_SEQ
179 tps_by_batch: dict[int, float] = {}
180 max_batch = 0
181 for bs in BATCH_SIZES:
182 if config.max_context < FIXED_SEQ:
183 tps_by_batch[bs] = float("nan")
184 continue
185 tok_b = jnp.ones((bs, 1), dtype=jnp.int32)
186 try:
187 for _ in range(N_WARMUP):
188 _decode_step(model, tok_b, None, FIXED_SEQ)
189 jax.block_until_ready(_decode_step(model, tok_b, None, FIXED_SEQ))
190 t0 = time.perf_counter()
191 for _ in range(N_MEASURE):
192 _decode_step(model, tok_b, None, FIXED_SEQ)
193 jax.block_until_ready(_decode_step(model, tok_b, None, FIXED_SEQ))
194 tps_by_batch[bs] = N_MEASURE * bs / (time.perf_counter() - t0)
195 max_batch = bs
196 log.debug("bs=%d: %.1f tok/s", bs, tps_by_batch[bs])
197 except Exception as exc:
198 log.warning("bs=%d OOM/failed: %s", bs, exc)
199 tps_by_batch[bs] = float("nan")
200 break
201 for bs in BATCH_SIZES:
202 tps_by_batch.setdefault(bs, float("nan"))
204 # FLOPs via XLA cost analysis
205 mid_idx = min(config.max_context // 2, config.max_context - 1)
206 decode_flops, decode_bytes = _xla_costs(_decode_step, model, tok, None, mid_idx)
207 prefill_flops, prefill_bytes = _xla_costs(_prefill_step, model, prompt)
209 def _safe_div(a: float, b: float) -> float:
210 return round(a / max(b, 1), 4) if not (np.isnan(a) or np.isnan(b)) else float("nan")
212 def _safe_round(v: float, n: float) -> float:
213 return round(v / n, 4) if not np.isnan(v) else float("nan")
215 best_tps = tps_by_seqlen.get(max(s for s in SEQ_LENS if s <= config.max_context), float("nan"))
216 decode_gflops = _safe_round(decode_flops, 1e9)
217 prefill_gflops = _safe_round(prefill_flops, 1e9)
219 _, model_state = nnx.split(model)
220 params_m = sum(
221 x.size for x in jax.tree_util.tree_leaves(model_state) if hasattr(x, "size")
222 ) / 1e6
224 return {
225 "run": os.path.basename(run_path),
226 "type": _attn_type(config),
227 "params_m": params_m,
228 "moe": getattr(config, "use_moe", False),
229 "num_blocks": config.num_blocks,
230 "dim": config.dim,
231 "n_heads": config.n_heads,
232 "kv_heads": getattr(config, "kv_heads", config.n_heads),
233 "max_context": config.max_context,
234 "down_dim_kv": getattr(config, "down_dim_kv", None),
235 "theoretical_cache_mb": round(_theoretical_cache_mb(config), 2),
236 "prefill_ms": prefill_ms,
237 "val_loss": _val_loss(run_path),
238 "decode_gflops": decode_gflops,
239 "prefill_gflops": prefill_gflops,
240 "decode_arith_int": _safe_div(decode_flops, decode_bytes),
241 "prefill_arith_int": _safe_div(prefill_flops, prefill_bytes),
242 "decode_tflops_s": (
243 round(decode_gflops * best_tps / 1e3, 4)
244 if not (np.isnan(decode_gflops) or np.isnan(best_tps))
245 else float("nan")
246 ),
247 "max_batch_survived": max_batch,
248 **{f"tps_{s}": tps_by_seqlen.get(s, float("nan")) for s in SEQ_LENS},
249 **{f"tps_bs{b}": tps_by_batch.get(b, float("nan")) for b in BATCH_SIZES},
250 }
253class BenchmarkRunner:
254 """
255 Benchmarks one or more DantinoX run directories.
257 Parameters
258 ----------
259 runs_dir : str
260 Directory containing run sub-directories (default ``"runs"``).
261 seq_lens : list[int], optional
262 Sequence lengths to test for throughput scaling.
263 batch_sizes : list[int], optional
264 Batch sizes to test for memory/throughput scaling.
266 Examples
267 --------
268 >>> runner = BenchmarkRunner("runs")
269 >>> df = runner.run()
270 >>> df.to_csv("results.csv", index=False)
271 """
273 def __init__(
274 self,
275 runs_dir: str = "runs",
276 *,
277 seq_lens: Sequence[int] | None = None,
278 batch_sizes: Sequence[int] | None = None,
279 ) -> None:
280 self.runs_dir = runs_dir
281 if seq_lens is not None:
282 global SEQ_LENS
283 SEQ_LENS = list(seq_lens)
284 if batch_sizes is not None:
285 global BATCH_SIZES
286 BATCH_SIZES = list(batch_sizes)
288 def __repr__(self) -> str:
289 return f"BenchmarkRunner(runs_dir={self.runs_dir!r})"
291 def run(
292 self,
293 run_names: Sequence[str] | None = None,
294 *,
295 out_csv: str | None = None,
296 ) -> Any:
297 """
298 Run benchmarks and return a DataFrame.
300 Parameters
301 ----------
302 run_names : list[str], optional
303 Subset of run names to evaluate. Benchmarks all runs if omitted.
304 out_csv : str, optional
305 Write results to this CSV path.
307 Returns
308 -------
309 pandas.DataFrame
311 Raises
312 ------
313 BenchmarkError
314 If the runs directory does not exist.
315 """
316 import pandas as pd
318 if not os.path.isdir(self.runs_dir):
319 raise BenchmarkError(f"Runs directory not found: {self.runs_dir}")
321 if run_names is None:
322 run_names = [
323 d for d in os.listdir(self.runs_dir)
324 if os.path.isdir(os.path.join(self.runs_dir, d))
325 ]
327 results = []
328 for name in run_names:
329 path = os.path.join(self.runs_dir, name)
330 log.info("Benchmarking: %s", name)
331 try:
332 results.append(benchmark_run(path))
333 except BenchmarkError as exc:
334 log.error(" Skipped %s: %s", name, exc)
335 except Exception as exc:
336 log.error(" Unexpected error for %s: %s\n%s", name, exc, traceback.format_exc())
338 df = pd.DataFrame(results)
339 if out_csv:
340 df.to_csv(out_csv, index=False)
341 log.info("Saved benchmark results to %s", out_csv)
342 return df