Coverage for dantinox / plots / plot_3d_dkv.py: 0%

249 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 18:33 +0200

1#!/home/marco.simoni/miniconda3/bin/python3 

2""" 

3plot_3d_dkv.py — 3D surfaces focused on down_dim_kv (MLA's latent KV dimension). 

4 

5Run with: /home/marco.simoni/miniconda3/bin/python3 plot_3d_dkv.py 

6 

7Figures: 

8 3d_5_dkv_cache_seqlen.png — MLA: down_dim_kv × seq_len → KV cache 

9 (with GQA/MHA fixed reference planes) 

10 3d_6_kv_decoupling.png — kv_dim_eff × params_m → cache 

11 (MLA cluster is flat; MHA/GQA cluster rises) 

12 3d_7_mla_quality.png — MLA: down_dim_kv × params_m → val_loss quality landscape 

13 3d_8_dkv_num_blocks.png — MLA: down_dim_kv × num_blocks → cache at multiple seq_lens 

14""" 

15 

16import os 

17import numpy as np 

18import pandas as pd 

19from scipy.interpolate import griddata 

20import matplotlib 

21matplotlib.use("Agg") 

22import matplotlib.pyplot as plt 

23from mpl_toolkits.mplot3d import Axes3D 

24from mpl_toolkits.mplot3d.art3d import Poly3DCollection 

25import matplotlib.patches as mpatches 

26from matplotlib.lines import Line2D 

27import matplotlib.cm as cm 

28 

29IN_CSV = "benchmark_results.csv" 

30OUT_DIR = "plots" 

31DPI = 210 

32 

33TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"} 

34TYPE_ORDER = ["MLA", "GQA", "MHA"] 

35SEQ_GRID = np.array([512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]) 

36NB_REF = 12 # representative num_blocks for surfaces 

37VRAM_GB = [(80, "#CC3311", "80 GB H100"), 

38 (24, "#EE7733", "24 GB RTX")] 

39 

40 

41def _save(fig, name): 

42 os.makedirs(OUT_DIR, exist_ok=True) 

43 p = os.path.join(OUT_DIR, name) 

44 fig.savefig(p, dpi=DPI, bbox_inches="tight") 

45 plt.close(fig) 

46 print(f" Saved {p}") 

47 

48 

49def _ax3(fig, pos, elev=24, azim=-52, fc="#F8F9FA"): 

50 ax = fig.add_subplot(*pos, projection="3d") 

51 ax.set_facecolor(fc) 

52 ax.view_init(elev=elev, azim=azim) 

53 return ax 

54 

55 

56def load(): 

57 df = pd.read_csv(IN_CSV) 

58 num_cols = ["params_m", "theoretical_cache_mb", "val_loss", 

59 "tps_512", "max_context", "num_blocks", 

60 "n_heads", "kv_heads", "dim", "down_dim_kv"] 

61 for c in num_cols: 

62 if c in df.columns: 

63 df[c] = pd.to_numeric(df[c], errors="coerce") 

64 # Effective KV dimension: the quantity that directly sets the cache. 

65 # cache_bytes = kv_dim_eff × num_blocks × seq_len × 4 (fp32) 

66 df["kv_dim_eff"] = ( 

67 df["theoretical_cache_mb"] * 1e6 

68 / (df["num_blocks"] * df["max_context"] * 4) 

69 ) 

70 return df 

71 

72 

73# ───────────────────────────────────────────────────────────────────────────── 

74# Fig 5 — MLA surface: down_dim_kv × seq_len → KV cache (GB) 

75# GQA and MHA shown as horizontal reference planes. 

76# ───────────────────────────────────────────────────────────────────────────── 

77def fig5_dkv_cache_seqlen(df): 

78 """ 

79 For MLA, cache scales with (down_dim_kv + rope_dim) × num_blocks × seq_len. 

80 kv_dim_eff(down_dim_kv) is fitted from actual run data (it absorbs rope_dim). 

81 GQA/MHA reference: their median kv_dim_eff at num_blocks=NB_REF. 

82 """ 

83 mla = df[(df["type"] == "MLA") & (~df["moe"])].dropna(subset=["down_dim_kv", "kv_dim_eff"]) 

84 dense = df[~df["moe"]] 

85 

86 # Fit kv_dim_eff vs down_dim_kv for MLA (linear — nearly perfect) 

87 agg_mla = mla.groupby("down_dim_kv")["kv_dim_eff"].mean() 

88 dkv_pts = agg_mla.index.values.astype(float) 

89 kve_pts = agg_mla.values 

90 slope, intercept = np.polyfit(dkv_pts, kve_pts, 1) # kv_dim_eff ≈ slope*dkv + intercept 

91 

92 # Surface grid 

93 dkv_range = np.linspace(dkv_pts.min(), 256, 60) 

94 S_range = SEQ_GRID 

95 DKV2D, S2D = np.meshgrid(dkv_range, S_range) 

96 

97 kve_2d = slope * DKV2D + intercept 

98 Z_mla = kve_2d * NB_REF * S2D * 4 / 1e9 # cache in GB 

99 

100 # GQA/MHA reference kv_dim_eff at num_blocks=NB_REF 

101 ref_cache = {} 

102 for t in ["GQA", "MHA"]: 

103 sub = dense[(dense["type"] == t) & (dense["num_blocks"] == NB_REF)] 

104 if sub.empty: 

105 sub = dense[dense["type"] == t] 

106 ref_cache[t] = float(sub["kv_dim_eff"].median()) * NB_REF * S_range * 4 / 1e9 

107 

108 # ── Figure ──────────────────────────────────────────────────────────────── 

109 fig = plt.figure(figsize=(15, 8)) 

110 fig.patch.set_facecolor("#F8F9FA") 

111 ax = _ax3(fig, (1, 1, 1), elev=26, azim=-55) 

112 

113 # MLA surface 

114 surf = ax.plot_surface(DKV2D, np.log10(S2D / 1024), Z_mla, 

115 color=TYPE_COLORS["MLA"], alpha=0.55, 

116 linewidth=0, antialiased=True, zorder=3) 

117 

118 # Actual MLA scatter points at their real seq_len=max_context 

119 for _, row in mla.iterrows(): 

120 z_pt = row["kv_dim_eff"] * row["num_blocks"] * row["max_context"] * 4 / 1e9 

121 ax.scatter(row["down_dim_kv"], np.log10(row["max_context"] / 1024), z_pt, 

122 c=TYPE_COLORS["MLA"], s=30, edgecolors="white", lw=0.5, 

123 alpha=0.80, zorder=6) 

124 

125 # GQA / MHA reference planes (constant in down_dim_kv direction) 

126 dkv_lims = [dkv_range[0], dkv_range[-1]] 

127 log_s = np.log10(S_range[[0, -1]] / 1024) 

128 for t, vcol in [("GQA", TYPE_COLORS["GQA"]), ("MHA", TYPE_COLORS["MHA"])]: 

129 cache_lo = ref_cache[t][0] 

130 cache_hi = ref_cache[t][-1] 

131 for si, sl in enumerate(S_range): 

132 z_val = ref_cache[t][si] 

133 if z_val > ax.get_zlim()[1] if ax.get_zlim()[1] > 0 else True: 

134 pass 

135 ax.plot(dkv_lims, [np.log10(sl / 1024)] * 2, [z_val, z_val], 

136 color=vcol, lw=1.2, alpha=0.55, zorder=4) 

137 

138 # One bold reference line at 512 tokens 

139 ax.plot(dkv_lims, [np.log10(0.512)] * 2, 

140 [ref_cache[t][0]] * 2, 

141 color=vcol, lw=2.5, alpha=0.85, zorder=5, 

142 label=f"{t} cache @ any dkv (num_blocks={NB_REF})") 

143 

144 # Scatter actual data points for fitted line (validation) 

145 ax.plot(dkv_pts, [np.log10(0.512)] * len(dkv_pts), 

146 kve_pts * NB_REF * 512 * 4 / 1e9, 

147 "o", color=TYPE_COLORS["MLA"], markersize=6, 

148 markeredgecolor="white", lw=0, zorder=7, 

149 label="MLA measured (512 tokens)") 

150 

151 # VRAM planes 

152 z_now = float(np.nanmax(Z_mla)) 

153 for vram, vc, vlbl in VRAM_GB: 

154 if vram > z_now: continue 

155 verts = [[(dkv_lims[0], log_s[0], vram), (dkv_lims[1], log_s[0], vram), 

156 (dkv_lims[1], log_s[1], vram), (dkv_lims[0], log_s[1], vram)]] 

157 poly = Poly3DCollection(verts, alpha=0.18) 

158 poly.set_facecolor(vc); poly.set_edgecolor(vc) 

159 ax.add_collection3d(poly) 

160 ax.text(dkv_lims[1], log_s[1], vram + 0.3, f" {vlbl}", 

161 color=vc, fontsize=9, fontweight="bold") 

162 

163 # Axes 

164 s_ticks = [0.512, 1, 4, 16, 64, 128] 

165 ax.set_yticks([np.log10(v) for v in s_ticks]) 

166 ax.set_yticklabels(["512" if v < 1 else f"{int(v)}K" for v in s_ticks], fontsize=8) 

167 ax.set_xlabel("down_dim_kv (MLA latent KV dim)", fontsize=10, labelpad=8) 

168 ax.set_ylabel("Context length", fontsize=10, labelpad=8) 

169 ax.set_zlabel("KV Cache per seq (GB)", fontsize=10, labelpad=10) 

170 ax.tick_params(axis="x", labelsize=8) 

171 ax.tick_params(axis="z", labelsize=8) 

172 

173 # Annotation on the surface 

174 ax.text2D(0.02, 0.92, 

175 f"MLA surface: cache = f(down_dim_kv, seq_len)\n" 

176 f"kv_dim_eff ≈ {slope:.2f} × down_dim_kv + {intercept:.1f} (num_blocks={NB_REF})\n" 

177 "GQA/MHA lines: their cache is fixed regardless of down_dim_kv", 

178 transform=ax.transAxes, fontsize=8.5, color="#333", 

179 bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="#ccc", lw=0.8, alpha=0.9)) 

180 

181 handles = [ 

182 mpatches.Patch(color=TYPE_COLORS["MLA"], alpha=0.6, label="MLA surface"), 

183 Line2D([0],[0], color=TYPE_COLORS["GQA"], lw=2.5, label=f"GQA reference (num_blocks={NB_REF})"), 

184 Line2D([0],[0], color=TYPE_COLORS["MHA"], lw=2.5, label=f"MHA reference (num_blocks={NB_REF})"), 

185 ] + [mpatches.Patch(color=vc, alpha=0.4, label=vlbl) for _, vc, vlbl in VRAM_GB] 

186 ax.legend(handles=handles, fontsize=9, framealpha=0.9, 

187 loc="upper left", bbox_to_anchor=(-0.01, 1.0)) 

188 

189 ax.set_title( 

190 "MLA: down_dim_kv × Context Length → KV Cache\n" 

191 "Slide down_dim_kv to land below GQA or MHA cache levels at any context", 

192 fontsize=12, fontweight="bold", pad=12 

193 ) 

194 plt.tight_layout() 

195 _save(fig, "3d_5_dkv_cache_seqlen.png") 

196 

197 

198# ───────────────────────────────────────────────────────────────────────────── 

199# Fig 6 — kv_dim_eff × params_m → KV cache: the decoupling plot 

200# MLA: kv_dim_eff is flat vs params (free parameter) 

201# GQA/MHA: kv_dim_eff rises with params (locked to architecture) 

202# ───────────────────────────────────────────────────────────────────────────── 

203def fig6_kv_decoupling(df): 

204 dense = df[~df["moe"]].dropna(subset=["params_m", "kv_dim_eff", "theoretical_cache_mb"]) 

205 

206 fig = plt.figure(figsize=(16, 8)) 

207 fig.patch.set_facecolor("#F8F9FA") 

208 

209 # ── Left: 3D scatter kv_dim_eff × params_m × cache ────────────────────── 

210 ax = _ax3(fig, (1, 2, 1), elev=20, azim=-50) 

211 

212 for t in TYPE_ORDER: 

213 pts = dense[dense["type"] == t] 

214 if pts.empty: continue 

215 ax.scatter(pts["params_m"], pts["kv_dim_eff"], 

216 pts["theoretical_cache_mb"], 

217 c=TYPE_COLORS[t], s=60, 

218 edgecolors="white", lw=0.6, 

219 alpha=0.85, zorder=4, label=t) 

220 

221 # Trend surfaces (fit per type) 

222 params_fit = np.linspace(dense["params_m"].min(), dense["params_m"].max(), 40) 

223 for t in TYPE_ORDER: 

224 pts = dense[dense["type"] == t] 

225 if len(pts) < 4: continue 

226 lx = np.log10(pts["params_m"]) 

227 ly = pts["kv_dim_eff"] 

228 b, a = np.polyfit(lx, ly, 1) 

229 kve_fit = b * np.log10(params_fit) + a 

230 

231 # Surface at fixed seq_len=512: cache = kv_dim_eff × num_blocks × 512 × 4 / 1e6 

232 # Use median num_blocks per type 

233 nb_med = float(pts["num_blocks"].median()) 

234 cache_fit = kve_fit * nb_med * 512 * 4 / 1e6 

235 

236 ax.plot(params_fit, kve_fit, cache_fit, 

237 color=TYPE_COLORS[t], lw=2.2, ls="--", alpha=0.65, zorder=3) 

238 

239 ax.set_xlabel("Model parameters (M)", fontsize=9, labelpad=6) 

240 ax.set_ylabel("kv_dim_eff\n(bytes / token / layer / 4)", fontsize=9, labelpad=8) 

241 ax.set_zlabel("KV Cache (MB) @ 512 tok", fontsize=9, labelpad=8) 

242 ax.tick_params(labelsize=7) 

243 ax.set_title("kv_dim_eff × Params × Cache\n" 

244 "MLA: flat vs params — GQA/MHA: rises", 

245 fontsize=11, fontweight="bold", pad=8) 

246 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER] 

247 ax.legend(handles=handles, fontsize=10, framealpha=0.9) 

248 

249 # ── Right: 2D projection kv_dim_eff vs params (the core message) ───────── 

250 ax2 = fig.add_subplot(1, 2, 2) 

251 ax2.set_facecolor("#F8F9FA") 

252 

253 for t in TYPE_ORDER: 

254 pts = dense[dense["type"] == t] 

255 if pts.empty: continue 

256 ax2.scatter(pts["params_m"], pts["kv_dim_eff"], 

257 c=TYPE_COLORS[t], s=65, edgecolors="white", lw=0.7, 

258 alpha=0.85, zorder=4, label=t) 

259 if len(pts) >= 4: 

260 lx = np.log10(pts["params_m"]) 

261 b, a = np.polyfit(lx, pts["kv_dim_eff"], 1) 

262 xf = np.logspace(lx.min() - 0.05, lx.max() + 0.05, 80) 

263 ax2.plot(xf, b * np.log10(xf) + a, 

264 color=TYPE_COLORS[t], lw=2.2, ls="--", alpha=0.65) 

265 

266 # Annotate slopes 

267 ax2.text(0.98, 0.96, 

268 "MHA: kv_dim_eff ∝ dim (locked to width)\n" 

269 "GQA: kv_dim_eff ∝ dim × (kv_heads/n_heads)\n" 

270 "MLA: kv_dim_eff = down_dim_kv (free param)\n\n" 

271 "→ MLA's cache doesn't grow with the model", 

272 transform=ax2.transAxes, ha="right", va="top", 

273 fontsize=9.5, color="#333", style="italic", 

274 bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="#ccc", lw=1, alpha=0.92)) 

275 

276 ax2.set_xscale("log") 

277 ax2.set_xlabel("Model parameters (M)", fontsize=11) 

278 ax2.set_ylabel("kv_dim_eff (effective KV dim per layer)", fontsize=11) 

279 ax2.set_title("The Decoupling: kv_dim_eff vs Model Size\n" 

280 "MHA/GQA locked to architecture — MLA freely chosen", 

281 fontsize=11, fontweight="bold") 

282 ax2.legend(fontsize=10, framealpha=0.9) 

283 ax2.grid(alpha=0.2, ls="--") 

284 ax2.spines[["top", "right"]].set_visible(False) 

285 

286 fig.suptitle( 

287 "kv_dim_eff: The Root Cause of MLA's Cache Advantage\n" 

288 "For MHA/GQA it is an architectural constant. For MLA it is a free hyperparameter.", 

289 fontsize=12, fontweight="bold", y=1.01 

290 ) 

291 plt.tight_layout() 

292 _save(fig, "3d_6_kv_decoupling.png") 

293 

294 

295# ───────────────────────────────────────────────────────────────────────────── 

296# Fig 7 — MLA: down_dim_kv × params_m → val_loss (quality landscape) 

297# Both levers affect quality; this surface shows the full tradeoff. 

298# ───────────────────────────────────────────────────────────────────────────── 

299def fig7_mla_quality(df): 

300 mla = df[(df["type"] == "MLA") & (~df["moe"])].dropna( 

301 subset=["down_dim_kv", "params_m", "val_loss"] 

302 ).copy() 

303 

304 # Reference: Dense GQA/MHA median val_loss (what MLA competes against) 

305 ref = df[~df["moe"]].dropna(subset=["val_loss"]) 

306 ref_mha = float(ref[ref["type"] == "MHA"]["val_loss"].median()) 

307 ref_gqa = float(ref[ref["type"] == "GQA"]["val_loss"].median()) 

308 

309 xs = mla["params_m"].values 

310 ys = mla["down_dim_kv"].values 

311 zs = mla["val_loss"].values 

312 

313 # Dense interpolated surface 

314 xi = np.linspace(xs.min(), xs.max(), 50) 

315 yi = np.linspace(ys.min(), ys.max(), 50) 

316 XI, YI = np.meshgrid(xi, yi) 

317 ZI = griddata((xs, ys), zs, (XI, YI), method="linear") 

318 

319 fig = plt.figure(figsize=(16, 8)) 

320 fig.patch.set_facecolor("#F8F9FA") 

321 

322 # ── Left: 3D surface ───────────────────────────────────────────────────── 

323 ax = _ax3(fig, (1, 2, 1), elev=22, azim=-48) 

324 

325 # Interpolated surface coloured by quality 

326 valid = ~np.isnan(ZI) 

327 vmin, vmax = float(np.nanmin(ZI)), float(np.nanmax(ZI)) 

328 surf = ax.plot_surface(XI, YI, ZI, 

329 facecolors=cm.RdYlGn_r((ZI - vmin) / (vmax - vmin + 1e-9)), 

330 alpha=0.55, linewidth=0, antialiased=True, zorder=3) 

331 

332 # Actual data scatter 

333 ax.scatter(xs, ys, zs, c=zs, cmap="RdYlGn_r", 

334 vmin=vmin, vmax=vmax, 

335 s=55, edgecolors="white", lw=0.6, alpha=0.90, zorder=6) 

336 

337 # MHA/GQA reference planes 

338 xl = [xs.min(), xs.max()] 

339 yl = [ys.min(), ys.max()] 

340 for ref_val, rcol, rlbl in [(ref_mha, TYPE_COLORS["MHA"], "MHA median"), 

341 (ref_gqa, TYPE_COLORS["GQA"], "GQA median")]: 

342 verts = [[(xl[0], yl[0], ref_val), (xl[1], yl[0], ref_val), 

343 (xl[1], yl[1], ref_val), (xl[0], yl[1], ref_val)]] 

344 poly = Poly3DCollection(verts, alpha=0.20) 

345 poly.set_facecolor(rcol); poly.set_edgecolor(rcol) 

346 ax.add_collection3d(poly) 

347 ax.text(xl[1], yl[1], ref_val + 0.01, f" {rlbl}", 

348 color=rcol, fontsize=8, fontweight="bold") 

349 

350 ax.set_xlabel("Parameters (M)", fontsize=9, labelpad=6) 

351 ax.set_ylabel("down_dim_kv", fontsize=9, labelpad=6) 

352 ax.set_zlabel("Validation Loss ↓", fontsize=9, labelpad=8) 

353 ax.tick_params(labelsize=7) 

354 ax.set_title("MLA Quality Landscape\n(surface colour: red=worse, green=better)", 

355 fontsize=11, fontweight="bold", pad=8) 

356 

357 # Colorbar proxy 

358 sm = cm.ScalarMappable(cmap="RdYlGn_r", 

359 norm=plt.Normalize(vmin=vmin, vmax=vmax)) 

360 sm.set_array([]) 

361 plt.colorbar(sm, ax=ax, shrink=0.5, pad=0.1, label="val_loss") 

362 

363 # ── Right: top-down heatmap (cleaner reading) ───────────────────────────── 

364 ax2 = fig.add_subplot(1, 2, 2) 

365 ax2.set_facecolor("#F8F9FA") 

366 

367 # Scatter (actual runs) 

368 sc = ax2.scatter(xs, ys, c=zs, cmap="RdYlGn_r", 

369 vmin=vmin, vmax=vmax, 

370 s=120, edgecolors="white", lw=1, alpha=0.92, zorder=4) 

371 

372 # Contour lines on interpolated grid 

373 CS = ax2.contour(XI, YI, ZI, levels=8, cmap="RdYlGn_r", 

374 vmin=vmin, vmax=vmax, alpha=0.55, linewidths=1.2) 

375 ax2.clabel(CS, fmt="%.3f", fontsize=8, inline=True) 

376 

377 # MHA/GQA reference horizontal bands 

378 for i, (ref_val, rcol, rlbl) in enumerate([(ref_mha, TYPE_COLORS["MHA"], "MHA median"), 

379 (ref_gqa, TYPE_COLORS["GQA"], "GQA median")]): 

380 ax2.text(xs.max() * 1.01, ys.max() * (0.98 - i * 0.08), 

381 f"{rlbl}: {ref_val:.3f}", 

382 ha="left", fontsize=8, color=rcol, fontweight="bold") 

383 

384 ax2.set_xlabel("Parameters (M)", fontsize=11) 

385 ax2.set_ylabel("down_dim_kv", fontsize=11) 

386 ax2.set_title("Top-down View: Quality Heatmap\n" 

387 "(green = lower loss = better quality)", 

388 fontsize=11, fontweight="bold") 

389 ax2.grid(alpha=0.15, ls="--") 

390 ax2.spines[["top", "right"]].set_visible(False) 

391 plt.colorbar(sc, ax=ax2, label="Validation Loss ↓") 

392 

393 fig.suptitle( 

394 "MLA Quality Landscape: down_dim_kv × Model Size → Validation Loss\n" 

395 "Both levers matter — but params_m drives quality more than down_dim_kv", 

396 fontsize=12, fontweight="bold", y=1.01 

397 ) 

398 plt.tight_layout() 

399 _save(fig, "3d_7_mla_quality.png") 

400 

401 

402# ───────────────────────────────────────────────────────────────────────────── 

403# Fig 8 — MLA: down_dim_kv × num_blocks → KV cache at multiple seq_lens 

404# Two independent knobs that both drive cache independently. 

405# ───────────────────────────────────────────────────────────────────────────── 

406def fig8_dkv_numblocks(df): 

407 mla = df[(df["type"] == "MLA") & (~df["moe"])].dropna( 

408 subset=["down_dim_kv", "num_blocks", "kv_dim_eff"] 

409 ) 

410 

411 # Fit: kv_dim_eff ≈ slope * down_dim_kv + intercept 

412 agg = mla.groupby("down_dim_kv")["kv_dim_eff"].mean() 

413 sl, ic = np.polyfit(agg.index.values.astype(float), agg.values, 1) 

414 

415 dkv_range = np.linspace(32, 256, 50) 

416 nb_range = np.array([4, 8, 12, 16, 24, 32]) # include realistic LLM depths 

417 

418 DKV2D, NB2D = np.meshgrid(dkv_range, nb_range) 

419 kve_2d = sl * DKV2D + ic 

420 

421 seq_levels = [512, 4096, 32768, 131072] 

422 seq_labels = ["512", "4K", "32K", "128K"] 

423 colors_seq = ["#3A86FF", "#8338EC", "#FF006E", "#FB5607"] 

424 

425 fig = plt.figure(figsize=(18, 9)) 

426 fig.patch.set_facecolor("#F8F9FA") 

427 

428 # ── One subplot per seq_len ─────────────────────────────────────────────── 

429 for pi, (sl_v, sl_lbl, sl_col) in enumerate(zip(seq_levels, seq_labels, colors_seq)): 

430 ax = _ax3(fig, (1, 4, pi + 1), elev=28, azim=-55) 

431 Z = kve_2d * NB2D * sl_v * 4 / 1e9 # GB 

432 

433 surf = ax.plot_surface(DKV2D, NB2D, Z, 

434 color=sl_col, alpha=0.50, 

435 linewidth=0, antialiased=True, zorder=3) 

436 

437 # Floor heatmap 

438 ax.contourf(DKV2D, NB2D, Z, zdir="z", offset=0, 

439 levels=np.linspace(0, float(np.nanmax(Z)), 10), 

440 cmap="Blues", alpha=0.40) 

441 

442 # VRAM planes 

443 for vram, vc, vlbl in VRAM_GB: 

444 z_cap = float(np.nanmax(Z)) 

445 if vram > z_cap: continue 

446 dkv_lims = [dkv_range[0], dkv_range[-1]] 

447 nb_lims = [nb_range[0], nb_range[-1]] 

448 verts = [[(dkv_lims[0], nb_lims[0], vram), 

449 (dkv_lims[1], nb_lims[0], vram), 

450 (dkv_lims[1], nb_lims[1], vram), 

451 (dkv_lims[0], nb_lims[1], vram)]] 

452 poly = Poly3DCollection(verts, alpha=0.22) 

453 poly.set_facecolor(vc); poly.set_edgecolor(vc) 

454 ax.add_collection3d(poly) 

455 ax.text(dkv_lims[1], nb_lims[0], vram + float(np.nanmax(Z)) * 0.02, 

456 f" {vlbl}", color=vc, fontsize=7, fontweight="bold") 

457 

458 # Actual data points (projected to the nearest seq_len) 

459 for _, row in mla.iterrows(): 

460 nb_v = int(row["num_blocks"]) 

461 if nb_v not in nb_range: continue 

462 z_pt = row["kv_dim_eff"] * nb_v * sl_v * 4 / 1e9 

463 ax.scatter(row["down_dim_kv"], nb_v, z_pt, 

464 c="white", s=22, edgecolors=sl_col, lw=1.2, 

465 alpha=0.85, zorder=7) 

466 

467 ax.set_xlabel("down_dim_kv", fontsize=8, labelpad=5) 

468 ax.set_ylabel("num_blocks", fontsize=8, labelpad=5) 

469 ax.set_zlabel("KV Cache (GB)", fontsize=8, labelpad=7) 

470 ax.tick_params(labelsize=7) 

471 ax.set_title(f"seq = {sl_lbl}", fontsize=12, 

472 fontweight="bold", color=sl_col, pad=6) 

473 

474 fig.suptitle( 

475 "MLA KV Cache = f(down_dim_kv, num_blocks) — shown at four context lengths\n" 

476 "Both knobs are tunable: reduce down_dim_kv OR use fewer layers to save cache", 

477 fontsize=12, fontweight="bold", y=1.01 

478 ) 

479 plt.tight_layout() 

480 _save(fig, "3d_8_dkv_num_blocks.png") 

481 

482 

483# ───────────────────────────────────────────────────────────────────────────── 

484if __name__ == "__main__": 

485 os.chdir(os.path.dirname(os.path.abspath(__file__))) 

486 df = load() 

487 print(f"Loaded {len(df)} runs") 

488 print("\nGenerating down_dim_kv 3D figures...") 

489 fig5_dkv_cache_seqlen(df) 

490 fig6_kv_decoupling(df) 

491 fig7_mla_quality(df) 

492 fig8_dkv_numblocks(df) 

493 print("Done.")