-
Notifications
You must be signed in to change notification settings - Fork 43
[QEff Finetune]: Enable PP+DDP #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Mamta Singh <[email protected]>
e8b1da7
to
df36ae1
Compare
3ca1229
to
53ff3c4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work, Mamta! Please address the comments. Let us discuss offline if anything is confusing.
QEfficient/cloud/finetune.py
Outdated
|
||
model.to(train_config.device) | ||
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay) | ||
# model.to(train_config.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented line will be required for non-(PP +DDP) use case?
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) | ||
if train_config.enable_ddp: | ||
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()]) | ||
model = nn.parallel.DistributedDataParallel(model) # , device_ids=[dist.get_rank()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we removed device_ids in case of ddp? Because we are using device_map now?
QEfficient/cloud/finetune.py
Outdated
) | ||
print(model.hf_device_map) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is just for debugging, please remove it. If we actually want to show the user that which part of model is distributed to which device then add some DEBUG logs about the split. That would be helpful for the user to debug easily. That will make our tool anti-black box. :)
@@ -99,6 +99,8 @@ class TrainConfig: | |||
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler | |||
|
|||
# dist-related | |||
enable_pp: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this support is only added for decoder kind of model. So this needs to be properly documented. May be we can share some numerical data as well. E.g. If user's model is more than lets say 8B then user may need 4 pp stages. If it is more than 30B, user may need 16 pp stage. Like that.
getattr(torch, torch_device.type).set_device(dist.get_rank()) | ||
if train_config.enable_pp: | ||
assert dist.get_world_size() % train_config.num_pp_stages == 0, ( | ||
"total available devices should be multiple of number of pipeline stages" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Total instead of total
full stop at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can we intimate the user that
if dist.get_world_size() // train_config.num_pp_stage == 1, this will be only pure PP.
if dist.get_world_size() // train_config.num_pp_stage > 1, this will be actually PP+DDP.
This might be helpful to make our system idiot proof.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we need another assert condition.
assert dist.get_world_size() * train_config.num_pp_stage == total_available_devices
- This device map structure is verified for llama models only. | ||
""" | ||
device_map = { | ||
"model.embed_tokens": rank * num_pp_stages, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some explanation why these particular layers are mapped to a particular device.
L64 to L67
"model.rotary_emb": rank * num_pp_stages + (num_pp_stages - 1), | ||
} | ||
n_layer_per_stage = math.ceil(num_layers / num_pp_stages) | ||
for j in range(num_pp_stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some strong documentation for this double for loop. It is difficult to understand without taking a case. Better add some example and explain with it.
QEfficient/cloud/finetune.py
Outdated
num_layers = get_num_layers_from_config(model_config) | ||
device_map = get_device_map(rank, train_config.num_pp_stages, num_layers) | ||
else: | ||
device_map = "auto" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does auto works well for only DDP use case and only single device use case?
"model.norm": rank * num_pp_stages + (num_pp_stages - 1), | ||
"model.rotary_emb": rank * num_pp_stages + (num_pp_stages - 1), | ||
} | ||
n_layer_per_stage = math.ceil(num_layers / num_pp_stages) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Use np.ceil so that no new module will be imported.
Signed-off-by: Mamta Singh <[email protected]>
Added support for PP and DDP
Command for PP only : QAIC_VISIBLE_DEVICES=0,1,2,3 python -m QEfficient.cloud.finetune --device qaic --enable_pp --dist_backend qccl (number of pipeline stages will be equal to visible devices)
Command for DDP only : QAIC_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 -m QEfficient.cloud.finetune --device qaic --enable_ddp --dist_backend qccl
Command for PP+DDP : For 4 qaic devices(1 Ultra) with 2 pipeline stages
QAIC_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node 2 -m QEfficient.cloud.finetune --device qaic --enable_ddp --enable_pp --num_pp_stages 2 --dist_backend qccl