Skip to content

Commit 5357f7f

Browse files
committed
Re-apply CompTrain microbenchmark on new base (openai#1296 SP4096)
1 parent feeb243 commit 5357f7f

1 file changed

Lines changed: 49 additions & 12 deletions

File tree

train_gpt.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@ class Hyperparameters():
3838
run_id = os.environ.get("RUN_ID", str(uuid.uuid4()))
3939

4040
# Training length
41-
iterations = int(os.environ.get('ITERATIONS', 20000))
41+
iterations = int(os.environ.get('ITERATIONS', 200))
4242
warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667))
4343
warmup_steps = int(os.environ.get('WARMUP_STEPS', 20))
4444
train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8))
4545
train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048))
4646
eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048))
4747
max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0))
48-
train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500))
48+
train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 10))
4949

5050
# Validation/Evals
5151
val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8))
52-
val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000))
52+
val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 200))
5353
sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1')))
5454

5555
# Model architecture
@@ -115,10 +115,14 @@ class Hyperparameters():
115115

116116
# Compression
117117
compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli)
118-
gptq_enabled = bool(int(os.environ.get('GPTQ_ENABLED', '1')))
118+
gptq_enabled = bool(int(os.environ.get('GPTQ_ENABLED', '0')))
119119
gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64))
120120
gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 10.0))
121121

122+
# CompTrain
123+
comptrain_enabled = bool(int(os.environ.get('COMPTRAIN_ENABLED', '1')))
124+
comptrain_alpha = float(os.environ.get('COMPTRAIN_ALPHA', '0.5'))
125+
122126
# Distributed setup
123127
distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
124128
rank = int(os.environ.get("RANK", "0"))
@@ -591,6 +595,19 @@ def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tenso
591595
return x_out
592596

593597

598+
def build_bigram_table(train_files: str, vocab_size: int, device: torch.device) -> Tensor:
599+
files = sorted(glob.glob(train_files))[:3]
600+
counts = torch.zeros(vocab_size * vocab_size, dtype=torch.float32)
601+
for f in files:
602+
tokens = load_data_shard(Path(f)).long()
603+
prev = tokens[:-1]
604+
curr = tokens[1:]
605+
flat_idx = prev * vocab_size + curr
606+
counts.scatter_add_(0, flat_idx, torch.ones_like(flat_idx, dtype=torch.float32))
607+
counts = counts.reshape(vocab_size, vocab_size)
608+
row_sums = counts.sum(dim=1, keepdim=True).clamp(min=1.0)
609+
return (counts / row_sums).to(device)
610+
594611
class GPT(nn.Module):
595612
def __init__(self, h: Hyperparameters):
596613
super().__init__()
@@ -652,6 +669,8 @@ def __init__(self, h: Hyperparameters):
652669
else:
653670
self.lane_merge = None
654671

672+
self._bigram_table = None
673+
self._comptrain_alpha = 0.5
655674
self._init_weights()
656675

657676
def set_recurrence_active(self, active: bool) -> None:
@@ -774,10 +793,18 @@ def forward_logits(self, input_ids: Tensor) -> Tensor:
774793
logits_proj = self.lm_head(x)
775794
return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
776795

777-
def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
796+
def forward(self, input_ids: Tensor, target_ids: Tensor):
778797
logits = self.forward_logits(input_ids)
779-
return F.cross_entropy(
798+
if self._bigram_table is not None:
799+
per_token = F.cross_entropy(
800+
logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="none")
801+
ctrl_loss = per_token.detach().mean()
802+
w = 1.0 - self._comptrain_alpha * self._bigram_table[input_ids.reshape(-1), target_ids.reshape(-1)]
803+
weighted_loss = (per_token * w).sum() / w.sum()
804+
return weighted_loss, ctrl_loss
805+
loss = F.cross_entropy(
780806
logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean")
807+
return loss, loss.detach()
781808

782809

783810
def classify_param(name: str) -> str:
@@ -1724,6 +1751,12 @@ def run_evals(
17241751
def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None:
17251752
# Set up model
17261753
base_model = GPT(h).to(device).bfloat16()
1754+
if h.comptrain_enabled:
1755+
log("comptrain:building bigram table from first 3 shards")
1756+
bigram_table = build_bigram_table(h.train_files, h.vocab_size, device)
1757+
base_model._bigram_table = bigram_table
1758+
base_model._comptrain_alpha = h.comptrain_alpha
1759+
log(f"comptrain:enabled alpha={h.comptrain_alpha} mean_bigram_prob={bigram_table.mean().item():.6f}")
17271760
restore_fp32_params(base_model)
17281761
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
17291762
if h.distributed:
@@ -1758,15 +1791,18 @@ def lr_mul(frac: float) -> float:
17581791
def step_fn(step, lr_scale):
17591792
optimizers.zero_grad_all()
17601793
train_loss = torch.zeros((), device=device)
1794+
weighted_loss_accum = torch.zeros((), device=device)
17611795
for micro_step in range(h.grad_accum_steps):
17621796
if h.distributed:
17631797
model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1
17641798
x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps)
17651799
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
1766-
loss = model(x, y)
1767-
train_loss += loss.detach()
1800+
loss, ctrl_loss = model(x, y)
1801+
train_loss += ctrl_loss
1802+
weighted_loss_accum += loss.detach()
17681803
(loss / h.grad_accum_steps).backward()
17691804
train_loss /= h.grad_accum_steps
1805+
weighted_loss_accum /= h.grad_accum_steps
17701806

17711807
frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0
17721808
muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum
@@ -1781,7 +1817,7 @@ def step_fn(step, lr_scale):
17811817
torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm)
17821818

17831819
optimizers.step()
1784-
return train_loss
1820+
return train_loss, weighted_loss_accum
17851821

17861822
# Model warmup
17871823
if h.warmup_steps > 0:
@@ -1839,7 +1875,7 @@ def step_fn(step, lr_scale):
18391875
elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0)
18401876
frac = training_frac(step, elapsed_ms)
18411877
scale = lr_mul(frac)
1842-
train_loss = step_fn(step, scale)
1878+
train_loss, weighted_loss = step_fn(step, scale)
18431879

18441880
with torch.no_grad():
18451881
for name, t in base_model.state_dict().items():
@@ -1855,8 +1891,9 @@ def step_fn(step, lr_scale):
18551891
if should_log_train:
18561892
tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0)
18571893
log(
1858-
f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} "
1859-
f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}"
1894+
f"{step}/{h.iterations} train_loss:{train_loss.item():.4f} "
1895+
f"weighted_loss:{weighted_loss.item():.4f} "
1896+
f"train_time:{approx_training_time_ms / 60000:.1f}m tok/s:{tok_per_sec:.0f}"
18601897
)
18611898

18621899
reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms

0 commit comments

Comments
 (0)