diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py index bbe5ab294..8002038b4 100644 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py +++ b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py @@ -47,11 +47,14 @@ class Hyperparameters: val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + val_max_tokens = int(os.environ.get("VAL_MAX_TOKENS", 0)) + eval_progress_every = int(os.environ.get("EVAL_PROGRESS_EVERY", 0)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -89,6 +92,21 @@ class Hyperparameters: bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + control_enabled = bool(int(os.environ.get("CONTROL_ENABLED", "1"))) + control_detach_gates = bool(int(os.environ.get("CONTROL_DETACH_GATES", "1"))) + control_gate_fuse_mode = os.environ.get("CONTROL_GATE_FUSE_MODE", "k") + control_energy_norm = os.environ.get("CONTROL_ENERGY_NORM", "ema_zscore") + control_energy_clip = float(os.environ.get("CONTROL_ENERGY_CLIP", 0.99)) + control_warmup_tokens = int(os.environ.get("CONTROL_WARMUP_TOKENS", 8)) + control_scan_block_size = int(os.environ.get("CONTROL_SCAN_BLOCK_SIZE", 16)) + compile_model = bool(int(os.environ.get("COMPILE_MODEL", "1"))) + compile_model_target = os.environ.get("COMPILE_MODEL_TARGET", "modules") + compile_fullgraph = bool(int(os.environ.get("COMPILE_FULLGRAPH", "0"))) + compile_zeropower = bool(int(os.environ.get("COMPILE_ZERPOWER", "1"))) + compile_control_scan = bool(int(os.environ.get("COMPILE_CONTROL_SCAN", "1"))) + compile_control_scan_fullgraph = bool(int(os.environ.get("COMPILE_CONTROL_SCAN_FULLGRAPH", "0"))) + compile_mode = os.environ.get("COMPILE_MODE", "reduce-overhead") + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) swa_every = int(os.environ.get("SWA_EVERY", 50)) @@ -201,11 +219,16 @@ def build_sentencepiece_luts( ) -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: +def load_validation_tokens(pattern: str, seq_len: int, max_tokens: int = 0) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + if max_tokens > 0 and tokens.numel() - 1 > max_tokens: + capped = (max_tokens // seq_len) * seq_len + if capped <= 0: + raise ValueError(f"VAL_MAX_TOKENS={max_tokens} is too small for TRAIN_SEQ_LEN={seq_len}") + tokens = tokens[: capped + 1].contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") @@ -240,13 +263,15 @@ def eval_val( val_byte_count = torch.zeros((), device=device, dtype=torch.float64) model.eval() with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + total_local_batches = max((seq_end - seq_start + local_batch_seqs - 1) // local_batch_seqs, 1) + for batch_idx, batch_seq_start in enumerate(range(seq_start, seq_end, local_batch_seqs), start=1): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) raw_start = batch_seq_start * args.train_seq_len raw_end = batch_seq_end * args.train_seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) + _cudagraph_mark_step_begin() with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) @@ -257,6 +282,19 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() + if rank == 0 and args.eval_progress_every > 0 and ( + batch_idx % args.eval_progress_every == 0 or batch_idx == total_local_batches + ): + pct = 100.0 * batch_idx / total_local_batches + running_bpb = 0.0 + if val_token_count.item() > 0 and val_byte_count.item() > 0: + running_loss = (val_loss_sum / val_token_count).item() + running_bpb = running_loss / math.log(2.0) * (val_token_count.item() / val_byte_count.item()) + print( + f" eval [{pct:5.1f}%] {batch_idx}/{total_local_batches} batches " + f"running_bpb={running_bpb:.6f}", + flush=True, + ) if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) @@ -489,6 +527,19 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: param.data = param.data.float() +def _disable_compile(fn): + if hasattr(torch, "compiler") and hasattr(torch.compiler, "disable"): + return torch.compiler.disable(fn) + if hasattr(torch, "_dynamo") and hasattr(torch._dynamo, "disable"): + return torch._dynamo.disable(fn) + return fn + + +def _cudagraph_mark_step_begin() -> None: + if hasattr(torch, "compiler") and hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() + + class Rotary(nn.Module): def __init__(self, dim: int, base: float = 10000.0): super().__init__() @@ -498,6 +549,7 @@ def __init__(self, dim: int, base: float = 10000.0): self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None + @_disable_compile def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None @@ -519,6 +571,271 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +def _causal_ema_scan(x: Tensor, lam: float) -> Tensor: + if lam <= 0.0: + return torch.zeros_like(x) + _, seq_len, _ = x.shape + alpha = 1.0 - lam + y = lam * x + cur_alpha = alpha + stride = 1 + while stride < seq_len: + y[:, stride:, :] = y[:, stride:, :] + cur_alpha * y[:, :-stride, :] + cur_alpha = cur_alpha * cur_alpha + stride *= 2 + return y + + +def _anchored_delta_seq(hidden: Tensor, tok_emb_weight: Tensor, input_ids: Tensor, lam_delta: float) -> tuple[Tensor, Tensor]: + normed = F.rms_norm(hidden, (hidden.size(-1),)) + logits = F.linear(normed, tok_emb_weight.to(normed.dtype)) + logp = F.log_softmax(logits.float(), dim=-1) + next_ids = input_ids[:, 1:] + picked = torch.gather(logp[:, :-1, :], dim=-1, index=next_ids.unsqueeze(-1)) + delta = -picked + zero = torch.zeros((delta.size(0), 1, 1), dtype=delta.dtype, device=delta.device) + delta_seq = torch.cat((zero, delta), dim=1) + return delta_seq, _causal_ema_scan(delta_seq, lam_delta) + + +def _apply_energy_transforms(e_seq: Tensor, energy_norm: str, energy_clip: float | None, lam: float = 0.05) -> Tensor: + out = e_seq.float() + if energy_norm == "ema_zscore": + mean = _causal_ema_scan(out, lam) + mean2 = _causal_ema_scan(out * out, lam) + var = torch.clamp(mean2 - mean * mean, min=1e-6) + out = (out - mean) / torch.sqrt(var) + elif energy_norm == "robust_mad": + med = torch.median(out, dim=1, keepdim=True).values + mad = torch.median(torch.abs(out - med), dim=1, keepdim=True).values + out = (out - med) / (mad + 1e-6) + elif energy_norm != "none": + raise ValueError(f"Unsupported CONTROL_ENERGY_NORM={energy_norm}") + if energy_clip is not None: + if energy_clip <= 1.0: + q = torch.quantile(out, energy_clip, dim=1, keepdim=True) + out = torch.minimum(torch.clamp_min(out, 0.0), q) + else: + out = torch.clamp(out, 0.0, energy_clip) + return out + + +def _control_scan_block_konly( + e_blk: Tensor, + phi_blk: Tensor, + g_blk: Tensor, + logR: Tensor, + P: Tensor, + M: Tensor, + V: Tensor, + E: Tensor, + lam_E: float, + lam_M: float, + lam_V: float, + eta_R: float, + alpha: float, + beta: float, + gamma: float, + v0: float, + Q0: float, + Q1: float, + min_var: float, + max_log_var: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + _, block_len, _ = e_blk.shape + k_blk = torch.empty_like(e_blk) + for t in range(block_len): + R_minus = torch.clamp(torch.exp(logR), min=min_var) + e_t = e_blk[:, t : t + 1, :] + phi_t = phi_blk[:, t : t + 1, :] + g_t = g_blk[:, t : t + 1, :] + z2 = e_t / R_minus + E = (1.0 - lam_E) * E + lam_E * phi_t + M = (1.0 - lam_M) * M + lam_M * z2 + diff = z2 - 1.0 + V = torch.clamp((1.0 - lam_V) * V + lam_V * (diff * diff), min=0.0) + c = torch.sigmoid(alpha * (M - beta) + gamma * (V - v0)) + P_prior = P + (Q0 + Q1 * c) + K = P_prior / torch.clamp(P_prior + R_minus, min=min_var) + logR = torch.clamp(logR + eta_R * (E * g_t * diff), -max_log_var, max_log_var) + P = torch.clamp((1.0 - K) * P_prior, min=min_var) + k_blk[:, t : t + 1, :] = K + return k_blk, logR, P, M, V, E + + +def _torch_compile_if_available(fn, *, fullgraph: bool, mode: str | None): + if not hasattr(torch, "compile"): + return fn + kwargs = {"fullgraph": fullgraph} + if mode and mode.lower() not in {"", "none"}: + kwargs["mode"] = mode + try: + return torch.compile(fn, **kwargs) + except Exception: + return fn + + +def _try_compile_control_scan(fn): + if os.environ.get("COMPILE_CONTROL_SCAN", "1") != "1": + return fn + fullgraph = os.environ.get("COMPILE_CONTROL_SCAN_FULLGRAPH", "0") == "1" + mode = os.environ.get("COMPILE_MODE", "reduce-overhead") + return _torch_compile_if_available(fn, fullgraph=fullgraph, mode=mode) + + +_compiled_control_scan_block_konly = _try_compile_control_scan(_control_scan_block_konly) + + +def _compile_module_if_enabled(module: nn.Module, *, fullgraph: bool, mode: str | None) -> nn.Module: + if not hasattr(torch, "compile"): + return module + kwargs = {"dynamic": False, "fullgraph": fullgraph} + if mode and mode.lower() not in {"", "none"}: + kwargs["mode"] = mode + try: + return torch.compile(module, **kwargs) + except Exception: + return module + + +class InternalControl(nn.Module): + def __init__( + self, + *, + enabled: bool, + detach_gates: bool, + gate_fuse_mode: str, + energy_norm: str, + energy_clip: float | None, + warmup_tokens: int, + scan_block_size: int, + ): + super().__init__() + self.enabled = enabled + self.detach_gates = detach_gates + self.gate_fuse_mode = gate_fuse_mode + self.energy_norm = energy_norm + self.energy_clip = energy_clip + self.warmup_tokens = warmup_tokens + self.scan_block_size = max(scan_block_size, 1) + self.R0 = 1.0 + self.P0 = 1.0 + self.M0 = 1.0 + self.V0 = 0.0 + self.E0 = 0.0 + self.eta_R = 0.02 + self.lam_E = 0.05 + self.lam_M = 0.05 + self.lam_V = 0.05 + self.lam_delta = 0.05 + self.k_g = 0.5 + self.alpha = 4.0 + self.beta = 1.5 + self.gamma = 2.0 + self.v0 = 0.1 + self.Q0 = 0.01 + self.Q1 = 0.25 + self.min_var = 1e-6 + self.max_log_var = 8.0 + + def enabled_for_layer(self, layer_idx: int, total_layers: int) -> bool: + return self.enabled and layer_idx >= total_layers // 2 + + def _anchored_gate(self, hidden: Tensor, tok_emb_weight: Tensor, input_ids: Tensor) -> Tensor: + delta_seq, bar = _anchored_delta_seq(hidden, tok_emb_weight, input_ids, self.lam_delta) + gate = torch.sigmoid(self.k_g * (delta_seq - bar)).float() + return gate.detach() if self.detach_gates else gate + + @_disable_compile + def _k_seq(self, u: Tensor, gate: Tensor) -> Tensor: + batch_size, seq_len, _ = u.shape + dtype = torch.float32 + state_shape = (batch_size, 1, 1) + logR = torch.full(state_shape, math.log(self.R0), dtype=dtype, device=u.device) + P = torch.full(state_shape, self.P0, dtype=dtype, device=u.device) + M = torch.full(state_shape, self.M0, dtype=dtype, device=u.device) + V = torch.full(state_shape, self.V0, dtype=dtype, device=u.device) + E = torch.full(state_shape, self.E0, dtype=dtype, device=u.device) + # Keep the recurrent chunk loop out of the outer model graph. Only the + # fixed-size scan block kernel should be torch.compile'd. + u32 = u.float() + e_seq = torch.mean(u32 * u32, dim=-1, keepdim=True) + e_seq = _apply_energy_transforms(e_seq, self.energy_norm, self.energy_clip) + phi_seq = torch.mean(u32 * u32, dim=-1, keepdim=True) + k_chunks = [] + full_chunks = seq_len // self.scan_block_size + for chunk_idx in range(full_chunks): + start = chunk_idx * self.scan_block_size + end = start + self.scan_block_size + k_blk, logR, P, M, V, E = _compiled_control_scan_block_konly( + e_seq[:, start:end, :], + phi_seq[:, start:end, :], + gate[:, start:end, :], + logR, + P, + M, + V, + E, + self.lam_E, + self.lam_M, + self.lam_V, + self.eta_R, + self.alpha, + self.beta, + self.gamma, + self.v0, + self.Q0, + self.Q1, + self.min_var, + self.max_log_var, + ) + k_chunks.append(k_blk) + remainder = seq_len % self.scan_block_size + if remainder: + start = seq_len - remainder + k_blk, logR, P, M, V, E = _control_scan_block_konly( + e_seq[:, start:, :], + phi_seq[:, start:, :], + gate[:, start:, :], + logR, + P, + M, + V, + E, + self.lam_E, + self.lam_M, + self.lam_V, + self.eta_R, + self.alpha, + self.beta, + self.gamma, + self.v0, + self.Q0, + self.Q1, + self.min_var, + self.max_log_var, + ) + k_chunks.append(k_blk) + k_seq = torch.cat(k_chunks, dim=1) if k_chunks else torch.zeros_like(e_seq) + if self.detach_gates: + k_seq = k_seq.detach() + if self.warmup_tokens > 0: + warm = min(self.warmup_tokens, seq_len) + k_seq = torch.cat((torch.zeros_like(k_seq[:, :warm, :]), k_seq[:, warm:, :]), dim=1) + return k_seq + + def fuse(self, u: Tensor, hidden: Tensor, tok_emb_weight: Tensor, input_ids: Tensor) -> Tensor: + gate = self._anchored_gate(hidden, tok_emb_weight, input_ids) + k_seq = self._k_seq(u, gate) + if self.gate_fuse_mode == "k": + return (k_seq * (gate.detach() if self.detach_gates else gate)).to(u.dtype) * u + if self.gate_fuse_mode == "fusion": + return k_seq.to(u.dtype) * u * (gate.detach() if self.detach_gates else gate) + if self.gate_fuse_mode == "none": + return k_seq.to(u.dtype) * u + raise ValueError(f"Unsupported CONTROL_GATE_FUSE_MODE={self.gate_fuse_mode}") + + class CausalSelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): super().__init__() @@ -622,12 +939,22 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + def mix_residual(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + return mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + def attn_residual(self, x: Tensor) -> Tensor: attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + + def mlp_residual(self, x: Tensor) -> Tensor: + mlp_out = self.mlp(self.mlp_norm(x)) + return self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + x = self.mix_residual(x, x0) + x = x + self.attn_residual(x) + x = x + self.mlp_residual(x) return x @@ -647,6 +974,13 @@ def __init__( qk_gain_init: float, bigram_vocab_size: int = 0, bigram_dim: int = 128, + control_enabled: bool = True, + control_detach_gates: bool = True, + control_gate_fuse_mode: str = "k", + control_energy_norm: str = "ema_zscore", + control_energy_clip: float | None = 0.99, + control_warmup_tokens: int = 8, + control_scan_block_size: int = 64, ): super().__init__() if logit_softcap <= 0.0: @@ -661,6 +995,15 @@ def __init__( self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) self.smear = SmearGate(model_dim) + self.control = InternalControl( + enabled=control_enabled, + detach_gates=control_detach_gates, + gate_fuse_mode=control_gate_fuse_mode, + energy_norm=control_energy_norm, + energy_clip=control_energy_clip, + warmup_tokens=control_warmup_tokens, + scan_block_size=control_scan_block_size, + ) self.blocks = nn.ModuleList( [ Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) @@ -695,13 +1038,27 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.smear(x) x0 = x skips: list[Tensor] = [] + total_layers = len(self.blocks) for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + block = self.blocks[i] + x = block.mix_residual(x, x0) + attn_out = block.attn_residual(x) + if self.control.enabled_for_layer(i, total_layers): + attn_out = self.control.fuse(attn_out, x, self.tok_emb.weight, input_ids) + x = x + attn_out + x = x + block.mlp_residual(x) skips.append(x) for i in range(self.num_decoder_layers): if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + block_idx = self.num_encoder_layers + i + block = self.blocks[block_idx] + x = block.mix_residual(x, x0) + attn_out = block.attn_residual(x) + if self.control.enabled_for_layer(block_idx, total_layers): + attn_out = self.control.fuse(attn_out, x, self.tok_emb.weight, input_ids) + x = x + attn_out + x = x + block.mlp_residual(x) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: @@ -721,13 +1078,27 @@ def forward_logits(self, input_ids: Tensor) -> Tensor: x = self.smear(x) x0 = x skips: list[Tensor] = [] + total_layers = len(self.blocks) for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + block = self.blocks[i] + x = block.mix_residual(x, x0) + attn_out = block.attn_residual(x) + if self.control.enabled_for_layer(i, total_layers): + attn_out = self.control.fuse(attn_out, x, self.tok_emb.weight, input_ids) + x = x + attn_out + x = x + block.mlp_residual(x) skips.append(x) for i in range(self.num_decoder_layers): if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + block_idx = self.num_encoder_layers + i + block = self.blocks[block_idx] + x = block.mix_residual(x, x0) + attn_out = block.attn_residual(x) + if self.control.enabled_for_layer(block_idx, total_layers): + attn_out = self.control.fuse(attn_out, x, self.tok_emb.weight, input_ids) + x = x + attn_out + x = x + block.mlp_residual(x) x = self.final_norm(x) if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) @@ -777,6 +1148,7 @@ def eval_val_sliding( chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) x_batch[i, :wlen] = chunk[:-1] y_batch[i, :wlen] = chunk[1:] + _cudagraph_mark_step_begin() with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits = base_model.forward_logits(x_batch) nll = F.cross_entropy( @@ -825,7 +1197,12 @@ def main() -> None: code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + if args.compile_zeropower: + zeropower_via_newtonschulz5 = _torch_compile_if_available( + zeropower_via_newtonschulz5, + fullgraph=False, + mode=args.compile_mode, + ) distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) @@ -833,9 +1210,12 @@ def main() -> None: local_rank = int(os.environ.get("LOCAL_RANK", "0")) if world_size <= 0: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + else: + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") @@ -893,13 +1273,15 @@ def log0(msg: str, console: bool = True) -> None: ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len, args.val_max_tokens) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_max_tokens > 0: + log0(f"val_loader:capped_tokens requested={args.val_max_tokens} actual={val_tokens.numel() - 1}") # MODEL + OPTIMIZER SETUP base_model = GPT( @@ -916,12 +1298,47 @@ def log0(msg: str, console: bool = True) -> None: qk_gain_init=args.qk_gain_init, bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + control_enabled=args.control_enabled, + control_detach_gates=args.control_detach_gates, + control_gate_fuse_mode=args.control_gate_fuse_mode, + control_energy_norm=args.control_energy_norm, + control_energy_clip=args.control_energy_clip, + control_warmup_tokens=args.control_warmup_tokens, + control_scan_block_size=args.control_scan_block_size, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if args.compile_model and hasattr(torch, "compile"): + if args.compile_model_target == "model": + compile_kwargs = {"dynamic": False, "fullgraph": args.compile_fullgraph} + if args.compile_mode and args.compile_mode.lower() not in {"", "none"}: + compile_kwargs["mode"] = args.compile_mode + compiled_model = torch.compile(base_model, **compile_kwargs) + elif args.compile_model_target == "modules": + for block in base_model.blocks: + block.attn = _compile_module_if_enabled( + block.attn, + fullgraph=args.compile_fullgraph, + mode=args.compile_mode, + ) + block.mlp = _compile_module_if_enabled( + block.mlp, + fullgraph=args.compile_fullgraph, + mode=args.compile_mode, + ) + if base_model.bigram is not None: + base_model.bigram = _compile_module_if_enabled( + base_model.bigram, + fullgraph=args.compile_fullgraph, + mode=args.compile_mode, + ) + compiled_model = base_model + else: + raise ValueError(f"Unsupported COMPILE_MODEL_TARGET={args.compile_model_target}") + else: + compiled_model = base_model model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model block_named_params = list(base_model.blocks.named_parameters()) @@ -992,6 +1409,17 @@ def log0(msg: str, console: bool = True) -> None: f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) + log0( + f"control:enabled:{args.control_enabled} detach_gates:{args.control_detach_gates} " + f"gate_fuse:{args.control_gate_fuse_mode} energy_norm:{args.control_energy_norm} " + f"energy_clip:{args.control_energy_clip} warmup_tokens:{args.control_warmup_tokens} " + f"scan_block_size:{args.control_scan_block_size}" + ) + log0( + f"compile:model:{args.compile_model} target:{args.compile_model_target} fullgraph:{args.compile_fullgraph} " + f"zeropower:{args.compile_zeropower} control_scan:{args.compile_control_scan} " + f"control_scan_fullgraph:{args.compile_control_scan_fullgraph} mode:{args.compile_mode}" + ) log0(f"seed:{args.seed}") # DATA LOADER & MODEL WARMUP @@ -1024,6 +1452,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + _cudagraph_mark_step_begin() with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): warmup_loss = model(x, y) (warmup_loss * grad_scale).backward() @@ -1083,6 +1512,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + _cudagraph_mark_step_begin() with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) train_loss += loss.detach()