Skip to content

Commit b6ff53e

Browse files
refactor deepspeed and lightning scripts
1 parent cb4cd4a commit b6ff53e

File tree

3 files changed

+216
-106
lines changed

3 files changed

+216
-106
lines changed

docs/source/examples/scripts/accelerate_train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# dependencies = [
44
# "accelerate",
55
# "datasets",
6-
# "tensorboard",
76
# "torch",
87
# "torchrunx",
98
# "transformers",
Lines changed: 104 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,140 @@
1+
# /// script
2+
# requires-python = ">=3.12"
3+
# dependencies = [
4+
# "deepspeed",
5+
# "datasets",
6+
# "tensorboard",
7+
# "torch",
8+
# "torchrunx",
9+
# "transformers",
10+
# "tyro",
11+
# ]
12+
# ///
13+
14+
import argparse
15+
import functools
16+
import os
117
from dataclasses import dataclass
2-
from pathlib import Path
18+
from typing import Annotated
319

420
import deepspeed
21+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
522
import torch
623

724
from datasets import load_dataset
8-
from torch import nn
925
from torch.utils.data import Dataset
10-
from transformers import AutoModelForCausalLM, AutoTokenizer
26+
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer, AutoConfig
1127

1228
import torchrunx
29+
import tyro
1330

1431

15-
class GPT2CausalLMDataset(Dataset):
16-
def __init__(self, text_dataset):
17-
self.dataset = text_dataset
18-
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
19-
self.tokenizer.pad_token = self.tokenizer.eos_token
20-
self.max_length = 1024
21-
22-
def __len__(self):
23-
return len(self.dataset)
24-
25-
def __getitem__(self, idx):
26-
encoded = self.tokenizer(
27-
self.dataset[idx]["text"],
28-
max_length=self.max_length,
29-
truncation=True,
30-
padding="max_length",
31-
return_tensors="pt",
32-
)
32+
@dataclass
33+
class ModelConfig:
34+
name: str
3335

34-
input_ids = encoded.input_ids.squeeze()
35-
attention_mask = encoded.attention_mask.squeeze()
36-
labels = input_ids.clone()
3736

38-
return {
39-
"input_ids": input_ids,
40-
"attention_mask": attention_mask,
41-
"labels": labels,
42-
}
37+
@dataclass
38+
class DatasetConfig:
39+
path: str
40+
name: str | None = None
41+
split: str | None = None
42+
text_column: str = "text"
43+
num_samples: int | None = None
4344

4445

4546
@dataclass
46-
class DSPArgs:
47+
class DeepSpeedArgs:
4748
deepspeed_config: str
48-
# train_batch_size: int
49-
# batch_size: int
49+
local_rank: int | None = None
50+
51+
52+
def load_training_data(
53+
tokenizer_name: str,
54+
dataset_config: DatasetConfig,
55+
) -> Dataset:
56+
# Load dataset
57+
58+
dataset = load_dataset(dataset_config.path, name=dataset_config.name, split=dataset_config.split)
59+
if dataset_config.num_samples is not None:
60+
dataset = dataset.select(range(dataset_config.num_samples))
61+
62+
# Build tokenizer
63+
64+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # to suppress warnings
65+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
66+
if tokenizer.pad_token is None:
67+
tokenizer.pad_token = tokenizer.eos_token
68+
tokenize_fn = functools.partial(
69+
tokenizer,
70+
max_length=tokenizer.model_max_length,
71+
truncation=True,
72+
padding="max_length",
73+
)
74+
75+
# Tokenize dataset
5076

77+
return dataset.map(
78+
tokenize_fn,
79+
batched=True,
80+
input_columns=[dataset_config.text_column],
81+
remove_columns=[dataset_config.text_column],
82+
).map(lambda x: {"labels": x["input_ids"]})
5183

52-
def train():
53-
model = AutoModelForCausalLM.from_pretrained("gpt2")
54-
# optimizer = torch.optim.Adam(model.parameters())
55-
wikitext_train = load_dataset("Salesforce/wikitext", "wikitext-2-v1", split="train")
56-
train_dataset = GPT2CausalLMDataset(text_dataset=wikitext_train)
5784

58-
loader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
85+
def train(
86+
model: PreTrainedModel,
87+
train_dataset: Dataset,
88+
deepspeed_args: DeepSpeedArgs
89+
) -> str:
5990

60-
model_engine, optimizer, _, _ = deepspeed.initialize(
61-
args=DSPArgs(deepspeed_config="dsp_config.json"),
91+
deepspeed_args.local_rank = int(os.environ["LOCAL_RANK"])
92+
93+
model_engine, _, loader, _ = deepspeed.initialize(
94+
args=deepspeed_args,
6295
model=model,
6396
model_parameters=model.parameters(),
97+
training_data=train_dataset
6498
)
6599

66-
model.train()
100+
model_engine.train()
67101
for batch_idx, batch in enumerate(loader):
68102
if batch_idx == 10:
69103
break
70-
print(f"Step {batch_idx}")
71-
72-
device_batch = {k: v.to(model.device) for k, v in batch.items()}
73-
74-
model.zero_grad()
104+
device_batch = {k: torch.stack(v, dim=0).to(model_engine.device) for k, v in batch.items()}
105+
model_engine.zero_grad()
75106

76107
loss = model_engine(**device_batch).loss
108+
print(f"Step {batch_idx}, loss: {loss.item()}", flush=True, end="")
77109
model_engine.backward(loss)
78110

79111
model_engine.step()
80112

113+
checkpoint_dir = "output"
114+
model_engine.save_checkpoint(checkpoint_dir)
81115

82-
if __name__ == "__main__":
83-
Path("output").mkdir(exist_ok=True)
84-
results = torchrunx.launch(
85-
func=train,
86-
hostnames=["localhost"],
87-
workers_per_host=1,
116+
return checkpoint_dir
117+
118+
def main(
119+
launcher: torchrunx.Launcher,
120+
model_config: Annotated[ModelConfig, tyro.conf.arg(name="model")],
121+
dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
122+
deepspeed_args: Annotated[DeepSpeedArgs, tyro.conf.arg(name="deepspeed")]
123+
):
124+
model = AutoModelForCausalLM.from_pretrained(model_config.name)
125+
train_dataset = load_training_data(tokenizer_name=model_config.name, dataset_config=dataset_config)
126+
127+
# Launch training
128+
results = launcher.run(train, (model, train_dataset, deepspeed_args))
129+
130+
# Loading trained model from checkpoint
131+
checkpoint_path = results.rank(0)
132+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_path)
133+
trained_model = AutoModelForCausalLM.from_config(
134+
AutoConfig.from_pretrained(model_config.name)
88135
)
136+
trained_model.load_state_dict(state_dict)
137+
138+
139+
if __name__ == "__main__":
140+
tyro.cli(main)

0 commit comments

Comments
 (0)