Coverage for dantinox / plots / plot_3d.py: 0%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 18:33 +0200

1#!/home/marco.simoni/miniconda3/bin/python3 

2""" 

3plot_3d.py — 3D visualisations of the MLA vs GQA vs MHA tradeoff. 

4 

5Run with: /home/marco.simoni/miniconda3/bin/python3 plot_3d.py 

6 

7Figures: 

8 3d_1_cache_surface.png — KV-cache surface per type (params × seq_len → cache GB) 

9 3d_2_quality_cube.png — 3D scatter: quality × cache × params 

10 3d_3_efficiency_cube.png — 3D scatter: speed × cache_efficiency × quality 

11 3d_4_serving_surface.png — aggregate serving throughput: VRAM budget × seq_len 

12""" 

13 

14import os, sys 

15import numpy as np 

16import pandas as pd 

17import matplotlib 

18matplotlib.use("Agg") 

19import matplotlib.pyplot as plt 

20from mpl_toolkits.mplot3d import Axes3D # registers the 3d projection 

21from mpl_toolkits.mplot3d.art3d import Poly3DCollection 

22import matplotlib.patches as mpatches 

23from matplotlib.lines import Line2D 

24 

25IN_CSV = "benchmark_results.csv" 

26OUT_DIR = "plots" 

27DPI = 210 

28 

29TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"} 

30TYPE_ORDER = ["MLA", "GQA", "MHA"] 

31MOE_ALPHA = {False: 0.88, True: 0.40} 

32MOE_MARKER = {False: "o", True: "^"} 

33 

34SEQ_GRID = np.array([512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]) 

35PARAMS_GRID = np.logspace(np.log10(3), np.log10(200), 60) 

36VRAM_GB = [(80, "#CC3311", "80 GB H100"), 

37 (24, "#EE7733", "24 GB RTX")] 

38 

39 

40# ─── helpers ───────────────────────────────────────────────────────────────── 

41def _save(fig, name): 

42 os.makedirs(OUT_DIR, exist_ok=True) 

43 p = os.path.join(OUT_DIR, name) 

44 fig.savefig(p, dpi=DPI, bbox_inches="tight") 

45 plt.close(fig) 

46 print(f" Saved {p}") 

47 

48 

49def _ax3(fig, pos, elev=22, azim=-50, fc="#F8F9FA"): 

50 ax = fig.add_subplot(*pos, projection="3d") 

51 ax.set_facecolor(fc) 

52 ax.view_init(elev=elev, azim=azim) 

53 return ax 

54 

55 

56def load(): 

57 df = pd.read_csv(IN_CSV) 

58 for c in ["params_m", "theoretical_cache_mb", "val_loss", 

59 "tps_512", "tps_64", "max_context", "num_blocks", 

60 "n_heads", "kv_heads", "dim", "down_dim_kv"]: 

61 if c in df.columns: 

62 df[c] = pd.to_numeric(df[c], errors="coerce") 

63 df["bpt"] = df["theoretical_cache_mb"] * 1e6 / df["max_context"] 

64 df["head_size"] = (df["dim"] / df["n_heads"]).round().astype("Int64") 

65 return df 

66 

67 

68def _bpt_fn(df, t): 

69 """ 

70 Returns bpt(params_array) → array (bytes per token, as float). 

71 MHA/GQA: log-log fit bpt ~ a * params^b. 

72 MLA: median bpt (decoupled from params). 

73 """ 

74 sub = df[(df["type"] == t) & (~df["moe"])].dropna(subset=["params_m", "bpt"]) 

75 if sub.empty: 

76 return lambda p: np.zeros_like(p, dtype=float) 

77 if t == "MLA": 

78 m = float(sub["bpt"].median()) 

79 return lambda p: np.full(np.asarray(p).shape, m) 

80 lx, ly = np.log10(sub["params_m"].values), np.log10(sub["bpt"].values) 

81 b, la = np.polyfit(lx, ly, 1) 

82 a = 10 ** la 

83 return lambda p: a * np.asarray(p, dtype=float) ** b 

84 

85 

86def _decode_flops_m(row, S=256): 

87 """Approximate analytical decode FLOPs (M) for one token.""" 

88 dim = int(row["dim"]); n_heads = int(row["n_heads"]) 

89 kv = int(row["kv_heads"]); h = int(row["head_size"]); nb = int(row["num_blocks"]) 

90 exp = 4 

91 mlp = 2 * dim * exp * dim * 3 

92 if row["type"] == "MLA": 

93 dq = dim // 2 

94 dkv = int(row["down_dim_kv"]) if not pd.isna(row["down_dim_kv"]) else dim // 4 

95 rope = max(16, dkv // 4) 

96 attn = (2*dim*dq + 4*dim*rope 

97 + 2*n_heads*dq*dkv*h + 2*n_heads*dq*dkv 

98 + 2*n_heads*S*dkv + 2*n_heads*S*rope 

99 + 2*n_heads*h*dkv*dim + 2*n_heads*dkv*dim) 

100 else: 

101 attn = (2*dim*dim + 4*dim*kv*h 

102 + 4*n_heads*S*h + 2*dim*dim) 

103 return (attn + mlp) * nb / 1e6 

104 

105 

106# ───────────────────────────────────────────────────────────────────────────── 

107# Fig 1 — KV-cache surface: params × seq_len → cache (GB) 

108# Three subplots + one combined, same z-scale. 

109# ───────────────────────────────────────────────────────────────────────────── 

110def fig1_cache_surface(df): 

111 dense = df[~df["moe"]] 

112 P2D, S2D = np.meshgrid(PARAMS_GRID, SEQ_GRID) 

113 

114 surfaces = {} 

115 z_global = 0.0 

116 for t in TYPE_ORDER: 

117 fn = _bpt_fn(dense, t) 

118 Z = fn(P2D.ravel()).reshape(P2D.shape) * S2D / 1e9 

119 surfaces[t] = Z 

120 z_global = max(z_global, float(np.nanmax(Z))) 

121 z_cap = min(z_global, 300) 

122 

123 fig = plt.figure(figsize=(22, 10)) 

124 fig.patch.set_facecolor("#F8F9FA") 

125 

126 # ── Three separate subplots (top row) ─────────────────────────────────── 

127 for col, t in enumerate(TYPE_ORDER): 

128 ax = _ax3(fig, (2, 4, col + 1), elev=24, azim=-52) 

129 Z = np.minimum(surfaces[t], z_cap) 

130 c = TYPE_COLORS[t] 

131 

132 # Surface 

133 ax.plot_surface(np.log10(P2D), np.log10(S2D / 1024), Z, 

134 color=c, alpha=0.55, linewidth=0, antialiased=True) 

135 

136 # Projected contour on the floor (z=0) 

137 ax.contourf(np.log10(P2D), np.log10(S2D / 1024), Z, 

138 zdir="z", offset=0, 

139 levels=np.linspace(0, z_cap, 12), 

140 cmap="Blues" if t == "MLA" else 

141 "Oranges" if t == "GQA" else "Greens", 

142 alpha=0.35) 

143 

144 # VRAM planes 

145 lp = np.log10(PARAMS_GRID[[0, -1]]) 

146 ls = np.log10(SEQ_GRID[[0, -1]] / 1024) 

147 for vram, vc, vlbl in VRAM_GB: 

148 if vram > z_cap: continue 

149 verts = [[(lp[0], ls[0], vram), (lp[1], ls[0], vram), 

150 (lp[1], ls[1], vram), (lp[0], ls[1], vram)]] 

151 poly = Poly3DCollection(verts, alpha=0.20) 

152 poly.set_facecolor(vc); poly.set_edgecolor(vc) 

153 ax.add_collection3d(poly) 

154 ax.text(lp[1], ls[0], vram + 2, vlbl, 

155 color=vc, fontsize=7, fontweight="bold") 

156 

157 # Axis ticks & labels 

158 _set_cache_axes(ax, z_cap) 

159 ax.set_title(t, fontsize=14, fontweight="bold", color=c, pad=6) 

160 

161 # ── Combined plot (bottom row, spans all columns) ──────────────────────── 

162 ax_c = _ax3(fig, (2, 1, 2), elev=26, azim=-46) 

163 

164 for t in TYPE_ORDER: 

165 Z = np.minimum(surfaces[t], z_cap) 

166 c = TYPE_COLORS[t] 

167 ax_c.plot_surface(np.log10(P2D), np.log10(S2D / 1024), Z, 

168 color=c, alpha=0.32, linewidth=0, antialiased=True) 

169 ax_c.plot_wireframe(np.log10(P2D), np.log10(S2D / 1024), Z, 

170 color=c, alpha=0.18, linewidth=0.5, 

171 rstride=3, cstride=3) 

172 

173 lp = np.log10(PARAMS_GRID[[0, -1]]) 

174 ls = np.log10(SEQ_GRID[[0, -1]] / 1024) 

175 for vram, vc, vlbl in VRAM_GB: 

176 if vram > z_cap: continue 

177 verts = [[(lp[0], ls[0], vram), (lp[1], ls[0], vram), 

178 (lp[1], ls[1], vram), (lp[0], ls[1], vram)]] 

179 poly = Poly3DCollection(verts, alpha=0.22) 

180 poly.set_facecolor(vc); poly.set_edgecolor(vc) 

181 ax_c.add_collection3d(poly) 

182 ax_c.text(lp[0], ls[1], vram + 3, f" {vlbl}", 

183 color=vc, fontsize=10, fontweight="bold", va="bottom") 

184 

185 _set_cache_axes(ax_c, z_cap) 

186 ax_c.set_title("All types — same scale", fontsize=12, fontweight="bold", pad=6) 

187 

188 # Proxy legend for combined plot 

189 handles = [mpatches.Patch(color=TYPE_COLORS[t], alpha=0.7, label=t) 

190 for t in TYPE_ORDER] 

191 for vram, vc, vlbl in VRAM_GB: 

192 handles.append(mpatches.Patch(color=vc, alpha=0.45, label=vlbl)) 

193 ax_c.legend(handles=handles, fontsize=10, loc="upper left", 

194 framealpha=0.9, bbox_to_anchor=(-0.02, 1.0)) 

195 

196 fig.suptitle( 

197 "KV Cache (GB) = f(Model Parameters, Context Length)\n" 

198 "MLA surface is flat along the params axis — " 

199 "growing the model doesn't inflate its cache", 

200 fontsize=13, fontweight="bold", y=1.01 

201 ) 

202 plt.tight_layout() 

203 _save(fig, "3d_1_cache_surface.png") 

204 

205 

206def _set_cache_axes(ax, z_cap): 

207 p_ticks = [3, 10, 30, 100] 

208 s_ticks = [0.512, 1, 4, 16, 64, 128] 

209 ax.set_xticks([np.log10(v) for v in p_ticks]) 

210 ax.set_xticklabels([f"{v}M" for v in p_ticks], fontsize=7) 

211 ax.set_yticks([np.log10(v) for v in s_ticks]) 

212 ax.set_yticklabels(["512" if v < 1 else f"{int(v)}K" for v in s_ticks], fontsize=7) 

213 ax.set_xlabel("Model parameters", fontsize=9, labelpad=6) 

214 ax.set_ylabel("Context length", fontsize=9, labelpad=6) 

215 ax.set_zlabel("KV Cache per seq (GB)", fontsize=9, labelpad=8) 

216 ax.set_zlim(0, z_cap) 

217 ax.tick_params(axis="z", labelsize=7) 

218 

219 

220# ───────────────────────────────────────────────────────────────────────────── 

221# Fig 2 — 3D scatter: params × KV-cache × val_loss 

222# The "quality cube": where do types land in the 3-way tradeoff? 

223# ───────────────────────────────────────────────────────────────────────────── 

224def fig2_quality_cube(df): 

225 sub = df.dropna(subset=["params_m", "theoretical_cache_mb", "val_loss"]).copy() 

226 

227 fig = plt.figure(figsize=(16, 8)) 

228 fig.patch.set_facecolor("#F8F9FA") 

229 

230 for pi, (moe_v, title) in enumerate([(False, "Dense"), (True, "MoE")]): 

231 ax = _ax3(fig, (1, 2, pi + 1), elev=20, azim=-55) 

232 grp = sub[sub["moe"] == moe_v] 

233 

234 for t in TYPE_ORDER: 

235 pts = grp[grp["type"] == t] 

236 if pts.empty: continue 

237 ax.scatter( 

238 pts["params_m"], 

239 pts["theoretical_cache_mb"], 

240 pts["val_loss"], 

241 c=TYPE_COLORS[t], s=55, 

242 marker=MOE_MARKER[moe_v], 

243 edgecolors="white", linewidths=0.6, 

244 alpha=0.85, zorder=4, label=t 

245 ) 

246 

247 # Ideal corner arrow annotation 

248 ax.text(sub["params_m"].min(), sub["theoretical_cache_mb"].min(), 

249 sub["val_loss"].min() - 0.04, 

250 "← ideal", fontsize=9, color="#555", style="italic") 

251 

252 ax.set_xlabel("Parameters (M)", fontsize=10, labelpad=6) 

253 ax.set_ylabel("KV Cache (MB)", fontsize=10, labelpad=6) 

254 ax.set_zlabel("Validation Loss ↓", fontsize=10, labelpad=8) 

255 ax.set_title(f"{title} — Quality · Cache · Size cube", 

256 fontsize=12, fontweight="bold", pad=8) 

257 ax.tick_params(labelsize=8) 

258 

259 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER 

260 if not grp[grp["type"] == t].empty] 

261 ax.legend(handles=handles, fontsize=10, framealpha=0.9, 

262 loc="upper right") 

263 

264 fig.suptitle( 

265 "The Quality–Cache–Size Cube\n" 

266 "Lower-left-front corner = ideal: fewer params, smaller cache, better quality", 

267 fontsize=13, fontweight="bold", y=1.01 

268 ) 

269 plt.tight_layout() 

270 _save(fig, "3d_2_quality_cube.png") 

271 

272 

273# ───────────────────────────────────────────────────────────────────────────── 

274# Fig 3 — 3D scatter: throughput × cache-efficiency × quality 

275# All three axes: higher = better. 

276# cache-efficiency = tps / cache_mb (tok/s per MB of KV cache) 

277# Upper-right-front corner = Pareto-optimal. 

278# ───────────────────────────────────────────────────────────────────────────── 

279def fig3_efficiency_cube(df): 

280 sub = df.dropna(subset=["tps_512", "theoretical_cache_mb", "val_loss"]).copy() 

281 sub["cache_eff"] = sub["tps_512"] / sub["theoretical_cache_mb"] 

282 sub["quality"] = 1.0 / sub["val_loss"] # higher = better 

283 

284 fig = plt.figure(figsize=(16, 8)) 

285 fig.patch.set_facecolor("#F8F9FA") 

286 

287 for pi, (moe_v, title) in enumerate([(False, "Dense"), (True, "MoE")]): 

288 ax = _ax3(fig, (1, 2, pi + 1), elev=22, azim=-48) 

289 grp = sub[sub["moe"] == moe_v] 

290 

291 for t in TYPE_ORDER: 

292 pts = grp[grp["type"] == t] 

293 if pts.empty: continue 

294 sz = (pts["params_m"] / sub["params_m"].max() * 300).clip(lower=20) 

295 ax.scatter( 

296 pts["tps_512"], 

297 pts["cache_eff"], 

298 pts["quality"], 

299 c=TYPE_COLORS[t], s=sz, 

300 marker=MOE_MARKER[moe_v], 

301 edgecolors="white", linewidths=0.6, 

302 alpha=0.85, zorder=4, label=t 

303 ) 

304 

305 # Ideal corner 

306 ax.text(grp["tps_512"].max() * 0.95, 

307 grp["cache_eff"].max() * 0.95, 

308 grp["quality"].max() * 1.01, 

309 "ideal ↗", fontsize=9, color="#555", style="italic") 

310 

311 ax.set_xlabel("Decode Throughput (tok/s) ↑", fontsize=9, labelpad=6) 

312 ax.set_ylabel("Cache Efficiency\n(tok/s per MB cache) ↑", fontsize=9, labelpad=8) 

313 ax.set_zlabel("Quality (1/val_loss) ↑", fontsize=9, labelpad=8) 

314 ax.set_title(f"{title} — Efficiency cube\n(bubble size ∝ params)", 

315 fontsize=11, fontweight="bold", pad=8) 

316 ax.tick_params(labelsize=8) 

317 

318 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER 

319 if not grp[grp["type"] == t].empty] 

320 ax.legend(handles=handles, fontsize=10, framealpha=0.9, loc="upper left") 

321 

322 fig.suptitle( 

323 "Efficiency Cube: Speed × Cache Efficiency × Quality\n" 

324 "All three axes ↑ = better. Upper-right-front = Pareto-optimal.", 

325 fontsize=13, fontweight="bold", y=1.01 

326 ) 

327 plt.tight_layout() 

328 _save(fig, "3d_3_efficiency_cube.png") 

329 

330 

331# ───────────────────────────────────────────────────────────────────────────── 

332# Fig 4 — Serving surface: VRAM budget × seq_len → aggregate tok/s 

333# Shows at what (budget, context) each type breaks down. 

334# ───────────────────────────────────────────────────────────────────────────── 

335def fig4_serving_surface(df): 

336 dense = df[~df["moe"]] 

337 

338 budgets_gb = np.logspace(np.log10(1), np.log10(80), 60) # 1–80 GB 

339 seq_lens = SEQ_GRID 

340 

341 B2D, S2D = np.meshgrid(budgets_gb, seq_lens) 

342 

343 fig = plt.figure(figsize=(22, 8)) 

344 fig.patch.set_facecolor("#F8F9FA") 

345 

346 surfaces = {} 

347 z_cap = 0.0 

348 

349 for t in TYPE_ORDER: 

350 fn_bpt = _bpt_fn(dense, t) 

351 # median tps at 512 — we'll scale it by 512/S to approximate seqlen effect 

352 tps_512_med = float(dense[dense["type"] == t]["tps_512"].median()) 

353 

354 # bpt at MEDIAN params (the representative model) 

355 bpt_med = float(dense[dense["type"] == t]["bpt"].median()) 

356 

357 # Cache per seq (GB) at each seq_len 

358 cache_per_seq_gb = bpt_med * S2D / 1e9 

359 

360 # Max concurrent sequences in budget 

361 n_seq = B2D / cache_per_seq_gb 

362 

363 # Throughput approximation: tps scales slightly with seq_len (from our data) 

364 # Use measured ratios: tps doesn't change much with seqlen in our data 

365 # so we use tps_512_med as a flat estimate 

366 total_tps = n_seq * tps_512_med / 1e3 # k tok/s 

367 

368 # Clip where cache_per_seq > budget (can't fit even 1 sequence) 

369 total_tps = np.where(cache_per_seq_gb > B2D, 0.0, total_tps) 

370 surfaces[t] = total_tps 

371 z_cap = max(z_cap, float(np.nanmax(total_tps))) 

372 

373 z_cap = min(z_cap, 5000) # k tok/s cap 

374 

375 for col, t in enumerate(TYPE_ORDER): 

376 ax = _ax3(fig, (1, 3, col + 1), elev=26, azim=-52) 

377 Z = np.minimum(surfaces[t], z_cap) 

378 c = TYPE_COLORS[t] 

379 

380 ax.plot_surface(np.log10(B2D), np.log10(S2D / 1024), Z, 

381 color=c, alpha=0.58, linewidth=0, antialiased=True) 

382 

383 # Floor projection 

384 ax.contourf(np.log10(B2D), np.log10(S2D / 1024), Z, 

385 zdir="z", offset=0, 

386 levels=np.linspace(0, z_cap, 10), 

387 cmap="Blues" if t == "MLA" else 

388 "Oranges" if t == "GQA" else "Greens", 

389 alpha=0.40) 

390 

391 # Budget reference lines (vertical planes at 24 GB and 80 GB) 

392 ls = np.log10(seq_lens[[0, -1]] / 1024) 

393 for gb, vc, vlbl in VRAM_GB: 

394 x_v = np.log10(gb) 

395 z_line = np.minimum(surfaces[t][:, np.argmin(np.abs(budgets_gb - gb))], z_cap) 

396 ax.plot([x_v] * len(seq_lens), np.log10(seq_lens / 1024), z_line, 

397 color=vc, lw=2.2, alpha=0.9, zorder=5) 

398 ax.text(x_v, ls[0], z_line[0] + z_cap * 0.02, 

399 vlbl, color=vc, fontsize=7, fontweight="bold") 

400 

401 # Axes 

402 b_ticks = [1, 5, 10, 40, 80] 

403 s_ticks = [0.512, 1, 4, 16, 64, 128] 

404 ax.set_xticks([np.log10(v) for v in b_ticks]) 

405 ax.set_xticklabels([f"{v}GB" for v in b_ticks], fontsize=7) 

406 ax.set_yticks([np.log10(v) for v in s_ticks]) 

407 ax.set_yticklabels(["512" if v < 1 else f"{int(v)}K" for v in s_ticks], fontsize=7) 

408 ax.set_zlim(0, z_cap) 

409 ax.tick_params(axis="z", labelsize=7) 

410 ax.set_xlabel("KV-cache VRAM budget", fontsize=9, labelpad=6) 

411 ax.set_ylabel("Context length", fontsize=9, labelpad=6) 

412 ax.set_zlabel("Aggregate throughput (k tok/s)", fontsize=9, labelpad=8) 

413 ax.set_title(t, fontsize=14, fontweight="bold", color=c, pad=6) 

414 

415 fig.suptitle( 

416 "Aggregate Serving Throughput = f(VRAM Budget, Context Length)\n" 

417 "The coloured lines mark throughput at 24 GB / 80 GB GPU. " 

418 "MLA's surface stays higher across every (budget, context) point.", 

419 fontsize=12, fontweight="bold", y=1.01 

420 ) 

421 plt.tight_layout() 

422 _save(fig, "3d_4_serving_surface.png") 

423 

424 

425# ───────────────────────────────────────────────────────────────────────────── 

426if __name__ == "__main__": 

427 os.chdir(os.path.dirname(os.path.abspath(__file__))) 

428 df = load() 

429 print(f"Loaded {len(df)} runs") 

430 print("\nGenerating 3D figures...") 

431 fig1_cache_surface(df) 

432 fig2_quality_cube(df) 

433 fig3_efficiency_cube(df) 

434 fig4_serving_surface(df) 

435 print("Done.")