Coverage for core / config.py: 89%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 10:39 +0200

1from __future__ import annotations 

2 

3from dataclasses import asdict, dataclass, fields 

4from typing import Any 

5 

6import yaml 

7 

8 

9@dataclass 

10class Config: 

11 # ── Model Architecture ─────────────────────────────────────────────────── 

12 dim: int = 512 

13 n_heads: int = 16 

14 head_size: int = 32 

15 num_blocks: int = 20 

16 vocab_size: int = 200 

17 max_context: int = 512 

18 kv_heads: int = 4 

19 weight_tying: bool = True 

20 activation: str = "gelu" 

21 gradient_checkpointing: bool = True 

22 dropout_rate: float = 0.15 

23 use_swiglu: bool = True 

24 

25 # ── MoE ───────────────────────────────────────────────────────────────── 

26 use_moe: bool = False 

27 n_experts: int = 4 

28 top_k_mlp: int = 2 

29 expansion: int = 4 

30 alpha_balance: float = 0.1 

31 

32 # ── Attention & Positional ─────────────────────────────────────────────── 

33 use_rotary_pos: bool = True 

34 trainable_pos: bool = False 

35 absolute_pos: bool = False 

36 sliding_window: bool = False 

37 context_window: int = 4 

38 no_sink: bool = True 

39 use_flash_attention: bool = False 

40 

41 # ── Multi-Head Latent Attention (MLA) ──────────────────────────────────── 

42 mla: bool = False 

43 inference: bool = False 

44 down_dim_q: int = 256 

45 down_dim_kv: int = 256 

46 rope_dim: int = 32 

47 

48 # ── Tokenizer ─────────────────────────────────────────────────────────── 

49 tokenizer_type: str = "char" 

50 tokenizer_path: str | None = None 

51 

52 # ── Dataset ───────────────────────────────────────────────────────────── 

53 dataset_source: str = "local" 

54 dataset_name: str = "" 

55 dataset_config: str = "" # HF subset/config name (e.g. "en" for allenai/c4) 

56 dataset_text_field: str = "text" # column that contains the raw text 

57 dataset_split: str = "train" # HF split to load 

58 streaming: bool = False 

59 

60 # ── Training & Optimisation ────────────────────────────────────────────── 

61 lr: float = 0.005 

62 batch_size: int = 128 

63 grad_accum: int = 16 

64 seed: int = 42 

65 optimizer: str = "adamw" 

66 epochs: int = 1000 

67 warmup_steps: int = 420 

68 grad_clip: float = 1.0 

69 patience: int = 0 

70 use_bf16: bool = False 

71 

72 # ── Normalisation ──────────────────────────────────────────────────────── 

73 norm_type: str = "layernorm" # "layernorm" | "rmsnorm" 

74 

75 # ── RoPE scaling ───────────────────────────────────────────────────────── 

76 rope_scale_factor: float = 1.0 # >1 compresses frequencies for long-ctx (NTK-aware) 

77 

78 # ── LR schedule ────────────────────────────────────────────────────────── 

79 lr_schedule: str = "cosine" # "cosine" | "linear" | "constant" | "wsd" 

80 

81 # ── LoRA fine-tuning ────────────────────────────────────────────────────── 

82 use_lora: bool = False 

83 lora_rank: int = 8 

84 lora_alpha: float = 16.0 

85 lora_dropout: float = 0.0 

86 lora_targets: str = "attention" # "attention" | "mlp" | "all" 

87 

88 # ── Multi-GPU data-parallel ─────────────────────────────────────────────── 

89 n_devices: int = 0 # 0 = use all available local devices 

90 

91 # ── Logging & Metrics ─────────────────────────────────────────────────── 

92 eval_iters: int = 20 

93 log_file: str = "training_log.csv" 

94 summary_file: str = "model_summary.json" 

95 

96 def __post_init__(self) -> None: 

97 if self.kv_heads is None: 

98 self.kv_heads = self.n_heads // 4 

99 

100 if self.dim != self.n_heads * self.head_size: 

101 raise ValueError( 

102 f"dim ({self.dim}) must equal n_heads * head_size " 

103 f"({self.n_heads} * {self.head_size} = {self.n_heads * self.head_size})" 

104 ) 

105 if self.n_heads % self.kv_heads != 0: 

106 raise ValueError( 

107 f"n_heads ({self.n_heads}) must be divisible by kv_heads ({self.kv_heads})" 

108 ) 

109 if self.mla and self.use_rotary_pos and self.rope_dim > self.head_size: 

110 raise ValueError( 

111 f"rope_dim ({self.rope_dim}) must be <= head_size ({self.head_size}) " 

112 "when using MLA with rotary positional encoding" 

113 ) 

114 if self.norm_type not in ("layernorm", "rmsnorm"): 

115 raise ValueError( 

116 f"norm_type must be 'layernorm' or 'rmsnorm', got {self.norm_type!r}" 

117 ) 

118 if self.lr_schedule not in ("cosine", "linear", "constant", "wsd"): 

119 raise ValueError( 

120 f"lr_schedule must be 'cosine', 'linear', 'constant', or 'wsd', " 

121 f"got {self.lr_schedule!r}" 

122 ) 

123 if self.rope_scale_factor <= 0: 

124 raise ValueError( 

125 f"rope_scale_factor must be > 0, got {self.rope_scale_factor}" 

126 ) 

127 if self.lora_targets not in ("attention", "mlp", "all"): 

128 raise ValueError( 

129 f"lora_targets must be 'attention', 'mlp', or 'all', got {self.lora_targets!r}" 

130 ) 

131 if self.lora_rank < 1: 

132 raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}") 

133 if self.n_devices < 0: 

134 raise ValueError(f"n_devices must be >= 0, got {self.n_devices}") 

135 

136 def __repr__(self) -> str: 

137 attn = "MLA" if self.mla else ("GQA" if self.kv_heads < self.n_heads else "MHA") 

138 moe = "+MoE" if self.use_moe else "" 

139 return ( 

140 f"Config(dim={self.dim}, heads={self.n_heads}, blocks={self.num_blocks}, " 

141 f"ctx={self.max_context}, attn={attn}{moe})" 

142 ) 

143 

144 # ── Serialisation ──────────────────────────────────────────────────────── 

145 

146 def to_dict(self) -> dict[str, Any]: 

147 """Return a plain dict of all config fields.""" 

148 return asdict(self) 

149 

150 @classmethod 

151 def from_dict(cls, d: dict[str, Any]) -> Config: 

152 """Construct a Config from a plain dict, ignoring unknown keys.""" 

153 valid = {f.name for f in fields(cls)} 

154 return cls(**{k: v for k, v in d.items() if k in valid}) 

155 

156 @classmethod 

157 def from_yaml(cls, path: str) -> Config: 

158 """Load a Config from a YAML file (flat or sectioned).""" 

159 with open(path) as f: 

160 raw = yaml.safe_load(f) 

161 

162 flat: dict[str, Any] = {} 

163 for v in raw.values(): 

164 if isinstance(v, dict): 

165 flat.update(v) 

166 if not flat: 

167 flat = raw 

168 

169 return cls.from_dict(flat) 

170 

171 def save_yaml(self, path: str) -> None: 

172 """Write the config to a YAML file.""" 

173 with open(path, "w") as f: 

174 yaml.dump(self.to_dict(), f)