Coverage for core / config.py: 89%
110 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 10:39 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 10:39 +0200
1from __future__ import annotations
3from dataclasses import asdict, dataclass, fields
4from typing import Any
6import yaml
9@dataclass
10class Config:
11 # ── Model Architecture ───────────────────────────────────────────────────
12 dim: int = 512
13 n_heads: int = 16
14 head_size: int = 32
15 num_blocks: int = 20
16 vocab_size: int = 200
17 max_context: int = 512
18 kv_heads: int = 4
19 weight_tying: bool = True
20 activation: str = "gelu"
21 gradient_checkpointing: bool = True
22 dropout_rate: float = 0.15
23 use_swiglu: bool = True
25 # ── MoE ─────────────────────────────────────────────────────────────────
26 use_moe: bool = False
27 n_experts: int = 4
28 top_k_mlp: int = 2
29 expansion: int = 4
30 alpha_balance: float = 0.1
32 # ── Attention & Positional ───────────────────────────────────────────────
33 use_rotary_pos: bool = True
34 trainable_pos: bool = False
35 absolute_pos: bool = False
36 sliding_window: bool = False
37 context_window: int = 4
38 no_sink: bool = True
39 use_flash_attention: bool = False
41 # ── Multi-Head Latent Attention (MLA) ────────────────────────────────────
42 mla: bool = False
43 inference: bool = False
44 down_dim_q: int = 256
45 down_dim_kv: int = 256
46 rope_dim: int = 32
48 # ── Tokenizer ───────────────────────────────────────────────────────────
49 tokenizer_type: str = "char"
50 tokenizer_path: str | None = None
52 # ── Dataset ─────────────────────────────────────────────────────────────
53 dataset_source: str = "local"
54 dataset_name: str = ""
55 dataset_config: str = "" # HF subset/config name (e.g. "en" for allenai/c4)
56 dataset_text_field: str = "text" # column that contains the raw text
57 dataset_split: str = "train" # HF split to load
58 streaming: bool = False
60 # ── Training & Optimisation ──────────────────────────────────────────────
61 lr: float = 0.005
62 batch_size: int = 128
63 grad_accum: int = 16
64 seed: int = 42
65 optimizer: str = "adamw"
66 epochs: int = 1000
67 warmup_steps: int = 420
68 grad_clip: float = 1.0
69 patience: int = 0
70 use_bf16: bool = False
72 # ── Normalisation ────────────────────────────────────────────────────────
73 norm_type: str = "layernorm" # "layernorm" | "rmsnorm"
75 # ── RoPE scaling ─────────────────────────────────────────────────────────
76 rope_scale_factor: float = 1.0 # >1 compresses frequencies for long-ctx (NTK-aware)
78 # ── LR schedule ──────────────────────────────────────────────────────────
79 lr_schedule: str = "cosine" # "cosine" | "linear" | "constant" | "wsd"
81 # ── LoRA fine-tuning ──────────────────────────────────────────────────────
82 use_lora: bool = False
83 lora_rank: int = 8
84 lora_alpha: float = 16.0
85 lora_dropout: float = 0.0
86 lora_targets: str = "attention" # "attention" | "mlp" | "all"
88 # ── Multi-GPU data-parallel ───────────────────────────────────────────────
89 n_devices: int = 0 # 0 = use all available local devices
91 # ── Logging & Metrics ───────────────────────────────────────────────────
92 eval_iters: int = 20
93 log_file: str = "training_log.csv"
94 summary_file: str = "model_summary.json"
96 def __post_init__(self) -> None:
97 if self.kv_heads is None:
98 self.kv_heads = self.n_heads // 4
100 if self.dim != self.n_heads * self.head_size:
101 raise ValueError(
102 f"dim ({self.dim}) must equal n_heads * head_size "
103 f"({self.n_heads} * {self.head_size} = {self.n_heads * self.head_size})"
104 )
105 if self.n_heads % self.kv_heads != 0:
106 raise ValueError(
107 f"n_heads ({self.n_heads}) must be divisible by kv_heads ({self.kv_heads})"
108 )
109 if self.mla and self.use_rotary_pos and self.rope_dim > self.head_size:
110 raise ValueError(
111 f"rope_dim ({self.rope_dim}) must be <= head_size ({self.head_size}) "
112 "when using MLA with rotary positional encoding"
113 )
114 if self.norm_type not in ("layernorm", "rmsnorm"):
115 raise ValueError(
116 f"norm_type must be 'layernorm' or 'rmsnorm', got {self.norm_type!r}"
117 )
118 if self.lr_schedule not in ("cosine", "linear", "constant", "wsd"):
119 raise ValueError(
120 f"lr_schedule must be 'cosine', 'linear', 'constant', or 'wsd', "
121 f"got {self.lr_schedule!r}"
122 )
123 if self.rope_scale_factor <= 0:
124 raise ValueError(
125 f"rope_scale_factor must be > 0, got {self.rope_scale_factor}"
126 )
127 if self.lora_targets not in ("attention", "mlp", "all"):
128 raise ValueError(
129 f"lora_targets must be 'attention', 'mlp', or 'all', got {self.lora_targets!r}"
130 )
131 if self.lora_rank < 1:
132 raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}")
133 if self.n_devices < 0:
134 raise ValueError(f"n_devices must be >= 0, got {self.n_devices}")
136 def __repr__(self) -> str:
137 attn = "MLA" if self.mla else ("GQA" if self.kv_heads < self.n_heads else "MHA")
138 moe = "+MoE" if self.use_moe else ""
139 return (
140 f"Config(dim={self.dim}, heads={self.n_heads}, blocks={self.num_blocks}, "
141 f"ctx={self.max_context}, attn={attn}{moe})"
142 )
144 # ── Serialisation ────────────────────────────────────────────────────────
146 def to_dict(self) -> dict[str, Any]:
147 """Return a plain dict of all config fields."""
148 return asdict(self)
150 @classmethod
151 def from_dict(cls, d: dict[str, Any]) -> Config:
152 """Construct a Config from a plain dict, ignoring unknown keys."""
153 valid = {f.name for f in fields(cls)}
154 return cls(**{k: v for k, v in d.items() if k in valid})
156 @classmethod
157 def from_yaml(cls, path: str) -> Config:
158 """Load a Config from a YAML file (flat or sectioned)."""
159 with open(path) as f:
160 raw = yaml.safe_load(f)
162 flat: dict[str, Any] = {}
163 for v in raw.values():
164 if isinstance(v, dict):
165 flat.update(v)
166 if not flat:
167 flat = raw
169 return cls.from_dict(flat)
171 def save_yaml(self, path: str) -> None:
172 """Write the config to a YAML file."""
173 with open(path, "w") as f:
174 yaml.dump(self.to_dict(), f)