Coverage for dantinox / plots / plot_insights.py: 0%
204 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 18:33 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 18:33 +0200
1"""
2plot_insights.py — three high-signal plots from benchmark_results.csv.
4 insight_1_pareto.png — quality-cache Pareto front (MLA dominates)
5 insight_2_serving.png — aggregate tok/s at fixed VRAM budget
6 insight_3_mla_dial.png — MLA's down_dim_kv quality-cache tradeoff knob
8Run: python3 plot_insights.py
9"""
11import os
12import numpy as np
13import pandas as pd
14import matplotlib.pyplot as plt
15import matplotlib.patches as mpatches
16from matplotlib.lines import Line2D
17from scipy.spatial import ConvexHull
19IN_CSV = "benchmark_results.csv"
20OUT_DIR = "plots"
21DPI = 180
22TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"}
23TYPE_ORDER = ["MLA", "GQA", "MHA"]
26def _save(fig, name):
27 os.makedirs(OUT_DIR, exist_ok=True)
28 p = os.path.join(OUT_DIR, name)
29 fig.savefig(p, dpi=DPI, bbox_inches="tight")
30 plt.close(fig)
31 print(f" Saved {p}")
34def load():
35 df = pd.read_csv(IN_CSV)
36 for c in ["params_m", "val_loss", "theoretical_cache_mb",
37 "tps_64", "tps_128", "tps_256", "tps_512"]:
38 if c in df.columns:
39 df[c] = pd.to_numeric(df[c], errors="coerce")
40 return df
43def _pareto_front(xs, ys):
44 """Return indices of points on the lower-left Pareto front (minimize both)."""
45 pts = sorted(enumerate(zip(xs, ys)), key=lambda p: p[1][0])
46 front, best_y = [], float("inf")
47 for idx, (x, y) in pts:
48 if y <= best_y:
49 front.append(idx)
50 best_y = y
51 return front
54# ─────────────────────────────────────────────────────────────────────────────
55# Figure 1 — Quality-Cache Pareto front
56# ─────────────────────────────────────────────────────────────────────────────
57def fig1_pareto(df):
58 """
59 x: KV cache (MB) — lower is better
60 y: val_loss — lower is better
61 Lower-left corner is ideal. MLA should populate the Pareto front.
63 Pareto-dominated points are shown semi-transparent; the Pareto frontier
64 is highlighted with a step-function line.
65 """
66 sub = df.dropna(subset=["theoretical_cache_mb", "val_loss", "params_m"]).copy()
67 dense = sub[~sub["moe"]]
69 fig, axes = plt.subplots(1, 2, figsize=(17, 7))
70 fig.patch.set_facecolor("#F8F9FA")
72 for ax_i, (data, title_sfx) in enumerate([(dense, "Dense"), (sub, "Dense + MoE")]):
73 ax = axes[ax_i]
74 ax.set_facecolor("#F8F9FA")
76 xs_all = data["theoretical_cache_mb"].values
77 ys_all = data["val_loss"].values
78 pareto_idx = set(_pareto_front(xs_all, ys_all))
80 for t in TYPE_ORDER:
81 pts = data[data["type"] == t].reset_index(drop=True)
82 if pts.empty:
83 continue
84 local_idx = data[data["type"] == t].index
85 sz = (pts["params_m"] / sub["params_m"].max() * 400).clip(lower=25)
86 mk = "^" if data.get("moe", pd.Series(False)).any() else "o"
88 for moe_v, grp in pts.groupby("moe"):
89 mk = "^" if moe_v else "o"
90 g_sz = (grp["params_m"] / sub["params_m"].max() * 400).clip(lower=25)
91 orig_idx = grp.index
92 alphas = [0.90 if i in pareto_idx else 0.22 for i in orig_idx]
93 for (_, row), a, s in zip(grp.iterrows(), alphas, g_sz):
94 ax.scatter(row["theoretical_cache_mb"], row["val_loss"],
95 s=s, c=TYPE_COLORS[t], marker=mk,
96 edgecolors="white", lw=0.7, alpha=a, zorder=4)
98 # Draw Pareto step-function
99 pareto_pts = sorted(
100 [(xs_all[i], ys_all[i]) for i in pareto_idx],
101 key=lambda p: p[0]
102 )
103 if pareto_pts:
104 px = [p[0] for p in pareto_pts]
105 py = [p[1] for p in pareto_pts]
106 # step line: move right then down
107 step_x, step_y = [px[0]], [py[0]]
108 for i in range(1, len(px)):
109 step_x += [px[i], px[i]]
110 step_y += [py[i - 1], py[i]]
111 ax.plot(step_x, step_y, color="#222222", lw=2, ls="-",
112 alpha=0.65, zorder=5, label="Pareto front")
113 # Identify which types are on the front
114 front_types = set(
115 data.iloc[i]["type"] for i in pareto_idx
116 if i < len(data)
117 )
118 ax.text(step_x[0] * 1.05, step_y[0] * 0.997,
119 f"Pareto front\n({', '.join(sorted(front_types))})",
120 fontsize=9, color="#222", style="italic",
121 va="top")
123 # Ideal direction arrow
124 ax.annotate("", xy=(0.06, 0.06), xytext=(0.22, 0.22),
125 xycoords="axes fraction",
126 arrowprops=dict(arrowstyle="-|>", color="#888",
127 lw=1.5, mutation_scale=14))
128 ax.text(0.04, 0.055, "ideal", transform=ax.transAxes,
129 fontsize=9, color="#888", style="italic")
131 ax.set_xlabel("KV Cache (MB) @ 512 tokens ← lower is better", fontsize=12)
132 ax.set_ylabel("Validation Loss (NLL) ← lower is better", fontsize=12)
133 ax.set_title(f"{title_sfx} — Quality vs Cache Pareto\n"
134 "(semi-transparent = Pareto-dominated)",
135 fontsize=12, fontweight="bold")
136 ax.grid(alpha=0.2, ls="--")
137 ax.spines[["top", "right"]].set_visible(False)
139 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER
140 if not data[data["type"] == t].empty]
141 handles += [
142 Line2D([0],[0], marker="o", color="#888", ls="None",
143 markersize=8, markeredgecolor="white", label="Dense"),
144 Line2D([0],[0], marker="^", color="#888", ls="None",
145 markersize=8, markeredgecolor="white", label="MoE"),
146 Line2D([0],[0], color="#222", lw=2, label="Pareto front"),
147 ]
148 ax.legend(handles=handles, fontsize=9, framealpha=0.9,
149 loc="upper right", ncol=2)
151 fig.suptitle(
152 "MLA Pareto-Dominates MHA: Better Quality AND Smaller KV Cache",
153 fontsize=14, fontweight="bold", y=1.01
154 )
155 plt.tight_layout()
156 _save(fig, "insight_1_pareto.png")
159# ─────────────────────────────────────────────────────────────────────────────
160# Figure 2 — Aggregate serving throughput at fixed VRAM budget
161# ─────────────────────────────────────────────────────────────────────────────
162def fig2_serving(df):
163 """
164 Key insight: a faster-per-sequence model is NOT always best for serving.
165 When VRAM is fixed, a model with a smaller KV cache fits more concurrent
166 sequences, which can outweigh its per-sequence latency.
168 Metric: total_tps(budget) = (budget_MB / cache_per_seq_MB) × tps_per_seq
170 We compute the MEDIAN tps and cache across all runs of each type (Dense).
171 Then we sweep VRAM budgets from 500 MB to 80 GB.
172 """
173 dense = df[~df["moe"]].dropna(subset=["tps_512", "theoretical_cache_mb"])
175 # Median per-sequence stats per type
176 stats = dense.groupby("type").agg(
177 tps=("tps_512", "median"),
178 cache_mb=("theoretical_cache_mb", "median"),
179 tps_lo=("tps_512", lambda x: x.quantile(0.25)),
180 tps_hi=("tps_512", lambda x: x.quantile(0.75)),
181 cache_lo=("theoretical_cache_mb", lambda x: x.quantile(0.25)),
182 cache_hi=("theoretical_cache_mb", lambda x: x.quantile(0.75)),
183 )
185 budgets_mb = np.logspace(np.log10(500), np.log10(80_000), 300)
187 fig, axes = plt.subplots(1, 2, figsize=(17, 7))
188 fig.patch.set_facecolor("#F8F9FA")
190 # ── Left: total tok/s vs budget ─────────────────────────────────────────
191 ax = axes[0]
192 ax.set_facecolor("#F8F9FA")
194 for t in TYPE_ORDER:
195 if t not in stats.index:
196 continue
197 tps = stats.loc[t, "tps"]
198 cache = stats.loc[t, "cache_mb"]
199 total = (budgets_mb / cache) * tps
200 ax.plot(budgets_mb / 1024, total / 1000,
201 color=TYPE_COLORS[t], lw=2.8, label=t, zorder=4)
203 # Uncertainty band (IQR of tps and cache)
204 t_lo = stats.loc[t, "tps_lo"]
205 t_hi = stats.loc[t, "tps_hi"]
206 c_lo = stats.loc[t, "cache_lo"]
207 c_hi = stats.loc[t, "cache_hi"]
208 tot_lo = (budgets_mb / c_hi) * t_lo
209 tot_hi = (budgets_mb / c_lo) * t_hi
210 ax.fill_between(budgets_mb / 1024, tot_lo / 1000, tot_hi / 1000,
211 color=TYPE_COLORS[t], alpha=0.12, zorder=3)
213 # GPU memory reference lines
214 for gb, lbl in [(24, "24 GB\n(RTX)"), (40, "40 GB\n(A100)"), (80, "80 GB\n(H100)")]:
215 ax.axvline(gb, color="#888", lw=1, ls="--", alpha=0.6)
216 ax.text(gb * 1.02, ax.get_ylim()[1] if ax.get_ylim()[1] > 0 else 1,
217 lbl, fontsize=8, color="#888", va="top")
219 ax.set_xscale("log")
220 ax.set_xlabel("KV Cache VRAM budget (GB)", fontsize=12)
221 ax.set_ylabel("Aggregate throughput (k tok/s)", fontsize=12)
222 ax.set_title("Total Serving Throughput at Fixed VRAM Budget\n"
223 "(median per-sequence tps × max concurrent sequences)",
224 fontsize=12, fontweight="bold")
225 ax.legend(fontsize=11, framealpha=0.9)
226 ax.grid(True, which="both", alpha=0.2, ls="--")
227 ax.spines[["top", "right"]].set_visible(False)
229 # ── Right: ratio vs MHA baseline ────────────────────────────────────────
230 ax2 = axes[1]
231 ax2.set_facecolor("#F8F9FA")
232 mha_total = (budgets_mb / stats.loc["MHA", "cache_mb"]) * stats.loc["MHA", "tps"]
234 for t in TYPE_ORDER:
235 if t not in stats.index:
236 continue
237 tps = stats.loc[t, "tps"]
238 cache = stats.loc[t, "cache_mb"]
239 ratio = ((budgets_mb / cache) * tps) / mha_total
240 ax2.plot(budgets_mb / 1024, ratio,
241 color=TYPE_COLORS[t], lw=2.8, label=t, zorder=4)
243 ax2.axhline(1.0, color=TYPE_COLORS["MHA"], lw=1.2, ls="--", alpha=0.6)
244 ax2.axhline(2.0, color="#CCCCCC", lw=0.8, ls=":", alpha=0.8)
245 ax2.axhline(2.5, color="#CCCCCC", lw=0.8, ls=":", alpha=0.8)
247 # Annotate at 80 GB
248 for t in TYPE_ORDER:
249 if t not in stats.index:
250 continue
251 tps = stats.loc[t, "tps"]
252 cache = stats.loc[t, "cache_mb"]
253 val_at_80 = ((80_000 / cache) * tps) / ((80_000 / stats.loc["MHA","cache_mb"]) * stats.loc["MHA","tps"])
254 ax2.annotate(f"{t}: {val_at_80:.1f}×",
255 xy=(80, val_at_80),
256 xytext=(-50, 8), textcoords="offset points",
257 fontsize=11, fontweight="bold", color=TYPE_COLORS[t],
258 arrowprops=dict(arrowstyle="-", color=TYPE_COLORS[t], lw=1))
260 for gb, lbl in [(24, "24 GB"), (40, "40 GB"), (80, "80 GB")]:
261 ax2.axvline(gb, color="#888", lw=1, ls="--", alpha=0.6)
262 ax2.text(gb * 1.02, 0.05, lbl, fontsize=8, color="#888",
263 transform=ax2.get_xaxis_transform())
265 ax2.set_xscale("log")
266 ax2.set_xlabel("KV Cache VRAM budget (GB)", fontsize=12)
267 ax2.set_ylabel("Throughput ratio (MHA = 1.0×)", fontsize=12)
268 ax2.set_title("Relative Serving Advantage vs MHA\n"
269 "(above 1.0 = more total tokens/s for the same VRAM)",
270 fontsize=12, fontweight="bold")
271 ax2.legend(fontsize=11, framealpha=0.9)
272 ax2.grid(True, which="both", alpha=0.2, ls="--")
273 ax2.spines[["top", "right"]].set_visible(False)
275 fig.suptitle(
276 "MLA's Smaller Cache Enables 2–3× More Total Throughput at Fixed VRAM\n"
277 "(even though MLA is slower per-sequence, it fits more concurrent users)",
278 fontsize=13, fontweight="bold", y=1.01
279 )
280 plt.tight_layout()
281 _save(fig, "insight_2_serving.png")
284# ─────────────────────────────────────────────────────────────────────────────
285# Figure 3 — MLA's down_dim_kv dial: quality vs cache
286# ─────────────────────────────────────────────────────────────────────────────
287def fig3_mla_dial(df):
288 """
289 MLA has a unique hyperparameter: down_dim_kv.
290 Increasing it improves quality but grows the KV cache linearly.
291 MHA and GQA have no equivalent knob — their cache is fixed by dim and kv_heads.
293 Left panel — down_dim_kv vs val_loss (quality cost of compression)
294 Right panel — down_dim_kv vs cache (linear relationship)
295 Both panels: MHA/GQA reference bands for context.
296 """
297 mla = df[(df["type"] == "MLA") & (~df["moe"])].copy()
298 other = df[(df["type"] != "MLA") & (~df["moe"])].copy()
300 if mla["down_dim_kv"].isna().all():
301 print("[fig3] No down_dim_kv data — skipping.")
302 return
304 # Aggregate MLA per down_dim_kv
305 agg = mla.groupby("down_dim_kv").agg(
306 val_loss_mean=("val_loss", "mean"),
307 val_loss_min =("val_loss", "min"),
308 val_loss_max =("val_loss", "max"),
309 cache_mean =("theoretical_cache_mb", "mean"),
310 cache_min =("theoretical_cache_mb", "min"),
311 cache_max =("theoretical_cache_mb", "max"),
312 n =("val_loss", "count"),
313 ).reset_index()
315 # MHA / GQA reference: overall median val_loss and cache across Dense runs
316 ref_mha_loss = other[other["type"] == "MHA"]["val_loss"].median()
317 ref_gqa_loss = other[other["type"] == "GQA"]["val_loss"].median()
318 ref_mha_cache = other[other["type"] == "MHA"]["theoretical_cache_mb"].median()
319 ref_gqa_cache = other[other["type"] == "GQA"]["theoretical_cache_mb"].median()
321 dkv = agg["down_dim_kv"].values
323 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6.5))
324 fig.patch.set_facecolor("#F8F9FA")
326 # ── Left: quality ────────────────────────────────────────────────────────
327 ax1.set_facecolor("#F8F9FA")
329 # MHA / GQA reference bands
330 for val, col, lbl in [
331 (ref_mha_loss, TYPE_COLORS["MHA"], "MHA median"),
332 (ref_gqa_loss, TYPE_COLORS["GQA"], "GQA median"),
333 ]:
334 ax1.axhline(val, color=col, lw=2, ls="--", alpha=0.75, label=lbl)
335 ax1.fill_between(
336 [dkv.min() * 0.8, dkv.max() * 1.2],
337 val * 0.97, val * 1.03,
338 color=col, alpha=0.08
339 )
341 # MLA line + IQR band
342 ax1.plot(dkv, agg["val_loss_mean"], color=TYPE_COLORS["MLA"],
343 lw=2.8, marker="o", markersize=8, label="MLA (mean)", zorder=5)
344 ax1.fill_between(dkv, agg["val_loss_min"], agg["val_loss_max"],
345 color=TYPE_COLORS["MLA"], alpha=0.18, zorder=3,
346 label="MLA (min–max range)")
348 # Annotate best MLA point
349 best = agg.loc[agg["val_loss_mean"].idxmin()]
350 ax1.scatter(best["down_dim_kv"], best["val_loss_mean"],
351 s=180, c=TYPE_COLORS["MLA"], edgecolors="black", lw=2, zorder=6)
352 ax1.annotate(f"best:\ndkv={int(best['down_dim_kv'])}\nloss={best['val_loss_mean']:.3f}",
353 (best["down_dim_kv"], best["val_loss_mean"]),
354 xytext=(14, -28), textcoords="offset points",
355 fontsize=9, color=TYPE_COLORS["MLA"],
356 arrowprops=dict(arrowstyle="-", color=TYPE_COLORS["MLA"], lw=1))
358 ax1.set_xlabel("down_dim_kv (MLA latent KV dimension)", fontsize=12)
359 ax1.set_ylabel("Validation Loss (NLL) ← lower is better", fontsize=12)
360 ax1.set_title("Quality vs down_dim_kv\n"
361 "MHA / GQA have no equivalent tuning knob",
362 fontsize=12, fontweight="bold")
363 ax1.legend(fontsize=10, framealpha=0.9, loc="upper right")
364 ax1.grid(alpha=0.2, ls="--")
365 ax1.spines[["top", "right"]].set_visible(False)
367 # ── Right: cache ─────────────────────────────────────────────────────────
368 ax2.set_facecolor("#F8F9FA")
370 for val, col, lbl in [
371 (ref_mha_cache, TYPE_COLORS["MHA"], "MHA median"),
372 (ref_gqa_cache, TYPE_COLORS["GQA"], "GQA median"),
373 ]:
374 ax2.axhline(val, color=col, lw=2, ls="--", alpha=0.75, label=lbl)
375 ax2.fill_between(
376 [dkv.min() * 0.8, dkv.max() * 1.2],
377 val * 0.85, val * 1.15,
378 color=col, alpha=0.08
379 )
381 ax2.plot(dkv, agg["cache_mean"], color=TYPE_COLORS["MLA"],
382 lw=2.8, marker="o", markersize=8, label="MLA (mean)", zorder=5)
383 ax2.fill_between(dkv, agg["cache_min"], agg["cache_max"],
384 color=TYPE_COLORS["MLA"], alpha=0.18, zorder=3,
385 label="MLA (min–max range)")
387 # Annotate the crossover points
388 for ref_val, ref_col, ref_lbl in [
389 (ref_mha_cache, TYPE_COLORS["MHA"], "MHA"),
390 (ref_gqa_cache, TYPE_COLORS["GQA"], "GQA"),
391 ]:
392 # Find where MLA cache crosses the reference
393 crosses = agg[agg["cache_mean"] <= ref_val]
394 if not crosses.empty:
395 cross_dkv = crosses["down_dim_kv"].max()
396 ax2.axvline(cross_dkv, color=ref_col, lw=1.2, ls=":", alpha=0.8)
397 ax2.text(cross_dkv, ref_val * 0.6,
398 f"MLA = {ref_lbl}\ncache at\ndkv={int(cross_dkv)}",
399 fontsize=8, color=ref_col, ha="center",
400 bbox=dict(boxstyle="round,pad=0.3", fc="white",
401 ec=ref_col, lw=0.8, alpha=0.85))
403 ax2.set_xlabel("down_dim_kv (MLA latent KV dimension)", fontsize=12)
404 ax2.set_ylabel("KV Cache (MB) @ 512 tokens", fontsize=12)
405 ax2.set_title("Cache Size vs down_dim_kv\n"
406 "Linear relationship — easy to predict at any scale",
407 fontsize=12, fontweight="bold")
408 ax2.legend(fontsize=10, framealpha=0.9, loc="upper left")
409 ax2.grid(alpha=0.2, ls="--")
410 ax2.spines[["top", "right"]].set_visible(False)
412 fig.suptitle(
413 "MLA's Unique Tuning Dial: down_dim_kv Controls the Quality-Cache Tradeoff\n"
414 "MHA / GQA are fixed — MLA lets you choose where on the curve to land",
415 fontsize=13, fontweight="bold", y=1.01
416 )
417 plt.tight_layout()
418 _save(fig, "insight_3_mla_dial.png")
421# ─────────────────────────────────────────────────────────────────────────────
422if __name__ == "__main__":
423 df = load()
424 print(f"Loaded {len(df)} runs")
425 print("\nGenerating insight figures...")
426 fig1_pareto(df)
427 fig2_serving(df)
428 fig3_mla_dial(df)
429 print("Done.")