Skip to content

Commit c213173

Browse files
committed
PP+DDP for 70B
Signed-off-by: Mamta Singh <[email protected]>
1 parent 504a850 commit c213173

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

QEfficient/cloud/finetune.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,37 @@
4343
# Suppress all warnings
4444
warnings.filterwarnings("ignore")
4545

46+
def get_device_map_for_llama_70B(dev0, dev1, dev2, dev3, dev4, dev5): # total_num_layers, num_stages
47+
device_map = {
48+
'model.embed_tokens': dev0,
49+
'lm_head': dev5,
50+
'model.norm': dev5,
51+
'model.rotary_emb': dev5
52+
}
53+
for i in range(80):
54+
if i < 14:
55+
device_map[f"model.layers.{i}"] = dev0
56+
elif i < 28:
57+
device_map[f"model.layers.{i}"] = dev1
58+
elif i < 42:
59+
device_map[f"model.layers.{i}"] = dev2
60+
elif i < 56:
61+
device_map[f"model.layers.{i}"] = dev3
62+
elif i < 70:
63+
device_map[f"model.layers.{i}"] = dev4
64+
else:
65+
device_map[f"model.layers.{i}"] = dev5
66+
return device_map
67+
68+
69+
def setup_distributed_training():
70+
torch_device = torch.device("qaic")
71+
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
72+
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
73+
dist.init_process_group(backend="qccl")
74+
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
75+
#getattr(torch, torch_device.type).set_device(dist.get_rank()*2)
76+
4677

4778
def main(**kwargs):
4879
"""
@@ -57,19 +88,6 @@ def main(**kwargs):
5788
train_config = TRAIN_CONFIG()
5889
update_config(train_config, **kwargs)
5990
dataset_config = generate_dataset_config(train_config, kwargs)
60-
device = train_config.device
61-
62-
# dist init
63-
if train_config.enable_ddp:
64-
# TODO: may have to init qccl backend, next try run with torchrun command
65-
torch_device = torch.device(device)
66-
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
67-
assert torch_device.index is None, (
68-
f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
69-
)
70-
dist.init_process_group(backend=train_config.dist_backend)
71-
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
72-
getattr(torch, torch_device.type).set_device(dist.get_rank())
7391

7492
# Set the seeds for reproducibility
7593
torch.manual_seed(train_config.seed)
@@ -97,12 +115,17 @@ def main(**kwargs):
97115
if param.requires_grad:
98116
param.data = param.data.to(torch.float32)
99117
else:
118+
rank = dist.get_rank()
119+
120+
device_map = get_device_map_for_llama_70B(rank*6, rank*6+1, rank*6+2, rank*6+3, rank*6+4, rank*6+5)
100121
model = AutoModelForCausalLM.from_pretrained(
101122
pretrained_model_path,
102123
use_cache=False,
103124
attn_implementation="sdpa",
104125
torch_dtype=torch.float16,
126+
device_map=device_map,
105127
)
128+
print(model.hf_device_map)
106129

107130
# Load the tokenizer and add special tokens
108131
tokenizer = AutoTokenizer.from_pretrained(
@@ -213,7 +236,7 @@ def main(**kwargs):
213236
f"passed context length is {train_config.context_length} and overall model's context length is "
214237
f"{model.config.max_position_embeddings}"
215238
)
216-
model.to(train_config.device)
239+
#model.to(train_config.device)
217240
optimizer = optim.AdamW(
218241
model.parameters(),
219242
lr=train_config.lr,
@@ -223,7 +246,7 @@ def main(**kwargs):
223246

224247
# wrap model with DDP
225248
if train_config.enable_ddp:
226-
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
249+
model = nn.parallel.DistributedDataParallel(model)#, device_ids=[dist.get_rank()])
227250

228251
_ = train(
229252
model,
@@ -245,4 +268,5 @@ def main(**kwargs):
245268

246269

247270
if __name__ == "__main__":
271+
setup_distributed_training()
248272
fire.Fire(main)

0 commit comments

Comments
 (0)