Coverage for dantinox / plots / plot_insights.py: 0%

204 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 18:33 +0200

1""" 

2plot_insights.py — three high-signal plots from benchmark_results.csv. 

3 

4 insight_1_pareto.png — quality-cache Pareto front (MLA dominates) 

5 insight_2_serving.png — aggregate tok/s at fixed VRAM budget 

6 insight_3_mla_dial.png — MLA's down_dim_kv quality-cache tradeoff knob 

7 

8Run: python3 plot_insights.py 

9""" 

10 

11import os 

12import numpy as np 

13import pandas as pd 

14import matplotlib.pyplot as plt 

15import matplotlib.patches as mpatches 

16from matplotlib.lines import Line2D 

17from scipy.spatial import ConvexHull 

18 

19IN_CSV = "benchmark_results.csv" 

20OUT_DIR = "plots" 

21DPI = 180 

22TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"} 

23TYPE_ORDER = ["MLA", "GQA", "MHA"] 

24 

25 

26def _save(fig, name): 

27 os.makedirs(OUT_DIR, exist_ok=True) 

28 p = os.path.join(OUT_DIR, name) 

29 fig.savefig(p, dpi=DPI, bbox_inches="tight") 

30 plt.close(fig) 

31 print(f" Saved {p}") 

32 

33 

34def load(): 

35 df = pd.read_csv(IN_CSV) 

36 for c in ["params_m", "val_loss", "theoretical_cache_mb", 

37 "tps_64", "tps_128", "tps_256", "tps_512"]: 

38 if c in df.columns: 

39 df[c] = pd.to_numeric(df[c], errors="coerce") 

40 return df 

41 

42 

43def _pareto_front(xs, ys): 

44 """Return indices of points on the lower-left Pareto front (minimize both).""" 

45 pts = sorted(enumerate(zip(xs, ys)), key=lambda p: p[1][0]) 

46 front, best_y = [], float("inf") 

47 for idx, (x, y) in pts: 

48 if y <= best_y: 

49 front.append(idx) 

50 best_y = y 

51 return front 

52 

53 

54# ───────────────────────────────────────────────────────────────────────────── 

55# Figure 1 — Quality-Cache Pareto front 

56# ───────────────────────────────────────────────────────────────────────────── 

57def fig1_pareto(df): 

58 """ 

59 x: KV cache (MB) — lower is better 

60 y: val_loss — lower is better 

61 Lower-left corner is ideal. MLA should populate the Pareto front. 

62 

63 Pareto-dominated points are shown semi-transparent; the Pareto frontier 

64 is highlighted with a step-function line. 

65 """ 

66 sub = df.dropna(subset=["theoretical_cache_mb", "val_loss", "params_m"]).copy() 

67 dense = sub[~sub["moe"]] 

68 

69 fig, axes = plt.subplots(1, 2, figsize=(17, 7)) 

70 fig.patch.set_facecolor("#F8F9FA") 

71 

72 for ax_i, (data, title_sfx) in enumerate([(dense, "Dense"), (sub, "Dense + MoE")]): 

73 ax = axes[ax_i] 

74 ax.set_facecolor("#F8F9FA") 

75 

76 xs_all = data["theoretical_cache_mb"].values 

77 ys_all = data["val_loss"].values 

78 pareto_idx = set(_pareto_front(xs_all, ys_all)) 

79 

80 for t in TYPE_ORDER: 

81 pts = data[data["type"] == t].reset_index(drop=True) 

82 if pts.empty: 

83 continue 

84 local_idx = data[data["type"] == t].index 

85 sz = (pts["params_m"] / sub["params_m"].max() * 400).clip(lower=25) 

86 mk = "^" if data.get("moe", pd.Series(False)).any() else "o" 

87 

88 for moe_v, grp in pts.groupby("moe"): 

89 mk = "^" if moe_v else "o" 

90 g_sz = (grp["params_m"] / sub["params_m"].max() * 400).clip(lower=25) 

91 orig_idx = grp.index 

92 alphas = [0.90 if i in pareto_idx else 0.22 for i in orig_idx] 

93 for (_, row), a, s in zip(grp.iterrows(), alphas, g_sz): 

94 ax.scatter(row["theoretical_cache_mb"], row["val_loss"], 

95 s=s, c=TYPE_COLORS[t], marker=mk, 

96 edgecolors="white", lw=0.7, alpha=a, zorder=4) 

97 

98 # Draw Pareto step-function 

99 pareto_pts = sorted( 

100 [(xs_all[i], ys_all[i]) for i in pareto_idx], 

101 key=lambda p: p[0] 

102 ) 

103 if pareto_pts: 

104 px = [p[0] for p in pareto_pts] 

105 py = [p[1] for p in pareto_pts] 

106 # step line: move right then down 

107 step_x, step_y = [px[0]], [py[0]] 

108 for i in range(1, len(px)): 

109 step_x += [px[i], px[i]] 

110 step_y += [py[i - 1], py[i]] 

111 ax.plot(step_x, step_y, color="#222222", lw=2, ls="-", 

112 alpha=0.65, zorder=5, label="Pareto front") 

113 # Identify which types are on the front 

114 front_types = set( 

115 data.iloc[i]["type"] for i in pareto_idx 

116 if i < len(data) 

117 ) 

118 ax.text(step_x[0] * 1.05, step_y[0] * 0.997, 

119 f"Pareto front\n({', '.join(sorted(front_types))})", 

120 fontsize=9, color="#222", style="italic", 

121 va="top") 

122 

123 # Ideal direction arrow 

124 ax.annotate("", xy=(0.06, 0.06), xytext=(0.22, 0.22), 

125 xycoords="axes fraction", 

126 arrowprops=dict(arrowstyle="-|>", color="#888", 

127 lw=1.5, mutation_scale=14)) 

128 ax.text(0.04, 0.055, "ideal", transform=ax.transAxes, 

129 fontsize=9, color="#888", style="italic") 

130 

131 ax.set_xlabel("KV Cache (MB) @ 512 tokens ← lower is better", fontsize=12) 

132 ax.set_ylabel("Validation Loss (NLL) ← lower is better", fontsize=12) 

133 ax.set_title(f"{title_sfx} — Quality vs Cache Pareto\n" 

134 "(semi-transparent = Pareto-dominated)", 

135 fontsize=12, fontweight="bold") 

136 ax.grid(alpha=0.2, ls="--") 

137 ax.spines[["top", "right"]].set_visible(False) 

138 

139 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER 

140 if not data[data["type"] == t].empty] 

141 handles += [ 

142 Line2D([0],[0], marker="o", color="#888", ls="None", 

143 markersize=8, markeredgecolor="white", label="Dense"), 

144 Line2D([0],[0], marker="^", color="#888", ls="None", 

145 markersize=8, markeredgecolor="white", label="MoE"), 

146 Line2D([0],[0], color="#222", lw=2, label="Pareto front"), 

147 ] 

148 ax.legend(handles=handles, fontsize=9, framealpha=0.9, 

149 loc="upper right", ncol=2) 

150 

151 fig.suptitle( 

152 "MLA Pareto-Dominates MHA: Better Quality AND Smaller KV Cache", 

153 fontsize=14, fontweight="bold", y=1.01 

154 ) 

155 plt.tight_layout() 

156 _save(fig, "insight_1_pareto.png") 

157 

158 

159# ───────────────────────────────────────────────────────────────────────────── 

160# Figure 2 — Aggregate serving throughput at fixed VRAM budget 

161# ───────────────────────────────────────────────────────────────────────────── 

162def fig2_serving(df): 

163 """ 

164 Key insight: a faster-per-sequence model is NOT always best for serving. 

165 When VRAM is fixed, a model with a smaller KV cache fits more concurrent 

166 sequences, which can outweigh its per-sequence latency. 

167 

168 Metric: total_tps(budget) = (budget_MB / cache_per_seq_MB) × tps_per_seq 

169 

170 We compute the MEDIAN tps and cache across all runs of each type (Dense). 

171 Then we sweep VRAM budgets from 500 MB to 80 GB. 

172 """ 

173 dense = df[~df["moe"]].dropna(subset=["tps_512", "theoretical_cache_mb"]) 

174 

175 # Median per-sequence stats per type 

176 stats = dense.groupby("type").agg( 

177 tps=("tps_512", "median"), 

178 cache_mb=("theoretical_cache_mb", "median"), 

179 tps_lo=("tps_512", lambda x: x.quantile(0.25)), 

180 tps_hi=("tps_512", lambda x: x.quantile(0.75)), 

181 cache_lo=("theoretical_cache_mb", lambda x: x.quantile(0.25)), 

182 cache_hi=("theoretical_cache_mb", lambda x: x.quantile(0.75)), 

183 ) 

184 

185 budgets_mb = np.logspace(np.log10(500), np.log10(80_000), 300) 

186 

187 fig, axes = plt.subplots(1, 2, figsize=(17, 7)) 

188 fig.patch.set_facecolor("#F8F9FA") 

189 

190 # ── Left: total tok/s vs budget ───────────────────────────────────────── 

191 ax = axes[0] 

192 ax.set_facecolor("#F8F9FA") 

193 

194 for t in TYPE_ORDER: 

195 if t not in stats.index: 

196 continue 

197 tps = stats.loc[t, "tps"] 

198 cache = stats.loc[t, "cache_mb"] 

199 total = (budgets_mb / cache) * tps 

200 ax.plot(budgets_mb / 1024, total / 1000, 

201 color=TYPE_COLORS[t], lw=2.8, label=t, zorder=4) 

202 

203 # Uncertainty band (IQR of tps and cache) 

204 t_lo = stats.loc[t, "tps_lo"] 

205 t_hi = stats.loc[t, "tps_hi"] 

206 c_lo = stats.loc[t, "cache_lo"] 

207 c_hi = stats.loc[t, "cache_hi"] 

208 tot_lo = (budgets_mb / c_hi) * t_lo 

209 tot_hi = (budgets_mb / c_lo) * t_hi 

210 ax.fill_between(budgets_mb / 1024, tot_lo / 1000, tot_hi / 1000, 

211 color=TYPE_COLORS[t], alpha=0.12, zorder=3) 

212 

213 # GPU memory reference lines 

214 for gb, lbl in [(24, "24 GB\n(RTX)"), (40, "40 GB\n(A100)"), (80, "80 GB\n(H100)")]: 

215 ax.axvline(gb, color="#888", lw=1, ls="--", alpha=0.6) 

216 ax.text(gb * 1.02, ax.get_ylim()[1] if ax.get_ylim()[1] > 0 else 1, 

217 lbl, fontsize=8, color="#888", va="top") 

218 

219 ax.set_xscale("log") 

220 ax.set_xlabel("KV Cache VRAM budget (GB)", fontsize=12) 

221 ax.set_ylabel("Aggregate throughput (k tok/s)", fontsize=12) 

222 ax.set_title("Total Serving Throughput at Fixed VRAM Budget\n" 

223 "(median per-sequence tps × max concurrent sequences)", 

224 fontsize=12, fontweight="bold") 

225 ax.legend(fontsize=11, framealpha=0.9) 

226 ax.grid(True, which="both", alpha=0.2, ls="--") 

227 ax.spines[["top", "right"]].set_visible(False) 

228 

229 # ── Right: ratio vs MHA baseline ──────────────────────────────────────── 

230 ax2 = axes[1] 

231 ax2.set_facecolor("#F8F9FA") 

232 mha_total = (budgets_mb / stats.loc["MHA", "cache_mb"]) * stats.loc["MHA", "tps"] 

233 

234 for t in TYPE_ORDER: 

235 if t not in stats.index: 

236 continue 

237 tps = stats.loc[t, "tps"] 

238 cache = stats.loc[t, "cache_mb"] 

239 ratio = ((budgets_mb / cache) * tps) / mha_total 

240 ax2.plot(budgets_mb / 1024, ratio, 

241 color=TYPE_COLORS[t], lw=2.8, label=t, zorder=4) 

242 

243 ax2.axhline(1.0, color=TYPE_COLORS["MHA"], lw=1.2, ls="--", alpha=0.6) 

244 ax2.axhline(2.0, color="#CCCCCC", lw=0.8, ls=":", alpha=0.8) 

245 ax2.axhline(2.5, color="#CCCCCC", lw=0.8, ls=":", alpha=0.8) 

246 

247 # Annotate at 80 GB 

248 for t in TYPE_ORDER: 

249 if t not in stats.index: 

250 continue 

251 tps = stats.loc[t, "tps"] 

252 cache = stats.loc[t, "cache_mb"] 

253 val_at_80 = ((80_000 / cache) * tps) / ((80_000 / stats.loc["MHA","cache_mb"]) * stats.loc["MHA","tps"]) 

254 ax2.annotate(f"{t}: {val_at_80:.1f}×", 

255 xy=(80, val_at_80), 

256 xytext=(-50, 8), textcoords="offset points", 

257 fontsize=11, fontweight="bold", color=TYPE_COLORS[t], 

258 arrowprops=dict(arrowstyle="-", color=TYPE_COLORS[t], lw=1)) 

259 

260 for gb, lbl in [(24, "24 GB"), (40, "40 GB"), (80, "80 GB")]: 

261 ax2.axvline(gb, color="#888", lw=1, ls="--", alpha=0.6) 

262 ax2.text(gb * 1.02, 0.05, lbl, fontsize=8, color="#888", 

263 transform=ax2.get_xaxis_transform()) 

264 

265 ax2.set_xscale("log") 

266 ax2.set_xlabel("KV Cache VRAM budget (GB)", fontsize=12) 

267 ax2.set_ylabel("Throughput ratio (MHA = 1.0×)", fontsize=12) 

268 ax2.set_title("Relative Serving Advantage vs MHA\n" 

269 "(above 1.0 = more total tokens/s for the same VRAM)", 

270 fontsize=12, fontweight="bold") 

271 ax2.legend(fontsize=11, framealpha=0.9) 

272 ax2.grid(True, which="both", alpha=0.2, ls="--") 

273 ax2.spines[["top", "right"]].set_visible(False) 

274 

275 fig.suptitle( 

276 "MLA's Smaller Cache Enables 2–3× More Total Throughput at Fixed VRAM\n" 

277 "(even though MLA is slower per-sequence, it fits more concurrent users)", 

278 fontsize=13, fontweight="bold", y=1.01 

279 ) 

280 plt.tight_layout() 

281 _save(fig, "insight_2_serving.png") 

282 

283 

284# ───────────────────────────────────────────────────────────────────────────── 

285# Figure 3 — MLA's down_dim_kv dial: quality vs cache 

286# ───────────────────────────────────────────────────────────────────────────── 

287def fig3_mla_dial(df): 

288 """ 

289 MLA has a unique hyperparameter: down_dim_kv. 

290 Increasing it improves quality but grows the KV cache linearly. 

291 MHA and GQA have no equivalent knob — their cache is fixed by dim and kv_heads. 

292 

293 Left panel — down_dim_kv vs val_loss (quality cost of compression) 

294 Right panel — down_dim_kv vs cache (linear relationship) 

295 Both panels: MHA/GQA reference bands for context. 

296 """ 

297 mla = df[(df["type"] == "MLA") & (~df["moe"])].copy() 

298 other = df[(df["type"] != "MLA") & (~df["moe"])].copy() 

299 

300 if mla["down_dim_kv"].isna().all(): 

301 print("[fig3] No down_dim_kv data — skipping.") 

302 return 

303 

304 # Aggregate MLA per down_dim_kv 

305 agg = mla.groupby("down_dim_kv").agg( 

306 val_loss_mean=("val_loss", "mean"), 

307 val_loss_min =("val_loss", "min"), 

308 val_loss_max =("val_loss", "max"), 

309 cache_mean =("theoretical_cache_mb", "mean"), 

310 cache_min =("theoretical_cache_mb", "min"), 

311 cache_max =("theoretical_cache_mb", "max"), 

312 n =("val_loss", "count"), 

313 ).reset_index() 

314 

315 # MHA / GQA reference: overall median val_loss and cache across Dense runs 

316 ref_mha_loss = other[other["type"] == "MHA"]["val_loss"].median() 

317 ref_gqa_loss = other[other["type"] == "GQA"]["val_loss"].median() 

318 ref_mha_cache = other[other["type"] == "MHA"]["theoretical_cache_mb"].median() 

319 ref_gqa_cache = other[other["type"] == "GQA"]["theoretical_cache_mb"].median() 

320 

321 dkv = agg["down_dim_kv"].values 

322 

323 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6.5)) 

324 fig.patch.set_facecolor("#F8F9FA") 

325 

326 # ── Left: quality ──────────────────────────────────────────────────────── 

327 ax1.set_facecolor("#F8F9FA") 

328 

329 # MHA / GQA reference bands 

330 for val, col, lbl in [ 

331 (ref_mha_loss, TYPE_COLORS["MHA"], "MHA median"), 

332 (ref_gqa_loss, TYPE_COLORS["GQA"], "GQA median"), 

333 ]: 

334 ax1.axhline(val, color=col, lw=2, ls="--", alpha=0.75, label=lbl) 

335 ax1.fill_between( 

336 [dkv.min() * 0.8, dkv.max() * 1.2], 

337 val * 0.97, val * 1.03, 

338 color=col, alpha=0.08 

339 ) 

340 

341 # MLA line + IQR band 

342 ax1.plot(dkv, agg["val_loss_mean"], color=TYPE_COLORS["MLA"], 

343 lw=2.8, marker="o", markersize=8, label="MLA (mean)", zorder=5) 

344 ax1.fill_between(dkv, agg["val_loss_min"], agg["val_loss_max"], 

345 color=TYPE_COLORS["MLA"], alpha=0.18, zorder=3, 

346 label="MLA (min–max range)") 

347 

348 # Annotate best MLA point 

349 best = agg.loc[agg["val_loss_mean"].idxmin()] 

350 ax1.scatter(best["down_dim_kv"], best["val_loss_mean"], 

351 s=180, c=TYPE_COLORS["MLA"], edgecolors="black", lw=2, zorder=6) 

352 ax1.annotate(f"best:\ndkv={int(best['down_dim_kv'])}\nloss={best['val_loss_mean']:.3f}", 

353 (best["down_dim_kv"], best["val_loss_mean"]), 

354 xytext=(14, -28), textcoords="offset points", 

355 fontsize=9, color=TYPE_COLORS["MLA"], 

356 arrowprops=dict(arrowstyle="-", color=TYPE_COLORS["MLA"], lw=1)) 

357 

358 ax1.set_xlabel("down_dim_kv (MLA latent KV dimension)", fontsize=12) 

359 ax1.set_ylabel("Validation Loss (NLL) ← lower is better", fontsize=12) 

360 ax1.set_title("Quality vs down_dim_kv\n" 

361 "MHA / GQA have no equivalent tuning knob", 

362 fontsize=12, fontweight="bold") 

363 ax1.legend(fontsize=10, framealpha=0.9, loc="upper right") 

364 ax1.grid(alpha=0.2, ls="--") 

365 ax1.spines[["top", "right"]].set_visible(False) 

366 

367 # ── Right: cache ───────────────────────────────────────────────────────── 

368 ax2.set_facecolor("#F8F9FA") 

369 

370 for val, col, lbl in [ 

371 (ref_mha_cache, TYPE_COLORS["MHA"], "MHA median"), 

372 (ref_gqa_cache, TYPE_COLORS["GQA"], "GQA median"), 

373 ]: 

374 ax2.axhline(val, color=col, lw=2, ls="--", alpha=0.75, label=lbl) 

375 ax2.fill_between( 

376 [dkv.min() * 0.8, dkv.max() * 1.2], 

377 val * 0.85, val * 1.15, 

378 color=col, alpha=0.08 

379 ) 

380 

381 ax2.plot(dkv, agg["cache_mean"], color=TYPE_COLORS["MLA"], 

382 lw=2.8, marker="o", markersize=8, label="MLA (mean)", zorder=5) 

383 ax2.fill_between(dkv, agg["cache_min"], agg["cache_max"], 

384 color=TYPE_COLORS["MLA"], alpha=0.18, zorder=3, 

385 label="MLA (min–max range)") 

386 

387 # Annotate the crossover points 

388 for ref_val, ref_col, ref_lbl in [ 

389 (ref_mha_cache, TYPE_COLORS["MHA"], "MHA"), 

390 (ref_gqa_cache, TYPE_COLORS["GQA"], "GQA"), 

391 ]: 

392 # Find where MLA cache crosses the reference 

393 crosses = agg[agg["cache_mean"] <= ref_val] 

394 if not crosses.empty: 

395 cross_dkv = crosses["down_dim_kv"].max() 

396 ax2.axvline(cross_dkv, color=ref_col, lw=1.2, ls=":", alpha=0.8) 

397 ax2.text(cross_dkv, ref_val * 0.6, 

398 f"MLA = {ref_lbl}\ncache at\ndkv={int(cross_dkv)}", 

399 fontsize=8, color=ref_col, ha="center", 

400 bbox=dict(boxstyle="round,pad=0.3", fc="white", 

401 ec=ref_col, lw=0.8, alpha=0.85)) 

402 

403 ax2.set_xlabel("down_dim_kv (MLA latent KV dimension)", fontsize=12) 

404 ax2.set_ylabel("KV Cache (MB) @ 512 tokens", fontsize=12) 

405 ax2.set_title("Cache Size vs down_dim_kv\n" 

406 "Linear relationship — easy to predict at any scale", 

407 fontsize=12, fontweight="bold") 

408 ax2.legend(fontsize=10, framealpha=0.9, loc="upper left") 

409 ax2.grid(alpha=0.2, ls="--") 

410 ax2.spines[["top", "right"]].set_visible(False) 

411 

412 fig.suptitle( 

413 "MLA's Unique Tuning Dial: down_dim_kv Controls the Quality-Cache Tradeoff\n" 

414 "MHA / GQA are fixed — MLA lets you choose where on the curve to land", 

415 fontsize=13, fontweight="bold", y=1.01 

416 ) 

417 plt.tight_layout() 

418 _save(fig, "insight_3_mla_dial.png") 

419 

420 

421# ───────────────────────────────────────────────────────────────────────────── 

422if __name__ == "__main__": 

423 df = load() 

424 print(f"Loaded {len(df)} runs") 

425 print("\nGenerating insight figures...") 

426 fig1_pareto(df) 

427 fig2_serving(df) 

428 fig3_mla_dial(df) 

429 print("Done.")