Coverage for dantinox / plots / plot_3d.py: 0%
232 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 18:33 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 18:33 +0200
1#!/home/marco.simoni/miniconda3/bin/python3
2"""
3plot_3d.py — 3D visualisations of the MLA vs GQA vs MHA tradeoff.
5Run with: /home/marco.simoni/miniconda3/bin/python3 plot_3d.py
7Figures:
8 3d_1_cache_surface.png — KV-cache surface per type (params × seq_len → cache GB)
9 3d_2_quality_cube.png — 3D scatter: quality × cache × params
10 3d_3_efficiency_cube.png — 3D scatter: speed × cache_efficiency × quality
11 3d_4_serving_surface.png — aggregate serving throughput: VRAM budget × seq_len
12"""
14import os, sys
15import numpy as np
16import pandas as pd
17import matplotlib
18matplotlib.use("Agg")
19import matplotlib.pyplot as plt
20from mpl_toolkits.mplot3d import Axes3D # registers the 3d projection
21from mpl_toolkits.mplot3d.art3d import Poly3DCollection
22import matplotlib.patches as mpatches
23from matplotlib.lines import Line2D
25IN_CSV = "benchmark_results.csv"
26OUT_DIR = "plots"
27DPI = 210
29TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"}
30TYPE_ORDER = ["MLA", "GQA", "MHA"]
31MOE_ALPHA = {False: 0.88, True: 0.40}
32MOE_MARKER = {False: "o", True: "^"}
34SEQ_GRID = np.array([512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072])
35PARAMS_GRID = np.logspace(np.log10(3), np.log10(200), 60)
36VRAM_GB = [(80, "#CC3311", "80 GB H100"),
37 (24, "#EE7733", "24 GB RTX")]
40# ─── helpers ─────────────────────────────────────────────────────────────────
41def _save(fig, name):
42 os.makedirs(OUT_DIR, exist_ok=True)
43 p = os.path.join(OUT_DIR, name)
44 fig.savefig(p, dpi=DPI, bbox_inches="tight")
45 plt.close(fig)
46 print(f" Saved {p}")
49def _ax3(fig, pos, elev=22, azim=-50, fc="#F8F9FA"):
50 ax = fig.add_subplot(*pos, projection="3d")
51 ax.set_facecolor(fc)
52 ax.view_init(elev=elev, azim=azim)
53 return ax
56def load():
57 df = pd.read_csv(IN_CSV)
58 for c in ["params_m", "theoretical_cache_mb", "val_loss",
59 "tps_512", "tps_64", "max_context", "num_blocks",
60 "n_heads", "kv_heads", "dim", "down_dim_kv"]:
61 if c in df.columns:
62 df[c] = pd.to_numeric(df[c], errors="coerce")
63 df["bpt"] = df["theoretical_cache_mb"] * 1e6 / df["max_context"]
64 df["head_size"] = (df["dim"] / df["n_heads"]).round().astype("Int64")
65 return df
68def _bpt_fn(df, t):
69 """
70 Returns bpt(params_array) → array (bytes per token, as float).
71 MHA/GQA: log-log fit bpt ~ a * params^b.
72 MLA: median bpt (decoupled from params).
73 """
74 sub = df[(df["type"] == t) & (~df["moe"])].dropna(subset=["params_m", "bpt"])
75 if sub.empty:
76 return lambda p: np.zeros_like(p, dtype=float)
77 if t == "MLA":
78 m = float(sub["bpt"].median())
79 return lambda p: np.full(np.asarray(p).shape, m)
80 lx, ly = np.log10(sub["params_m"].values), np.log10(sub["bpt"].values)
81 b, la = np.polyfit(lx, ly, 1)
82 a = 10 ** la
83 return lambda p: a * np.asarray(p, dtype=float) ** b
86def _decode_flops_m(row, S=256):
87 """Approximate analytical decode FLOPs (M) for one token."""
88 dim = int(row["dim"]); n_heads = int(row["n_heads"])
89 kv = int(row["kv_heads"]); h = int(row["head_size"]); nb = int(row["num_blocks"])
90 exp = 4
91 mlp = 2 * dim * exp * dim * 3
92 if row["type"] == "MLA":
93 dq = dim // 2
94 dkv = int(row["down_dim_kv"]) if not pd.isna(row["down_dim_kv"]) else dim // 4
95 rope = max(16, dkv // 4)
96 attn = (2*dim*dq + 4*dim*rope
97 + 2*n_heads*dq*dkv*h + 2*n_heads*dq*dkv
98 + 2*n_heads*S*dkv + 2*n_heads*S*rope
99 + 2*n_heads*h*dkv*dim + 2*n_heads*dkv*dim)
100 else:
101 attn = (2*dim*dim + 4*dim*kv*h
102 + 4*n_heads*S*h + 2*dim*dim)
103 return (attn + mlp) * nb / 1e6
106# ─────────────────────────────────────────────────────────────────────────────
107# Fig 1 — KV-cache surface: params × seq_len → cache (GB)
108# Three subplots + one combined, same z-scale.
109# ─────────────────────────────────────────────────────────────────────────────
110def fig1_cache_surface(df):
111 dense = df[~df["moe"]]
112 P2D, S2D = np.meshgrid(PARAMS_GRID, SEQ_GRID)
114 surfaces = {}
115 z_global = 0.0
116 for t in TYPE_ORDER:
117 fn = _bpt_fn(dense, t)
118 Z = fn(P2D.ravel()).reshape(P2D.shape) * S2D / 1e9
119 surfaces[t] = Z
120 z_global = max(z_global, float(np.nanmax(Z)))
121 z_cap = min(z_global, 300)
123 fig = plt.figure(figsize=(22, 10))
124 fig.patch.set_facecolor("#F8F9FA")
126 # ── Three separate subplots (top row) ───────────────────────────────────
127 for col, t in enumerate(TYPE_ORDER):
128 ax = _ax3(fig, (2, 4, col + 1), elev=24, azim=-52)
129 Z = np.minimum(surfaces[t], z_cap)
130 c = TYPE_COLORS[t]
132 # Surface
133 ax.plot_surface(np.log10(P2D), np.log10(S2D / 1024), Z,
134 color=c, alpha=0.55, linewidth=0, antialiased=True)
136 # Projected contour on the floor (z=0)
137 ax.contourf(np.log10(P2D), np.log10(S2D / 1024), Z,
138 zdir="z", offset=0,
139 levels=np.linspace(0, z_cap, 12),
140 cmap="Blues" if t == "MLA" else
141 "Oranges" if t == "GQA" else "Greens",
142 alpha=0.35)
144 # VRAM planes
145 lp = np.log10(PARAMS_GRID[[0, -1]])
146 ls = np.log10(SEQ_GRID[[0, -1]] / 1024)
147 for vram, vc, vlbl in VRAM_GB:
148 if vram > z_cap: continue
149 verts = [[(lp[0], ls[0], vram), (lp[1], ls[0], vram),
150 (lp[1], ls[1], vram), (lp[0], ls[1], vram)]]
151 poly = Poly3DCollection(verts, alpha=0.20)
152 poly.set_facecolor(vc); poly.set_edgecolor(vc)
153 ax.add_collection3d(poly)
154 ax.text(lp[1], ls[0], vram + 2, vlbl,
155 color=vc, fontsize=7, fontweight="bold")
157 # Axis ticks & labels
158 _set_cache_axes(ax, z_cap)
159 ax.set_title(t, fontsize=14, fontweight="bold", color=c, pad=6)
161 # ── Combined plot (bottom row, spans all columns) ────────────────────────
162 ax_c = _ax3(fig, (2, 1, 2), elev=26, azim=-46)
164 for t in TYPE_ORDER:
165 Z = np.minimum(surfaces[t], z_cap)
166 c = TYPE_COLORS[t]
167 ax_c.plot_surface(np.log10(P2D), np.log10(S2D / 1024), Z,
168 color=c, alpha=0.32, linewidth=0, antialiased=True)
169 ax_c.plot_wireframe(np.log10(P2D), np.log10(S2D / 1024), Z,
170 color=c, alpha=0.18, linewidth=0.5,
171 rstride=3, cstride=3)
173 lp = np.log10(PARAMS_GRID[[0, -1]])
174 ls = np.log10(SEQ_GRID[[0, -1]] / 1024)
175 for vram, vc, vlbl in VRAM_GB:
176 if vram > z_cap: continue
177 verts = [[(lp[0], ls[0], vram), (lp[1], ls[0], vram),
178 (lp[1], ls[1], vram), (lp[0], ls[1], vram)]]
179 poly = Poly3DCollection(verts, alpha=0.22)
180 poly.set_facecolor(vc); poly.set_edgecolor(vc)
181 ax_c.add_collection3d(poly)
182 ax_c.text(lp[0], ls[1], vram + 3, f" {vlbl}",
183 color=vc, fontsize=10, fontweight="bold", va="bottom")
185 _set_cache_axes(ax_c, z_cap)
186 ax_c.set_title("All types — same scale", fontsize=12, fontweight="bold", pad=6)
188 # Proxy legend for combined plot
189 handles = [mpatches.Patch(color=TYPE_COLORS[t], alpha=0.7, label=t)
190 for t in TYPE_ORDER]
191 for vram, vc, vlbl in VRAM_GB:
192 handles.append(mpatches.Patch(color=vc, alpha=0.45, label=vlbl))
193 ax_c.legend(handles=handles, fontsize=10, loc="upper left",
194 framealpha=0.9, bbox_to_anchor=(-0.02, 1.0))
196 fig.suptitle(
197 "KV Cache (GB) = f(Model Parameters, Context Length)\n"
198 "MLA surface is flat along the params axis — "
199 "growing the model doesn't inflate its cache",
200 fontsize=13, fontweight="bold", y=1.01
201 )
202 plt.tight_layout()
203 _save(fig, "3d_1_cache_surface.png")
206def _set_cache_axes(ax, z_cap):
207 p_ticks = [3, 10, 30, 100]
208 s_ticks = [0.512, 1, 4, 16, 64, 128]
209 ax.set_xticks([np.log10(v) for v in p_ticks])
210 ax.set_xticklabels([f"{v}M" for v in p_ticks], fontsize=7)
211 ax.set_yticks([np.log10(v) for v in s_ticks])
212 ax.set_yticklabels(["512" if v < 1 else f"{int(v)}K" for v in s_ticks], fontsize=7)
213 ax.set_xlabel("Model parameters", fontsize=9, labelpad=6)
214 ax.set_ylabel("Context length", fontsize=9, labelpad=6)
215 ax.set_zlabel("KV Cache per seq (GB)", fontsize=9, labelpad=8)
216 ax.set_zlim(0, z_cap)
217 ax.tick_params(axis="z", labelsize=7)
220# ─────────────────────────────────────────────────────────────────────────────
221# Fig 2 — 3D scatter: params × KV-cache × val_loss
222# The "quality cube": where do types land in the 3-way tradeoff?
223# ─────────────────────────────────────────────────────────────────────────────
224def fig2_quality_cube(df):
225 sub = df.dropna(subset=["params_m", "theoretical_cache_mb", "val_loss"]).copy()
227 fig = plt.figure(figsize=(16, 8))
228 fig.patch.set_facecolor("#F8F9FA")
230 for pi, (moe_v, title) in enumerate([(False, "Dense"), (True, "MoE")]):
231 ax = _ax3(fig, (1, 2, pi + 1), elev=20, azim=-55)
232 grp = sub[sub["moe"] == moe_v]
234 for t in TYPE_ORDER:
235 pts = grp[grp["type"] == t]
236 if pts.empty: continue
237 ax.scatter(
238 pts["params_m"],
239 pts["theoretical_cache_mb"],
240 pts["val_loss"],
241 c=TYPE_COLORS[t], s=55,
242 marker=MOE_MARKER[moe_v],
243 edgecolors="white", linewidths=0.6,
244 alpha=0.85, zorder=4, label=t
245 )
247 # Ideal corner arrow annotation
248 ax.text(sub["params_m"].min(), sub["theoretical_cache_mb"].min(),
249 sub["val_loss"].min() - 0.04,
250 "← ideal", fontsize=9, color="#555", style="italic")
252 ax.set_xlabel("Parameters (M)", fontsize=10, labelpad=6)
253 ax.set_ylabel("KV Cache (MB)", fontsize=10, labelpad=6)
254 ax.set_zlabel("Validation Loss ↓", fontsize=10, labelpad=8)
255 ax.set_title(f"{title} — Quality · Cache · Size cube",
256 fontsize=12, fontweight="bold", pad=8)
257 ax.tick_params(labelsize=8)
259 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER
260 if not grp[grp["type"] == t].empty]
261 ax.legend(handles=handles, fontsize=10, framealpha=0.9,
262 loc="upper right")
264 fig.suptitle(
265 "The Quality–Cache–Size Cube\n"
266 "Lower-left-front corner = ideal: fewer params, smaller cache, better quality",
267 fontsize=13, fontweight="bold", y=1.01
268 )
269 plt.tight_layout()
270 _save(fig, "3d_2_quality_cube.png")
273# ─────────────────────────────────────────────────────────────────────────────
274# Fig 3 — 3D scatter: throughput × cache-efficiency × quality
275# All three axes: higher = better.
276# cache-efficiency = tps / cache_mb (tok/s per MB of KV cache)
277# Upper-right-front corner = Pareto-optimal.
278# ─────────────────────────────────────────────────────────────────────────────
279def fig3_efficiency_cube(df):
280 sub = df.dropna(subset=["tps_512", "theoretical_cache_mb", "val_loss"]).copy()
281 sub["cache_eff"] = sub["tps_512"] / sub["theoretical_cache_mb"]
282 sub["quality"] = 1.0 / sub["val_loss"] # higher = better
284 fig = plt.figure(figsize=(16, 8))
285 fig.patch.set_facecolor("#F8F9FA")
287 for pi, (moe_v, title) in enumerate([(False, "Dense"), (True, "MoE")]):
288 ax = _ax3(fig, (1, 2, pi + 1), elev=22, azim=-48)
289 grp = sub[sub["moe"] == moe_v]
291 for t in TYPE_ORDER:
292 pts = grp[grp["type"] == t]
293 if pts.empty: continue
294 sz = (pts["params_m"] / sub["params_m"].max() * 300).clip(lower=20)
295 ax.scatter(
296 pts["tps_512"],
297 pts["cache_eff"],
298 pts["quality"],
299 c=TYPE_COLORS[t], s=sz,
300 marker=MOE_MARKER[moe_v],
301 edgecolors="white", linewidths=0.6,
302 alpha=0.85, zorder=4, label=t
303 )
305 # Ideal corner
306 ax.text(grp["tps_512"].max() * 0.95,
307 grp["cache_eff"].max() * 0.95,
308 grp["quality"].max() * 1.01,
309 "ideal ↗", fontsize=9, color="#555", style="italic")
311 ax.set_xlabel("Decode Throughput (tok/s) ↑", fontsize=9, labelpad=6)
312 ax.set_ylabel("Cache Efficiency\n(tok/s per MB cache) ↑", fontsize=9, labelpad=8)
313 ax.set_zlabel("Quality (1/val_loss) ↑", fontsize=9, labelpad=8)
314 ax.set_title(f"{title} — Efficiency cube\n(bubble size ∝ params)",
315 fontsize=11, fontweight="bold", pad=8)
316 ax.tick_params(labelsize=8)
318 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER
319 if not grp[grp["type"] == t].empty]
320 ax.legend(handles=handles, fontsize=10, framealpha=0.9, loc="upper left")
322 fig.suptitle(
323 "Efficiency Cube: Speed × Cache Efficiency × Quality\n"
324 "All three axes ↑ = better. Upper-right-front = Pareto-optimal.",
325 fontsize=13, fontweight="bold", y=1.01
326 )
327 plt.tight_layout()
328 _save(fig, "3d_3_efficiency_cube.png")
331# ─────────────────────────────────────────────────────────────────────────────
332# Fig 4 — Serving surface: VRAM budget × seq_len → aggregate tok/s
333# Shows at what (budget, context) each type breaks down.
334# ─────────────────────────────────────────────────────────────────────────────
335def fig4_serving_surface(df):
336 dense = df[~df["moe"]]
338 budgets_gb = np.logspace(np.log10(1), np.log10(80), 60) # 1–80 GB
339 seq_lens = SEQ_GRID
341 B2D, S2D = np.meshgrid(budgets_gb, seq_lens)
343 fig = plt.figure(figsize=(22, 8))
344 fig.patch.set_facecolor("#F8F9FA")
346 surfaces = {}
347 z_cap = 0.0
349 for t in TYPE_ORDER:
350 fn_bpt = _bpt_fn(dense, t)
351 # median tps at 512 — we'll scale it by 512/S to approximate seqlen effect
352 tps_512_med = float(dense[dense["type"] == t]["tps_512"].median())
354 # bpt at MEDIAN params (the representative model)
355 bpt_med = float(dense[dense["type"] == t]["bpt"].median())
357 # Cache per seq (GB) at each seq_len
358 cache_per_seq_gb = bpt_med * S2D / 1e9
360 # Max concurrent sequences in budget
361 n_seq = B2D / cache_per_seq_gb
363 # Throughput approximation: tps scales slightly with seq_len (from our data)
364 # Use measured ratios: tps doesn't change much with seqlen in our data
365 # so we use tps_512_med as a flat estimate
366 total_tps = n_seq * tps_512_med / 1e3 # k tok/s
368 # Clip where cache_per_seq > budget (can't fit even 1 sequence)
369 total_tps = np.where(cache_per_seq_gb > B2D, 0.0, total_tps)
370 surfaces[t] = total_tps
371 z_cap = max(z_cap, float(np.nanmax(total_tps)))
373 z_cap = min(z_cap, 5000) # k tok/s cap
375 for col, t in enumerate(TYPE_ORDER):
376 ax = _ax3(fig, (1, 3, col + 1), elev=26, azim=-52)
377 Z = np.minimum(surfaces[t], z_cap)
378 c = TYPE_COLORS[t]
380 ax.plot_surface(np.log10(B2D), np.log10(S2D / 1024), Z,
381 color=c, alpha=0.58, linewidth=0, antialiased=True)
383 # Floor projection
384 ax.contourf(np.log10(B2D), np.log10(S2D / 1024), Z,
385 zdir="z", offset=0,
386 levels=np.linspace(0, z_cap, 10),
387 cmap="Blues" if t == "MLA" else
388 "Oranges" if t == "GQA" else "Greens",
389 alpha=0.40)
391 # Budget reference lines (vertical planes at 24 GB and 80 GB)
392 ls = np.log10(seq_lens[[0, -1]] / 1024)
393 for gb, vc, vlbl in VRAM_GB:
394 x_v = np.log10(gb)
395 z_line = np.minimum(surfaces[t][:, np.argmin(np.abs(budgets_gb - gb))], z_cap)
396 ax.plot([x_v] * len(seq_lens), np.log10(seq_lens / 1024), z_line,
397 color=vc, lw=2.2, alpha=0.9, zorder=5)
398 ax.text(x_v, ls[0], z_line[0] + z_cap * 0.02,
399 vlbl, color=vc, fontsize=7, fontweight="bold")
401 # Axes
402 b_ticks = [1, 5, 10, 40, 80]
403 s_ticks = [0.512, 1, 4, 16, 64, 128]
404 ax.set_xticks([np.log10(v) for v in b_ticks])
405 ax.set_xticklabels([f"{v}GB" for v in b_ticks], fontsize=7)
406 ax.set_yticks([np.log10(v) for v in s_ticks])
407 ax.set_yticklabels(["512" if v < 1 else f"{int(v)}K" for v in s_ticks], fontsize=7)
408 ax.set_zlim(0, z_cap)
409 ax.tick_params(axis="z", labelsize=7)
410 ax.set_xlabel("KV-cache VRAM budget", fontsize=9, labelpad=6)
411 ax.set_ylabel("Context length", fontsize=9, labelpad=6)
412 ax.set_zlabel("Aggregate throughput (k tok/s)", fontsize=9, labelpad=8)
413 ax.set_title(t, fontsize=14, fontweight="bold", color=c, pad=6)
415 fig.suptitle(
416 "Aggregate Serving Throughput = f(VRAM Budget, Context Length)\n"
417 "The coloured lines mark throughput at 24 GB / 80 GB GPU. "
418 "MLA's surface stays higher across every (budget, context) point.",
419 fontsize=12, fontweight="bold", y=1.01
420 )
421 plt.tight_layout()
422 _save(fig, "3d_4_serving_surface.png")
425# ─────────────────────────────────────────────────────────────────────────────
426if __name__ == "__main__":
427 os.chdir(os.path.dirname(os.path.abspath(__file__)))
428 df = load()
429 print(f"Loaded {len(df)} runs")
430 print("\nGenerating 3D figures...")
431 fig1_cache_surface(df)
432 fig2_quality_cube(df)
433 fig3_efficiency_cube(df)
434 fig4_serving_surface(df)
435 print("Done.")