Skip to content

Commit

Permalink
Merge branch 'synapse_1_20' into transformers_future
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Feb 17, 2025
2 parents d053218 + 8044aa4 commit fe01ca2
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 75 deletions.
2 changes: 2 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ python3 ../gaudi_spawn.py \
--lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```

The single card training command for llava-hf/llava-1.5-7b-hf is similar.

> For different models, please adjust training parameters and `lora_target_modules`. Such as replace `lora_target_modules`
> with below for HuggingFaceM4/idefics2-8b.
> '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
174 changes: 145 additions & 29 deletions examples/image-to-text/run_image2text_lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,58 @@ def __call__(self, examples):
return batch


def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length):
class LLavaDataCollator:
def __init__(self, processor, max_seq_length):
self.processor = processor

num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * (
self.processor.image_processor.crop_size["width"] // self.processor.patch_size
) + 1
if self.processor.vision_feature_select_strategy == "default":
num_image_tokens -= 1

# text length + image length
self.max_seq_length = max_seq_length + num_image_tokens

def __call__(self, examples):
texts = []
images = []

keys = list(examples[0].keys())
if not all(key in ["image", "query", "answers"] for key in keys):
raise ValueError("Unsupported dataset format")
for example in examples:
image = example["image"]
question = example["query"]["en"]
answer = random.choice(example["answers"])
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append(image)

batch = self.processor(
images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length
)

labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch


def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length, model_type):
from tqdm import tqdm

answers_unique = []
Expand All @@ -307,7 +358,6 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
for i in tqdm(range(0, len(dataset), batch_size)):
examples = dataset[i : i + batch_size]
answers_unique.extend(examples["answers"])
images = [[im] for im in examples["image"]]
texts = []
for q in examples["query"]:
messages = [
Expand All @@ -322,14 +372,31 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
texts.append(text.strip())
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
)

if model_type is not None and model_type == "llava":
images = []
for im in examples["image"]:
images.append(im)

inputs = processor(
images,
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
padding_side="left",
)
else:
images = [[im] for im in examples["image"]]
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_seq_length,
)
inputs = {k: v.to("hpu") for k, v in inputs.items()}
generated_ids = model.generate(
**inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs
Expand All @@ -346,6 +413,22 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m
return anls


def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -380,7 +463,7 @@ def main():
do_image_splitting=model_args.do_image_splitting,
padding_side="right",
)
setattr(processor.image_processor, "pad_to_longest_edge", True)

config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
Expand All @@ -395,7 +478,13 @@ def main():
else:
raise ValueError("Please provide value for model_name_or_path or config_name.")

# Load model
if config.model_type == "llava":
setattr(processor, "patch_size", config.vision_config.patch_size)
setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy)
else:
setattr(processor.image_processor, "pad_to_longest_edge", True)

# Load model
if model_args.model_name_or_path:
model_dtype = torch.bfloat16 if training_args.bf16 else None
model = AutoModelForVision2Seq.from_pretrained(
Expand All @@ -413,11 +502,16 @@ def main():
else:
raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")

if finetune_args.lora_target_modules is None:
target_modules = find_all_linear_names(model)
else:
target_modules = finetune_args.lora_target_modules

lora_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=finetune_args.lora_target_modules,
target_modules=target_modules,
init_lora_weights="gaussian",
)
model = get_peft_model(model, lora_config)
Expand Down Expand Up @@ -456,15 +550,21 @@ def main():
if col not in (data_args.input_column_names + data_args.output_column_names)
]
)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
if config.model_type == "llava":
data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length)
else:
raise ValueError("Please provide value for image_token_id")
data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama
image_token_id = config.image_token_index
else:
raise ValueError("Please provide value for image_token_id")

data_collator = MyDataCollator(
processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id
)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
Expand Down Expand Up @@ -509,14 +609,29 @@ def main():
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[text.strip()],
images=[image],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
)

if config.model_type == "llava":
# don't expand image_token_id
setattr(processor, "patch_size", None)
setattr(processor, "vision_feature_select_strategy", None)
inputs = processor(
[image],
[text.strip()],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
padding_side="left",
)
else:
inputs = processor(
text=[text.strip()],
images=[image],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=data_args.max_seq_length,
)
inputs = {k: v.to("hpu") for k, v in inputs.items()}
generated_ids = model.generate(
**inputs,
Expand All @@ -543,6 +658,7 @@ def main():
use_lazy_mode=training_args.use_lazy_mode,
use_hpu_graphs=training_args.use_hpu_graphs_for_inference,
max_seq_length=data_args.max_seq_length,
model_type=config.model_type,
)
eval_metrics = {"eval_accuracy": anls}
trainer.log_metrics("eval", eval_metrics)
Expand Down
Loading

0 comments on commit fe01ca2

Please sign in to comment.