Coverage for dantinox / plots / plot_perf.py: 0%

293 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 18:33 +0200

1""" 

2plot_perf.py — performance comparison: memory, KV cache, speed, FLOPs. 

3 

4Figures produced 

5 perf_1_cache_breakdown.png — KV cache vs model size (theoretical, all types) 

6 perf_2_seqlen_throughput.png — tok/s vs sequence length (existing data, bs=1) 

7 perf_3_flops_vs_cache.png — analytical decode FLOPs vs KV cache (Pareto) 

8 perf_4_batch_throughput.png — tok/s vs batch size (needs batch_sweep_results.csv) 

9 perf_5_prefill.png — prefill latency vs model size 

10 

11Run with existing data: 

12 python3 plot_perf.py 

13 

14Run after batch sweep: 

15 CUDA_VISIBLE_DEVICES=0 python3 benchmark_batch_sweep.py 

16 python3 plot_perf.py 

17""" 

18 

19import os 

20import numpy as np 

21import pandas as pd 

22import matplotlib.pyplot as plt 

23import matplotlib.patches as mpatches 

24from matplotlib.lines import Line2D 

25 

26IN_CSV = "benchmark_results.csv" 

27BATCH_CSV = "batch_sweep_results.csv" 

28OUT_DIR = "plots" 

29DPI = 180 

30 

31TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"} 

32TYPE_ORDER = ["MLA", "GQA", "MHA"] 

33SEQ_LENS = [64, 128, 256, 512] 

34 

35 

36# ───────────────────────────────────────────────────────────────────────────── 

37def _save(fig, name: str): 

38 os.makedirs(OUT_DIR, exist_ok=True) 

39 path = os.path.join(OUT_DIR, name) 

40 fig.savefig(path, dpi=DPI, bbox_inches="tight") 

41 plt.close(fig) 

42 print(f" Saved {path}") 

43 

44 

45def load() -> pd.DataFrame: 

46 df = pd.read_csv(IN_CSV) 

47 num_cols = ["params_m", "val_loss", "theoretical_cache_mb", 

48 "measured_cache_mb", "prefill_ms"] + [f"tps_{s}" for s in SEQ_LENS] 

49 for c in num_cols: 

50 if c in df.columns: 

51 df[c] = pd.to_numeric(df[c], errors="coerce") 

52 # head_size derived from architecture 

53 df["head_size"] = (df["dim"] / df["n_heads"]).round().astype(int) 

54 return df 

55 

56 

57def load_batch() -> pd.DataFrame | None: 

58 if not os.path.exists(BATCH_CSV): 

59 return None 

60 df = pd.read_csv(BATCH_CSV) 

61 for c in ["tps", "cache_mb_total", "batch_size", "params_m", 

62 "theoretical_cache_mb"]: 

63 if c in df.columns: 

64 df[c] = pd.to_numeric(df[c], errors="coerce") 

65 return df 

66 

67 

68# ───────────────────────────────────────────────────────────────────────────── 

69# Analytical FLOPs for one decode step (T=1, context length S, batch B=1) 

70# ───────────────────────────────────────────────────────────────────────────── 

71def _decode_flops(row: pd.Series, S: int = 256) -> float: 

72 """Approximate decode FLOPs (×2 counted, i.e. MACs×2) for one token.""" 

73 dim = int(row["dim"]) 

74 n_heads = int(row["n_heads"]) 

75 kv_heads = int(row["kv_heads"]) 

76 h = int(row["head_size"]) # = dim // n_heads 

77 nb = int(row["num_blocks"]) 

78 exp = 4 # expansion factor; SwiGLU ≈ same total 

79 

80 if row["type"] == "MLA": 

81 down_dim_q = row.get("down_dim_q", dim // 2) or dim // 2 

82 down_dim_kv = row["down_dim_kv"] if not pd.isna(row["down_dim_kv"]) else dim // 4 

83 rope_dim = max(16, int(down_dim_kv) // 4) 

84 down_dim_q = int(down_dim_q) 

85 down_dim_kv = int(down_dim_kv) 

86 # Per decode step, per batch=1: 

87 attn = ( 

88 2 * dim * down_dim_q # down_q 

89 + 2 * dim * rope_dim * 2 # q_pe + k_pe 

90 + 2 * n_heads * down_dim_q * down_dim_kv * h # wt-wt attn_proj (per layer) 

91 + 2 * n_heads * down_dim_q * down_dim_kv # q × attn_proj 

92 + 2 * n_heads * S * down_dim_kv # vs compressed cache 

93 + 2 * n_heads * S * rope_dim # rope attention 

94 + 2 * n_heads * h * down_dim_kv * dim # W_vo wt-wt (per layer) 

95 + 2 * n_heads * down_dim_kv * dim # output einsum 

96 ) 

97 else: 

98 attn = ( 

99 2 * dim * dim # Q proj (n_heads×h = dim) 

100 + 2 * dim * kv_heads * h * 2 # K + V proj 

101 + 2 * n_heads * S * h # QK^T 

102 + 2 * n_heads * S * h # Attn × V 

103 + 2 * dim * dim # O proj 

104 ) 

105 

106 mlp = 2 * dim * exp * dim * 3 # SwiGLU: two up projections + down 

107 

108 return (attn + mlp) * nb 

109 

110 

111def _prefill_flops(row: pd.Series) -> float: 

112 """Prefill FLOPs scale quadratically with sequence length.""" 

113 T = int(row.get("max_context", 512)) 

114 return _decode_flops(row, S=T // 2) * T # rough: O(T²) for attention 

115 

116 

117# ───────────────────────────────────────────────────────────────────────────── 

118# Figure 1 — KV cache vs model parameters, split by type + Dense/MoE 

119# ───────────────────────────────────────────────────────────────────────────── 

120def fig1_cache_breakdown(df: pd.DataFrame): 

121 sub = df.dropna(subset=["params_m", "theoretical_cache_mb"]).copy() 

122 

123 fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)) 

124 fig.patch.set_facecolor("#F8F9FA") 

125 

126 # Left: scatter params vs cache 

127 ax = axes[0] 

128 ax.set_facecolor("#F8F9FA") 

129 for t in TYPE_ORDER: 

130 for moe, grp in sub[sub["type"] == t].groupby("moe"): 

131 mk = "^" if moe else "o" 

132 ax.scatter(grp["params_m"], grp["theoretical_cache_mb"], 

133 s=60, c=TYPE_COLORS[t], marker=mk, 

134 edgecolors="white", lw=0.8, alpha=0.85, zorder=4) 

135 # mean annotation 

136 pts = sub[sub["type"] == t] 

137 ax.scatter([], [], c=TYPE_COLORS[t], s=80, label=t) 

138 

139 ax.set_xlabel("Parameters (M)", fontsize=12) 

140 ax.set_ylabel("KV Cache (MB) @ 512 tokens", fontsize=12) 

141 ax.set_title("Params vs KV Cache — MLA decoupled", 

142 fontsize=12, fontweight="bold") 

143 ax.set_xscale("log") 

144 ax.set_yscale("log") 

145 ax.legend(fontsize=10) 

146 ax.grid(True, which="both", alpha=0.2, ls="--") 

147 ax.spines[["top","right"]].set_visible(False) 

148 

149 # Right: bar chart — cache per type (Dense only), split by num_blocks 

150 ax2 = axes[1] 

151 ax2.set_facecolor("#F8F9FA") 

152 dense = sub[~sub["moe"]] 

153 agg = dense.groupby(["type","num_blocks"])["theoretical_cache_mb"].mean().reset_index() 

154 nb_vals = sorted(agg["num_blocks"].unique()) 

155 w = 0.8 / len(nb_vals) 

156 x = np.arange(len(TYPE_ORDER)) 

157 cmap = plt.cm.Blues(np.linspace(0.4, 0.9, len(nb_vals))) 

158 

159 for i, nb in enumerate(nb_vals): 

160 sub2 = agg[agg["num_blocks"] == nb] 

161 vals = [sub2[sub2["type"] == t]["theoretical_cache_mb"].mean() 

162 if not sub2[sub2["type"] == t].empty else 0 for t in TYPE_ORDER] 

163 off = (i - len(nb_vals) / 2 + 0.5) * w 

164 bars = ax2.bar(x + off, vals, w, color=[TYPE_COLORS[t] for t in TYPE_ORDER], 

165 alpha=0.55 + 0.45 * i / max(len(nb_vals)-1, 1), 

166 edgecolor="white", lw=0.8, label=f"{nb} layers", zorder=3) 

167 for bar, v in zip(bars, vals): 

168 if v > 0: 

169 ax2.text(bar.get_x() + bar.get_width()/2, v * 1.04, 

170 f"{v:.1f}", ha="center", va="bottom", fontsize=8) 

171 

172 ax2.set_xticks(x) 

173 ax2.set_xticklabels(TYPE_ORDER, fontsize=12) 

174 ax2.set_ylabel("KV Cache (MB) @ 512 tokens", fontsize=12) 

175 ax2.set_title("Dense models: cache scales with depth\n(MLA remains smallest at every depth)", 

176 fontsize=11, fontweight="bold") 

177 ax2.legend(title="num_blocks", fontsize=9, framealpha=0.9) 

178 ax2.grid(axis="y", alpha=0.2, ls="--", zorder=0) 

179 ax2.spines[["top","right","left"]].set_visible(False) 

180 ax2.tick_params(left=False) 

181 

182 fig.suptitle("KV Cache Footprint by Attention Type", 

183 fontsize=14, fontweight="bold", y=1.01) 

184 plt.tight_layout() 

185 _save(fig, "perf_1_cache_breakdown.png") 

186 

187 

188# ───────────────────────────────────────────────────────────────────────────── 

189# Figure 2 — Decode throughput vs sequence length (bs=1, existing data) 

190# ───────────────────────────────────────────────────────────────────────────── 

191def fig2_seqlen_throughput(df: pd.DataFrame): 

192 dense = df[~df["moe"]].copy() 

193 

194 fig, axes = plt.subplots(1, 2, figsize=(16, 6)) 

195 fig.patch.set_facecolor("#F8F9FA") 

196 

197 # Left: median tps per type vs seq_len 

198 ax = axes[0] 

199 ax.set_facecolor("#F8F9FA") 

200 for t in TYPE_ORDER: 

201 sub = dense[dense["type"] == t] 

202 if sub.empty: 

203 continue 

204 ys_med = [sub[f"tps_{s}"].median() for s in SEQ_LENS] 

205 ys_p25 = [sub[f"tps_{s}"].quantile(0.25) for s in SEQ_LENS] 

206 ys_p75 = [sub[f"tps_{s}"].quantile(0.75) for s in SEQ_LENS] 

207 ax.plot(SEQ_LENS, ys_med, marker="o", color=TYPE_COLORS[t], 

208 lw=2.5, label=t, zorder=4) 

209 ax.fill_between(SEQ_LENS, ys_p25, ys_p75, 

210 color=TYPE_COLORS[t], alpha=0.15, zorder=3) 

211 

212 ax.set_xlabel("Context / sequence length (tokens)", fontsize=12) 

213 ax.set_ylabel("Tokens / sec (batch=1)", fontsize=12) 

214 ax.set_title("Decode Throughput vs Sequence Length\n(batch=1, all Dense models)", 

215 fontsize=12, fontweight="bold") 

216 ax.legend(fontsize=11, framealpha=0.9) 

217 ax.grid(alpha=0.2, ls="--") 

218 ax.spines[["top","right"]].set_visible(False) 

219 

220 # Annotation: why MLA is slower at bs=1 

221 ax.text(0.97, 0.97, 

222 "At bs=1, GPU is underutilized.\n" 

223 "MLA has more ops per step\n" 

224 "(up-projections + fused einsum).\n" 

225 "Cache savings don't matter yet.", 

226 transform=ax.transAxes, ha="right", va="top", 

227 fontsize=9, color="#555", style="italic", 

228 bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="#ccc", lw=1)) 

229 

230 # Right: normalized (MLA = 1.0 at each seq_len) 

231 ax2 = axes[1] 

232 ax2.set_facecolor("#F8F9FA") 

233 mla_med = {s: dense[dense["type"] == "MLA"][f"tps_{s}"].median() for s in SEQ_LENS} 

234 for t in TYPE_ORDER: 

235 sub = dense[dense["type"] == t] 

236 if sub.empty: 

237 continue 

238 ys = [sub[f"tps_{s}"].median() / mla_med[s] for s in SEQ_LENS] 

239 ax2.plot(SEQ_LENS, ys, marker="o", color=TYPE_COLORS[t], 

240 lw=2.5, label=t, zorder=4) 

241 

242 ax2.axhline(1.0, color=TYPE_COLORS["MLA"], lw=1.2, ls="--", alpha=0.6) 

243 ax2.set_xlabel("Context / sequence length (tokens)", fontsize=12) 

244 ax2.set_ylabel("Relative throughput (MLA = 1.0)", fontsize=12) 

245 ax2.set_title("Relative Throughput — MLA baseline\n" 

246 "(below 1 = slower than MLA, above 1 = faster)", 

247 fontsize=12, fontweight="bold") 

248 ax2.legend(fontsize=11, framealpha=0.9) 

249 ax2.grid(alpha=0.2, ls="--") 

250 ax2.spines[["top","right"]].set_visible(False) 

251 

252 fig.suptitle("Decode Speed: Sequence Length Scaling (Batch=1)", 

253 fontsize=14, fontweight="bold", y=1.01) 

254 plt.tight_layout() 

255 _save(fig, "perf_2_seqlen_throughput.png") 

256 

257 

258# ───────────────────────────────────────────────────────────────────────────── 

259# Figure 3 — Analytical FLOPs vs KV cache (Pareto front) 

260# ───────────────────────────────────────────────────────────────────────────── 

261def fig3_flops_vs_cache(df: pd.DataFrame): 

262 sub = df.dropna(subset=["theoretical_cache_mb"]).copy() 

263 sub["decode_flops_m"] = sub.apply(lambda r: _decode_flops(r, S=256), axis=1) / 1e6 

264 

265 fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)) 

266 fig.patch.set_facecolor("#F8F9FA") 

267 

268 for ax_idx, (moe_val, title_sfx) in enumerate([(False, "Dense"), (True, "MoE")]): 

269 ax = axes[ax_idx] 

270 ax.set_facecolor("#F8F9FA") 

271 grp = sub[sub["moe"] == moe_val] 

272 if grp.empty: 

273 ax.text(0.5, 0.5, "No data", ha="center", va="center", 

274 transform=ax.transAxes) 

275 ax.set_title(title_sfx) 

276 continue 

277 

278 # Pareto front (lower-left is better) 

279 for t in TYPE_ORDER: 

280 pts = grp[grp["type"] == t] 

281 if pts.empty: 

282 continue 

283 sz = (pts["params_m"] / pts["params_m"].max() * 250).clip(lower=25) 

284 ax.scatter(pts["theoretical_cache_mb"], pts["decode_flops_m"], 

285 s=sz, c=TYPE_COLORS[t], edgecolors="white", lw=0.8, 

286 alpha=0.85, zorder=4, label=t) 

287 

288 # Annotate centroid per type 

289 for t in TYPE_ORDER: 

290 pts = grp[grp["type"] == t] 

291 if pts.empty: 

292 continue 

293 cx = pts["theoretical_cache_mb"].median() 

294 cy = pts["decode_flops_m"].median() 

295 ax.annotate(t, (cx, cy), fontsize=11, fontweight="bold", 

296 color=TYPE_COLORS[t], 

297 xytext=(6, 4), textcoords="offset points") 

298 

299 ax.set_xlabel("KV Cache (MB) @ 512 tokens ← lower is better", fontsize=12) 

300 ax.set_ylabel("Decode FLOPs (M, analytical, S=256) ← lower is better", fontsize=12) 

301 ax.set_title(f"{title_sfx} models — FLOPs vs Cache tradeoff\n" 

302 "(bubble size ∝ params)", 

303 fontsize=12, fontweight="bold") 

304 ax.legend(fontsize=10, framealpha=0.9) 

305 ax.grid(alpha=0.2, ls="--") 

306 ax.spines[["top","right"]].set_visible(False) 

307 

308 # Ideal direction arrow 

309 ax.annotate("", xy=(0.08, 0.08), xytext=(0.28, 0.28), 

310 xycoords="axes fraction", 

311 arrowprops=dict(arrowstyle="-|>", color="#888", lw=1.4, 

312 mutation_scale=14)) 

313 ax.text(0.06, 0.06, "ideal", transform=ax.transAxes, 

314 fontsize=8, color="#888", style="italic") 

315 

316 fig.suptitle( 

317 "FLOPs vs Cache Pareto: MLA Pays More Compute for Less Memory\n" 

318 "(analytical decode FLOPs per token, context S=256)", 

319 fontsize=13, fontweight="bold", y=1.01 

320 ) 

321 plt.tight_layout() 

322 _save(fig, "perf_3_flops_vs_cache.png") 

323 

324 

325# ───────────────────────────────────────────────────────────────────────────── 

326# Figure 4 — Throughput vs batch size (requires batch_sweep_results.csv) 

327# ───────────────────────────────────────────────────────────────────────────── 

328def fig4_batch_throughput(bdf: pd.DataFrame): 

329 bdf = bdf[~bdf["oom"]].copy() 

330 

331 dims = sorted(bdf["dim"].unique()) 

332 n_dims = len(dims) 

333 

334 fig, axes = plt.subplots(1, n_dims, figsize=(8 * n_dims, 6.5), squeeze=False) 

335 fig.patch.set_facecolor("#F8F9FA") 

336 

337 for col, dim in enumerate(dims): 

338 ax = axes[0, col] 

339 ax.set_facecolor("#F8F9FA") 

340 sub = bdf[bdf["dim"] == dim] 

341 

342 for t in TYPE_ORDER: 

343 pts = sub[sub["type"] == t].sort_values("batch_size") 

344 if pts.empty: 

345 continue 

346 ax.plot(pts["batch_size"], pts["tps"], 

347 marker="o", color=TYPE_COLORS[t], lw=2.5, label=t, zorder=4) 

348 # Mark max achieved tps 

349 best = pts.loc[pts["tps"].idxmax()] 

350 ax.scatter(best["batch_size"], best["tps"], 

351 s=120, c=TYPE_COLORS[t], edgecolors="black", 

352 lw=1.2, zorder=5) 

353 

354 ax.set_xlabel("Batch size", fontsize=12) 

355 ax.set_ylabel("Tokens / sec (aggregate)", fontsize=12) 

356 nb = int(sub["num_blocks"].iloc[0]) if not sub.empty else "?" 

357 ax.set_title(f"dim={dim}, {nb} layers\n" 

358 f"seq_len={int(sub['seq_len'].iloc[0]) if not sub.empty else '?'} tokens", 

359 fontsize=12, fontweight="bold") 

360 ax.legend(fontsize=11, framealpha=0.9) 

361 ax.set_xscale("log", base=2) 

362 ax.set_xticks(sorted(bdf["batch_size"].unique())) 

363 ax.set_xticklabels([str(b) for b in sorted(bdf["batch_size"].unique())]) 

364 ax.grid(alpha=0.2, ls="--") 

365 ax.spines[["top","right"]].set_visible(False) 

366 

367 fig.suptitle( 

368 "Decode Throughput vs Batch Size\n" 

369 "MLA's smaller KV cache allows more sequences per GPU " 

370 "→ crossover at larger batches", 

371 fontsize=13, fontweight="bold", y=1.01 

372 ) 

373 plt.tight_layout() 

374 _save(fig, "perf_4_batch_throughput.png") 

375 

376 

377def fig4_missing(): 

378 """Placeholder if batch sweep hasn't been run yet.""" 

379 fig, ax = plt.subplots(figsize=(10, 5)) 

380 fig.patch.set_facecolor("#F8F9FA") 

381 ax.set_facecolor("#F8F9FA") 

382 ax.text(0.5, 0.6, 

383 "Run the batch sweep first:", 

384 ha="center", va="center", fontsize=14, fontweight="bold", 

385 transform=ax.transAxes) 

386 ax.text(0.5, 0.42, 

387 "CUDA_VISIBLE_DEVICES=0 python3 benchmark_batch_sweep.py", 

388 ha="center", va="center", fontsize=12, family="monospace", 

389 color="#3A86FF", transform=ax.transAxes) 

390 ax.text(0.5, 0.25, 

391 "Then re-run plot_perf.py", 

392 ha="center", va="center", fontsize=11, color="#555", 

393 transform=ax.transAxes) 

394 ax.axis("off") 

395 ax.set_title("Figure 4: Throughput vs Batch Size — data pending", 

396 fontsize=13, fontweight="bold") 

397 _save(fig, "perf_4_batch_throughput.png") 

398 

399 

400# ───────────────────────────────────────────────────────────────────────────── 

401# Figure 5 — Prefill latency + theoretical KV cache per context length 

402# ───────────────────────────────────────────────────────────────────────────── 

403def fig5_prefill(df: pd.DataFrame): 

404 dense = df[~df["moe"]].dropna(subset=["prefill_ms", "params_m"]).copy() 

405 

406 fig, axes = plt.subplots(1, 2, figsize=(16, 6)) 

407 fig.patch.set_facecolor("#F8F9FA") 

408 

409 # Left: prefill ms vs params_m, coloured by type 

410 ax = axes[0] 

411 ax.set_facecolor("#F8F9FA") 

412 for t in TYPE_ORDER: 

413 pts = dense[dense["type"] == t] 

414 if pts.empty: 

415 continue 

416 ax.scatter(pts["params_m"], pts["prefill_ms"], 

417 s=70, c=TYPE_COLORS[t], edgecolors="white", lw=0.8, 

418 alpha=0.85, zorder=4, label=t) 

419 # Trend line 

420 if len(pts) >= 3: 

421 lx = np.log10(pts["params_m"]) 

422 ly = np.log10(pts["prefill_ms"]) 

423 c = np.polyfit(lx, ly, 1) 

424 xf = np.logspace(lx.min() - 0.1, lx.max() + 0.1, 50) 

425 ax.plot(xf, 10**np.polyval(c, np.log10(xf)), 

426 color=TYPE_COLORS[t], lw=2, ls="--", alpha=0.6, zorder=3) 

427 

428 ax.set_xscale("log") 

429 ax.set_yscale("log") 

430 ax.set_xlabel("Model parameters (M)", fontsize=12) 

431 ax.set_ylabel("Prefill latency (ms) @ max_context tokens", fontsize=12) 

432 ax.set_title("Prefill Latency vs Model Size\n(Dense, batch=1)", 

433 fontsize=12, fontweight="bold") 

434 ax.legend(fontsize=11, framealpha=0.9) 

435 ax.grid(True, which="both", alpha=0.2, ls="--") 

436 ax.spines[["top","right"]].set_visible(False) 

437 ax.text(0.97, 0.05, 

438 "MLA prefill is slower: it runs\nup-projections for every token\n" 

439 "rather than storing compressed KV.", 

440 transform=ax.transAxes, ha="right", va="bottom", 

441 fontsize=9, color="#555", style="italic", 

442 bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="#ccc", lw=1)) 

443 

444 # Right: theoretical cache vs context length for median representative models 

445 ax2 = axes[1] 

446 ax2.set_facecolor("#F8F9FA") 

447 ctx = np.array([512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]) 

448 # Use median bytes-per-token per type 

449 for t in TYPE_ORDER: 

450 sub = dense[dense["type"] == t] 

451 if sub.empty: 

452 continue 

453 bpt_median = (sub["theoretical_cache_mb"] * 1e6 / sub["max_context"]).median() 

454 cache_gb = bpt_median * ctx / 1e9 

455 ax2.plot(ctx / 1024, cache_gb, color=TYPE_COLORS[t], lw=2.5, label=t, 

456 marker="o", markersize=5) 

457 

458 for vram, lbl, col in [(80, "80 GB (A100/H100)", "#CC3311"), 

459 (24, "24 GB (RTX)", "#EE7733")]: 

460 ax2.axhline(vram, color=col, lw=1.3, ls="-.", alpha=0.8) 

461 ax2.text(ctx[-1] / 1024 * 0.97, vram * 1.06, lbl, 

462 ha="right", fontsize=9, color=col) 

463 

464 ax2.set_xscale("log") 

465 ax2.set_yscale("log") 

466 ax2.set_xlabel("Context length (K tokens)", fontsize=12) 

467 ax2.set_ylabel("KV Cache per sequence (GB)", fontsize=12) 

468 ax2.set_title("KV Cache Growth with Context\n" 

469 "(theoretical, median bytes-per-token from experiments)", 

470 fontsize=12, fontweight="bold") 

471 ax2.legend(fontsize=11, framealpha=0.9) 

472 ax2.grid(True, which="both", alpha=0.2, ls="--") 

473 ax2.spines[["top","right"]].set_visible(False) 

474 

475 fig.suptitle("Prefill Cost & KV Cache Scaling", 

476 fontsize=14, fontweight="bold", y=1.01) 

477 plt.tight_layout() 

478 _save(fig, "perf_5_prefill.png") 

479 

480 

481# ───────────────────────────────────────────────────────────────────────────── 

482if __name__ == "__main__": 

483 df = load() 

484 bdf = load_batch() 

485 print(f"Loaded {len(df)} benchmark runs") 

486 if bdf is not None: 

487 print(f"Loaded {len(bdf)} batch-sweep rows") 

488 else: 

489 print("No batch_sweep_results.csv found — fig4 will be a placeholder.") 

490 

491 print("\nGenerating performance figures...") 

492 fig1_cache_breakdown(df) 

493 fig2_seqlen_throughput(df) 

494 fig3_flops_vs_cache(df) 

495 

496 if bdf is not None and not bdf.empty: 

497 fig4_batch_throughput(bdf) 

498 else: 

499 fig4_missing() 

500 

501 fig5_prefill(df) 

502 print("Done.") 

503 

504 # Quick summary 

505 print("\n── Dense model summary ──") 

506 dense = df[~df["moe"]] 

507 dense = dense.copy() 

508 dense["decode_flops_m"] = dense.apply(lambda r: _decode_flops(r, S=256), axis=1) / 1e6 

509 cols = ["params_m", "theoretical_cache_mb", "prefill_ms", 

510 "decode_flops_m"] + [f"tps_{s}" for s in SEQ_LENS] 

511 cols = [c for c in cols if c in dense.columns] 

512 print(dense.groupby("type")[cols].mean().round(2).to_string())