43
43
# Suppress all warnings
44
44
warnings .filterwarnings ("ignore" )
45
45
46
+ def get_device_map_for_llama_70B (dev0 , dev1 , dev2 , dev3 , dev4 , dev5 ): # total_num_layers, num_stages
47
+ device_map = {
48
+ 'model.embed_tokens' : dev0 ,
49
+ 'lm_head' : dev5 ,
50
+ 'model.norm' : dev5 ,
51
+ 'model.rotary_emb' : dev5
52
+ }
53
+ for i in range (80 ):
54
+ if i < 14 :
55
+ device_map [f"model.layers.{ i } " ] = dev0
56
+ elif i < 28 :
57
+ device_map [f"model.layers.{ i } " ] = dev1
58
+ elif i < 42 :
59
+ device_map [f"model.layers.{ i } " ] = dev2
60
+ elif i < 56 :
61
+ device_map [f"model.layers.{ i } " ] = dev3
62
+ elif i < 70 :
63
+ device_map [f"model.layers.{ i } " ] = dev4
64
+ else :
65
+ device_map [f"model.layers.{ i } " ] = dev5
66
+ return device_map
67
+
68
+
69
+ def setup_distributed_training ():
70
+ torch_device = torch .device ("qaic" )
71
+ assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
72
+ assert torch_device .index is None , f"DDP requires only device type, got: { torch_device } "
73
+ dist .init_process_group (backend = "qccl" )
74
+ # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
75
+ #getattr(torch, torch_device.type).set_device(dist.get_rank()*2)
76
+
46
77
47
78
def main (** kwargs ):
48
79
"""
@@ -57,19 +88,6 @@ def main(**kwargs):
57
88
train_config = TRAIN_CONFIG ()
58
89
update_config (train_config , ** kwargs )
59
90
dataset_config = generate_dataset_config (train_config , kwargs )
60
- device = train_config .device
61
-
62
- # dist init
63
- if train_config .enable_ddp :
64
- # TODO: may have to init qccl backend, next try run with torchrun command
65
- torch_device = torch .device (device )
66
- assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
67
- assert torch_device .index is None , (
68
- f"DDP requires specification of device type only, however provided device index as well: { torch_device } "
69
- )
70
- dist .init_process_group (backend = train_config .dist_backend )
71
- # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
72
- getattr (torch , torch_device .type ).set_device (dist .get_rank ())
73
91
74
92
# Set the seeds for reproducibility
75
93
torch .manual_seed (train_config .seed )
@@ -97,12 +115,17 @@ def main(**kwargs):
97
115
if param .requires_grad :
98
116
param .data = param .data .to (torch .float32 )
99
117
else :
118
+ rank = dist .get_rank ()
119
+
120
+ device_map = get_device_map_for_llama_70B (rank * 6 , rank * 6 + 1 , rank * 6 + 2 , rank * 6 + 3 , rank * 6 + 4 , rank * 6 + 5 )
100
121
model = AutoModelForCausalLM .from_pretrained (
101
122
pretrained_model_path ,
102
123
use_cache = False ,
103
124
attn_implementation = "sdpa" ,
104
125
torch_dtype = torch .float16 ,
126
+ device_map = device_map ,
105
127
)
128
+ print (model .hf_device_map )
106
129
107
130
# Load the tokenizer and add special tokens
108
131
tokenizer = AutoTokenizer .from_pretrained (
@@ -213,7 +236,7 @@ def main(**kwargs):
213
236
f"passed context length is { train_config .context_length } and overall model's context length is "
214
237
f"{ model .config .max_position_embeddings } "
215
238
)
216
- model .to (train_config .device )
239
+ # model.to(train_config.device)
217
240
optimizer = optim .AdamW (
218
241
model .parameters (),
219
242
lr = train_config .lr ,
@@ -223,7 +246,7 @@ def main(**kwargs):
223
246
224
247
# wrap model with DDP
225
248
if train_config .enable_ddp :
226
- model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
249
+ model = nn .parallel .DistributedDataParallel (model ) # , device_ids=[dist.get_rank()])
227
250
228
251
_ = train (
229
252
model ,
@@ -245,4 +268,5 @@ def main(**kwargs):
245
268
246
269
247
270
if __name__ == "__main__" :
271
+ setup_distributed_training ()
248
272
fire .Fire (main )
0 commit comments