@@ -38,18 +38,18 @@ class Hyperparameters():
3838 run_id = os .environ .get ("RUN_ID" , str (uuid .uuid4 ()))
3939
4040 # Training length
41- iterations = int (os .environ .get ('ITERATIONS' , 20000 ))
41+ iterations = int (os .environ .get ('ITERATIONS' , 200 ))
4242 warmdown_frac = float (os .environ .get ('WARMDOWN_FRAC' , 0.667 ))
4343 warmup_steps = int (os .environ .get ('WARMUP_STEPS' , 20 ))
4444 train_batch_tokens = int (os .environ .get ('TRAIN_BATCH_TOKENS' , 2048 * 48 * 8 ))
4545 train_seq_len = int (os .environ .get ('TRAIN_SEQ_LEN' , 2048 ))
4646 eval_seq_len = int (os .environ .get ('EVAL_SEQ_LEN' , 2048 ))
4747 max_wallclock_seconds = float (os .environ .get ('MAX_WALLCLOCK_SECONDS' , 600.0 ))
48- train_log_every = int (os .environ .get ('TRAIN_LOG_EVERY' , 500 ))
48+ train_log_every = int (os .environ .get ('TRAIN_LOG_EVERY' , 10 ))
4949
5050 # Validation/Evals
5151 val_batch_tokens = int (os .environ .get ('VAL_BATCH_TOKENS' , 2048 * 32 * 8 ))
52- val_loss_every = int (os .environ .get ('VAL_LOSS_EVERY' , 4000 ))
52+ val_loss_every = int (os .environ .get ('VAL_LOSS_EVERY' , 200 ))
5353 sliding_window_enabled = bool (int (os .environ .get ('SLIDING_WINDOW_ENABLED' , '1' )))
5454
5555 # Model architecture
@@ -115,10 +115,14 @@ class Hyperparameters():
115115
116116 # Compression
117117 compressor = os .environ .get ('COMPRESSOR' , 'brotli' ) #(lzma or brotli)
118- gptq_enabled = bool (int (os .environ .get ('GPTQ_ENABLED' , '1 ' )))
118+ gptq_enabled = bool (int (os .environ .get ('GPTQ_ENABLED' , '0 ' )))
119119 gptq_calibration_batches = int (os .environ .get ('GPTQ_CALIBRATION_BATCHES' , 64 ))
120120 gptq_reserve_seconds = float (os .environ .get ('GPTQ_RESERVE_SECONDS' , 10.0 ))
121121
122+ # CompTrain
123+ comptrain_enabled = bool (int (os .environ .get ('COMPTRAIN_ENABLED' , '1' )))
124+ comptrain_alpha = float (os .environ .get ('COMPTRAIN_ALPHA' , '0.5' ))
125+
122126 # Distributed setup
123127 distributed = "RANK" in os .environ and "WORLD_SIZE" in os .environ
124128 rank = int (os .environ .get ("RANK" , "0" ))
@@ -591,6 +595,19 @@ def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tenso
591595 return x_out
592596
593597
598+ def build_bigram_table (train_files : str , vocab_size : int , device : torch .device ) -> Tensor :
599+ files = sorted (glob .glob (train_files ))[:3 ]
600+ counts = torch .zeros (vocab_size * vocab_size , dtype = torch .float32 )
601+ for f in files :
602+ tokens = load_data_shard (Path (f )).long ()
603+ prev = tokens [:- 1 ]
604+ curr = tokens [1 :]
605+ flat_idx = prev * vocab_size + curr
606+ counts .scatter_add_ (0 , flat_idx , torch .ones_like (flat_idx , dtype = torch .float32 ))
607+ counts = counts .reshape (vocab_size , vocab_size )
608+ row_sums = counts .sum (dim = 1 , keepdim = True ).clamp (min = 1.0 )
609+ return (counts / row_sums ).to (device )
610+
594611class GPT (nn .Module ):
595612 def __init__ (self , h : Hyperparameters ):
596613 super ().__init__ ()
@@ -652,6 +669,8 @@ def __init__(self, h: Hyperparameters):
652669 else :
653670 self .lane_merge = None
654671
672+ self ._bigram_table = None
673+ self ._comptrain_alpha = 0.5
655674 self ._init_weights ()
656675
657676 def set_recurrence_active (self , active : bool ) -> None :
@@ -774,10 +793,18 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
774793 logits_proj = self .lm_head (x )
775794 return self .logit_softcap * torch .tanh (logits_proj / self .logit_softcap )
776795
777- def forward (self , input_ids : Tensor , target_ids : Tensor ) -> Tensor :
796+ def forward (self , input_ids : Tensor , target_ids : Tensor ):
778797 logits = self .forward_logits (input_ids )
779- return F .cross_entropy (
798+ if self ._bigram_table is not None :
799+ per_token = F .cross_entropy (
800+ logits .reshape (- 1 , logits .size (- 1 )).float (), target_ids .reshape (- 1 ), reduction = "none" )
801+ ctrl_loss = per_token .detach ().mean ()
802+ w = 1.0 - self ._comptrain_alpha * self ._bigram_table [input_ids .reshape (- 1 ), target_ids .reshape (- 1 )]
803+ weighted_loss = (per_token * w ).sum () / w .sum ()
804+ return weighted_loss , ctrl_loss
805+ loss = F .cross_entropy (
780806 logits .reshape (- 1 , logits .size (- 1 )).float (), target_ids .reshape (- 1 ), reduction = "mean" )
807+ return loss , loss .detach ()
781808
782809
783810def classify_param (name : str ) -> str :
@@ -1724,6 +1751,12 @@ def run_evals(
17241751def train_model (h : Hyperparameters , device : torch .device , val_data : ValidationData ) -> None :
17251752 # Set up model
17261753 base_model = GPT (h ).to (device ).bfloat16 ()
1754+ if h .comptrain_enabled :
1755+ log ("comptrain:building bigram table from first 3 shards" )
1756+ bigram_table = build_bigram_table (h .train_files , h .vocab_size , device )
1757+ base_model ._bigram_table = bigram_table
1758+ base_model ._comptrain_alpha = h .comptrain_alpha
1759+ log (f"comptrain:enabled alpha={ h .comptrain_alpha } mean_bigram_prob={ bigram_table .mean ().item ():.6f} " )
17271760 restore_fp32_params (base_model )
17281761 compiled_model = torch .compile (base_model , dynamic = False , fullgraph = True )
17291762 if h .distributed :
@@ -1758,15 +1791,18 @@ def lr_mul(frac: float) -> float:
17581791 def step_fn (step , lr_scale ):
17591792 optimizers .zero_grad_all ()
17601793 train_loss = torch .zeros ((), device = device )
1794+ weighted_loss_accum = torch .zeros ((), device = device )
17611795 for micro_step in range (h .grad_accum_steps ):
17621796 if h .distributed :
17631797 model .require_backward_grad_sync = micro_step == h .grad_accum_steps - 1
17641798 x , y = train_loader .next_batch (h .train_batch_tokens , h .train_seq_len , h .grad_accum_steps )
17651799 with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 , enabled = True ):
1766- loss = model (x , y )
1767- train_loss += loss .detach ()
1800+ loss , ctrl_loss = model (x , y )
1801+ train_loss += ctrl_loss
1802+ weighted_loss_accum += loss .detach ()
17681803 (loss / h .grad_accum_steps ).backward ()
17691804 train_loss /= h .grad_accum_steps
1805+ weighted_loss_accum /= h .grad_accum_steps
17701806
17711807 frac = min (step / h .muon_momentum_warmup_steps , 1.0 ) if h .muon_momentum_warmup_steps > 0 else 1.0
17721808 muon_momentum = (1 - frac ) * h .muon_momentum_warmup_start + frac * h .muon_momentum
@@ -1781,7 +1817,7 @@ def step_fn(step, lr_scale):
17811817 torch .nn .utils .clip_grad_norm_ (base_model .parameters (), h .grad_clip_norm )
17821818
17831819 optimizers .step ()
1784- return train_loss
1820+ return train_loss , weighted_loss_accum
17851821
17861822 # Model warmup
17871823 if h .warmup_steps > 0 :
@@ -1839,7 +1875,7 @@ def step_fn(step, lr_scale):
18391875 elapsed_ms = training_time_ms + 1000.0 * (time .perf_counter () - t0 )
18401876 frac = training_frac (step , elapsed_ms )
18411877 scale = lr_mul (frac )
1842- train_loss = step_fn (step , scale )
1878+ train_loss , weighted_loss = step_fn (step , scale )
18431879
18441880 with torch .no_grad ():
18451881 for name , t in base_model .state_dict ().items ():
@@ -1855,8 +1891,9 @@ def step_fn(step, lr_scale):
18551891 if should_log_train :
18561892 tok_per_sec = step * h .train_batch_tokens / (approx_training_time_ms / 1000.0 )
18571893 log (
1858- f"{ step } /{ h .iterations } train_loss: { train_loss .item ():.4f} "
1859- f"train_time: { approx_training_time_ms / 60000 :.1f} m tok/s: { tok_per_sec :.0f} "
1894+ f"{ step } /{ h .iterations } train_loss:{ train_loss .item ():.4f} "
1895+ f"weighted_loss:{ weighted_loss .item ():.4f} "
1896+ f"train_time:{ approx_training_time_ms / 60000 :.1f} m tok/s:{ tok_per_sec :.0f} "
18601897 )
18611898
18621899 reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms
0 commit comments