Coverage for dantinox / plots / plot_3d_dkv.py: 0%
249 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 18:33 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 18:33 +0200
1#!/home/marco.simoni/miniconda3/bin/python3
2"""
3plot_3d_dkv.py — 3D surfaces focused on down_dim_kv (MLA's latent KV dimension).
5Run with: /home/marco.simoni/miniconda3/bin/python3 plot_3d_dkv.py
7Figures:
8 3d_5_dkv_cache_seqlen.png — MLA: down_dim_kv × seq_len → KV cache
9 (with GQA/MHA fixed reference planes)
10 3d_6_kv_decoupling.png — kv_dim_eff × params_m → cache
11 (MLA cluster is flat; MHA/GQA cluster rises)
12 3d_7_mla_quality.png — MLA: down_dim_kv × params_m → val_loss quality landscape
13 3d_8_dkv_num_blocks.png — MLA: down_dim_kv × num_blocks → cache at multiple seq_lens
14"""
16import os
17import numpy as np
18import pandas as pd
19from scipy.interpolate import griddata
20import matplotlib
21matplotlib.use("Agg")
22import matplotlib.pyplot as plt
23from mpl_toolkits.mplot3d import Axes3D
24from mpl_toolkits.mplot3d.art3d import Poly3DCollection
25import matplotlib.patches as mpatches
26from matplotlib.lines import Line2D
27import matplotlib.cm as cm
29IN_CSV = "benchmark_results.csv"
30OUT_DIR = "plots"
31DPI = 210
33TYPE_COLORS = {"MLA": "#3A86FF", "GQA": "#FF6B35", "MHA": "#2DC653"}
34TYPE_ORDER = ["MLA", "GQA", "MHA"]
35SEQ_GRID = np.array([512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072])
36NB_REF = 12 # representative num_blocks for surfaces
37VRAM_GB = [(80, "#CC3311", "80 GB H100"),
38 (24, "#EE7733", "24 GB RTX")]
41def _save(fig, name):
42 os.makedirs(OUT_DIR, exist_ok=True)
43 p = os.path.join(OUT_DIR, name)
44 fig.savefig(p, dpi=DPI, bbox_inches="tight")
45 plt.close(fig)
46 print(f" Saved {p}")
49def _ax3(fig, pos, elev=24, azim=-52, fc="#F8F9FA"):
50 ax = fig.add_subplot(*pos, projection="3d")
51 ax.set_facecolor(fc)
52 ax.view_init(elev=elev, azim=azim)
53 return ax
56def load():
57 df = pd.read_csv(IN_CSV)
58 num_cols = ["params_m", "theoretical_cache_mb", "val_loss",
59 "tps_512", "max_context", "num_blocks",
60 "n_heads", "kv_heads", "dim", "down_dim_kv"]
61 for c in num_cols:
62 if c in df.columns:
63 df[c] = pd.to_numeric(df[c], errors="coerce")
64 # Effective KV dimension: the quantity that directly sets the cache.
65 # cache_bytes = kv_dim_eff × num_blocks × seq_len × 4 (fp32)
66 df["kv_dim_eff"] = (
67 df["theoretical_cache_mb"] * 1e6
68 / (df["num_blocks"] * df["max_context"] * 4)
69 )
70 return df
73# ─────────────────────────────────────────────────────────────────────────────
74# Fig 5 — MLA surface: down_dim_kv × seq_len → KV cache (GB)
75# GQA and MHA shown as horizontal reference planes.
76# ─────────────────────────────────────────────────────────────────────────────
77def fig5_dkv_cache_seqlen(df):
78 """
79 For MLA, cache scales with (down_dim_kv + rope_dim) × num_blocks × seq_len.
80 kv_dim_eff(down_dim_kv) is fitted from actual run data (it absorbs rope_dim).
81 GQA/MHA reference: their median kv_dim_eff at num_blocks=NB_REF.
82 """
83 mla = df[(df["type"] == "MLA") & (~df["moe"])].dropna(subset=["down_dim_kv", "kv_dim_eff"])
84 dense = df[~df["moe"]]
86 # Fit kv_dim_eff vs down_dim_kv for MLA (linear — nearly perfect)
87 agg_mla = mla.groupby("down_dim_kv")["kv_dim_eff"].mean()
88 dkv_pts = agg_mla.index.values.astype(float)
89 kve_pts = agg_mla.values
90 slope, intercept = np.polyfit(dkv_pts, kve_pts, 1) # kv_dim_eff ≈ slope*dkv + intercept
92 # Surface grid
93 dkv_range = np.linspace(dkv_pts.min(), 256, 60)
94 S_range = SEQ_GRID
95 DKV2D, S2D = np.meshgrid(dkv_range, S_range)
97 kve_2d = slope * DKV2D + intercept
98 Z_mla = kve_2d * NB_REF * S2D * 4 / 1e9 # cache in GB
100 # GQA/MHA reference kv_dim_eff at num_blocks=NB_REF
101 ref_cache = {}
102 for t in ["GQA", "MHA"]:
103 sub = dense[(dense["type"] == t) & (dense["num_blocks"] == NB_REF)]
104 if sub.empty:
105 sub = dense[dense["type"] == t]
106 ref_cache[t] = float(sub["kv_dim_eff"].median()) * NB_REF * S_range * 4 / 1e9
108 # ── Figure ────────────────────────────────────────────────────────────────
109 fig = plt.figure(figsize=(15, 8))
110 fig.patch.set_facecolor("#F8F9FA")
111 ax = _ax3(fig, (1, 1, 1), elev=26, azim=-55)
113 # MLA surface
114 surf = ax.plot_surface(DKV2D, np.log10(S2D / 1024), Z_mla,
115 color=TYPE_COLORS["MLA"], alpha=0.55,
116 linewidth=0, antialiased=True, zorder=3)
118 # Actual MLA scatter points at their real seq_len=max_context
119 for _, row in mla.iterrows():
120 z_pt = row["kv_dim_eff"] * row["num_blocks"] * row["max_context"] * 4 / 1e9
121 ax.scatter(row["down_dim_kv"], np.log10(row["max_context"] / 1024), z_pt,
122 c=TYPE_COLORS["MLA"], s=30, edgecolors="white", lw=0.5,
123 alpha=0.80, zorder=6)
125 # GQA / MHA reference planes (constant in down_dim_kv direction)
126 dkv_lims = [dkv_range[0], dkv_range[-1]]
127 log_s = np.log10(S_range[[0, -1]] / 1024)
128 for t, vcol in [("GQA", TYPE_COLORS["GQA"]), ("MHA", TYPE_COLORS["MHA"])]:
129 cache_lo = ref_cache[t][0]
130 cache_hi = ref_cache[t][-1]
131 for si, sl in enumerate(S_range):
132 z_val = ref_cache[t][si]
133 if z_val > ax.get_zlim()[1] if ax.get_zlim()[1] > 0 else True:
134 pass
135 ax.plot(dkv_lims, [np.log10(sl / 1024)] * 2, [z_val, z_val],
136 color=vcol, lw=1.2, alpha=0.55, zorder=4)
138 # One bold reference line at 512 tokens
139 ax.plot(dkv_lims, [np.log10(0.512)] * 2,
140 [ref_cache[t][0]] * 2,
141 color=vcol, lw=2.5, alpha=0.85, zorder=5,
142 label=f"{t} cache @ any dkv (num_blocks={NB_REF})")
144 # Scatter actual data points for fitted line (validation)
145 ax.plot(dkv_pts, [np.log10(0.512)] * len(dkv_pts),
146 kve_pts * NB_REF * 512 * 4 / 1e9,
147 "o", color=TYPE_COLORS["MLA"], markersize=6,
148 markeredgecolor="white", lw=0, zorder=7,
149 label="MLA measured (512 tokens)")
151 # VRAM planes
152 z_now = float(np.nanmax(Z_mla))
153 for vram, vc, vlbl in VRAM_GB:
154 if vram > z_now: continue
155 verts = [[(dkv_lims[0], log_s[0], vram), (dkv_lims[1], log_s[0], vram),
156 (dkv_lims[1], log_s[1], vram), (dkv_lims[0], log_s[1], vram)]]
157 poly = Poly3DCollection(verts, alpha=0.18)
158 poly.set_facecolor(vc); poly.set_edgecolor(vc)
159 ax.add_collection3d(poly)
160 ax.text(dkv_lims[1], log_s[1], vram + 0.3, f" {vlbl}",
161 color=vc, fontsize=9, fontweight="bold")
163 # Axes
164 s_ticks = [0.512, 1, 4, 16, 64, 128]
165 ax.set_yticks([np.log10(v) for v in s_ticks])
166 ax.set_yticklabels(["512" if v < 1 else f"{int(v)}K" for v in s_ticks], fontsize=8)
167 ax.set_xlabel("down_dim_kv (MLA latent KV dim)", fontsize=10, labelpad=8)
168 ax.set_ylabel("Context length", fontsize=10, labelpad=8)
169 ax.set_zlabel("KV Cache per seq (GB)", fontsize=10, labelpad=10)
170 ax.tick_params(axis="x", labelsize=8)
171 ax.tick_params(axis="z", labelsize=8)
173 # Annotation on the surface
174 ax.text2D(0.02, 0.92,
175 f"MLA surface: cache = f(down_dim_kv, seq_len)\n"
176 f"kv_dim_eff ≈ {slope:.2f} × down_dim_kv + {intercept:.1f} (num_blocks={NB_REF})\n"
177 "GQA/MHA lines: their cache is fixed regardless of down_dim_kv",
178 transform=ax.transAxes, fontsize=8.5, color="#333",
179 bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="#ccc", lw=0.8, alpha=0.9))
181 handles = [
182 mpatches.Patch(color=TYPE_COLORS["MLA"], alpha=0.6, label="MLA surface"),
183 Line2D([0],[0], color=TYPE_COLORS["GQA"], lw=2.5, label=f"GQA reference (num_blocks={NB_REF})"),
184 Line2D([0],[0], color=TYPE_COLORS["MHA"], lw=2.5, label=f"MHA reference (num_blocks={NB_REF})"),
185 ] + [mpatches.Patch(color=vc, alpha=0.4, label=vlbl) for _, vc, vlbl in VRAM_GB]
186 ax.legend(handles=handles, fontsize=9, framealpha=0.9,
187 loc="upper left", bbox_to_anchor=(-0.01, 1.0))
189 ax.set_title(
190 "MLA: down_dim_kv × Context Length → KV Cache\n"
191 "Slide down_dim_kv to land below GQA or MHA cache levels at any context",
192 fontsize=12, fontweight="bold", pad=12
193 )
194 plt.tight_layout()
195 _save(fig, "3d_5_dkv_cache_seqlen.png")
198# ─────────────────────────────────────────────────────────────────────────────
199# Fig 6 — kv_dim_eff × params_m → KV cache: the decoupling plot
200# MLA: kv_dim_eff is flat vs params (free parameter)
201# GQA/MHA: kv_dim_eff rises with params (locked to architecture)
202# ─────────────────────────────────────────────────────────────────────────────
203def fig6_kv_decoupling(df):
204 dense = df[~df["moe"]].dropna(subset=["params_m", "kv_dim_eff", "theoretical_cache_mb"])
206 fig = plt.figure(figsize=(16, 8))
207 fig.patch.set_facecolor("#F8F9FA")
209 # ── Left: 3D scatter kv_dim_eff × params_m × cache ──────────────────────
210 ax = _ax3(fig, (1, 2, 1), elev=20, azim=-50)
212 for t in TYPE_ORDER:
213 pts = dense[dense["type"] == t]
214 if pts.empty: continue
215 ax.scatter(pts["params_m"], pts["kv_dim_eff"],
216 pts["theoretical_cache_mb"],
217 c=TYPE_COLORS[t], s=60,
218 edgecolors="white", lw=0.6,
219 alpha=0.85, zorder=4, label=t)
221 # Trend surfaces (fit per type)
222 params_fit = np.linspace(dense["params_m"].min(), dense["params_m"].max(), 40)
223 for t in TYPE_ORDER:
224 pts = dense[dense["type"] == t]
225 if len(pts) < 4: continue
226 lx = np.log10(pts["params_m"])
227 ly = pts["kv_dim_eff"]
228 b, a = np.polyfit(lx, ly, 1)
229 kve_fit = b * np.log10(params_fit) + a
231 # Surface at fixed seq_len=512: cache = kv_dim_eff × num_blocks × 512 × 4 / 1e6
232 # Use median num_blocks per type
233 nb_med = float(pts["num_blocks"].median())
234 cache_fit = kve_fit * nb_med * 512 * 4 / 1e6
236 ax.plot(params_fit, kve_fit, cache_fit,
237 color=TYPE_COLORS[t], lw=2.2, ls="--", alpha=0.65, zorder=3)
239 ax.set_xlabel("Model parameters (M)", fontsize=9, labelpad=6)
240 ax.set_ylabel("kv_dim_eff\n(bytes / token / layer / 4)", fontsize=9, labelpad=8)
241 ax.set_zlabel("KV Cache (MB) @ 512 tok", fontsize=9, labelpad=8)
242 ax.tick_params(labelsize=7)
243 ax.set_title("kv_dim_eff × Params × Cache\n"
244 "MLA: flat vs params — GQA/MHA: rises",
245 fontsize=11, fontweight="bold", pad=8)
246 handles = [mpatches.Patch(color=TYPE_COLORS[t], label=t) for t in TYPE_ORDER]
247 ax.legend(handles=handles, fontsize=10, framealpha=0.9)
249 # ── Right: 2D projection kv_dim_eff vs params (the core message) ─────────
250 ax2 = fig.add_subplot(1, 2, 2)
251 ax2.set_facecolor("#F8F9FA")
253 for t in TYPE_ORDER:
254 pts = dense[dense["type"] == t]
255 if pts.empty: continue
256 ax2.scatter(pts["params_m"], pts["kv_dim_eff"],
257 c=TYPE_COLORS[t], s=65, edgecolors="white", lw=0.7,
258 alpha=0.85, zorder=4, label=t)
259 if len(pts) >= 4:
260 lx = np.log10(pts["params_m"])
261 b, a = np.polyfit(lx, pts["kv_dim_eff"], 1)
262 xf = np.logspace(lx.min() - 0.05, lx.max() + 0.05, 80)
263 ax2.plot(xf, b * np.log10(xf) + a,
264 color=TYPE_COLORS[t], lw=2.2, ls="--", alpha=0.65)
266 # Annotate slopes
267 ax2.text(0.98, 0.96,
268 "MHA: kv_dim_eff ∝ dim (locked to width)\n"
269 "GQA: kv_dim_eff ∝ dim × (kv_heads/n_heads)\n"
270 "MLA: kv_dim_eff = down_dim_kv (free param)\n\n"
271 "→ MLA's cache doesn't grow with the model",
272 transform=ax2.transAxes, ha="right", va="top",
273 fontsize=9.5, color="#333", style="italic",
274 bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="#ccc", lw=1, alpha=0.92))
276 ax2.set_xscale("log")
277 ax2.set_xlabel("Model parameters (M)", fontsize=11)
278 ax2.set_ylabel("kv_dim_eff (effective KV dim per layer)", fontsize=11)
279 ax2.set_title("The Decoupling: kv_dim_eff vs Model Size\n"
280 "MHA/GQA locked to architecture — MLA freely chosen",
281 fontsize=11, fontweight="bold")
282 ax2.legend(fontsize=10, framealpha=0.9)
283 ax2.grid(alpha=0.2, ls="--")
284 ax2.spines[["top", "right"]].set_visible(False)
286 fig.suptitle(
287 "kv_dim_eff: The Root Cause of MLA's Cache Advantage\n"
288 "For MHA/GQA it is an architectural constant. For MLA it is a free hyperparameter.",
289 fontsize=12, fontweight="bold", y=1.01
290 )
291 plt.tight_layout()
292 _save(fig, "3d_6_kv_decoupling.png")
295# ─────────────────────────────────────────────────────────────────────────────
296# Fig 7 — MLA: down_dim_kv × params_m → val_loss (quality landscape)
297# Both levers affect quality; this surface shows the full tradeoff.
298# ─────────────────────────────────────────────────────────────────────────────
299def fig7_mla_quality(df):
300 mla = df[(df["type"] == "MLA") & (~df["moe"])].dropna(
301 subset=["down_dim_kv", "params_m", "val_loss"]
302 ).copy()
304 # Reference: Dense GQA/MHA median val_loss (what MLA competes against)
305 ref = df[~df["moe"]].dropna(subset=["val_loss"])
306 ref_mha = float(ref[ref["type"] == "MHA"]["val_loss"].median())
307 ref_gqa = float(ref[ref["type"] == "GQA"]["val_loss"].median())
309 xs = mla["params_m"].values
310 ys = mla["down_dim_kv"].values
311 zs = mla["val_loss"].values
313 # Dense interpolated surface
314 xi = np.linspace(xs.min(), xs.max(), 50)
315 yi = np.linspace(ys.min(), ys.max(), 50)
316 XI, YI = np.meshgrid(xi, yi)
317 ZI = griddata((xs, ys), zs, (XI, YI), method="linear")
319 fig = plt.figure(figsize=(16, 8))
320 fig.patch.set_facecolor("#F8F9FA")
322 # ── Left: 3D surface ─────────────────────────────────────────────────────
323 ax = _ax3(fig, (1, 2, 1), elev=22, azim=-48)
325 # Interpolated surface coloured by quality
326 valid = ~np.isnan(ZI)
327 vmin, vmax = float(np.nanmin(ZI)), float(np.nanmax(ZI))
328 surf = ax.plot_surface(XI, YI, ZI,
329 facecolors=cm.RdYlGn_r((ZI - vmin) / (vmax - vmin + 1e-9)),
330 alpha=0.55, linewidth=0, antialiased=True, zorder=3)
332 # Actual data scatter
333 ax.scatter(xs, ys, zs, c=zs, cmap="RdYlGn_r",
334 vmin=vmin, vmax=vmax,
335 s=55, edgecolors="white", lw=0.6, alpha=0.90, zorder=6)
337 # MHA/GQA reference planes
338 xl = [xs.min(), xs.max()]
339 yl = [ys.min(), ys.max()]
340 for ref_val, rcol, rlbl in [(ref_mha, TYPE_COLORS["MHA"], "MHA median"),
341 (ref_gqa, TYPE_COLORS["GQA"], "GQA median")]:
342 verts = [[(xl[0], yl[0], ref_val), (xl[1], yl[0], ref_val),
343 (xl[1], yl[1], ref_val), (xl[0], yl[1], ref_val)]]
344 poly = Poly3DCollection(verts, alpha=0.20)
345 poly.set_facecolor(rcol); poly.set_edgecolor(rcol)
346 ax.add_collection3d(poly)
347 ax.text(xl[1], yl[1], ref_val + 0.01, f" {rlbl}",
348 color=rcol, fontsize=8, fontweight="bold")
350 ax.set_xlabel("Parameters (M)", fontsize=9, labelpad=6)
351 ax.set_ylabel("down_dim_kv", fontsize=9, labelpad=6)
352 ax.set_zlabel("Validation Loss ↓", fontsize=9, labelpad=8)
353 ax.tick_params(labelsize=7)
354 ax.set_title("MLA Quality Landscape\n(surface colour: red=worse, green=better)",
355 fontsize=11, fontweight="bold", pad=8)
357 # Colorbar proxy
358 sm = cm.ScalarMappable(cmap="RdYlGn_r",
359 norm=plt.Normalize(vmin=vmin, vmax=vmax))
360 sm.set_array([])
361 plt.colorbar(sm, ax=ax, shrink=0.5, pad=0.1, label="val_loss")
363 # ── Right: top-down heatmap (cleaner reading) ─────────────────────────────
364 ax2 = fig.add_subplot(1, 2, 2)
365 ax2.set_facecolor("#F8F9FA")
367 # Scatter (actual runs)
368 sc = ax2.scatter(xs, ys, c=zs, cmap="RdYlGn_r",
369 vmin=vmin, vmax=vmax,
370 s=120, edgecolors="white", lw=1, alpha=0.92, zorder=4)
372 # Contour lines on interpolated grid
373 CS = ax2.contour(XI, YI, ZI, levels=8, cmap="RdYlGn_r",
374 vmin=vmin, vmax=vmax, alpha=0.55, linewidths=1.2)
375 ax2.clabel(CS, fmt="%.3f", fontsize=8, inline=True)
377 # MHA/GQA reference horizontal bands
378 for i, (ref_val, rcol, rlbl) in enumerate([(ref_mha, TYPE_COLORS["MHA"], "MHA median"),
379 (ref_gqa, TYPE_COLORS["GQA"], "GQA median")]):
380 ax2.text(xs.max() * 1.01, ys.max() * (0.98 - i * 0.08),
381 f"{rlbl}: {ref_val:.3f}",
382 ha="left", fontsize=8, color=rcol, fontweight="bold")
384 ax2.set_xlabel("Parameters (M)", fontsize=11)
385 ax2.set_ylabel("down_dim_kv", fontsize=11)
386 ax2.set_title("Top-down View: Quality Heatmap\n"
387 "(green = lower loss = better quality)",
388 fontsize=11, fontweight="bold")
389 ax2.grid(alpha=0.15, ls="--")
390 ax2.spines[["top", "right"]].set_visible(False)
391 plt.colorbar(sc, ax=ax2, label="Validation Loss ↓")
393 fig.suptitle(
394 "MLA Quality Landscape: down_dim_kv × Model Size → Validation Loss\n"
395 "Both levers matter — but params_m drives quality more than down_dim_kv",
396 fontsize=12, fontweight="bold", y=1.01
397 )
398 plt.tight_layout()
399 _save(fig, "3d_7_mla_quality.png")
402# ─────────────────────────────────────────────────────────────────────────────
403# Fig 8 — MLA: down_dim_kv × num_blocks → KV cache at multiple seq_lens
404# Two independent knobs that both drive cache independently.
405# ─────────────────────────────────────────────────────────────────────────────
406def fig8_dkv_numblocks(df):
407 mla = df[(df["type"] == "MLA") & (~df["moe"])].dropna(
408 subset=["down_dim_kv", "num_blocks", "kv_dim_eff"]
409 )
411 # Fit: kv_dim_eff ≈ slope * down_dim_kv + intercept
412 agg = mla.groupby("down_dim_kv")["kv_dim_eff"].mean()
413 sl, ic = np.polyfit(agg.index.values.astype(float), agg.values, 1)
415 dkv_range = np.linspace(32, 256, 50)
416 nb_range = np.array([4, 8, 12, 16, 24, 32]) # include realistic LLM depths
418 DKV2D, NB2D = np.meshgrid(dkv_range, nb_range)
419 kve_2d = sl * DKV2D + ic
421 seq_levels = [512, 4096, 32768, 131072]
422 seq_labels = ["512", "4K", "32K", "128K"]
423 colors_seq = ["#3A86FF", "#8338EC", "#FF006E", "#FB5607"]
425 fig = plt.figure(figsize=(18, 9))
426 fig.patch.set_facecolor("#F8F9FA")
428 # ── One subplot per seq_len ───────────────────────────────────────────────
429 for pi, (sl_v, sl_lbl, sl_col) in enumerate(zip(seq_levels, seq_labels, colors_seq)):
430 ax = _ax3(fig, (1, 4, pi + 1), elev=28, azim=-55)
431 Z = kve_2d * NB2D * sl_v * 4 / 1e9 # GB
433 surf = ax.plot_surface(DKV2D, NB2D, Z,
434 color=sl_col, alpha=0.50,
435 linewidth=0, antialiased=True, zorder=3)
437 # Floor heatmap
438 ax.contourf(DKV2D, NB2D, Z, zdir="z", offset=0,
439 levels=np.linspace(0, float(np.nanmax(Z)), 10),
440 cmap="Blues", alpha=0.40)
442 # VRAM planes
443 for vram, vc, vlbl in VRAM_GB:
444 z_cap = float(np.nanmax(Z))
445 if vram > z_cap: continue
446 dkv_lims = [dkv_range[0], dkv_range[-1]]
447 nb_lims = [nb_range[0], nb_range[-1]]
448 verts = [[(dkv_lims[0], nb_lims[0], vram),
449 (dkv_lims[1], nb_lims[0], vram),
450 (dkv_lims[1], nb_lims[1], vram),
451 (dkv_lims[0], nb_lims[1], vram)]]
452 poly = Poly3DCollection(verts, alpha=0.22)
453 poly.set_facecolor(vc); poly.set_edgecolor(vc)
454 ax.add_collection3d(poly)
455 ax.text(dkv_lims[1], nb_lims[0], vram + float(np.nanmax(Z)) * 0.02,
456 f" {vlbl}", color=vc, fontsize=7, fontweight="bold")
458 # Actual data points (projected to the nearest seq_len)
459 for _, row in mla.iterrows():
460 nb_v = int(row["num_blocks"])
461 if nb_v not in nb_range: continue
462 z_pt = row["kv_dim_eff"] * nb_v * sl_v * 4 / 1e9
463 ax.scatter(row["down_dim_kv"], nb_v, z_pt,
464 c="white", s=22, edgecolors=sl_col, lw=1.2,
465 alpha=0.85, zorder=7)
467 ax.set_xlabel("down_dim_kv", fontsize=8, labelpad=5)
468 ax.set_ylabel("num_blocks", fontsize=8, labelpad=5)
469 ax.set_zlabel("KV Cache (GB)", fontsize=8, labelpad=7)
470 ax.tick_params(labelsize=7)
471 ax.set_title(f"seq = {sl_lbl}", fontsize=12,
472 fontweight="bold", color=sl_col, pad=6)
474 fig.suptitle(
475 "MLA KV Cache = f(down_dim_kv, num_blocks) — shown at four context lengths\n"
476 "Both knobs are tunable: reduce down_dim_kv OR use fewer layers to save cache",
477 fontsize=12, fontweight="bold", y=1.01
478 )
479 plt.tight_layout()
480 _save(fig, "3d_8_dkv_num_blocks.png")
483# ─────────────────────────────────────────────────────────────────────────────
484if __name__ == "__main__":
485 os.chdir(os.path.dirname(os.path.abspath(__file__)))
486 df = load()
487 print(f"Loaded {len(df)} runs")
488 print("\nGenerating down_dim_kv 3D figures...")
489 fig5_dkv_cache_seqlen(df)
490 fig6_kv_decoupling(df)
491 fig7_mla_quality(df)
492 fig8_dkv_numblocks(df)
493 print("Done.")