-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
304 lines (266 loc) · 12.7 KB
/
train.py
File metadata and controls
304 lines (266 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""
RegisterGPT — PyTorch/CUDA training for all model versions.
Compatible with torchrun for multi-GPU training.
"""
from __future__ import annotations
import io
import json
import os
import random
import time
import zlib
from datetime import datetime
from pathlib import Path
import numpy as np
import sentencepiece as spm
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from core.config import Hyperparameters
from core.data import DistributedTokenLoader, load_validation_tokens
from core.eval import build_sentencepiece_luts, eval_val
from core.quantize import (
CONTROL_TENSOR_NAME_PATTERNS,
dequantize_state_dict_int8,
quantize_state_dict_int8,
)
from core.registry import get_registry, build_model
def main():
code = Path(__file__).read_text(encoding="utf-8")
args = Hyperparameters()
distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
grad_accum_steps = args.grad_accum_steps
grad_scale = 1.0 / grad_accum_steps
if not torch.cuda.is_available():
raise RuntimeError("CUDA required")
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
master = rank == 0
if distributed:
os.environ["NCCL_P2P_DISABLE"] = args.nccl_p2p_disable
if master:
print(f"[init] NCCL_P2P_DISABLE={args.nccl_p2p_disable}, calling init_process_group")
dist.init_process_group(backend="nccl")
if master:
print("[init] init_process_group done, waiting on barrier")
dist.barrier()
if master:
print("[init] DDP barrier passed")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logfile = None
if master:
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{args.run_id}.txt"
print(logfile)
def log0(msg, console=True):
if not master: return
ts = datetime.now().strftime("%H:%M:%S")
line = f"[{ts}] {msg}"
if console: print(line)
if logfile:
with open(logfile, "a") as f: print(line, file=f)
log0(code, console=False)
log0(f"[config]\n{json.dumps(args.to_dict(), indent=2, default=str)}", console=False)
log0(f"[init] rank={rank} world_size={world_size} device={device} model_version={args.model_version}")
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
log0(f"[init] loading tokenizer from {args.tokenizer_path}")
sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
if int(sp.vocab_size()) != args.vocab_size:
raise ValueError(f"VOCAB_SIZE mismatch: tokenizer has {sp.vocab_size()}, expected {args.vocab_size}")
log0(f"[init] loading validation data from {args.val_files}")
val_tokens = load_validation_tokens(args.val_files, args.train_seq_len)
bbl, hsl, ibl = build_sentencepiece_luts(sp, args.vocab_size, device)
log0(f"[init] data loaded: {val_tokens.numel()} val tokens")
base_model = build_model(args.model_version, args).to(device).bfloat16()
# Keep small control params in fp32
with torch.no_grad():
for name, p in base_model.named_parameters():
if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) and p.dtype != torch.float32:
p.data = p.data.float()
n_params = sum(p.numel() for p in base_model.parameters())
log0(f"[init] model built: {n_params/1e3:.0f}K params on {device}")
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.torch_compile else base_model
if distributed:
dist.barrier()
log0(f"[init] wrapping with DDP (distributed={distributed})")
model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
log0(f"[init] DDP ready")
optimizer = torch.optim.Adam(
base_model.parameters(),
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.adam_eps,
weight_decay=args.weight_decay,
fused=True,
)
for g in optimizer.param_groups:
g["base_lr"] = args.lr
n_params = sum(p.numel() for p in base_model.parameters())
n_trainable = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
raw_bytes = sum(p.numel() * p.element_size() for p in base_model.parameters())
est_int8_bytes = sum(p.numel() for p in base_model.parameters()) # ~1 byte per param
log0(f"run_id:{args.run_id}")
model_cls = get_registry()[args.model_version]
log0(f"model_version:{args.model_version}")
log0(f"model_params:{n_params} trainable:{n_trainable} vocab=dim={args.vocab_size} raw:{raw_bytes/1e6:.1f}MB est_int8:{est_int8_bytes/1e6:.1f}MB")
log0(f"architecture:{model_cls.architecture}")
log0(f"cross_position:{model_cls.cross_position} within_position:{model_cls.within_position}")
log0(f"config:{model_cls.build_kwargs(args)}")
log0(f"activation:{args.activation} lr:{args.lr} grad_clip:{args.grad_clip_norm} decay_init:{args.decay_init}")
log0(f"world_size:{world_size} grad_accum:{grad_accum_steps} batch:{args.train_batch_tokens}")
log0(f"[init] loading training data from {args.train_files}")
train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds else None
def lr_mul(step, elapsed_ms):
if args.warmdown_iters <= 0:
return 1.0
if max_wc_ms is None:
ws = max(args.iterations - args.warmdown_iters, 0)
return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if ws <= step < args.iterations else 1.0
sms = elapsed_ms / max(step, 1)
wms = args.warmdown_iters * sms
rms = max(max_wc_ms - elapsed_ms, 0.0)
return rms / max(wms, 1e-9) if rms <= wms else 1.0
# Resume from checkpoint if provided
start_step = 0
if args.resume_from and Path(args.resume_from).exists():
ckpt = torch.load(args.resume_from, map_location=device, weights_only=False)
if "model" in ckpt and "step" in ckpt:
base_model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"]
log0(f"resumed checkpoint from {args.resume_from} at step {start_step}")
else:
base_model.load_state_dict(ckpt, strict=False)
log0(f"loaded weights from {args.resume_from} (no optimizer/step, restarting from 0)")
def save_checkpoint(step):
if not master:
return
ckpt_path = f"logs/{args.run_id}_ckpt.pt"
torch.save({
"model": base_model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step,
}, ckpt_path)
log0(f"checkpoint:{ckpt_path} step:{step}")
# Warmup (skip if resuming)
log0(f"[init] starting warmup ({args.warmup_steps} steps)")
if start_step == 0:
for ws in range(args.warmup_steps):
optimizer.zero_grad(set_to_none=True)
for _ in range(grad_accum_steps):
x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(x, y)
(loss * grad_scale).backward()
if args.grad_clip_norm > 0:
torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm)
optimizer.step()
torch.cuda.synchronize()
if master and (ws + 1) % 10 == 0 or ws + 1 == args.warmup_steps:
log0(f"warmup:{ws + 1}/{args.warmup_steps}")
train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
# Training
train_ms = 0.0
stop_after = None
t0 = time.perf_counter()
step = start_step
while True:
last = step == args.iterations or (stop_after is not None and step >= stop_after)
if last or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
train_ms += 1000.0 * (time.perf_counter() - t0)
vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, bbl, hsl, ibl)
log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} train_time:{train_ms:.0f}ms")
t0 = time.perf_counter()
if last:
if stop_after is not None and step < args.iterations:
log0(f"stopping_early: step:{step}/{args.iterations}")
break
lm = lr_mul(step, train_ms + 1000.0 * (time.perf_counter() - t0))
for g in optimizer.param_groups:
g["lr"] = g["base_lr"] * lm
step_t0 = time.perf_counter()
optimizer.zero_grad(set_to_none=True)
train_loss_accum = 0.0
for _ in range(grad_accum_steps):
x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(x, y)
(loss * grad_scale).backward()
train_loss_accum += loss.item() * grad_scale
if args.grad_clip_norm > 0:
torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm)
optimizer.step()
torch.cuda.synchronize()
step_ms = 1000.0 * (time.perf_counter() - step_t0)
approx_ms = train_ms + 1000.0 * (time.perf_counter() - t0)
step += 1
if master and args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0):
log0(f"step:{step}/{args.iterations} train_loss:{train_loss_accum:.4f} time:{approx_ms:.0f}ms avg:{approx_ms / step:.1f}ms tok/s:{args.train_batch_tokens / (step_ms / 1000):.0f}")
if max_wc_ms and stop_after is None and approx_ms >= max_wc_ms:
stop_after = step
if args.checkpoint_every > 0 and step % args.checkpoint_every == 0:
save_checkpoint(step)
# Final checkpoint
save_checkpoint(step)
# Serialize
if master:
sd = {k: v for k, v in base_model.state_dict().items()}
out = f"logs/{args.run_id}_model.pt"
torch.save(sd, out)
log0(f"saved:{out} bytes:{Path(out).stat().st_size}")
qobj, qstats = quantize_state_dict_int8(sd)
buf = io.BytesIO()
torch.save(qobj, buf)
compressed = zlib.compress(buf.getvalue(), 9)
qpath = f"logs/{args.run_id}_model.int8.ptz"
Path(qpath).write_bytes(compressed)
log0(f"quantized:{qpath} bytes:{len(compressed)} ratio:{qstats['baseline_tensor_bytes'] / max(qstats['int8_payload_bytes'], 1):.2f}x")
if args.roundtrip_eval:
dq = dequantize_state_dict_int8(torch.load(io.BytesIO(zlib.decompress(Path(qpath).read_bytes())), weights_only=False))
base_model.load_state_dict(dq, strict=False)
qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, bbl, hsl, ibl)
log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f}")
# Write manifest
manifest = {
"run_id": args.run_id,
"model_version": args.model_version,
"params": n_params,
"trainable": n_trainable,
"raw_bytes": raw_bytes,
"est_int8_bytes": est_int8_bytes,
"quantized_bytes": len(compressed),
"vocab_size": args.vocab_size,
"num_steps": args.num_steps,
"n_channels": getattr(args, "n_channels", None),
"n_fourier_basis": getattr(args, "n_fourier_basis", None),
"state_dim": getattr(args, "state_dim", None),
"inner_dim": getattr(args, "inner_dim", None),
"n_ops": getattr(args, "n_ops", None),
"lr": args.lr,
"activation": args.activation,
"decay_init": args.decay_init,
"steps_trained": step,
"final_train_loss": train_loss_accum if step > start_step else None,
"val_loss": vl,
"val_bpb": vbpb,
"train_time_ms": train_ms,
"world_size": world_size,
"grad_accum_steps": grad_accum_steps,
"batch_tokens": args.train_batch_tokens,
"model_path": out,
"quantized_path": qpath,
}
manifest_path = f"logs/{args.run_id}_manifest.json"
with open(manifest_path, "w") as f:
json.dump(manifest, f, indent=2)
log0(f"manifest:{manifest_path}")
if __name__ == "__main__":
main()