Skip to content

Commit df36ae1

Browse files
committed
Merge branch 'quic:main' into pp_ddp
2 parents c213173 + 38989e9 commit df36ae1

27 files changed

+1617
-147
lines changed

.github/CODEOWNERS

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77

88
# Default owners
99
# review when someone opens a pull request and assign appropriate reviewer
10-
* @quic-rishinr @ochougul @quic-hemagnih
10+
* @quic-rishinr @ochougul @quic-hemagnih @quic-amitraj
1111
pyproject.toml @carlstreeter-quic
1212

QEfficient/cloud/finetune.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import math
89
import random
910
import warnings
1011

@@ -30,49 +31,44 @@
3031
get_preprocessed_dataset,
3132
)
3233
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
33-
from QEfficient.utils._utils import login_and_download_hf_lm
34+
from QEfficient.utils._utils import get_num_layers_from_config, login_and_download_hf_lm
3435

3536
try:
3637
import torch_qaic # noqa: F401
3738
except ImportError as e:
3839
print(f"Warning: {e}. Moving ahead without these qaic modules.")
3940

4041

41-
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
42+
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
4243

4344
# Suppress all warnings
4445
warnings.filterwarnings("ignore")
4546

46-
def get_device_map_for_llama_70B(dev0, dev1, dev2, dev3, dev4, dev5): # total_num_layers, num_stages
47+
48+
def get_device_map(rank, num_pp_stages, num_layers):
4749
device_map = {
48-
'model.embed_tokens': dev0,
49-
'lm_head': dev5,
50-
'model.norm': dev5,
51-
'model.rotary_emb': dev5
50+
"model.embed_tokens": rank * num_pp_stages,
51+
"lm_head": rank * num_pp_stages,
52+
"model.norm": rank * num_pp_stages + (num_pp_stages - 1),
53+
"model.rotary_emb": rank * num_pp_stages + (num_pp_stages - 1),
5254
}
53-
for i in range(80):
54-
if i < 14:
55-
device_map[f"model.layers.{i}"] = dev0
56-
elif i < 28:
57-
device_map[f"model.layers.{i}"] = dev1
58-
elif i < 42:
59-
device_map[f"model.layers.{i}"] = dev2
60-
elif i < 56:
61-
device_map[f"model.layers.{i}"] = dev3
62-
elif i < 70:
63-
device_map[f"model.layers.{i}"] = dev4
64-
else:
65-
device_map[f"model.layers.{i}"] = dev5
55+
n_layer_per_stage = math.ceil(num_layers / num_pp_stages) # number of layers per device 80/6 = 13.3 ~ 14
56+
for j in range(num_pp_stages):
57+
for i in range(n_layer_per_stage * j, n_layer_per_stage * (j + 1)):
58+
if i < num_layers:
59+
device_map[f"model.layers.{i}"] = rank * num_pp_stages + j
60+
6661
return device_map
6762

6863

69-
def setup_distributed_training():
70-
torch_device = torch.device("qaic")
64+
def setup_distributed_training(train_config):
65+
torch_device = torch.device(train_config.device)
7166
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
7267
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
73-
dist.init_process_group(backend="qccl")
74-
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
75-
#getattr(torch, torch_device.type).set_device(dist.get_rank()*2)
68+
dist.init_process_group(backend=train_config.dist_backend)
69+
if not train_config.enable_pp:
70+
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
71+
getattr(torch, torch_device.type).set_device(dist.get_rank())
7672

7773

7874
def main(**kwargs):
@@ -87,6 +83,13 @@ def main(**kwargs):
8783
# update the configuration for the training process
8884
train_config = TRAIN_CONFIG()
8985
update_config(train_config, **kwargs)
86+
87+
if train_config.enable_ddp or train_config.enable_pp:
88+
setup_distributed_training(train_config)
89+
if train_config.enable_pp:
90+
assert dist.get_world_size() % train_config.num_pp_stages == 0, (
91+
"total available devices should be multiple of number of pipeline stages"
92+
)
9093
dataset_config = generate_dataset_config(train_config, kwargs)
9194

9295
# Set the seeds for reproducibility
@@ -95,7 +98,7 @@ def main(**kwargs):
9598
np.random.seed(train_config.seed)
9699

97100
# Load the pre-trained model and setup its configuration
98-
# config = AutoConfig.from_pretrained(train_config.model_name)
101+
model_config = AutoConfig.from_pretrained(train_config.model_name)
99102
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
100103
if train_config.task_type == "seq_classification":
101104
model = AutoModelForSequenceClassification.from_pretrained(
@@ -115,9 +118,12 @@ def main(**kwargs):
115118
if param.requires_grad:
116119
param.data = param.data.to(torch.float32)
117120
else:
118-
rank = dist.get_rank()
119-
120-
device_map = get_device_map_for_llama_70B(rank*6, rank*6+1, rank*6+2, rank*6+3, rank*6+4, rank*6+5)
121+
if train_config.enable_pp and train_config.enable_ddp:
122+
rank = dist.get_rank()
123+
num_layers = get_num_layers_from_config(model_config)
124+
device_map = get_device_map(rank, train_config.num_pp_stages, num_layers)
125+
else:
126+
device_map = "auto"
121127
model = AutoModelForCausalLM.from_pretrained(
122128
pretrained_model_path,
123129
use_cache=False,
@@ -246,7 +252,7 @@ def main(**kwargs):
246252

247253
# wrap model with DDP
248254
if train_config.enable_ddp:
249-
model = nn.parallel.DistributedDataParallel(model)#, device_ids=[dist.get_rank()])
255+
model = nn.parallel.DistributedDataParallel(model) # , device_ids=[dist.get_rank()])
250256

251257
_ = train(
252258
model,
@@ -268,5 +274,4 @@ def main(**kwargs):
268274

269275

270276
if __name__ == "__main__":
271-
setup_distributed_training()
272277
fire.Fire(main)

QEfficient/finetune/configs/training.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class train_config:
5353
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
5454

5555
# dist-related
56+
enable_pp: bool = False
57+
num_pp_stages: int = 1
5658
enable_ddp: bool = False
5759
dist_backend: str = "cpu:gloo,qaic:qccl,cuda:gloo"
5860

QEfficient/transformers/models/internvl/modeling_internvl.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def __init__(self, model):
2020
self.model = model
2121

2222
def forward(self, pixel_values):
23-
vit_embeds = self.model.extract_feature(pixel_values)
24-
return vit_embeds
23+
vision_embeds = self.model.extract_feature(pixel_values)
24+
return vision_embeds
2525

2626

2727
class QEffInternDecoderWrapper(nn.Module):
@@ -31,21 +31,21 @@ def __init__(self, model):
3131
self.config = self.model.language_model.config
3232
self.language_model = self.model.language_model
3333

34-
def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
34+
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
3535
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
3636
B, N, C = input_embeds.shape
3737
image_input_embeds = input_embeds.reshape(B * N, C)
3838
image_input_ids = input_ids.reshape(B * N)
3939
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
4040
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
4141
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
42-
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
42+
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
4343
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
4444
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
4545
outputs = self.model.language_model(
4646
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
4747
)
48-
return outputs.logits, vit_embeds, outputs.past_key_values
48+
return outputs.logits, vision_embeds, outputs.past_key_values
4949

5050

5151
class QEffInternVLModel(nn.Module):
@@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
122122
lang_dynamic_axes = {}
123123
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
124124
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
125-
lang_dynamic_axes["vit_embeds"] = {0: "num_patches"}
125+
lang_dynamic_axes["vision_embeds"] = {0: "num_patches"}
126126
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
127127

128128
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
@@ -139,15 +139,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
139139
return dynamic_axes
140140

141141
def get_output_names(self, kv_offload: bool = False):
142-
vision_output_names = ["vit_embeds"]
142+
vision_output_names = ["vision_embeds"]
143143
lang_output_names = ["logits"]
144144
for i in range(self.language_model.config.num_hidden_layers):
145145
for kv in ["key", "value"]:
146146
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
147147

148148
output_names = {}
149149
if kv_offload:
150-
lang_output_names.insert(1, "vit_embeds_RetainedState")
150+
lang_output_names.insert(1, "vision_embeds_RetainedState")
151151
output_names["vision"] = vision_output_names
152152
output_names["lang"] = lang_output_names
153153
else:
@@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
175175
# Define shapes
176176
inputs_shapes = {}
177177
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
178-
inputs_shapes["vit_embeds"] = (
178+
inputs_shapes["vision_embeds"] = (
179179
constants.INTERN_NUM_PATCHES,
180180
constants.INTERN_FEATURE_SIZE,
181181
self.language_model.config.hidden_size,
@@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
196196
lang_inputs = {}
197197
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
198198
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
199-
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
199+
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
200200
lang_inputs["position_ids"] = (
201201
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
202202
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
@@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False):
220220
inputs["vision"] = vision_inputs
221221
inputs["lang"] = lang_inputs
222222
else:
223-
lang_inputs.pop("vit_embeds")
223+
lang_inputs.pop("vision_embeds")
224224
inputs = {**vision_inputs, **lang_inputs}
225225

226226
return inputs
227227

228228
def forward(self, input_ids, pixel_values, position_ids, past_key_values):
229229
input_embeds = self.language_model.get_input_embeddings()(input_ids)
230-
vit_embeds = self.extract_feature(pixel_values)
230+
vision_embeds = self.extract_feature(pixel_values)
231231
B, N, C = input_embeds.shape
232232
image_input_embeds = input_embeds.reshape(B * N, C)
233233
image_input_ids = input_ids.reshape(B * N)
234234
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
235235
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
236236
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
237-
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
237+
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
238238
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
239239
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
240240
outputs = self.language_model(

QEfficient/transformers/models/llava/modeling_llava.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def forward(self, pixel_values):
3838
selected_image_feature = selected_image_feature
3939
else:
4040
raise ValueError(f"Unexpected select feature strategy: {self.model.config.vision_feature_select_strategy}")
41-
image_features = self.model.multi_modal_projector(selected_image_feature)
41+
vision_embeds = self.model.multi_modal_projector(selected_image_feature)
4242

43-
return image_features
43+
return vision_embeds
4444

4545

4646
class QEFFLlavaDecoderWrapper(nn.Module):
@@ -50,21 +50,21 @@ def __init__(self, model):
5050
self.config = self.model.config
5151
self.language_model = self.model.language_model
5252

53-
def forward(self, input_ids, image_features, position_ids, past_key_values):
53+
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
5454
inputs_embeds = self.model.get_input_embeddings()(input_ids)
55-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
55+
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
5656
mask = input_ids == self.model.config.image_token_index
5757
indices1 = mask.to(torch.int64).cumsum(1) - 1
5858
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
59-
image_features_expanded = image_features[indices0, indices1]
60-
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
59+
vision_embeds_expanded = vision_embeds[indices0, indices1]
60+
inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds)
6161
outputs = self.model.language_model(
6262
inputs_embeds=inputs_embeds,
6363
position_ids=position_ids,
6464
past_key_values=past_key_values,
6565
)
6666

67-
return outputs.logits, image_features, outputs.past_key_values
67+
return outputs.logits, vision_embeds, outputs.past_key_values
6868

6969

7070
class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
@@ -86,14 +86,14 @@ def forward(self, input_ids, position_ids, pixel_values, past_key_values):
8686
selected_image_feature = selected_image_feature
8787
else:
8888
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
89-
image_features = self.multi_modal_projector(selected_image_feature)
90-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
89+
vision_embeds = self.multi_modal_projector(selected_image_feature)
90+
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
9191

9292
mask = input_ids == self.config.image_token_index
9393
indices1 = mask.to(torch.int64).cumsum(1) - 1
9494
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
95-
image_features_expanded = image_features[indices0, indices1]
96-
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
95+
vision_embeds_expanded = vision_embeds[indices0, indices1]
96+
image_inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds)
9797
# *where to skip image encoder for decode*
9898
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
9999
outputs = self.language_model(
@@ -118,7 +118,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
118118
}
119119
lang_inputs = {
120120
"input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
121-
"image_features": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32),
121+
"vision_embeds": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32),
122122
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
123123
}
124124
lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1)
@@ -137,7 +137,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
137137
inputs["vision"] = vision_inputs
138138
inputs["lang"] = lang_inputs
139139
else:
140-
lang_inputs.pop("image_features")
140+
lang_inputs.pop("vision_embeds")
141141
inputs = {**vision_inputs, **lang_inputs}
142142
return inputs
143143

@@ -218,15 +218,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
218218
return dynamic_axes
219219

220220
def get_output_names(self, kv_offload: bool = False):
221-
vision_output_names = ["image_features"]
221+
vision_output_names = ["vision_embeds"]
222222
lang_output_names = ["logits"]
223223
for i in range(self.language_model.config.num_hidden_layers):
224224
for kv in ["key", "value"]:
225225
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
226226

227227
output_names = {}
228228
if kv_offload:
229-
lang_output_names.insert(1, "image_features_RetainedState")
229+
lang_output_names.insert(1, "vision_embeds_RetainedState")
230230
output_names["vision"] = vision_output_names
231231
output_names["lang"] = lang_output_names
232232
else:

QEfficient/transformers/models/mllama/modeling_mllama.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def forward(
161161
value_states_new = torch.index_put(value_states_old, indices, value_states)
162162

163163
# Select old or new image KV states based on q_len
164-
key_states = torch.where(q_len == 1, key_states_old, key_states_new)
165-
value_states = torch.where(q_len == 1, value_states_old, value_states_new)
164+
key_states = torch.where(torch.tensor(q_len == 1), key_states_old, key_states_new)
165+
value_states = torch.where(torch.tensor(q_len == 1), value_states_old, value_states_new)
166166

167167
# Update the image cache
168168
past_key_value.key_cache[self.layer_idx] = key_states
@@ -924,7 +924,7 @@ def forward(
924924
return_dict=return_dict,
925925
cache_position=cache_position,
926926
)
927-
927+
outputs["pixel_values"] = pixel_values
928928
return outputs
929929

930930
def get_dummy_inputs(self, kv_offload: bool = False):
@@ -1092,6 +1092,8 @@ def get_output_names(self, kv_offload: bool = False):
10921092
"logits",
10931093
*[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]],
10941094
]
1095+
if not kv_offload:
1096+
lang_output_names.append("pixel_values_RetainedState")
10951097

10961098
output_names = {}
10971099
if kv_offload:

0 commit comments

Comments
 (0)