-
Notifications
You must be signed in to change notification settings - Fork 43
[QEff Finetune]: Enable PP+DDP #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
import math | ||
import random | ||
import warnings | ||
from typing import Any, Dict, Optional, Union | ||
|
@@ -18,7 +19,7 @@ | |
import torch.utils.data | ||
from peft import PeftModel, get_peft_model | ||
from torch.optim.lr_scheduler import StepLR | ||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer | ||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer | ||
|
||
from QEfficient.finetune.configs.training import TrainConfig | ||
from QEfficient.finetune.utils.config_utils import ( | ||
|
@@ -32,7 +33,7 @@ | |
get_preprocessed_dataset, | ||
) | ||
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train | ||
from QEfficient.utils._utils import login_and_download_hf_lm | ||
from QEfficient.utils._utils import get_num_layers_from_config, login_and_download_hf_lm | ||
|
||
# Try importing QAIC-specific module, proceed without it if unavailable | ||
try: | ||
|
@@ -41,12 +42,37 @@ | |
print(f"Warning: {e}. Proceeding without QAIC modules.") | ||
|
||
|
||
from transformers import AutoModelForSequenceClassification | ||
|
||
# Suppress all warnings | ||
warnings.filterwarnings("ignore") | ||
|
||
|
||
def get_device_map(rank, num_pp_stages, num_layers): | ||
"""Returns device map for model layers and given process rank based on number of pipeline stages. | ||
|
||
Args: | ||
rank (int): process rank | ||
num_pp_stages (int): number of stages in pipeline | ||
num_layers (int): total number of layers in the models | ||
|
||
Returns: | ||
Dict: A dictionary of layers and corresponding device id. | ||
|
||
Notes: | ||
- This device map structure is verified for llama models only. | ||
""" | ||
device_map = { | ||
"model.embed_tokens": rank * num_pp_stages, | ||
"lm_head": rank * num_pp_stages, | ||
"model.norm": rank * num_pp_stages + (num_pp_stages - 1), | ||
"model.rotary_emb": rank * num_pp_stages + (num_pp_stages - 1), | ||
} | ||
n_layer_per_stage = math.ceil(num_layers / num_pp_stages) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Use np.ceil so that no new module will be imported. |
||
for j in range(num_pp_stages): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add some strong documentation for this double for loop. It is difficult to understand without taking a case. Better add some example and explain with it. |
||
for i in range(n_layer_per_stage * j, min(n_layer_per_stage * (j + 1), num_layers)): | ||
device_map[f"model.layers.{i}"] = rank * num_pp_stages + j | ||
return device_map | ||
|
||
|
||
def setup_distributed_training(train_config: TrainConfig) -> None: | ||
"""Initialize distributed training environment if enabled. | ||
|
||
|
@@ -69,8 +95,13 @@ def setup_distributed_training(train_config: TrainConfig) -> None: | |
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}" | ||
|
||
dist.init_process_group(backend=train_config.dist_backend) | ||
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank | ||
getattr(torch, torch_device.type).set_device(dist.get_rank()) | ||
if train_config.enable_pp: | ||
assert dist.get_world_size() % train_config.num_pp_stages == 0, ( | ||
"total available devices should be multiple of number of pipeline stages" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Total instead of total There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can we intimate the user that This might be helpful to make our system idiot proof. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, we need another assert condition. |
||
) | ||
else: | ||
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank | ||
getattr(torch, torch_device.type).set_device(dist.get_rank()) | ||
|
||
|
||
def setup_seeds(seed: int) -> None: | ||
|
@@ -128,12 +159,29 @@ def load_model_and_tokenizer( | |
if param.requires_grad: | ||
param.data = param.data.to(torch.float32) | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
pretrained_model_path, | ||
use_cache=False, | ||
attn_implementation="sdpa", | ||
torch_dtype=torch.float16, | ||
) | ||
if train_config.enable_pp: | ||
if train_config.enable_ddp: | ||
rank = dist.get_rank() | ||
model_config = AutoConfig.from_pretrained(train_config.model_name) | ||
num_layers = get_num_layers_from_config(model_config) | ||
device_map = get_device_map(rank, train_config.num_pp_stages, num_layers) | ||
else: | ||
device_map = "auto" | ||
model = AutoModelForCausalLM.from_pretrained( | ||
pretrained_model_path, | ||
use_cache=False, | ||
attn_implementation="sdpa", | ||
torch_dtype=torch.float16, | ||
device_map=device_map, | ||
) | ||
print(model.hf_device_map) | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
pretrained_model_path, | ||
use_cache=False, | ||
attn_implementation="sdpa", | ||
torch_dtype=torch.float16, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name | ||
|
@@ -332,12 +380,17 @@ def main(peft_config_file: str = None, **kwargs) -> None: | |
f"passed context length is {train_config.context_length} and overall model's context length is " | ||
f"{model.config.max_position_embeddings}" | ||
) | ||
|
||
model.to(train_config.device) | ||
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay) | ||
if not train_config.enable_pp: | ||
model.to(train_config.device) | ||
optimizer = optim.AdamW( | ||
model.parameters(), | ||
lr=train_config.lr, | ||
weight_decay=train_config.weight_decay, | ||
) | ||
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) | ||
if train_config.enable_ddp: | ||
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()]) | ||
model = nn.parallel.DistributedDataParallel(model) # , device_ids=[dist.get_rank()]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we removed device_ids in case of ddp? Because we are using device_map now? |
||
|
||
results = train( | ||
model, | ||
tokenizer, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,8 @@ class TrainConfig: | |
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler | ||
|
||
# dist-related | ||
enable_pp: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this support is only added for decoder kind of model. So this needs to be properly documented. May be we can share some numerical data as well. E.g. If user's model is more than lets say 8B then user may need 4 pp stages. If it is more than 30B, user may need 16 pp stage. Like that. |
||
num_pp_stages: int = 1 | ||
enable_ddp: bool = False | ||
dist_backend: str = "cpu:gloo,qaic:qccl,cuda:gloo" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some explanation why these particular layers are mapped to a particular device.
L64 to L67