diff --git a/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/README.md b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/README.md new file mode 100644 index 0000000000..aefb190b93 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/README.md @@ -0,0 +1,62 @@ +# Record: Parallel Residuals + Mini Depth Recurrence + +**val_bpb: 1.1063** (3-seed mean, std 0.0017) | **1.8679 nats** | **~15.94 MB** | 8×H100 SXM, 600s | No TTT + +I started this submission from [PR #1179](https://github.com/openai/parameter-golf/pull/1179), which gave me the base training stack I wanted to iterate on here. On top of that, I ported over the mixed-quantization and autoregressive GPTQ path from [PR #1105](https://github.com/openai/parameter-golf/pull/1105). That was partly a modeling choice and partly a practical one: AR self-generated GPTQ calibration was already a known acceptable path for this challenge, and it let me avoid having the quantization step depend on last-minute training-data access in a way that makes the 10-minute budget awkward to manage. + +## Results (8×H100 80GB SXM, 600s, no TTT) + +| Seed | Steps | ms/step | Post-EMA BPB | **Sliding BPB** | val_loss (nats) | Artifact | +|------|-------|---------|--------------|-----------------|-----------------|----------| +| 1337 | 6,242 | 96.1 | 1.1232 | **1.1066** | 1.8684 | 15,942,395 | +| 42 | 6,248 | 96.0 | 1.1235 | **1.1077** | 1.8704 | 15,919,617 | +| 2024 | 6,240 | 96.2 | 1.1216 | **1.1044** | 1.8648 | 15,946,657 | +| **Mean** | **6,243** | **96.1** | **1.1228** | **1.1063** | **1.8679** | **15,936,223** | + +Comparison baseline [PR #1179](https://github.com/openai/parameter-golf/pull/1179): **1.11053346 BPB** (**1.87508426 nats**). +This run's exact 3-seed mean: **1.10625353 BPB** (**1.86785780 nats**). +Delta vs PR #1179: **-0.00722646 nats** (**-0.00427993 BPB**). + +Current merged SOTA ([2026-03-25 AR Self-Gen GPTQ + XSA-all + BigramHash 3072×112](https://github.com/openai/parameter-golf/blob/main/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md)): **1.11473509 BPB** (**1.88217853 nats**). +Delta vs current merged SOTA: **-0.01432073 nats** (**-0.00848156 BPB**). + +## Parallel residuals + +I took this idea from my modded-nanogpt record in [KellerJordan/modded-nanogpt PR #230](https://github.com/KellerJordan/modded-nanogpt/pull/230) and adapted it to this codebase. + +Chronologically, this change actually came last. I am putting it first here because it ended up being the single biggest gain on top of the base + mini-depth-recurrence stack: relative to the under-budget mini-DR baseline (`1.8705` val loss / `1.1078` BPB in sliding-window eval), it improved things by roughly another `0.0037` nats and `0.0022` BPB, landing around `1.8668` / `1.1056`. But this is still a one-sample observation, so I do not want to overstate the precision of that delta. + +Starting from layer 7, attention and MLP read from different residual lanes, and each sublayer learns how strongly to write back into both lanes. + +One interesting pattern is that the learned routing is quite asymmetric, which is also what I saw in the modded-nanogpt run: MLP barely writes back into attention's residual stream, especially in the deeper partitioned layers. + +| Virtual layer | Physical layer | `attn_to_attn` | `attn_to_mlp` | `mlp_to_attn` | `mlp_to_mlp` | +|---|---:|---:|---:|---:|---:| +| 9 | 7 | 1.3030 | 0.8484 | 0.3851 | 1.3043 | +| 10 | 8 | 2.0972 | 0.8114 | 0.0557 | 1.7884 | +| 11 | 9 | 0.4523 | 0.9251 | 0.0098 | 0.2692 | +| 12 | 10 | 1.0153 | -0.0160 | 0.0844 | 0.0844 | + +Despite that pattern, I also tried the followup optimization from [modded-nanogpt PR #241](https://github.com/KellerJordan/modded-nanogpt/pull/241), where MLP simply does not write to the attention lane at all in order to get a speedup. In this repo that brought a slight regression, so I kept the original parallel-residual formulation instead. + +## Mini Depth Recurrence + +Note: Most of the recurrence sweeps under this section were run on an older baseline, and I later transferred the final recipe over to the newer baseline used for this submission. + +After some early failed attempts at full recurrence, I backed off to a much smaller version of the idea: instead of recurring the whole stack, I only repeated a couple of middle layers. I had already convinced myself from over-budget probes that extra depth was real, so the question became how much of that gain I could recover with minimal weight sharing. + +The main sweeps were simple but informative. Repeating one layer helped, repeating two consecutive layers helped more, and repeating three was already losing to the step-time penalty. I also swept the position of the repeated pair and found a clear sweet spot at layers `4,5`, right around the U-Net hinge point. So the useful regime here was not “add recurrence everywhere”, it was “reuse a very small part of the middle of the stack.” + +The next improvement was to turn recurrence on only mid training. Since repeated layers slow every step down, I trained the cheaper non-recurrent model first and only activated recurrence later. In the earlier sweep, always-on recurrence reached about `1.1163` BPB post-TTT, while delayed recurrence improved that to about `1.1153`, with `RECUR_START_STEP=3000` working well. + +Finally, because mixed precision left me some parameter budget headroom, I found that the best place to spend it was untying the repeated MLPs while leaving the rest of the recurrent block shared. That gave another small but real improvement. Roughly speaking, mini depth recurrence was worth about `0.003-0.004` nats and `0.002-0.003` BPB over the best under-budget non-recurrent depth probe I had at the time. + +## Reproducibility + +The main training runs for this submission used the following command: + +```bash +SEED=$SEED POST_GPTQ_EVAL_ONLY=0 BIGRAM_DIM=112 MIXED_QUANT=1 N_INT6_LAYERS=32 NUM_LAYERS=11 RECUR_LAYERS=4,5 RECUR_START_STEP=3000 REPEAT_UNTIE_MLP=full REPEAT_UNTIE_MLP_LAYERS=4,5 DISABLE_LAYER0_ATTN=1 PARALLEL_RESIDUAL=1 PARALLEL_START_LAYER=7 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +`brotli` also needs to be installed for the final artifact path. It is included in the copied [requirements.txt](/root/parameter-golf/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/requirements.txt). diff --git a/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/requirements.txt b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/requirements.txt new file mode 100644 index 0000000000..7d206a3220 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece +brotli \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/submission.json b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/submission.json new file mode 100644 index 0000000000..1f71c77207 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/submission.json @@ -0,0 +1,51 @@ +{ + "author": "Marko Sisovic", + "github_id": "msisovic", + "name": "Parallel Residuals + Mini Depth Recurrence", + "blurb": "Built from PR #1179 with AR self-generated GPTQ and mixed quantization ported from PR #1105. Adds parallel residual routing from layer 7 plus delayed mini depth recurrence on layers 4,5 with untied repeated MLPs. Exact 3-seed mean: 1.10625353 BPB / 1.86785780 nats, improving on PR #1179 by 0.00722646 nats and on the current merged SOTA by 0.01432073 nats.", + "date": "2026-03-31", + "track": "10min_16mb", + "val_loss": 1.86785780, + "val_bpb": 1.10625353, + "val_loss_std": 0.00283270, + "val_bpb_std": 0.00167769, + "seeds": [1337, 42, 2024], + "seed_results": { + "1337": { + "val_loss": 1.86841284, + "val_bpb": 1.10658225, + "artifact_bytes": 15942395, + "steps": 6242, + "step_avg_ms": 96.14 + }, + "42": { + "val_loss": 1.87037189, + "val_bpb": 1.10774252, + "artifact_bytes": 15919617, + "steps": 6248, + "step_avg_ms": 96.04 + }, + "2024": { + "val_loss": 1.86478866, + "val_bpb": 1.10443581, + "artifact_bytes": 15946657, + "steps": 6240, + "step_avg_ms": 96.16 + } + }, + "comparison_baseline_pr": 1179, + "delta_vs_pr1179_nats": -0.00722646, + "delta_vs_pr1179_bpb": -0.00427993, + "merged_sota_reference": "records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md", + "delta_vs_current_merged_sota_nats": -0.01432073, + "delta_vs_current_merged_sota_bpb": -0.00848156, + "artifact_bytes_mean": 15936223, + "artifact_bytes_max": 15946657, + "bytes_total": 15946657, + "code_bytes": 93853, + "train_steps_mean": 6243.33, + "step_avg_ms_mean": 96.11, + "hardware": "8xH100 80GB SXM", + "calibration": "AR self-generated GPTQ calibration", + "technique_summary": "Parallel residuals + mini depth recurrence + mixed quantization + AR self-generated GPTQ + Brotli compression" +} diff --git a/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/train_gpt.py b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/train_gpt.py new file mode 100644 index 0000000000..ab11274f7b --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/train_gpt.py @@ -0,0 +1,880 @@ +from __future__ import annotations +_i='passthrough_ctrl' +_h='passthrough_orig_dtypes' +_g='dtypes' +_f='scales' +_e='quantized' +_d='per_row' +_c='scheme' +_b='torch.' +_a='momentum' +_Z='shard_mom' +_Y='padded_grad' +_X='fineweb_train_*.bin' +_W='little' +_V='.scale' +_U='mlp_down_bank' +_T='mlp_up_bank' +_S='kv_bank' +_R='qo_bank' +_Q='X.size(-1) + if transposed:X=X.mT + X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) + for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X + if transposed:X=X.mT + if was_2d:X=X.squeeze(0) + return X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C + def _build(self): + self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] + for group in self.param_groups: + for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) + self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B + def launch_reduce_scatters(self): + if not self._built:self._build() + if not self._distributed:return + self._rs_futures=[] + for m in self._bank_meta: + p=m['p'] + if p.grad is _A:self._rs_futures.append(_A);continue + pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0]>m['B']:pg[m['B']:].zero_() + fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) + @torch.no_grad() + def step(self,closure=_A): + B='_rs_futures';A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + if not self._built:self._build() + for group in self.param_groups: + lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) + for(i,m)in enumerate(self._bank_meta): + p=m['p'] + if p.grad is _A:continue + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] + else: + g=p.grad.bfloat16();state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A] + buf.mul_(momentum).add_(g) + if nesterov:update=g.add(buf,alpha=momentum) + else:update=buf + update=zeropower_via_newtonschulz5(update,steps=backend_steps) + if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m + else: + if wd>_E:p.data.mul_(_D-lr*wd) + p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if hasattr(self,B):del self._rs_futures + return loss +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=_C + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode(_I)) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) + if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def quantize_state_dict_int8(state_dict): + F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) + for(name,tensor)in state_dict.items(): + t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) + if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue + if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue + stats[D]+=1;q,s=quantize_float_tensor(t) + if s.ndim>0:qmeta[name]={_c:_d,'axis':0} + quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) + obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} + if qmeta:obj['qmeta']=qmeta + if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes + return obj,stats +def dequantize_state_dict_int8(obj): + out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) + for(name,q)in obj[_e].items(): + dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] + if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() + else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() + for(name,t)in obj[_L].items(): + out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() + out[name]=out_t + return out +def load_data_shard(file): + header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes + if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") + tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) + if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") + return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) +_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize +_SHARD_NTOKENS_CACHE={} +_MMAP_CACHE={} +def _read_num_tokens(file): + key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) + if cached is not _A:return cached + header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n +def _get_shard_memmap(file): + key=str(file);mm=_MMAP_CACHE.get(key) + if mm is not _A:return mm + n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm +class DistributedTokenLoader: + def __init__(self,pattern,rank,world_size,device): + self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] + if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 + for f in self.files: + for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff + self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 + def _pick_coprime_stride(self,n): + if n<=1:return 1 + while _B: + s=int(self._rng.integers(1,n)) + if math.gcd(s,n)==1:return s + def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +def apply_canon_residual(x,w): + w=w.to(dtype=x.dtype);y=x*w[0][_A,_A,:] + y=y+F.pad(x[:,:-1],(0,0,1,0))*w[1][_A,_A,:] + y=y+F.pad(x[:,:-2],(0,0,2,0))*w[2][_A,_A,:] + y=y+F.pad(x[:,:-3],(0,0,3,0))*w[3][_A,_A,:] + return x+y +class CastedLinear(nn.Linear): + _qat_enabled=_C;_qat_alpha=_D + def forward(self,x): + w=self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q + bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(name,param)in module.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.recur_layers=sorted(set(recur_layers or[]));self.repeat_untie_mlp=repeat_untie_mlp + self.canon_ac_layers=sorted(set(canon_ac_layers or[]));self._canon_ac_layer_set=set(self.canon_ac_layers) + for cl in self.canon_ac_layers: + if not 0<=cl0: + head_dim=model_dim//num_heads + for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) + self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if xsa_last_n>0: + for i in range(max(0,self.virtual_num_layers-xsa_last_n),self.virtual_num_layers):self.blocks[i].attn.use_xsa=_B + self.set_recurrence_active(recurrence_active);self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) + n=self.num_layers;proj_scale=_D/math.sqrt(2*n) + for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) + for repeat_mlp in self.repeat_mlp: + if repeat_mlp.fc is not _A:nn.init.zeros_(repeat_mlp.fc.weight) + if repeat_mlp.proj is not _A:nn.init.zeros_(repeat_mlp.proj.weight) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def set_recurrence_active(self,active): + was_active=getattr(self,'_recurrence_active',_C);self._recurrence_active=bool(active)and bool(self.recur_layers) + if self._recurrence_active:self.v2p=self._v2p_recur;self.num_encoder_layers=self._enc_recur;self.num_decoder_layers=self._dec_recur + else:self.v2p=self._v2p_no_recur;self.num_encoder_layers=self._enc_no_recur;self.num_decoder_layers=self._dec_no_recur + if self._recurrence_active and not was_active and self.repeat_mlp:self._sync_repeat_mlp_from_base() + def _sync_repeat_mlp_from_base(self): + with torch.no_grad(): + for(repeat_idx,physical_idx)in enumerate(self.recur_layers): + repeat_mlp=self.repeat_mlp[repeat_idx] + if repeat_mlp.fc is not _A:repeat_mlp.fc.weight.copy_(self.mlp_up_bank[physical_idx]) + if repeat_mlp.proj is not _A:repeat_mlp.proj.weight.copy_(self.mlp_down_bank[physical_idx]) + def _is_repeated_virtual_index(self,virtual_idx):return self._recurrence_active and bool(self.recur_layers) and self._enc_recur<=virtual_idx=self.parallel_start_layer + return virtual_idx>=self.parallel_start_layer + def _mix_with_x0(self,lane,x0,resid_mix): + mix=resid_mix.to(dtype=lane.dtype);return mix[0][_A,_A,:]*lane+mix[1][_A,_A,:]*x0 + def _apply_skip_single(self,x,skip,i): + if isinstance(skip,tuple):skip=skip[1] + g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skip;return torch.lerp(scaled_skip,x,g) + def _apply_skip_parallel(self,lane0,lane1,skip,i): + if isinstance(skip,tuple):skip0,skip1=skip + else:skip0=skip1=skip + g=torch.sigmoid(self.skip_gates[i].to(dtype=lane0.dtype))[_A,_A,:];w=self.skip_weights[i].to(dtype=lane0.dtype)[_A,_A,:] + return torch.lerp(w*skip0,lane0,g),torch.lerp(w*skip1,lane1,g) + def _final_parallel_hidden(self,lane0,lane1): + # The branch starts as a clone, so average the summed lanes to keep the output scale close to the single-lane path. + return (lane0+lane1)*.5 + def _parallel_block(self,virtual_idx,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=_A,canon_a_w=_A,canon_c_w=_A): + block=self.blocks[virtual_idx];physical_idx=self.v2p[virtual_idx] + if not block.disable_attn: + attn_read=self._mix_with_x0(lane0,x0,block.resid_mix);attn_in=block.attn_norm(attn_read)*block.ln_scale_factor + if canon_a_w is not _A:attn_in=apply_canon_residual(attn_in,canon_a_w) + attn_out=block.attn(attn_in,q_w,k_w,v_w,out_w,v_embed=v_embed);attn_out=block.attn_scale.to(dtype=attn_out.dtype)[_A,_A,:]*attn_out;resid=self.parallel_resid_lambdas[physical_idx,0].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,0].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*attn_out;lane1=resid*lane1+post[1]*attn_out + mlp_read=self._mix_with_x0(lane1,x0,block.resid_mix);mlp_in=block.mlp_norm(mlp_read)*block.ln_scale_factor + if canon_c_w is not _A:mlp_in=apply_canon_residual(mlp_in,canon_c_w) + mlp_out=block.mlp_scale.to(dtype=lane1.dtype)[_A,_A,:]*block.mlp(mlp_in,up_w,down_w);resid=self.parallel_resid_lambdas[physical_idx,1].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,1].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*mlp_out;lane1=resid*lane1+post[1]*mlp_out;return lane0,lane1 + def _get_block_weights(self,virtual_idx): + n=self.num_layers;physical_idx=self.v2p[virtual_idx];q_w=self.qo_bank[physical_idx];k_w=self.kv_bank[physical_idx];v_w=self.kv_bank[n+physical_idx];out_w=self.qo_bank[n+physical_idx];up_w=self.mlp_up_bank[physical_idx];down_w=self.mlp_down_bank[physical_idx];canon_a_w=self.canon_a[physical_idx]if self.canon_a is not _A and physical_idx in self._canon_ac_layer_set else _A;canon_c_w=self.canon_c[physical_idx]if self.canon_c is not _A and physical_idx in self._canon_ac_layer_set else _A + if self._is_repeated_virtual_index(virtual_idx): + repeated_idx=virtual_idx-self._enc_recur + if self.repeat_mlp: + repeat_mlp=self.repeat_mlp[repeated_idx] + if repeat_mlp.fc is not _A:up_w=repeat_mlp.fc.weight + if repeat_mlp.proj is not _A:down_w=repeat_mlp.proj.weight + return q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w + def _backbone(self,input_ids): + x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={};lane0=lane1=_A + for i in range(self.num_encoder_layers): + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(i);ve=self._get_ve(i,input_ids,ve_cache) + if self._parallel_active_for_layer(i): + if lane0 is _A:lane0=lane1=x + lane0,lane1=self._parallel_block(i,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append((lane0,lane1)) + else:x=self.blocks[i](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(bi);ve=self._get_ve(bi,input_ids,ve_cache) + if self._parallel_active_for_layer(bi): + if lane0 is _A:lane0=lane1=x + if skips:lane0,lane1=self._apply_skip_parallel(lane0,lane1,skips.pop(),i) + lane0,lane1=self._parallel_block(bi,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + else: + if skips:x=self._apply_skip_single(x,skips.pop(),i) + x=self.blocks[bi](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + return self.final_norm(self._final_parallel_hidden(lane0,lane1) if lane1 is not _A else x) + def forward(self,input_ids,target_ids): + x=self._backbone(input_ids);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) + if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) + else: + if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + logits_proj=self.lm_head(x_flat) + logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') + def forward_hidden(self,input_ids):return self._backbone(input_ids) + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] + if use_slot: + with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) + H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) + for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:-1].reshape(-1,adapted.size(-1)),y_batch[:,:seq_len-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() + with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) + else: + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) + with torch.no_grad(): + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] + for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] + for(name,p)in base_model.named_parameters(): + freeze=_C + for bi in frozen_block_ids: + if f"blocks.{bi}."in name:freeze=_B;break + if freeze:p.requires_grad_(_C) + else:p.requires_grad_(_B);ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() + for ci in range(num_chunks): + windows=chunk_windows[ci] + if not windows:continue + chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() + with torch.inference_mode(): + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] + with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + is_last_chunk=ci==num_chunks-1 + if not is_last_chunk and args.ttt_epochs>0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg[_H]=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): + be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_tokens.numel():continue + local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() + if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) + for p in base_model.parameters():p.requires_grad_(_B) + base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb +def generate_autoregressive_calib(model,device,num_seqs=64,seq_len=2048,vocab_size=1024,temperature=.8,batch_size=8,seed=42): + was_training=model.training;model.eval();rng=torch.Generator(device=device);rng.manual_seed(seed);all_tokens=[] + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for batch_start in range(0,num_seqs,batch_size): + bs=min(batch_size,num_seqs-batch_start);tokens=torch.randint(0,vocab_size,(bs,1),device=device,generator=rng) + for _ in range(seq_len-1): + logits=model.forward_logits(tokens);next_logit=logits[:,-1,:];probs=torch.softmax(next_logit/max(temperature,1e-4),dim=-1);next_tok=torch.multinomial(probs,1,generator=rng);tokens=torch.cat([tokens,next_tok],dim=1) + for i in range(bs):all_tokens.append(tokens[i:i+1].detach().clone()) + model.train(was_training);return all_tokens +def gptq_collect_hessians_from_tokens(base_model,token_seqs,device): + dim=base_model.tok_emb.weight.shape[1];mlp_dim=base_model.mlp_up_bank.shape[1];hessians=_init_hessians(base_model,dim,mlp_dim,device) + for block in base_model.blocks:block.attn._save_gptq=_B;block.mlp._save_gptq=_B + was_training=base_model.training;base_model.eval() + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for seq in token_seqs:x=seq[:,:-1].to(device=device,dtype=torch.int64);y=seq[:,1:].to(device=device,dtype=torch.int64);base_model(x,y);_accum_hessians(hessians,base_model,dim,mlp_dim) + for block in base_model.blocks:block.attn._save_gptq=_C;block.mlp._save_gptq=_C + _finalize_hessians(hessians,max(len(token_seqs),1));base_model.train(was_training);return hessians +def _classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if name.startswith('canon_a'):return'attn' + if name.startswith('canon_c'):return'mlp' + if A in name or name.startswith('repeat_mlp.'):return'mlp' + if'.attn.'in name or'.proj.'in name and A not in name:return'attn' + return'other' +def _parse_layer_list(layers_str): + return[int(x)for x in layers_str.split(',')if x.strip()] +def _get_block_idx_from_name(name): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + return _A +def _get_physical_layer_idx_from_name(name,recur_layers): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + if len(parts)>2 and parts[0]=='repeat_mlp'and parts[1].isdigit(): + repeat_idx=int(parts[1]) + if 0<=repeat_idx0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def _unbank_state_dict(sd,num_layers): + out={};n=num_layers + for(name,tensor)in sd.items(): + if name==_R: + for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] + elif name==_S: + for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] + elif name==_T: + for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] + elif name==_U: + for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] + else:out[name]=tensor + return out +def _rebank_state_dict(sd,num_layers,template_sd): + out={};n=num_layers;qo_slices=[template_sd[_R][i]for i in range(2*n)];kv_slices=[template_sd[_S][i]for i in range(2*n)];up_slices=[template_sd[_T][i]for i in range(n)];down_slices=[template_sd[_U][i]for i in range(n)];consumed=set() + for i in range(n): + qk=f"blocks.{i}.attn.c_q.weight" + if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) + ok=f"blocks.{i}.attn.proj.weight" + if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) + kk=f"blocks.{i}.attn.c_k.weight" + if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) + vk=f"blocks.{i}.attn.c_v.weight" + if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) + fk=f"blocks.{i}.mlp.fc.weight" + if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) + dk=f"blocks.{i}.mlp.proj.weight" + if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) + out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) + for(name,tensor)in sd.items(): + if name not in consumed:out[name]=tensor + return out +def _drop_disabled_layer0_attn_unbanked(sd,disable_layer0_attn): + if not disable_layer0_attn:return sd + disabled_keys={'blocks.0.attn.c_q.weight','blocks.0.attn.c_k.weight','blocks.0.attn.c_v.weight','blocks.0.attn.proj.weight'} + return{k:v for(k,v)in sd.items()if k not in disabled_keys} +def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A,clip_ranges=_A): + A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=_classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue + if cat in int6_cats and t.ndim>=1: + H=hessians.get(name)if hessians else _A;cr=clip_ranges.get(name,clip_range)if isinstance(clip_ranges,dict)else clip_range + if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=cr);gptq_count+=1 + else:q,s=quantize_int6_per_row(t,clip_range=cr);naive_count+=1 + result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'if cr>=31 else 'int5'} + else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} + if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_L,_i,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+_V] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): + W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");recur_layers=_parse_layer_list(args.recur_layers_str);repeat_untie_mlp_layers=_parse_layer_list(args.repeat_untie_mlp_layers);canon_ac_layers=_parse_layer_list(args.canon_ac_layers) + if args.post_gptq_eval_only: + eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=bool(recur_layers),repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model) + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);template_sd={k:v.detach().cpu()for(k,v)in eval_model.state_dict().items()};template_unbanked=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(template_sd,args.num_layers),args.disable_layer0_attn);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],template_unbanked);eval_model.load_state_dict(_rebank_state_dict(deq_unbanked,args.num_layers,template_sd),strict=_B);q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) + token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] + if base_model.bigram is not _A: + tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) + if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not _A: + tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) + if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales:scalar_params.append(s) + optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups:group[A]=args.matrix_lr + optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) + for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) + replicated_params.extend(scalar_params);optimizer_head=_A + if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) + optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] + if optimizer_head is not _A:optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"recurrence:layers={recur_layers} start_step={args.recur_start_step} active={int(base_model._recurrence_active)}");log0(f"canon_ac:layers={canon_ac_layers} params={0 if base_model.canon_a is _A else base_model.canon_a.numel()+base_model.canon_c.numel()} physical_only=1");log0(f"parallel_residual:active={int(base_model.parallel_post_lambdas is not _A)} start_layer={base_model.parallel_start_layer} start_mode={'physical'if base_model.parallel_start_layer_is_physical else 'virtual'} params={0 if base_model.parallel_post_lambdas is _A else base_model.parallel_post_lambdas.numel()+base_model.parallel_resid_lambdas.numel()} final_lane=mlp");log0(f"repeat_untie_mlp:mode={args.repeat_untie_mlp} layers={repeat_untie_mlp_layers if repeat_untie_mlp_layers else recur_layers if args.repeat_untie_mlp!='none' else []} params={sum(p.numel()for p in base_model.repeat_mlp.parameters())}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + def zero_grad_all(): + for opt in optimizers:opt.zero_grad(set_to_none=_B) + max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A + if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step,elapsed_ms): + if args.warmdown_iters<=0:return _D + if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train();run_warmup_steps(args.warmup_steps,'base') + if recur_layers:base_model.set_recurrence_active(_B);log0(f"recurrence:prewarm active={int(base_model._recurrence_active)} virtual_layers:{base_model.virtual_num_layers}");run_warmup_steps(args.warmup_steps,'recur');base_model.set_recurrence_active(_C) + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + zero_grad_all();base_model.set_recurrence_active(_C);train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();timed_wallclock_t0=time.perf_counter();t0=timed_wallclock_t0;step=0 + while _B: + if recur_layers and not base_model._recurrence_active and step>=args.recur_start_step:base_model.set_recurrence_active(_B);log0(f"recurrence:activated step:{step} layers={recur_layers} virtual_layers:{base_model.virtual_num_layers}") + last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum + for group in optimizer_muon.param_groups:group[_a]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[A]*scale + if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) + if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: + s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr + with torch.no_grad(): + for bank in[base_model.qo_bank,base_model.kv_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) + for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + optimizer_tok.step();optimizer_scalar.step() + if optimizer_head is not _A:optimizer_head.step() + optimizer_muon.step();zero_grad_all() + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) + if args.late_qat_threshold>0 and scale=2000: + if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress + if args.swa_enabled and scale<.2 and step%args.swa_every==0: + if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") + else: + for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() + swa_count+=1 + should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) + if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);log_parallel_residual_converged(log0,base_model);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() + if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") + sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(sd_cpu,args.num_layers),args.disable_layer0_attn);gptq_hessians=_A + if args.use_gptq: + t_gptq=time.perf_counter();recur_was_active=base_model._recurrence_active;base_model.set_recurrence_active(recur_was_active);log0(f"gptq:calibration recurrence_active={int(base_model._recurrence_active)} repeat_mlp={len(base_model.repeat_mlp)} parallel_residual={int(base_model.parallel_post_lambdas is not _A)} ar_selfgen={int(args.gptq_ar_selfgen)}") + if args.gptq_ar_selfgen: + log0(f"gptq:generating autoregressive calibration data ({args.gptq_calib_samples} seqs x {args.train_seq_len} tokens, temp={args.gptq_temperature:.2f})...");t_gen=time.perf_counter();ar_tokens=generate_autoregressive_calib(base_model,device,num_seqs=args.gptq_calib_samples,seq_len=args.train_seq_len,vocab_size=args.vocab_size,temperature=args.gptq_temperature,batch_size=args.gptq_batch_size,seed=args.seed);log0(f"gptq:generated {len(ar_tokens)} sequences in {time.perf_counter()-t_gen:.1f}s");log0("gptq:collecting hessians from autoregressive data...");gptq_hessians=gptq_collect_hessians_from_tokens(base_model,ar_tokens,device);del ar_tokens;log0(f"gptq:collected hessians for {len(gptq_hessians)} layers (AR self-gen)") + else: + log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;log0(f"gptq:calibrated {len(gptq_hessians)} layers from training data") + base_model.set_recurrence_active(recur_was_active);gptq_elapsed=time.perf_counter()-t_gptq;total_wallclock_elapsed=time.perf_counter()-timed_wallclock_t0;log0(f"gptq:done in {gptq_elapsed:.1f}s");log0(f"wallclock:post_gptq total_elapsed:{total_wallclock_elapsed:.1f}s train_budget:{args.max_wallclock_seconds:.1f}s");torch.cuda.empty_cache() + clip_ranges=_A + if args.mixed_quant and gptq_hessians is not _A: + quant_names=[n for n in unbanked_sd if _classify_param(n)in{'mlp','attn'}and unbanked_sd[n].ndim>=1 and unbanked_sd[n].numel()>65536];sens={n:gptq_hessians[n].diag().sum().item()if n in gptq_hessians else 0.0 for n in quant_names};ranked=sorted(sens.items(),key=lambda x:-x[1]);clip_ranges={n:15 for n in quant_names};recur_layer_set=set(recur_layers);recur_quant_names=[name for name in quant_names if _get_physical_layer_idx_from_name(name,recur_layers)in recur_layer_set];recur_ranked=sorted(recur_quant_names,key=lambda name:-sens[name]);forced_int6=min(args.n_int6_layers,len(recur_ranked));selected_int6_names=recur_ranked[:forced_int6];selected_int6_set=set(selected_int6_names) + for(name,_)in ranked: + if len(selected_int6_names)>=args.n_int6_layers:break + if name in selected_int6_set:continue + selected_int6_names.append(name);selected_int6_set.add(name) + [clip_ranges.__setitem__(name,31) for name in selected_int6_names];int6_names=[n for n,cr in clip_ranges.items()if cr==31];int5_names=[n for n,cr in clip_ranges.items()if cr==15];log0(f"mixed_quant: {len(int6_names)} int6, {len(int5_names)} int5");log0(f"mixed_quant: forced_recur_int6={forced_int6}/{len(recur_ranked)} recur_layers={recur_layers}");log0(f"mixed_quant: int6 layers: {int6_names[:5]}...") + quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians,clip_ranges=clip_ranges);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) + if master_process: + with open(F,'wb')as f:f.write(quant_blob) + quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") + if distributed:dist.barrier() + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=base_model._recurrence_active,repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_strideX.size(-1) + if transposed:X=X.mT + X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) + for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X + if transposed:X=X.mT + if was_2d:X=X.squeeze(0) + return X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C + def _build(self): + self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] + for group in self.param_groups: + for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) + self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B + def launch_reduce_scatters(self): + if not self._built:self._build() + if not self._distributed:return + self._rs_futures=[] + for m in self._bank_meta: + p=m['p'] + if p.grad is _A:self._rs_futures.append(_A);continue + pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0]>m['B']:pg[m['B']:].zero_() + fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) + @torch.no_grad() + def step(self,closure=_A): + B='_rs_futures';A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + if not self._built:self._build() + for group in self.param_groups: + lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) + for(i,m)in enumerate(self._bank_meta): + p=m['p'] + if p.grad is _A:continue + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] + else: + g=p.grad.bfloat16();state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A] + buf.mul_(momentum).add_(g) + if nesterov:update=g.add(buf,alpha=momentum) + else:update=buf + update=zeropower_via_newtonschulz5(update,steps=backend_steps) + if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m + else: + if wd>_E:p.data.mul_(_D-lr*wd) + p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if hasattr(self,B):del self._rs_futures + return loss +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=_C + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode(_I)) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) + if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def quantize_state_dict_int8(state_dict): + F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) + for(name,tensor)in state_dict.items(): + t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) + if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue + if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue + stats[D]+=1;q,s=quantize_float_tensor(t) + if s.ndim>0:qmeta[name]={_c:_d,'axis':0} + quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) + obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} + if qmeta:obj['qmeta']=qmeta + if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes + return obj,stats +def dequantize_state_dict_int8(obj): + out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) + for(name,q)in obj[_e].items(): + dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] + if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() + else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() + for(name,t)in obj[_L].items(): + out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() + out[name]=out_t + return out +def load_data_shard(file): + header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes + if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") + tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) + if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") + return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) +_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize +_SHARD_NTOKENS_CACHE={} +_MMAP_CACHE={} +def _read_num_tokens(file): + key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) + if cached is not _A:return cached + header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n +def _get_shard_memmap(file): + key=str(file);mm=_MMAP_CACHE.get(key) + if mm is not _A:return mm + n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm +class DistributedTokenLoader: + def __init__(self,pattern,rank,world_size,device): + self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] + if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 + for f in self.files: + for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff + self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 + def _pick_coprime_stride(self,n): + if n<=1:return 1 + while _B: + s=int(self._rng.integers(1,n)) + if math.gcd(s,n)==1:return s + def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +def apply_canon_residual(x,w): + w=w.to(dtype=x.dtype);y=x*w[0][_A,_A,:] + y=y+F.pad(x[:,:-1],(0,0,1,0))*w[1][_A,_A,:] + y=y+F.pad(x[:,:-2],(0,0,2,0))*w[2][_A,_A,:] + y=y+F.pad(x[:,:-3],(0,0,3,0))*w[3][_A,_A,:] + return x+y +class CastedLinear(nn.Linear): + _qat_enabled=_C;_qat_alpha=_D + def forward(self,x): + w=self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q + bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(name,param)in module.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.recur_layers=sorted(set(recur_layers or[]));self.repeat_untie_mlp=repeat_untie_mlp + self.canon_ac_layers=sorted(set(canon_ac_layers or[]));self._canon_ac_layer_set=set(self.canon_ac_layers) + for cl in self.canon_ac_layers: + if not 0<=cl0: + head_dim=model_dim//num_heads + for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) + self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if xsa_last_n>0: + for i in range(max(0,self.virtual_num_layers-xsa_last_n),self.virtual_num_layers):self.blocks[i].attn.use_xsa=_B + self.set_recurrence_active(recurrence_active);self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) + n=self.num_layers;proj_scale=_D/math.sqrt(2*n) + for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) + for repeat_mlp in self.repeat_mlp: + if repeat_mlp.fc is not _A:nn.init.zeros_(repeat_mlp.fc.weight) + if repeat_mlp.proj is not _A:nn.init.zeros_(repeat_mlp.proj.weight) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def set_recurrence_active(self,active): + was_active=getattr(self,'_recurrence_active',_C);self._recurrence_active=bool(active)and bool(self.recur_layers) + if self._recurrence_active:self.v2p=self._v2p_recur;self.num_encoder_layers=self._enc_recur;self.num_decoder_layers=self._dec_recur + else:self.v2p=self._v2p_no_recur;self.num_encoder_layers=self._enc_no_recur;self.num_decoder_layers=self._dec_no_recur + if self._recurrence_active and not was_active and self.repeat_mlp:self._sync_repeat_mlp_from_base() + def _sync_repeat_mlp_from_base(self): + with torch.no_grad(): + for(repeat_idx,physical_idx)in enumerate(self.recur_layers): + repeat_mlp=self.repeat_mlp[repeat_idx] + if repeat_mlp.fc is not _A:repeat_mlp.fc.weight.copy_(self.mlp_up_bank[physical_idx]) + if repeat_mlp.proj is not _A:repeat_mlp.proj.weight.copy_(self.mlp_down_bank[physical_idx]) + def _is_repeated_virtual_index(self,virtual_idx):return self._recurrence_active and bool(self.recur_layers) and self._enc_recur<=virtual_idx=self.parallel_start_layer + return virtual_idx>=self.parallel_start_layer + def _mix_with_x0(self,lane,x0,resid_mix): + mix=resid_mix.to(dtype=lane.dtype);return mix[0][_A,_A,:]*lane+mix[1][_A,_A,:]*x0 + def _apply_skip_single(self,x,skip,i): + if isinstance(skip,tuple):skip=skip[1] + g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skip;return torch.lerp(scaled_skip,x,g) + def _apply_skip_parallel(self,lane0,lane1,skip,i): + if isinstance(skip,tuple):skip0,skip1=skip + else:skip0=skip1=skip + g=torch.sigmoid(self.skip_gates[i].to(dtype=lane0.dtype))[_A,_A,:];w=self.skip_weights[i].to(dtype=lane0.dtype)[_A,_A,:] + return torch.lerp(w*skip0,lane0,g),torch.lerp(w*skip1,lane1,g) + def _final_parallel_hidden(self,lane0,lane1): + # The branch starts as a clone, so average the summed lanes to keep the output scale close to the single-lane path. + return (lane0+lane1)*.5 + def _parallel_block(self,virtual_idx,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=_A,canon_a_w=_A,canon_c_w=_A): + block=self.blocks[virtual_idx];physical_idx=self.v2p[virtual_idx] + if not block.disable_attn: + attn_read=self._mix_with_x0(lane0,x0,block.resid_mix);attn_in=block.attn_norm(attn_read)*block.ln_scale_factor + if canon_a_w is not _A:attn_in=apply_canon_residual(attn_in,canon_a_w) + attn_out=block.attn(attn_in,q_w,k_w,v_w,out_w,v_embed=v_embed);attn_out=block.attn_scale.to(dtype=attn_out.dtype)[_A,_A,:]*attn_out;resid=self.parallel_resid_lambdas[physical_idx,0].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,0].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*attn_out;lane1=resid*lane1+post[1]*attn_out + mlp_read=self._mix_with_x0(lane1,x0,block.resid_mix);mlp_in=block.mlp_norm(mlp_read)*block.ln_scale_factor + if canon_c_w is not _A:mlp_in=apply_canon_residual(mlp_in,canon_c_w) + mlp_out=block.mlp_scale.to(dtype=lane1.dtype)[_A,_A,:]*block.mlp(mlp_in,up_w,down_w);resid=self.parallel_resid_lambdas[physical_idx,1].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,1].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*mlp_out;lane1=resid*lane1+post[1]*mlp_out;return lane0,lane1 + def _get_block_weights(self,virtual_idx): + n=self.num_layers;physical_idx=self.v2p[virtual_idx];q_w=self.qo_bank[physical_idx];k_w=self.kv_bank[physical_idx];v_w=self.kv_bank[n+physical_idx];out_w=self.qo_bank[n+physical_idx];up_w=self.mlp_up_bank[physical_idx];down_w=self.mlp_down_bank[physical_idx];canon_a_w=self.canon_a[physical_idx]if self.canon_a is not _A and physical_idx in self._canon_ac_layer_set else _A;canon_c_w=self.canon_c[physical_idx]if self.canon_c is not _A and physical_idx in self._canon_ac_layer_set else _A + if self._is_repeated_virtual_index(virtual_idx): + repeated_idx=virtual_idx-self._enc_recur + if self.repeat_mlp: + repeat_mlp=self.repeat_mlp[repeated_idx] + if repeat_mlp.fc is not _A:up_w=repeat_mlp.fc.weight + if repeat_mlp.proj is not _A:down_w=repeat_mlp.proj.weight + return q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w + def _backbone(self,input_ids): + x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={};lane0=lane1=_A + for i in range(self.num_encoder_layers): + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(i);ve=self._get_ve(i,input_ids,ve_cache) + if self._parallel_active_for_layer(i): + if lane0 is _A:lane0=lane1=x + lane0,lane1=self._parallel_block(i,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append((lane0,lane1)) + else:x=self.blocks[i](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(bi);ve=self._get_ve(bi,input_ids,ve_cache) + if self._parallel_active_for_layer(bi): + if lane0 is _A:lane0=lane1=x + if skips:lane0,lane1=self._apply_skip_parallel(lane0,lane1,skips.pop(),i) + lane0,lane1=self._parallel_block(bi,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + else: + if skips:x=self._apply_skip_single(x,skips.pop(),i) + x=self.blocks[bi](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + return self.final_norm(self._final_parallel_hidden(lane0,lane1) if lane1 is not _A else x) + def forward(self,input_ids,target_ids): + x=self._backbone(input_ids);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) + if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) + else: + if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + logits_proj=self.lm_head(x_flat) + logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') + def forward_hidden(self,input_ids):return self._backbone(input_ids) + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] + if use_slot: + with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) + H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) + for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:-1].reshape(-1,adapted.size(-1)),y_batch[:,:seq_len-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() + with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) + else: + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) + with torch.no_grad(): + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] + for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] + for(name,p)in base_model.named_parameters(): + freeze=_C + for bi in frozen_block_ids: + if f"blocks.{bi}."in name:freeze=_B;break + if freeze:p.requires_grad_(_C) + else:p.requires_grad_(_B);ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() + for ci in range(num_chunks): + windows=chunk_windows[ci] + if not windows:continue + chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() + with torch.inference_mode(): + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] + with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + is_last_chunk=ci==num_chunks-1 + if not is_last_chunk and args.ttt_epochs>0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg[_H]=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): + be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_tokens.numel():continue + local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() + if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) + for p in base_model.parameters():p.requires_grad_(_B) + base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb +def generate_autoregressive_calib(model,device,num_seqs=64,seq_len=2048,vocab_size=1024,temperature=.8,batch_size=8,seed=42): + was_training=model.training;model.eval();rng=torch.Generator(device=device);rng.manual_seed(seed);all_tokens=[] + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for batch_start in range(0,num_seqs,batch_size): + bs=min(batch_size,num_seqs-batch_start);tokens=torch.randint(0,vocab_size,(bs,1),device=device,generator=rng) + for _ in range(seq_len-1): + logits=model.forward_logits(tokens);next_logit=logits[:,-1,:];probs=torch.softmax(next_logit/max(temperature,1e-4),dim=-1);next_tok=torch.multinomial(probs,1,generator=rng);tokens=torch.cat([tokens,next_tok],dim=1) + for i in range(bs):all_tokens.append(tokens[i:i+1].detach().clone()) + model.train(was_training);return all_tokens +def gptq_collect_hessians_from_tokens(base_model,token_seqs,device): + dim=base_model.tok_emb.weight.shape[1];mlp_dim=base_model.mlp_up_bank.shape[1];hessians=_init_hessians(base_model,dim,mlp_dim,device) + for block in base_model.blocks:block.attn._save_gptq=_B;block.mlp._save_gptq=_B + was_training=base_model.training;base_model.eval() + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for seq in token_seqs:x=seq[:,:-1].to(device=device,dtype=torch.int64);y=seq[:,1:].to(device=device,dtype=torch.int64);base_model(x,y);_accum_hessians(hessians,base_model,dim,mlp_dim) + for block in base_model.blocks:block.attn._save_gptq=_C;block.mlp._save_gptq=_C + _finalize_hessians(hessians,max(len(token_seqs),1));base_model.train(was_training);return hessians +def _classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if name.startswith('canon_a'):return'attn' + if name.startswith('canon_c'):return'mlp' + if A in name or name.startswith('repeat_mlp.'):return'mlp' + if'.attn.'in name or'.proj.'in name and A not in name:return'attn' + return'other' +def _parse_layer_list(layers_str): + return[int(x)for x in layers_str.split(',')if x.strip()] +def _get_block_idx_from_name(name): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + return _A +def _get_physical_layer_idx_from_name(name,recur_layers): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + if len(parts)>2 and parts[0]=='repeat_mlp'and parts[1].isdigit(): + repeat_idx=int(parts[1]) + if 0<=repeat_idx0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def _unbank_state_dict(sd,num_layers): + out={};n=num_layers + for(name,tensor)in sd.items(): + if name==_R: + for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] + elif name==_S: + for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] + elif name==_T: + for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] + elif name==_U: + for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] + else:out[name]=tensor + return out +def _rebank_state_dict(sd,num_layers,template_sd): + out={};n=num_layers;qo_slices=[template_sd[_R][i]for i in range(2*n)];kv_slices=[template_sd[_S][i]for i in range(2*n)];up_slices=[template_sd[_T][i]for i in range(n)];down_slices=[template_sd[_U][i]for i in range(n)];consumed=set() + for i in range(n): + qk=f"blocks.{i}.attn.c_q.weight" + if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) + ok=f"blocks.{i}.attn.proj.weight" + if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) + kk=f"blocks.{i}.attn.c_k.weight" + if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) + vk=f"blocks.{i}.attn.c_v.weight" + if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) + fk=f"blocks.{i}.mlp.fc.weight" + if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) + dk=f"blocks.{i}.mlp.proj.weight" + if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) + out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) + for(name,tensor)in sd.items(): + if name not in consumed:out[name]=tensor + return out +def _drop_disabled_layer0_attn_unbanked(sd,disable_layer0_attn): + if not disable_layer0_attn:return sd + disabled_keys={'blocks.0.attn.c_q.weight','blocks.0.attn.c_k.weight','blocks.0.attn.c_v.weight','blocks.0.attn.proj.weight'} + return{k:v for(k,v)in sd.items()if k not in disabled_keys} +def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A,clip_ranges=_A): + A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=_classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue + if cat in int6_cats and t.ndim>=1: + H=hessians.get(name)if hessians else _A;cr=clip_ranges.get(name,clip_range)if isinstance(clip_ranges,dict)else clip_range + if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=cr);gptq_count+=1 + else:q,s=quantize_int6_per_row(t,clip_range=cr);naive_count+=1 + result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'if cr>=31 else 'int5'} + else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} + if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_L,_i,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+_V] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): + W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");recur_layers=_parse_layer_list(args.recur_layers_str);repeat_untie_mlp_layers=_parse_layer_list(args.repeat_untie_mlp_layers);canon_ac_layers=_parse_layer_list(args.canon_ac_layers) + if args.post_gptq_eval_only: + eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=bool(recur_layers),repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model) + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);template_sd={k:v.detach().cpu()for(k,v)in eval_model.state_dict().items()};template_unbanked=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(template_sd,args.num_layers),args.disable_layer0_attn);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],template_unbanked);eval_model.load_state_dict(_rebank_state_dict(deq_unbanked,args.num_layers,template_sd),strict=_B);q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) + token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] + if base_model.bigram is not _A: + tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) + if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not _A: + tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) + if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales:scalar_params.append(s) + optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups:group[A]=args.matrix_lr + optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) + for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) + replicated_params.extend(scalar_params);optimizer_head=_A + if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) + optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] + if optimizer_head is not _A:optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"recurrence:layers={recur_layers} start_step={args.recur_start_step} active={int(base_model._recurrence_active)}");log0(f"canon_ac:layers={canon_ac_layers} params={0 if base_model.canon_a is _A else base_model.canon_a.numel()+base_model.canon_c.numel()} physical_only=1");log0(f"parallel_residual:active={int(base_model.parallel_post_lambdas is not _A)} start_layer={base_model.parallel_start_layer} start_mode={'physical'if base_model.parallel_start_layer_is_physical else 'virtual'} params={0 if base_model.parallel_post_lambdas is _A else base_model.parallel_post_lambdas.numel()+base_model.parallel_resid_lambdas.numel()} final_lane=mlp");log0(f"repeat_untie_mlp:mode={args.repeat_untie_mlp} layers={repeat_untie_mlp_layers if repeat_untie_mlp_layers else recur_layers if args.repeat_untie_mlp!='none' else []} params={sum(p.numel()for p in base_model.repeat_mlp.parameters())}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + def zero_grad_all(): + for opt in optimizers:opt.zero_grad(set_to_none=_B) + max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A + if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step,elapsed_ms): + if args.warmdown_iters<=0:return _D + if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train();run_warmup_steps(args.warmup_steps,'base') + if recur_layers:base_model.set_recurrence_active(_B);log0(f"recurrence:prewarm active={int(base_model._recurrence_active)} virtual_layers:{base_model.virtual_num_layers}");run_warmup_steps(args.warmup_steps,'recur');base_model.set_recurrence_active(_C) + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + zero_grad_all();base_model.set_recurrence_active(_C);train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();timed_wallclock_t0=time.perf_counter();t0=timed_wallclock_t0;step=0 + while _B: + if recur_layers and not base_model._recurrence_active and step>=args.recur_start_step:base_model.set_recurrence_active(_B);log0(f"recurrence:activated step:{step} layers={recur_layers} virtual_layers:{base_model.virtual_num_layers}") + last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum + for group in optimizer_muon.param_groups:group[_a]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[A]*scale + if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) + if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: + s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr + with torch.no_grad(): + for bank in[base_model.qo_bank,base_model.kv_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) + for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + optimizer_tok.step();optimizer_scalar.step() + if optimizer_head is not _A:optimizer_head.step() + optimizer_muon.step();zero_grad_all() + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) + if args.late_qat_threshold>0 and scale=2000: + if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress + if args.swa_enabled and scale<.2 and step%args.swa_every==0: + if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") + else: + for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() + swa_count+=1 + should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) + if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);log_parallel_residual_converged(log0,base_model);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() + if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") + sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(sd_cpu,args.num_layers),args.disable_layer0_attn);gptq_hessians=_A + if args.use_gptq: + t_gptq=time.perf_counter();recur_was_active=base_model._recurrence_active;base_model.set_recurrence_active(recur_was_active);log0(f"gptq:calibration recurrence_active={int(base_model._recurrence_active)} repeat_mlp={len(base_model.repeat_mlp)} parallel_residual={int(base_model.parallel_post_lambdas is not _A)} ar_selfgen={int(args.gptq_ar_selfgen)}") + if args.gptq_ar_selfgen: + log0(f"gptq:generating autoregressive calibration data ({args.gptq_calib_samples} seqs x {args.train_seq_len} tokens, temp={args.gptq_temperature:.2f})...");t_gen=time.perf_counter();ar_tokens=generate_autoregressive_calib(base_model,device,num_seqs=args.gptq_calib_samples,seq_len=args.train_seq_len,vocab_size=args.vocab_size,temperature=args.gptq_temperature,batch_size=args.gptq_batch_size,seed=args.seed);log0(f"gptq:generated {len(ar_tokens)} sequences in {time.perf_counter()-t_gen:.1f}s");log0("gptq:collecting hessians from autoregressive data...");gptq_hessians=gptq_collect_hessians_from_tokens(base_model,ar_tokens,device);del ar_tokens;log0(f"gptq:collected hessians for {len(gptq_hessians)} layers (AR self-gen)") + else: + log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;log0(f"gptq:calibrated {len(gptq_hessians)} layers from training data") + base_model.set_recurrence_active(recur_was_active);gptq_elapsed=time.perf_counter()-t_gptq;total_wallclock_elapsed=time.perf_counter()-timed_wallclock_t0;log0(f"gptq:done in {gptq_elapsed:.1f}s");log0(f"wallclock:post_gptq total_elapsed:{total_wallclock_elapsed:.1f}s train_budget:{args.max_wallclock_seconds:.1f}s");torch.cuda.empty_cache() + clip_ranges=_A + if args.mixed_quant and gptq_hessians is not _A: + quant_names=[n for n in unbanked_sd if _classify_param(n)in{'mlp','attn'}and unbanked_sd[n].ndim>=1 and unbanked_sd[n].numel()>65536];sens={n:gptq_hessians[n].diag().sum().item()if n in gptq_hessians else 0.0 for n in quant_names};ranked=sorted(sens.items(),key=lambda x:-x[1]);clip_ranges={n:15 for n in quant_names};recur_layer_set=set(recur_layers);recur_quant_names=[name for name in quant_names if _get_physical_layer_idx_from_name(name,recur_layers)in recur_layer_set];recur_ranked=sorted(recur_quant_names,key=lambda name:-sens[name]);forced_int6=min(args.n_int6_layers,len(recur_ranked));selected_int6_names=recur_ranked[:forced_int6];selected_int6_set=set(selected_int6_names) + for(name,_)in ranked: + if len(selected_int6_names)>=args.n_int6_layers:break + if name in selected_int6_set:continue + selected_int6_names.append(name);selected_int6_set.add(name) + [clip_ranges.__setitem__(name,31) for name in selected_int6_names];int6_names=[n for n,cr in clip_ranges.items()if cr==31];int5_names=[n for n,cr in clip_ranges.items()if cr==15];log0(f"mixed_quant: {len(int6_names)} int6, {len(int5_names)} int5");log0(f"mixed_quant: forced_recur_int6={forced_int6}/{len(recur_ranked)} recur_layers={recur_layers}");log0(f"mixed_quant: int6 layers: {int6_names[:5]}...") + quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians,clip_ranges=clip_ranges);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) + if master_process: + with open(F,'wb')as f:f.write(quant_blob) + quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") + if distributed:dist.barrier() + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=base_model._recurrence_active,repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_strideX.size(-1) + if transposed:X=X.mT + X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) + for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X + if transposed:X=X.mT + if was_2d:X=X.squeeze(0) + return X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C + def _build(self): + self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] + for group in self.param_groups: + for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) + self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B + def launch_reduce_scatters(self): + if not self._built:self._build() + if not self._distributed:return + self._rs_futures=[] + for m in self._bank_meta: + p=m['p'] + if p.grad is _A:self._rs_futures.append(_A);continue + pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0]>m['B']:pg[m['B']:].zero_() + fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) + @torch.no_grad() + def step(self,closure=_A): + B='_rs_futures';A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + if not self._built:self._build() + for group in self.param_groups: + lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) + for(i,m)in enumerate(self._bank_meta): + p=m['p'] + if p.grad is _A:continue + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] + else: + g=p.grad.bfloat16();state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A] + buf.mul_(momentum).add_(g) + if nesterov:update=g.add(buf,alpha=momentum) + else:update=buf + update=zeropower_via_newtonschulz5(update,steps=backend_steps) + if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m + else: + if wd>_E:p.data.mul_(_D-lr*wd) + p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if hasattr(self,B):del self._rs_futures + return loss +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=_C + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode(_I)) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) + if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def quantize_state_dict_int8(state_dict): + F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) + for(name,tensor)in state_dict.items(): + t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) + if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue + if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue + stats[D]+=1;q,s=quantize_float_tensor(t) + if s.ndim>0:qmeta[name]={_c:_d,'axis':0} + quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) + obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} + if qmeta:obj['qmeta']=qmeta + if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes + return obj,stats +def dequantize_state_dict_int8(obj): + out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) + for(name,q)in obj[_e].items(): + dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] + if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() + else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() + for(name,t)in obj[_L].items(): + out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() + out[name]=out_t + return out +def load_data_shard(file): + header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes + if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") + tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) + if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") + return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) +_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize +_SHARD_NTOKENS_CACHE={} +_MMAP_CACHE={} +def _read_num_tokens(file): + key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) + if cached is not _A:return cached + header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n +def _get_shard_memmap(file): + key=str(file);mm=_MMAP_CACHE.get(key) + if mm is not _A:return mm + n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm +class DistributedTokenLoader: + def __init__(self,pattern,rank,world_size,device): + self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] + if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 + for f in self.files: + for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff + self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 + def _pick_coprime_stride(self,n): + if n<=1:return 1 + while _B: + s=int(self._rng.integers(1,n)) + if math.gcd(s,n)==1:return s + def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +def apply_canon_residual(x,w): + w=w.to(dtype=x.dtype);y=x*w[0][_A,_A,:] + y=y+F.pad(x[:,:-1],(0,0,1,0))*w[1][_A,_A,:] + y=y+F.pad(x[:,:-2],(0,0,2,0))*w[2][_A,_A,:] + y=y+F.pad(x[:,:-3],(0,0,3,0))*w[3][_A,_A,:] + return x+y +class CastedLinear(nn.Linear): + _qat_enabled=_C;_qat_alpha=_D + def forward(self,x): + w=self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q + bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(name,param)in module.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.recur_layers=sorted(set(recur_layers or[]));self.repeat_untie_mlp=repeat_untie_mlp + self.canon_ac_layers=sorted(set(canon_ac_layers or[]));self._canon_ac_layer_set=set(self.canon_ac_layers) + for cl in self.canon_ac_layers: + if not 0<=cl0: + head_dim=model_dim//num_heads + for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) + self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if xsa_last_n>0: + for i in range(max(0,self.virtual_num_layers-xsa_last_n),self.virtual_num_layers):self.blocks[i].attn.use_xsa=_B + self.set_recurrence_active(recurrence_active);self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) + n=self.num_layers;proj_scale=_D/math.sqrt(2*n) + for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) + for repeat_mlp in self.repeat_mlp: + if repeat_mlp.fc is not _A:nn.init.zeros_(repeat_mlp.fc.weight) + if repeat_mlp.proj is not _A:nn.init.zeros_(repeat_mlp.proj.weight) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def set_recurrence_active(self,active): + was_active=getattr(self,'_recurrence_active',_C);self._recurrence_active=bool(active)and bool(self.recur_layers) + if self._recurrence_active:self.v2p=self._v2p_recur;self.num_encoder_layers=self._enc_recur;self.num_decoder_layers=self._dec_recur + else:self.v2p=self._v2p_no_recur;self.num_encoder_layers=self._enc_no_recur;self.num_decoder_layers=self._dec_no_recur + if self._recurrence_active and not was_active and self.repeat_mlp:self._sync_repeat_mlp_from_base() + def _sync_repeat_mlp_from_base(self): + with torch.no_grad(): + for(repeat_idx,physical_idx)in enumerate(self.recur_layers): + repeat_mlp=self.repeat_mlp[repeat_idx] + if repeat_mlp.fc is not _A:repeat_mlp.fc.weight.copy_(self.mlp_up_bank[physical_idx]) + if repeat_mlp.proj is not _A:repeat_mlp.proj.weight.copy_(self.mlp_down_bank[physical_idx]) + def _is_repeated_virtual_index(self,virtual_idx):return self._recurrence_active and bool(self.recur_layers) and self._enc_recur<=virtual_idx=self.parallel_start_layer + return virtual_idx>=self.parallel_start_layer + def _mix_with_x0(self,lane,x0,resid_mix): + mix=resid_mix.to(dtype=lane.dtype);return mix[0][_A,_A,:]*lane+mix[1][_A,_A,:]*x0 + def _apply_skip_single(self,x,skip,i): + if isinstance(skip,tuple):skip=skip[1] + g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skip;return torch.lerp(scaled_skip,x,g) + def _apply_skip_parallel(self,lane0,lane1,skip,i): + if isinstance(skip,tuple):skip0,skip1=skip + else:skip0=skip1=skip + g=torch.sigmoid(self.skip_gates[i].to(dtype=lane0.dtype))[_A,_A,:];w=self.skip_weights[i].to(dtype=lane0.dtype)[_A,_A,:] + return torch.lerp(w*skip0,lane0,g),torch.lerp(w*skip1,lane1,g) + def _final_parallel_hidden(self,lane0,lane1): + # The branch starts as a clone, so average the summed lanes to keep the output scale close to the single-lane path. + return (lane0+lane1)*.5 + def _parallel_block(self,virtual_idx,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=_A,canon_a_w=_A,canon_c_w=_A): + block=self.blocks[virtual_idx];physical_idx=self.v2p[virtual_idx] + if not block.disable_attn: + attn_read=self._mix_with_x0(lane0,x0,block.resid_mix);attn_in=block.attn_norm(attn_read)*block.ln_scale_factor + if canon_a_w is not _A:attn_in=apply_canon_residual(attn_in,canon_a_w) + attn_out=block.attn(attn_in,q_w,k_w,v_w,out_w,v_embed=v_embed);attn_out=block.attn_scale.to(dtype=attn_out.dtype)[_A,_A,:]*attn_out;resid=self.parallel_resid_lambdas[physical_idx,0].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,0].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*attn_out;lane1=resid*lane1+post[1]*attn_out + mlp_read=self._mix_with_x0(lane1,x0,block.resid_mix);mlp_in=block.mlp_norm(mlp_read)*block.ln_scale_factor + if canon_c_w is not _A:mlp_in=apply_canon_residual(mlp_in,canon_c_w) + mlp_out=block.mlp_scale.to(dtype=lane1.dtype)[_A,_A,:]*block.mlp(mlp_in,up_w,down_w);resid=self.parallel_resid_lambdas[physical_idx,1].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,1].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*mlp_out;lane1=resid*lane1+post[1]*mlp_out;return lane0,lane1 + def _get_block_weights(self,virtual_idx): + n=self.num_layers;physical_idx=self.v2p[virtual_idx];q_w=self.qo_bank[physical_idx];k_w=self.kv_bank[physical_idx];v_w=self.kv_bank[n+physical_idx];out_w=self.qo_bank[n+physical_idx];up_w=self.mlp_up_bank[physical_idx];down_w=self.mlp_down_bank[physical_idx];canon_a_w=self.canon_a[physical_idx]if self.canon_a is not _A and physical_idx in self._canon_ac_layer_set else _A;canon_c_w=self.canon_c[physical_idx]if self.canon_c is not _A and physical_idx in self._canon_ac_layer_set else _A + if self._is_repeated_virtual_index(virtual_idx): + repeated_idx=virtual_idx-self._enc_recur + if self.repeat_mlp: + repeat_mlp=self.repeat_mlp[repeated_idx] + if repeat_mlp.fc is not _A:up_w=repeat_mlp.fc.weight + if repeat_mlp.proj is not _A:down_w=repeat_mlp.proj.weight + return q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w + def _backbone(self,input_ids): + x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={};lane0=lane1=_A + for i in range(self.num_encoder_layers): + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(i);ve=self._get_ve(i,input_ids,ve_cache) + if self._parallel_active_for_layer(i): + if lane0 is _A:lane0=lane1=x + lane0,lane1=self._parallel_block(i,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append((lane0,lane1)) + else:x=self.blocks[i](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(bi);ve=self._get_ve(bi,input_ids,ve_cache) + if self._parallel_active_for_layer(bi): + if lane0 is _A:lane0=lane1=x + if skips:lane0,lane1=self._apply_skip_parallel(lane0,lane1,skips.pop(),i) + lane0,lane1=self._parallel_block(bi,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + else: + if skips:x=self._apply_skip_single(x,skips.pop(),i) + x=self.blocks[bi](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + return self.final_norm(self._final_parallel_hidden(lane0,lane1) if lane1 is not _A else x) + def forward(self,input_ids,target_ids): + x=self._backbone(input_ids);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) + if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) + else: + if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + logits_proj=self.lm_head(x_flat) + logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') + def forward_hidden(self,input_ids):return self._backbone(input_ids) + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] + if use_slot: + with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) + H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) + for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:-1].reshape(-1,adapted.size(-1)),y_batch[:,:seq_len-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() + with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) + else: + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) + with torch.no_grad(): + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] + for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] + for(name,p)in base_model.named_parameters(): + freeze=_C + for bi in frozen_block_ids: + if f"blocks.{bi}."in name:freeze=_B;break + if freeze:p.requires_grad_(_C) + else:p.requires_grad_(_B);ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() + for ci in range(num_chunks): + windows=chunk_windows[ci] + if not windows:continue + chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() + with torch.inference_mode(): + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] + with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + is_last_chunk=ci==num_chunks-1 + if not is_last_chunk and args.ttt_epochs>0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg[_H]=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): + be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_tokens.numel():continue + local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() + if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) + for p in base_model.parameters():p.requires_grad_(_B) + base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb +def generate_autoregressive_calib(model,device,num_seqs=64,seq_len=2048,vocab_size=1024,temperature=.8,batch_size=8,seed=42): + was_training=model.training;model.eval();rng=torch.Generator(device=device);rng.manual_seed(seed);all_tokens=[] + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for batch_start in range(0,num_seqs,batch_size): + bs=min(batch_size,num_seqs-batch_start);tokens=torch.randint(0,vocab_size,(bs,1),device=device,generator=rng) + for _ in range(seq_len-1): + logits=model.forward_logits(tokens);next_logit=logits[:,-1,:];probs=torch.softmax(next_logit/max(temperature,1e-4),dim=-1);next_tok=torch.multinomial(probs,1,generator=rng);tokens=torch.cat([tokens,next_tok],dim=1) + for i in range(bs):all_tokens.append(tokens[i:i+1].detach().clone()) + model.train(was_training);return all_tokens +def gptq_collect_hessians_from_tokens(base_model,token_seqs,device): + dim=base_model.tok_emb.weight.shape[1];mlp_dim=base_model.mlp_up_bank.shape[1];hessians=_init_hessians(base_model,dim,mlp_dim,device) + for block in base_model.blocks:block.attn._save_gptq=_B;block.mlp._save_gptq=_B + was_training=base_model.training;base_model.eval() + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for seq in token_seqs:x=seq[:,:-1].to(device=device,dtype=torch.int64);y=seq[:,1:].to(device=device,dtype=torch.int64);base_model(x,y);_accum_hessians(hessians,base_model,dim,mlp_dim) + for block in base_model.blocks:block.attn._save_gptq=_C;block.mlp._save_gptq=_C + _finalize_hessians(hessians,max(len(token_seqs),1));base_model.train(was_training);return hessians +def _classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if name.startswith('canon_a'):return'attn' + if name.startswith('canon_c'):return'mlp' + if A in name or name.startswith('repeat_mlp.'):return'mlp' + if'.attn.'in name or'.proj.'in name and A not in name:return'attn' + return'other' +def _parse_layer_list(layers_str): + return[int(x)for x in layers_str.split(',')if x.strip()] +def _get_block_idx_from_name(name): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + return _A +def _get_physical_layer_idx_from_name(name,recur_layers): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + if len(parts)>2 and parts[0]=='repeat_mlp'and parts[1].isdigit(): + repeat_idx=int(parts[1]) + if 0<=repeat_idx0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def _unbank_state_dict(sd,num_layers): + out={};n=num_layers + for(name,tensor)in sd.items(): + if name==_R: + for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] + elif name==_S: + for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] + elif name==_T: + for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] + elif name==_U: + for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] + else:out[name]=tensor + return out +def _rebank_state_dict(sd,num_layers,template_sd): + out={};n=num_layers;qo_slices=[template_sd[_R][i]for i in range(2*n)];kv_slices=[template_sd[_S][i]for i in range(2*n)];up_slices=[template_sd[_T][i]for i in range(n)];down_slices=[template_sd[_U][i]for i in range(n)];consumed=set() + for i in range(n): + qk=f"blocks.{i}.attn.c_q.weight" + if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) + ok=f"blocks.{i}.attn.proj.weight" + if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) + kk=f"blocks.{i}.attn.c_k.weight" + if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) + vk=f"blocks.{i}.attn.c_v.weight" + if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) + fk=f"blocks.{i}.mlp.fc.weight" + if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) + dk=f"blocks.{i}.mlp.proj.weight" + if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) + out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) + for(name,tensor)in sd.items(): + if name not in consumed:out[name]=tensor + return out +def _drop_disabled_layer0_attn_unbanked(sd,disable_layer0_attn): + if not disable_layer0_attn:return sd + disabled_keys={'blocks.0.attn.c_q.weight','blocks.0.attn.c_k.weight','blocks.0.attn.c_v.weight','blocks.0.attn.proj.weight'} + return{k:v for(k,v)in sd.items()if k not in disabled_keys} +def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A,clip_ranges=_A): + A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=_classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue + if cat in int6_cats and t.ndim>=1: + H=hessians.get(name)if hessians else _A;cr=clip_ranges.get(name,clip_range)if isinstance(clip_ranges,dict)else clip_range + if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=cr);gptq_count+=1 + else:q,s=quantize_int6_per_row(t,clip_range=cr);naive_count+=1 + result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'if cr>=31 else 'int5'} + else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} + if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_L,_i,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+_V] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): + W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");recur_layers=_parse_layer_list(args.recur_layers_str);repeat_untie_mlp_layers=_parse_layer_list(args.repeat_untie_mlp_layers);canon_ac_layers=_parse_layer_list(args.canon_ac_layers) + if args.post_gptq_eval_only: + eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=bool(recur_layers),repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model) + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);template_sd={k:v.detach().cpu()for(k,v)in eval_model.state_dict().items()};template_unbanked=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(template_sd,args.num_layers),args.disable_layer0_attn);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],template_unbanked);eval_model.load_state_dict(_rebank_state_dict(deq_unbanked,args.num_layers,template_sd),strict=_B);q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) + token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] + if base_model.bigram is not _A: + tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) + if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not _A: + tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) + if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales:scalar_params.append(s) + optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups:group[A]=args.matrix_lr + optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) + for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) + replicated_params.extend(scalar_params);optimizer_head=_A + if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) + optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] + if optimizer_head is not _A:optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"recurrence:layers={recur_layers} start_step={args.recur_start_step} active={int(base_model._recurrence_active)}");log0(f"canon_ac:layers={canon_ac_layers} params={0 if base_model.canon_a is _A else base_model.canon_a.numel()+base_model.canon_c.numel()} physical_only=1");log0(f"parallel_residual:active={int(base_model.parallel_post_lambdas is not _A)} start_layer={base_model.parallel_start_layer} start_mode={'physical'if base_model.parallel_start_layer_is_physical else 'virtual'} params={0 if base_model.parallel_post_lambdas is _A else base_model.parallel_post_lambdas.numel()+base_model.parallel_resid_lambdas.numel()} final_lane=mlp");log0(f"repeat_untie_mlp:mode={args.repeat_untie_mlp} layers={repeat_untie_mlp_layers if repeat_untie_mlp_layers else recur_layers if args.repeat_untie_mlp!='none' else []} params={sum(p.numel()for p in base_model.repeat_mlp.parameters())}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + def zero_grad_all(): + for opt in optimizers:opt.zero_grad(set_to_none=_B) + max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A + if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step,elapsed_ms): + if args.warmdown_iters<=0:return _D + if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train();run_warmup_steps(args.warmup_steps,'base') + if recur_layers:base_model.set_recurrence_active(_B);log0(f"recurrence:prewarm active={int(base_model._recurrence_active)} virtual_layers:{base_model.virtual_num_layers}");run_warmup_steps(args.warmup_steps,'recur');base_model.set_recurrence_active(_C) + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + zero_grad_all();base_model.set_recurrence_active(_C);train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();timed_wallclock_t0=time.perf_counter();t0=timed_wallclock_t0;step=0 + while _B: + if recur_layers and not base_model._recurrence_active and step>=args.recur_start_step:base_model.set_recurrence_active(_B);log0(f"recurrence:activated step:{step} layers={recur_layers} virtual_layers:{base_model.virtual_num_layers}") + last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum + for group in optimizer_muon.param_groups:group[_a]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[A]*scale + if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) + if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: + s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr + with torch.no_grad(): + for bank in[base_model.qo_bank,base_model.kv_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) + for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + optimizer_tok.step();optimizer_scalar.step() + if optimizer_head is not _A:optimizer_head.step() + optimizer_muon.step();zero_grad_all() + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) + if args.late_qat_threshold>0 and scale=2000: + if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress + if args.swa_enabled and scale<.2 and step%args.swa_every==0: + if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") + else: + for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() + swa_count+=1 + should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) + if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);log_parallel_residual_converged(log0,base_model);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() + if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") + sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(sd_cpu,args.num_layers),args.disable_layer0_attn);gptq_hessians=_A + if args.use_gptq: + t_gptq=time.perf_counter();recur_was_active=base_model._recurrence_active;base_model.set_recurrence_active(recur_was_active);log0(f"gptq:calibration recurrence_active={int(base_model._recurrence_active)} repeat_mlp={len(base_model.repeat_mlp)} parallel_residual={int(base_model.parallel_post_lambdas is not _A)} ar_selfgen={int(args.gptq_ar_selfgen)}") + if args.gptq_ar_selfgen: + log0(f"gptq:generating autoregressive calibration data ({args.gptq_calib_samples} seqs x {args.train_seq_len} tokens, temp={args.gptq_temperature:.2f})...");t_gen=time.perf_counter();ar_tokens=generate_autoregressive_calib(base_model,device,num_seqs=args.gptq_calib_samples,seq_len=args.train_seq_len,vocab_size=args.vocab_size,temperature=args.gptq_temperature,batch_size=args.gptq_batch_size,seed=args.seed);log0(f"gptq:generated {len(ar_tokens)} sequences in {time.perf_counter()-t_gen:.1f}s");log0("gptq:collecting hessians from autoregressive data...");gptq_hessians=gptq_collect_hessians_from_tokens(base_model,ar_tokens,device);del ar_tokens;log0(f"gptq:collected hessians for {len(gptq_hessians)} layers (AR self-gen)") + else: + log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;log0(f"gptq:calibrated {len(gptq_hessians)} layers from training data") + base_model.set_recurrence_active(recur_was_active);gptq_elapsed=time.perf_counter()-t_gptq;total_wallclock_elapsed=time.perf_counter()-timed_wallclock_t0;log0(f"gptq:done in {gptq_elapsed:.1f}s");log0(f"wallclock:post_gptq total_elapsed:{total_wallclock_elapsed:.1f}s train_budget:{args.max_wallclock_seconds:.1f}s");torch.cuda.empty_cache() + clip_ranges=_A + if args.mixed_quant and gptq_hessians is not _A: + quant_names=[n for n in unbanked_sd if _classify_param(n)in{'mlp','attn'}and unbanked_sd[n].ndim>=1 and unbanked_sd[n].numel()>65536];sens={n:gptq_hessians[n].diag().sum().item()if n in gptq_hessians else 0.0 for n in quant_names};ranked=sorted(sens.items(),key=lambda x:-x[1]);clip_ranges={n:15 for n in quant_names};recur_layer_set=set(recur_layers);recur_quant_names=[name for name in quant_names if _get_physical_layer_idx_from_name(name,recur_layers)in recur_layer_set];recur_ranked=sorted(recur_quant_names,key=lambda name:-sens[name]);forced_int6=min(args.n_int6_layers,len(recur_ranked));selected_int6_names=recur_ranked[:forced_int6];selected_int6_set=set(selected_int6_names) + for(name,_)in ranked: + if len(selected_int6_names)>=args.n_int6_layers:break + if name in selected_int6_set:continue + selected_int6_names.append(name);selected_int6_set.add(name) + [clip_ranges.__setitem__(name,31) for name in selected_int6_names];int6_names=[n for n,cr in clip_ranges.items()if cr==31];int5_names=[n for n,cr in clip_ranges.items()if cr==15];log0(f"mixed_quant: {len(int6_names)} int6, {len(int5_names)} int5");log0(f"mixed_quant: forced_recur_int6={forced_int6}/{len(recur_ranked)} recur_layers={recur_layers}");log0(f"mixed_quant: int6 layers: {int6_names[:5]}...") + quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians,clip_ranges=clip_ranges);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) + if master_process: + with open(F,'wb')as f:f.write(quant_blob) + quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") + if distributed:dist.barrier() + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=base_model._recurrence_active,repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_strideX.size(-1) + if transposed:X=X.mT + X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) + for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X + if transposed:X=X.mT + if was_2d:X=X.squeeze(0) + return X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C + def _build(self): + self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] + for group in self.param_groups: + for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) + self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B + def launch_reduce_scatters(self): + if not self._built:self._build() + if not self._distributed:return + self._rs_futures=[] + for m in self._bank_meta: + p=m['p'] + if p.grad is _A:self._rs_futures.append(_A);continue + pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0]>m['B']:pg[m['B']:].zero_() + fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) + @torch.no_grad() + def step(self,closure=_A): + B='_rs_futures';A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + if not self._built:self._build() + for group in self.param_groups: + lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) + for(i,m)in enumerate(self._bank_meta): + p=m['p'] + if p.grad is _A:continue + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] + else: + g=p.grad.bfloat16();state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A] + buf.mul_(momentum).add_(g) + if nesterov:update=g.add(buf,alpha=momentum) + else:update=buf + update=zeropower_via_newtonschulz5(update,steps=backend_steps) + if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m + else: + if wd>_E:p.data.mul_(_D-lr*wd) + p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if hasattr(self,B):del self._rs_futures + return loss +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=_C + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode(_I)) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) + if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def quantize_state_dict_int8(state_dict): + F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) + for(name,tensor)in state_dict.items(): + t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) + if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue + if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue + stats[D]+=1;q,s=quantize_float_tensor(t) + if s.ndim>0:qmeta[name]={_c:_d,'axis':0} + quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) + obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} + if qmeta:obj['qmeta']=qmeta + if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes + return obj,stats +def dequantize_state_dict_int8(obj): + out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) + for(name,q)in obj[_e].items(): + dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] + if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() + else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() + for(name,t)in obj[_L].items(): + out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() + out[name]=out_t + return out +def load_data_shard(file): + header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes + if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") + tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) + if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") + return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) +_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize +_SHARD_NTOKENS_CACHE={} +_MMAP_CACHE={} +def _read_num_tokens(file): + key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) + if cached is not _A:return cached + header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n +def _get_shard_memmap(file): + key=str(file);mm=_MMAP_CACHE.get(key) + if mm is not _A:return mm + n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm +class DistributedTokenLoader: + def __init__(self,pattern,rank,world_size,device): + self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] + if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 + for f in self.files: + for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff + self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 + def _pick_coprime_stride(self,n): + if n<=1:return 1 + while _B: + s=int(self._rng.integers(1,n)) + if math.gcd(s,n)==1:return s + def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +def apply_canon_residual(x,w): + w=w.to(dtype=x.dtype);y=x*w[0][_A,_A,:] + y=y+F.pad(x[:,:-1],(0,0,1,0))*w[1][_A,_A,:] + y=y+F.pad(x[:,:-2],(0,0,2,0))*w[2][_A,_A,:] + y=y+F.pad(x[:,:-3],(0,0,3,0))*w[3][_A,_A,:] + return x+y +class CastedLinear(nn.Linear): + _qat_enabled=_C;_qat_alpha=_D + def forward(self,x): + w=self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q + bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(name,param)in module.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.recur_layers=sorted(set(recur_layers or[]));self.repeat_untie_mlp=repeat_untie_mlp + self.canon_ac_layers=sorted(set(canon_ac_layers or[]));self._canon_ac_layer_set=set(self.canon_ac_layers) + for cl in self.canon_ac_layers: + if not 0<=cl0: + head_dim=model_dim//num_heads + for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) + self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if xsa_last_n>0: + for i in range(max(0,self.virtual_num_layers-xsa_last_n),self.virtual_num_layers):self.blocks[i].attn.use_xsa=_B + self.set_recurrence_active(recurrence_active);self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) + n=self.num_layers;proj_scale=_D/math.sqrt(2*n) + for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) + for repeat_mlp in self.repeat_mlp: + if repeat_mlp.fc is not _A:nn.init.zeros_(repeat_mlp.fc.weight) + if repeat_mlp.proj is not _A:nn.init.zeros_(repeat_mlp.proj.weight) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def set_recurrence_active(self,active): + was_active=getattr(self,'_recurrence_active',_C);self._recurrence_active=bool(active)and bool(self.recur_layers) + if self._recurrence_active:self.v2p=self._v2p_recur;self.num_encoder_layers=self._enc_recur;self.num_decoder_layers=self._dec_recur + else:self.v2p=self._v2p_no_recur;self.num_encoder_layers=self._enc_no_recur;self.num_decoder_layers=self._dec_no_recur + if self._recurrence_active and not was_active and self.repeat_mlp:self._sync_repeat_mlp_from_base() + def _sync_repeat_mlp_from_base(self): + with torch.no_grad(): + for(repeat_idx,physical_idx)in enumerate(self.recur_layers): + repeat_mlp=self.repeat_mlp[repeat_idx] + if repeat_mlp.fc is not _A:repeat_mlp.fc.weight.copy_(self.mlp_up_bank[physical_idx]) + if repeat_mlp.proj is not _A:repeat_mlp.proj.weight.copy_(self.mlp_down_bank[physical_idx]) + def _is_repeated_virtual_index(self,virtual_idx):return self._recurrence_active and bool(self.recur_layers) and self._enc_recur<=virtual_idx=self.parallel_start_layer + return virtual_idx>=self.parallel_start_layer + def _mix_with_x0(self,lane,x0,resid_mix): + mix=resid_mix.to(dtype=lane.dtype);return mix[0][_A,_A,:]*lane+mix[1][_A,_A,:]*x0 + def _apply_skip_single(self,x,skip,i): + if isinstance(skip,tuple):skip=skip[1] + g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skip;return torch.lerp(scaled_skip,x,g) + def _apply_skip_parallel(self,lane0,lane1,skip,i): + if isinstance(skip,tuple):skip0,skip1=skip + else:skip0=skip1=skip + g=torch.sigmoid(self.skip_gates[i].to(dtype=lane0.dtype))[_A,_A,:];w=self.skip_weights[i].to(dtype=lane0.dtype)[_A,_A,:] + return torch.lerp(w*skip0,lane0,g),torch.lerp(w*skip1,lane1,g) + def _final_parallel_hidden(self,lane0,lane1): + # The branch starts as a clone, so average the summed lanes to keep the output scale close to the single-lane path. + return (lane0+lane1)*.5 + def _parallel_block(self,virtual_idx,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=_A,canon_a_w=_A,canon_c_w=_A): + block=self.blocks[virtual_idx];physical_idx=self.v2p[virtual_idx] + if not block.disable_attn: + attn_read=self._mix_with_x0(lane0,x0,block.resid_mix);attn_in=block.attn_norm(attn_read)*block.ln_scale_factor + if canon_a_w is not _A:attn_in=apply_canon_residual(attn_in,canon_a_w) + attn_out=block.attn(attn_in,q_w,k_w,v_w,out_w,v_embed=v_embed);attn_out=block.attn_scale.to(dtype=attn_out.dtype)[_A,_A,:]*attn_out;resid=self.parallel_resid_lambdas[physical_idx,0].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,0].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*attn_out;lane1=resid*lane1+post[1]*attn_out + mlp_read=self._mix_with_x0(lane1,x0,block.resid_mix);mlp_in=block.mlp_norm(mlp_read)*block.ln_scale_factor + if canon_c_w is not _A:mlp_in=apply_canon_residual(mlp_in,canon_c_w) + mlp_out=block.mlp_scale.to(dtype=lane1.dtype)[_A,_A,:]*block.mlp(mlp_in,up_w,down_w);resid=self.parallel_resid_lambdas[physical_idx,1].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,1].to(dtype=lane0.dtype) + lane0=resid*lane0+post[0]*mlp_out;lane1=resid*lane1+post[1]*mlp_out;return lane0,lane1 + def _get_block_weights(self,virtual_idx): + n=self.num_layers;physical_idx=self.v2p[virtual_idx];q_w=self.qo_bank[physical_idx];k_w=self.kv_bank[physical_idx];v_w=self.kv_bank[n+physical_idx];out_w=self.qo_bank[n+physical_idx];up_w=self.mlp_up_bank[physical_idx];down_w=self.mlp_down_bank[physical_idx];canon_a_w=self.canon_a[physical_idx]if self.canon_a is not _A and physical_idx in self._canon_ac_layer_set else _A;canon_c_w=self.canon_c[physical_idx]if self.canon_c is not _A and physical_idx in self._canon_ac_layer_set else _A + if self._is_repeated_virtual_index(virtual_idx): + repeated_idx=virtual_idx-self._enc_recur + if self.repeat_mlp: + repeat_mlp=self.repeat_mlp[repeated_idx] + if repeat_mlp.fc is not _A:up_w=repeat_mlp.fc.weight + if repeat_mlp.proj is not _A:down_w=repeat_mlp.proj.weight + return q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w + def _backbone(self,input_ids): + x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={};lane0=lane1=_A + for i in range(self.num_encoder_layers): + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(i);ve=self._get_ve(i,input_ids,ve_cache) + if self._parallel_active_for_layer(i): + if lane0 is _A:lane0=lane1=x + lane0,lane1=self._parallel_block(i,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append((lane0,lane1)) + else:x=self.blocks[i](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + q_w,k_w,v_w,out_w,up_w,down_w,canon_a_w,canon_c_w=self._get_block_weights(bi);ve=self._get_ve(bi,input_ids,ve_cache) + if self._parallel_active_for_layer(bi): + if lane0 is _A:lane0=lane1=x + if skips:lane0,lane1=self._apply_skip_parallel(lane0,lane1,skips.pop(),i) + lane0,lane1=self._parallel_block(bi,lane0,lane1,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + else: + if skips:x=self._apply_skip_single(x,skips.pop(),i) + x=self.blocks[bi](x,x0,q_w,k_w,v_w,out_w,up_w,down_w,v_embed=ve,canon_a_w=canon_a_w,canon_c_w=canon_c_w) + return self.final_norm(self._final_parallel_hidden(lane0,lane1) if lane1 is not _A else x) + def forward(self,input_ids,target_ids): + x=self._backbone(input_ids);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) + if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) + else: + if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + logits_proj=self.lm_head(x_flat) + logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') + def forward_hidden(self,input_ids):return self._backbone(input_ids) + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] + if use_slot: + with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) + H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) + for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:-1].reshape(-1,adapted.size(-1)),y_batch[:,:seq_len-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() + with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) + else: + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) + with torch.no_grad(): + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] + for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] + for(name,p)in base_model.named_parameters(): + freeze=_C + for bi in frozen_block_ids: + if f"blocks.{bi}."in name:freeze=_B;break + if freeze:p.requires_grad_(_C) + else:p.requires_grad_(_B);ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() + for ci in range(num_chunks): + windows=chunk_windows[ci] + if not windows:continue + chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() + with torch.inference_mode(): + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] + with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + is_last_chunk=ci==num_chunks-1 + if not is_last_chunk and args.ttt_epochs>0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg[_H]=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): + be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_tokens.numel():continue + local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() + if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) + for p in base_model.parameters():p.requires_grad_(_B) + base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb +def generate_autoregressive_calib(model,device,num_seqs=64,seq_len=2048,vocab_size=1024,temperature=.8,batch_size=8,seed=42): + was_training=model.training;model.eval();rng=torch.Generator(device=device);rng.manual_seed(seed);all_tokens=[] + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for batch_start in range(0,num_seqs,batch_size): + bs=min(batch_size,num_seqs-batch_start);tokens=torch.randint(0,vocab_size,(bs,1),device=device,generator=rng) + for _ in range(seq_len-1): + logits=model.forward_logits(tokens);next_logit=logits[:,-1,:];probs=torch.softmax(next_logit/max(temperature,1e-4),dim=-1);next_tok=torch.multinomial(probs,1,generator=rng);tokens=torch.cat([tokens,next_tok],dim=1) + for i in range(bs):all_tokens.append(tokens[i:i+1].detach().clone()) + model.train(was_training);return all_tokens +def gptq_collect_hessians_from_tokens(base_model,token_seqs,device): + dim=base_model.tok_emb.weight.shape[1];mlp_dim=base_model.mlp_up_bank.shape[1];hessians=_init_hessians(base_model,dim,mlp_dim,device) + for block in base_model.blocks:block.attn._save_gptq=_B;block.mlp._save_gptq=_B + was_training=base_model.training;base_model.eval() + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16): + for seq in token_seqs:x=seq[:,:-1].to(device=device,dtype=torch.int64);y=seq[:,1:].to(device=device,dtype=torch.int64);base_model(x,y);_accum_hessians(hessians,base_model,dim,mlp_dim) + for block in base_model.blocks:block.attn._save_gptq=_C;block.mlp._save_gptq=_C + _finalize_hessians(hessians,max(len(token_seqs),1));base_model.train(was_training);return hessians +def _classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if name.startswith('canon_a'):return'attn' + if name.startswith('canon_c'):return'mlp' + if A in name or name.startswith('repeat_mlp.'):return'mlp' + if'.attn.'in name or'.proj.'in name and A not in name:return'attn' + return'other' +def _parse_layer_list(layers_str): + return[int(x)for x in layers_str.split(',')if x.strip()] +def _get_block_idx_from_name(name): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + return _A +def _get_physical_layer_idx_from_name(name,recur_layers): + parts=name.split('.') + if len(parts)>2 and parts[0]=='blocks'and parts[1].isdigit():return int(parts[1]) + if len(parts)>2 and parts[0]=='repeat_mlp'and parts[1].isdigit(): + repeat_idx=int(parts[1]) + if 0<=repeat_idx0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def _unbank_state_dict(sd,num_layers): + out={};n=num_layers + for(name,tensor)in sd.items(): + if name==_R: + for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] + elif name==_S: + for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] + elif name==_T: + for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] + elif name==_U: + for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] + else:out[name]=tensor + return out +def _rebank_state_dict(sd,num_layers,template_sd): + out={};n=num_layers;qo_slices=[template_sd[_R][i]for i in range(2*n)];kv_slices=[template_sd[_S][i]for i in range(2*n)];up_slices=[template_sd[_T][i]for i in range(n)];down_slices=[template_sd[_U][i]for i in range(n)];consumed=set() + for i in range(n): + qk=f"blocks.{i}.attn.c_q.weight" + if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) + ok=f"blocks.{i}.attn.proj.weight" + if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) + kk=f"blocks.{i}.attn.c_k.weight" + if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) + vk=f"blocks.{i}.attn.c_v.weight" + if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) + fk=f"blocks.{i}.mlp.fc.weight" + if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) + dk=f"blocks.{i}.mlp.proj.weight" + if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) + out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) + for(name,tensor)in sd.items(): + if name not in consumed:out[name]=tensor + return out +def _drop_disabled_layer0_attn_unbanked(sd,disable_layer0_attn): + if not disable_layer0_attn:return sd + disabled_keys={'blocks.0.attn.c_q.weight','blocks.0.attn.c_k.weight','blocks.0.attn.c_v.weight','blocks.0.attn.proj.weight'} + return{k:v for(k,v)in sd.items()if k not in disabled_keys} +def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A,clip_ranges=_A): + A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=_classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue + if cat in int6_cats and t.ndim>=1: + H=hessians.get(name)if hessians else _A;cr=clip_ranges.get(name,clip_range)if isinstance(clip_ranges,dict)else clip_range + if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=cr);gptq_count+=1 + else:q,s=quantize_int6_per_row(t,clip_range=cr);naive_count+=1 + result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'if cr>=31 else 'int5'} + else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} + if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_L,_i,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+_V] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): + W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");recur_layers=_parse_layer_list(args.recur_layers_str);repeat_untie_mlp_layers=_parse_layer_list(args.repeat_untie_mlp_layers);canon_ac_layers=_parse_layer_list(args.canon_ac_layers) + if args.post_gptq_eval_only: + eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=bool(recur_layers),repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model) + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);template_sd={k:v.detach().cpu()for(k,v)in eval_model.state_dict().items()};template_unbanked=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(template_sd,args.num_layers),args.disable_layer0_attn);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],template_unbanked);eval_model.load_state_dict(_rebank_state_dict(deq_unbanked,args.num_layers,template_sd),strict=_B);q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) + token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] + if base_model.bigram is not _A: + tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) + if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not _A: + tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) + if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales:scalar_params.append(s) + optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups:group[A]=args.matrix_lr + optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) + for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) + replicated_params.extend(scalar_params);optimizer_head=_A + if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) + optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] + if optimizer_head is not _A:optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"recurrence:layers={recur_layers} start_step={args.recur_start_step} active={int(base_model._recurrence_active)}");log0(f"canon_ac:layers={canon_ac_layers} params={0 if base_model.canon_a is _A else base_model.canon_a.numel()+base_model.canon_c.numel()} physical_only=1");log0(f"parallel_residual:active={int(base_model.parallel_post_lambdas is not _A)} start_layer={base_model.parallel_start_layer} start_mode={'physical'if base_model.parallel_start_layer_is_physical else 'virtual'} params={0 if base_model.parallel_post_lambdas is _A else base_model.parallel_post_lambdas.numel()+base_model.parallel_resid_lambdas.numel()} final_lane=mlp");log0(f"repeat_untie_mlp:mode={args.repeat_untie_mlp} layers={repeat_untie_mlp_layers if repeat_untie_mlp_layers else recur_layers if args.repeat_untie_mlp!='none' else []} params={sum(p.numel()for p in base_model.repeat_mlp.parameters())}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + def zero_grad_all(): + for opt in optimizers:opt.zero_grad(set_to_none=_B) + max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A + if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step,elapsed_ms): + if args.warmdown_iters<=0:return _D + if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train();run_warmup_steps(args.warmup_steps,'base') + if recur_layers:base_model.set_recurrence_active(_B);log0(f"recurrence:prewarm active={int(base_model._recurrence_active)} virtual_layers:{base_model.virtual_num_layers}");run_warmup_steps(args.warmup_steps,'recur');base_model.set_recurrence_active(_C) + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + zero_grad_all();base_model.set_recurrence_active(_C);train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();timed_wallclock_t0=time.perf_counter();t0=timed_wallclock_t0;step=0 + while _B: + if recur_layers and not base_model._recurrence_active and step>=args.recur_start_step:base_model.set_recurrence_active(_B);log0(f"recurrence:activated step:{step} layers={recur_layers} virtual_layers:{base_model.virtual_num_layers}") + last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum + for group in optimizer_muon.param_groups:group[_a]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[A]*scale + if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) + if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: + s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr + with torch.no_grad(): + for bank in[base_model.qo_bank,base_model.kv_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) + for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + optimizer_tok.step();optimizer_scalar.step() + if optimizer_head is not _A:optimizer_head.step() + optimizer_muon.step();zero_grad_all() + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) + if args.late_qat_threshold>0 and scale=2000: + if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress + if args.swa_enabled and scale<.2 and step%args.swa_every==0: + if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") + else: + for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() + swa_count+=1 + should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) + if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);log_parallel_residual_converged(log0,base_model);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() + if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") + sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_drop_disabled_layer0_attn_unbanked(_unbank_state_dict(sd_cpu,args.num_layers),args.disable_layer0_attn);gptq_hessians=_A + if args.use_gptq: + t_gptq=time.perf_counter();recur_was_active=base_model._recurrence_active;base_model.set_recurrence_active(recur_was_active);log0(f"gptq:calibration recurrence_active={int(base_model._recurrence_active)} repeat_mlp={len(base_model.repeat_mlp)} parallel_residual={int(base_model.parallel_post_lambdas is not _A)} ar_selfgen={int(args.gptq_ar_selfgen)}") + if args.gptq_ar_selfgen: + log0(f"gptq:generating autoregressive calibration data ({args.gptq_calib_samples} seqs x {args.train_seq_len} tokens, temp={args.gptq_temperature:.2f})...");t_gen=time.perf_counter();ar_tokens=generate_autoregressive_calib(base_model,device,num_seqs=args.gptq_calib_samples,seq_len=args.train_seq_len,vocab_size=args.vocab_size,temperature=args.gptq_temperature,batch_size=args.gptq_batch_size,seed=args.seed);log0(f"gptq:generated {len(ar_tokens)} sequences in {time.perf_counter()-t_gen:.1f}s");log0("gptq:collecting hessians from autoregressive data...");gptq_hessians=gptq_collect_hessians_from_tokens(base_model,ar_tokens,device);del ar_tokens;log0(f"gptq:collected hessians for {len(gptq_hessians)} layers (AR self-gen)") + else: + log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;log0(f"gptq:calibrated {len(gptq_hessians)} layers from training data") + base_model.set_recurrence_active(recur_was_active);gptq_elapsed=time.perf_counter()-t_gptq;total_wallclock_elapsed=time.perf_counter()-timed_wallclock_t0;log0(f"gptq:done in {gptq_elapsed:.1f}s");log0(f"wallclock:post_gptq total_elapsed:{total_wallclock_elapsed:.1f}s train_budget:{args.max_wallclock_seconds:.1f}s");torch.cuda.empty_cache() + clip_ranges=_A + if args.mixed_quant and gptq_hessians is not _A: + quant_names=[n for n in unbanked_sd if _classify_param(n)in{'mlp','attn'}and unbanked_sd[n].ndim>=1 and unbanked_sd[n].numel()>65536];sens={n:gptq_hessians[n].diag().sum().item()if n in gptq_hessians else 0.0 for n in quant_names};ranked=sorted(sens.items(),key=lambda x:-x[1]);clip_ranges={n:15 for n in quant_names};recur_layer_set=set(recur_layers);recur_quant_names=[name for name in quant_names if _get_physical_layer_idx_from_name(name,recur_layers)in recur_layer_set];recur_ranked=sorted(recur_quant_names,key=lambda name:-sens[name]);forced_int6=min(args.n_int6_layers,len(recur_ranked));selected_int6_names=recur_ranked[:forced_int6];selected_int6_set=set(selected_int6_names) + for(name,_)in ranked: + if len(selected_int6_names)>=args.n_int6_layers:break + if name in selected_int6_set:continue + selected_int6_names.append(name);selected_int6_set.add(name) + [clip_ranges.__setitem__(name,31) for name in selected_int6_names];int6_names=[n for n,cr in clip_ranges.items()if cr==31];int5_names=[n for n,cr in clip_ranges.items()if cr==15];log0(f"mixed_quant: {len(int6_names)} int6, {len(int5_names)} int5");log0(f"mixed_quant: forced_recur_int6={forced_int6}/{len(recur_ranked)} recur_layers={recur_layers}");log0(f"mixed_quant: int6 layers: {int6_names[:5]}...") + quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians,clip_ranges=clip_ranges);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) + if master_process: + with open(F,'wb')as f:f.write(quant_blob) + quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") + if distributed:dist.barrier() + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,canon_ac_layers=canon_ac_layers,parallel_residual=args.parallel_residual,parallel_start_layer=args.parallel_start_layer,parallel_start_layer_is_physical=args.parallel_start_layer_is_physical,neg_slope=args.negative_slope,disable_layer0_attn=args.disable_layer0_attn,recur_layers=recur_layers,recurrence_active=base_model._recurrence_active,repeat_untie_mlp=args.repeat_untie_mlp,repeat_untie_mlp_layers=repeat_untie_mlp_layers).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,eval_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride