Skip to content

Feat: Onbaord PlamoForCausalLM Architecture #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 54 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
56c5030
Onboard Plamo model
quic-shagun Apr 4, 2025
de15585
Update Plamo modeling file to support Opset 13
quic-shagun Apr 7, 2025
ed9d55d
Fix Plamo accuracy issues
quic-shagun Apr 9, 2025
ea87ba2
Merge branch 'quic:main' into plamo
quic-shagun Apr 9, 2025
e426634
Fix lint issues
quic-shagun Apr 9, 2025
537ba6f
Add Qualcomm Signature in new modeling file
quic-shagun Apr 9, 2025
1516dee
nit: Add Plamo in test file and Update README
quic-shagun Apr 9, 2025
34d9d97
Merge branch 'quic:main' into plamo
quic-shagun Apr 22, 2025
4edea79
Merge branch 'quic:main' into plamo
quic-shagun May 7, 2025
5c100c5
Update modeling file as per latest guidelines
quic-shagun May 8, 2025
845a9c9
Merge branch 'quic:main' into plamo
quic-shagun Jun 3, 2025
ad64436
Clean up modeling file
quic-shagun Jun 10, 2025
d4f8a32
Merge branch 'quic:main' into plamo
quic-shagun Jun 10, 2025
d91fe8b
Llama4 vlm changes (#443)
quic-dhirajku Jun 11, 2025
2514c0b
Features upgrade of `Embedding model` (#424)
quic-amitraj Jun 11, 2025
1e8039b
[QEff Finetune]: Enable --help for finetune CLI (#392)
quic-mamta Jun 12, 2025
6bcf5de
Grok-1Modelling changes and On device sampling (#447)
quic-rishinr Jun 12, 2025
64f4a04
[QEff Finetune]: Adding steps about how to fine tune on any custom da…
quic-swatia Jun 13, 2025
0d50e29
Gemma3 + llama4 bug fix (#453)
quic-rishinr Jun 13, 2025
0a693fa
Lint Errors Fixed (#468)
quic-amitraj Jun 20, 2025
80691b1
[QEff Finetune]: Made some formatting changes and removed obsolete fl…
quic-swatia Jun 20, 2025
8dcd85b
Updated example script of embedding model (#469)
quic-amitraj Jun 20, 2025
5c5acce
[QEff Finetune]: Removing indentation for urls to reflect. (#470)
quic-swatia Jun 20, 2025
04a668f
Add llama chunk + gemma changes (#461)
quic-amitraj Jun 20, 2025
1453942
Announcement update for Granite Vision (#474)
qcdipankar Jun 22, 2025
740f7c2
Fixes for mllama (#462)
qcdipankar Jun 23, 2025
61b1445
BugFix: Fix reshape error for llama swiftkv models (#432)
quic-shagun Jun 25, 2025
eff9472
Gemma 3 minor fixes (#476)
quic-akuruvil Jun 25, 2025
77cfb29
Bug fix for spdTransform (#467)
qcdipankar Jun 27, 2025
6c64d35
[QEff. Finetune]: Enabled FT CI tests. (#420)
quic-meetkuma Jul 1, 2025
10fb2ac
Gemma 3 minor fixes (#476) - CPR (#484)
quic-akuruvil Jul 1, 2025
71e554f
Revert "Gemma 3 minor fixes (#476) - CPR" (#485)
quic-hemagnih Jul 1, 2025
d823503
[Docs/Readme]: Main Readme updating for latest news and adding the on…
abukhoy Jul 2, 2025
c5a5c17
QUICKFIX: Removed the redundant breakpoint comment in modeling_llava_…
quic-dhirajku Jul 3, 2025
b90c1ac
MDP hash support (#479)
quic-rishinr Jul 3, 2025
db38927
[QEff Finetune] Adding dataset padding changes (#478)
quic-swatia Jul 4, 2025
6254efe
Fixed QNN data format config issue. (#480)
shubhagr-qc Jul 7, 2025
2ba491d
Corrected Total Inference Time unit (#505)
asmigosw Jul 9, 2025
3aaa2d8
[QEff. Finetune]: Added support to sync gradients across devices duri…
quic-meetkuma Jul 9, 2025
30d1579
[QEff Finetune]: Implement logger for finetuning and enable dumping (…
quic-mamta Jul 9, 2025
09c05db
Adding Fix for Falcon model (#508)
qcdipankar Jul 10, 2025
50b1404
[QEff. Finetune]: Removed samsum dataset references from FT code. (#482)
quic-meetkuma Jul 10, 2025
31771e3
Dynamic cache support on llama4 (#494)
quic-rishinr Jul 13, 2025
a78e983
Dependency package upgrade (#407)
qcdipankar Jul 14, 2025
7ba58e4
[QEff Finetune] : fix task_type variable in configs (#514)
quic-mamta Jul 14, 2025
761d339
Onboard Plamo model
quic-shagun Apr 4, 2025
5078513
Update Plamo modeling file to support Opset 13
quic-shagun Apr 7, 2025
a6caa54
Fix Plamo accuracy issues
quic-shagun Apr 9, 2025
ea7c330
Fix lint issues
quic-shagun Apr 9, 2025
1363b22
Add Qualcomm Signature in new modeling file
quic-shagun Apr 9, 2025
f64a5b7
nit: Add Plamo in test file and Update README
quic-shagun Apr 9, 2025
3fe384c
Update modeling file as per latest guidelines
quic-shagun May 8, 2025
0e07bd3
Clean up modeling file
quic-shagun Jun 10, 2025
55103e0
Merge branch 'plamo' of https://github.com/quic-shagun/efficient-tran…
quic-shagun Jul 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
@@ -6,16 +6,21 @@
# -----------------------------------------------------------------------------

import os
import warnings

from QEfficient.utils import custom_format_warning

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Placeholder for all non-transformer models registered in QEfficient
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils.logging_utils import logger

# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning


def check_qaic_sdk():
"""Check if QAIC SDK is installed"""
9 changes: 6 additions & 3 deletions QEfficient/base/common.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from transformers import AutoConfig

from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING, MODEL_CLASS_MAPPING
from QEfficient.utils import login_and_download_hf_lm


@@ -40,9 +40,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
"""
Downloads HuggingFace model if already doesn't exist locally, returns QEFFAutoModel object based on type of model.
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

class_name = MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
class_name = (
MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
or EXTERNAL_MODEL_CLASS_MAPPING[config.__class__.__name__]
)
if class_name:
module = __import__("QEfficient.transformers.models.modeling_auto")
model_class = getattr(module, class_name)
63 changes: 31 additions & 32 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@

import hashlib
import inspect
import json
import logging
import shutil
import subprocess
@@ -23,7 +22,7 @@
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants, dump_qconfig
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
from QEfficient.utils.cache import QEFF_HOME, to_hashable

logger = logging.getLogger(__name__)
@@ -269,17 +268,17 @@ def _compile(
specializations=specializations,
custom_io=custom_io,
device_group=list(range(mdp_ts_num_devices)),
num_cores=compiler_options.get("aic_num_cores", 16),
mxfp6=compiler_options.get("mxfp6_matmul", False),
num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES),
mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL),
mxint8=mxint8_kv_cache,
qnn_config=qnn_config,
)

return self.qpc_path

command = constants.COMPILER + [f"-m={onnx_path}"]
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
mdp_ts_num_devices = None

if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")

for key, value in compiler_options.items():
@@ -289,6 +288,17 @@ def _compile(
command.append(option)
continue
command.append(f"{option}={value}")

# Create a dummy mdp_ts_json if mdp-load-partition-config not provided and num_devices > 1
if mdp_ts_json_path is not None:
mdp_ts_json = load_json(str(mdp_ts_json_path))
elif mdp_ts_num_devices > 1:
mdp_ts_json = generate_mdp_partition_config(
mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES)
)
else:
mdp_ts_json = None

compile_hash = hashlib.sha256(to_hashable(command))

if specializations is not None:
@@ -300,27 +310,36 @@ def _compile(
if num_speculative_tokens:
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))

# Hash the MDP partition config and the number of devices.
compile_hash.update(to_hashable(mdp_ts_json))
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))

# Check if already compiled
compile_hash = compile_hash.hexdigest()[:16]
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
qpc_path = compile_dir / "qpc"
qpc_path.mkdir(parents=True, exist_ok=True)

if qpc_path.is_dir():
if (qpc_path / "programqpc.bin").is_file():
self.qpc_path = qpc_path
return qpc_path
# Probably compilation failure last time, delete directory to start over
shutil.rmtree(qpc_path)

# write the MDP partition config file if not provided
if mdp_ts_json is not None:
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
create_json(str(mdp_ts_json_path), mdp_ts_json)
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")

# Write specializations.json file
if specializations is not None:
specializations_json = compile_dir / "specializations.json"
with open(specializations_json, "w") as fp:
json.dump(
{"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]},
fp,
indent=4,
)
specializations_data = {
"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]
}
create_json(str(specializations_json), specializations_data)
command.append(f"-network-specialization-config={specializations_json}")

# Write custom_io.yaml file
@@ -331,26 +350,6 @@ def _compile(
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
command.append(f"-custom-IO-list-file={custom_io_yaml}")

# Write mdp_config.json file
if not mdp_ts_json_path and mdp_ts_num_devices > 1:
num_cores = compiler_options.get("aic_num_cores", 16)
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
with open(mdp_ts_json, "w") as fp:
json.dump(
{
"connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}],
"partitions": [
{
"name": "Partition0",
"devices": [{"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices)],
}
],
},
fp,
indent=4,
)
command.append(f"-mdp-load-partition-config={mdp_ts_json}")

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
try:
69 changes: 68 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@

from torch import nn

from QEfficient.utils.logging_utils import logger


class PytorchTransform:
"""
@@ -90,7 +92,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")


class ModuleMethodMapperTransform(PytorchTransform):
class ExternalModuleMapperTransform(PytorchTransform):
"""
Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
"""
@@ -107,6 +109,71 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
# Handling the __init__ calls in the models
if hasattr(module, "__qeff_init__"):
module.__qeff_init__()
transformed = True

return model, transformed


class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model
For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()
for layer_idx in range(num_layers):
# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if delete_fused_key:
del sd[fused_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp
return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
138 changes: 49 additions & 89 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,11 @@
#
# -----------------------------------------------------------------------------

import logging
import random
import warnings
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union

import fire
import numpy as np
import torch
import torch.distributed as dist
@@ -18,31 +18,33 @@
import torch.utils.data
from peft import PeftModel, get_peft_model
from torch.optim.lr_scheduler import StepLR
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.config_utils import (
generate_dataset_config,
generate_peft_config,
get_dataloader_kwargs,
update_config,
)
from QEfficient.finetune.utils.dataset_utils import (
get_custom_data_collator,
get_preprocessed_dataset,
from QEfficient.finetune.utils.dataset_utils import get_dataloader
from QEfficient.finetune.utils.helper import Task_Mode
from QEfficient.finetune.utils.logging_utils import logger
from QEfficient.finetune.utils.parser import get_finetune_parser
from QEfficient.finetune.utils.train_utils import (
get_longest_seq_length,
print_model_size,
print_trainable_parameters,
train,
)
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
from QEfficient.utils._utils import login_and_download_hf_lm
from QEfficient.utils._utils import hf_download

# Try importing QAIC-specific module, proceed without it if unavailable
try:
import torch_qaic # noqa: F401
except ImportError as e:
print(f"Warning: {e}. Proceeding without QAIC modules.")
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.", logging.WARNING)


from transformers import AutoModelForSequenceClassification

# Suppress all warnings
warnings.filterwarnings("ignore")

@@ -68,7 +70,8 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"

dist.init_process_group(backend=train_config.dist_backend)
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
dist.init_process_group(backend=dist_backend_map[torch_device.type])
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
getattr(torch, torch_device.type).set_device(dist.get_rank())

@@ -88,14 +91,13 @@ def setup_seeds(seed: int) -> None:


def load_model_and_tokenizer(
train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs
train_config: TrainConfig, dataset_config: Any, **kwargs
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Load the pre-trained model and tokenizer from Hugging Face.
Args:
config (TrainConfig): Training configuration object containing model and tokenizer names.
dataset_config (Any): A dataclass object representing dataset configuration.
peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
kwargs: Additional arguments to override PEFT config.
Returns:
@@ -109,8 +111,9 @@ def load_model_and_tokenizer(
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
"""
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
if train_config.task_type == "seq_classification":
logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
pretrained_model_path = hf_download(train_config.model_name)
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_path,
num_labels=dataset_config.num_labels,
@@ -119,7 +122,7 @@ def load_model_and_tokenizer(
)

if not hasattr(model, "base_model_prefix"):
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
logger.raise_error("Given huggingface model does not have 'base_model_prefix' attribute.", RuntimeError)

for param in getattr(model, model.base_model_prefix).parameters():
param.requires_grad = False
@@ -144,11 +147,10 @@ def load_model_and_tokenizer(
# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
print("WARNING: Resizing embedding matrix to match tokenizer vocab size.")
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logging.WARNING)
model.resize_token_embeddings(len(tokenizer))

# FIXME (Meet): Cover below line inside the logger once it is implemented.
print_model_size(model, train_config)
print_model_size(model)

# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -160,27 +162,25 @@ def load_model_and_tokenizer(
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
else:
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
logger.raise_error(
"Given model doesn't support gradient checkpointing. Please disable it and run it.", RuntimeError
)

model = apply_peft(model, train_config, peft_config_file, **kwargs)
model = apply_peft(model, train_config, **kwargs)

return model, tokenizer


def apply_peft(
model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs
) -> Union[AutoModel, PeftModel]:
def apply_peft(model: AutoModel, train_config: TrainConfig, **kwargs) -> Union[AutoModel, PeftModel]:
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
Args:
model (AutoModel): Huggingface model.
train_config (TrainConfig): Training configuration object.
peft_config_file (str, optional): Path to YAML/JSON file containing
PEFT (LoRA) config. Defaults to None.
kwargs: Additional arguments to override PEFT config params.
Returns:
Union[AutoModel, PeftModel]: If the use_peft in train_config is True
Union[AutoModel, PeftModel]: If use_peft in train_config is True
then PeftModel object is returned else original model object
(AutoModel) is returned.
"""
@@ -193,9 +193,9 @@ def apply_peft(
peft_config = model.peft_config
# Generate the peft config and start fine-tuning from original model
else:
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
peft_config = generate_peft_config(train_config, **kwargs)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print_trainable_parameters(model)

return model

@@ -220,70 +220,26 @@ def setup_dataloaders(
- Length of longest sequence in the dataset.
Raises:
ValueError: If validation is enabled but the validation set is too small.
RuntimeError: If validation is enabled but the validation set is too small.
Notes:
- Applies a custom data collator if provided by get_custom_data_collator.
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
"""
# Get the dataset utils
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
dataset_processer, dataset_config, split="train", context_length=train_config.context_length
)

dataset_val = get_preprocessed_dataset(
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
)

# TODO: vbaddi, check if its necessary to do this?
# dataset_train = ConcatDataset(
# dataset_train, chunk_size=train_config.context_length
# )
##
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
print("length of dataset_train", len(dataset_train))

# FIXME (Meet): Add custom data collator registration from the outside by the user.
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
if custom_data_collator:
print("custom_data_collator is used")
train_dl_kwargs["collate_fn"] = custom_data_collator

# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**train_dl_kwargs,
)
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
train_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="train")
logger.log_rank_zero(f"Number of Training Set Batches loaded = {len(train_dataloader)}")

eval_dataloader = None
if train_config.run_validation:
# if train_config.batching_strategy == "packing":
# dataset_val = ConcatDataset(
# dataset_val, chunk_size=train_config.context_length
# )

val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
if custom_data_collator:
val_dl_kwargs["collate_fn"] = custom_data_collator

eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**val_dl_kwargs,
)
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
if len(eval_dataloader) == 0:
raise ValueError(
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
logger.raise_error(
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
ValueError,
)
else:
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")

longest_seq_length, _ = get_longest_seq_length(
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
@@ -294,12 +250,11 @@ def setup_dataloaders(
return train_dataloader, eval_dataloader, longest_seq_length


def main(peft_config_file: str = None, **kwargs) -> None:
def main(**kwargs) -> None:
"""
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
Args:
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
kwargs: Additional arguments to override TrainConfig.
Example:
@@ -316,23 +271,25 @@ def main(peft_config_file: str = None, **kwargs) -> None:
--model_name "meta-llama/Llama-3.2-1B" \\
--lr 5e-4
"""
# TODO:Remove TrainConfig() and update_config() as all params are passed in kwargs by parser
train_config = TrainConfig()
update_config(train_config, **kwargs)
dataset_config = generate_dataset_config(train_config.dataset)
update_config(dataset_config, **kwargs)

logger.prepare_for_logs(train_config.output_dir, train_config.dump_logs, train_config.log_level)

setup_distributed_training(train_config)
setup_seeds(train_config.seed)
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, **kwargs)

# Create DataLoaders for the training and validation dataset
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
print(
logger.log_rank_zero(
f"The longest sequence length in the train data is {longest_seq_length}, "
f"passed context length is {train_config.context_length} and overall model's context length is "
f"{model.config.max_position_embeddings}"
)

model.to(train_config.device)
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
@@ -354,4 +311,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:


if __name__ == "__main__":
fire.Fire(main)
parser = get_finetune_parser()
args = parser.parse_args()
args_dict = vars(args)
main(**args_dict)
9 changes: 9 additions & 0 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
@@ -111,6 +111,7 @@ def main(
allow_mxint8_mdp_io: bool = False,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> None:
"""
@@ -140,6 +141,7 @@ def main(
:allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
:trust_remote_code (bool): Trust remote code execution. ``Defaults to False.``
:kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
-allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1
-qpc_crc=True -> -qpc-crc
@@ -164,6 +166,7 @@ def main(
hf_token=hf_token,
full_batch_size=full_batch_size,
local_model_dir=local_model_dir,
trust_remote_code=trust_remote_code,
)

image_path = kwargs.pop("image_path", None)
@@ -264,6 +267,12 @@ def main(
action="store_true",
help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
default=False,
help="Enable trusting remote code when loading models. Default is False; set to True by passing this flag.",
)
parser.add_argument(
"--mxint8",
"--mxint8_kv_cache",
1 change: 0 additions & 1 deletion QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
@@ -129,7 +129,6 @@ def export_bertstyle_model_to_onnx(model_name, model, tokenizer, onnx_dir_path,
)

# Generate inputFiles
# todo(ochougul):rename to bert_style_input_list.txt
input_list_file = os.path.join(onnx_dir_path, "input_list.txt")
generate_input_files(
input_files_path=os.path.join(onnx_dir_path, "inputFiles"),
4 changes: 0 additions & 4 deletions QEfficient/exporter/export_utils.py
Original file line number Diff line number Diff line change
@@ -218,8 +218,6 @@ def fix_onnx_fp16(
:str: Updated base name of exported ONNX model.
"""
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
# TODO: Remove this `fix_onnx_fp16` function and replace with this transform
# as we're not utilizing the validations done in this function
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)

if fp16_fix:
@@ -256,8 +254,6 @@ def fix_onnx_fp16(
if ort_outputs is not None:
for oname, orto, ortof in zip(output_names, ort_outputs, ort_outputs_fixed):
fix_diff = np.abs(orto.astype(np.float32) - ortof.astype(np.float32)).max()
# TODO: need to the debug this
# info(oname, fix_diff)
close_outputs.append(fix_diff < 1e-5)
else:
info("No constants out of FP16 range")
7 changes: 0 additions & 7 deletions QEfficient/finetune/configs/dataset_config.py
Original file line number Diff line number Diff line change
@@ -8,13 +8,6 @@
from dataclasses import dataclass


@dataclass
class samsum_dataset:
dataset: str = "samsum_dataset"
train_split: str = "train"
test_split: str = "validation"


@dataclass
class grammar_dataset:
dataset: str = "grammar_dataset"
7 changes: 0 additions & 7 deletions QEfficient/finetune/configs/peft_config.py
Original file line number Diff line number Diff line change
@@ -30,10 +30,3 @@ class LoraConfig:
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False # should be False for finetuning


# CAUTION prefix tuning is currently not supported
@dataclass
class PrefixConfig:
num_virtual_tokens: int = 30
task_type: str = "CAUSAL_LM"
67 changes: 33 additions & 34 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,12 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import logging
from dataclasses import dataclass

from QEfficient.finetune.utils.helper import Batching_Strategy, Device, Peft_Method, Task_Mode


# Configuration Classes
@dataclass
@@ -16,10 +20,13 @@ class TrainConfig:
model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B").
tokenizer_name (str): Name of the tokenizer (defaults to model_name if None).
run_validation (bool): Whether to run validation during training (default: True).
batch_size_training (int): Batch size for training (default: 1).
train_batch_size (int): Batch size for training (default: 1).
val_batch_size (int): Batch size for validation (default: 1).
context_length (Optional[int]): Maximum sequence length for inputs (default: None).
gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4).
gradient checkpointing (bool): Enable gradient checkpointing to save the memory by compromising the speed. (default: False).
use_autocast (bool): Use autocast for mixed precision (default: True).
grad_scaler (bool): Use gradient scaler (default: True).
num_epochs (int): Number of training epochs (default: 1).
max_train_step (int): Maximum training steps (default: 0, unlimited if 0).
max_eval_step (int): Maximum evaluation steps (default: 0, unlimited if 0).
@@ -29,17 +36,13 @@ class TrainConfig:
weight_decay (float): Weight decay for optimizer (default: 0.0).
gamma (float): Learning rate decay factor (default: 0.85).
seed (int): Random seed for reproducibility (default: 42).
use_fp16 (bool): Use mixed precision training (default: True).
use_autocast (bool): Use autocast for mixed precision (default: True).
val_batch_size (int): Batch size for validation (default: 1).
dataset (str): Dataset name for training (default: "samsum_dataset").
task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
dataset (str): Dataset name for training (default: "alpaca_dataset").
task_mode (str): Mode of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
use_peft (bool): Whether to use PEFT (default: True).
from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
output_dir (str): Directory to save outputs (default: "meta-llama-samsum").
num_freeze_layers (int): Number of layers to freeze (default: 1).
one_qaic (bool): Use single QAIC device (default: False).
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
peft_config_file (str): Path to YAML/JSON file containing PEFT (LoRA) config. (default: None)
from_peft_checkpoint (str): Path to PEFT checkpoint (default: None).
output_dir (str): Directory to save outputs (default: "training_results").
save_model (bool): Save the trained model (default: True).
save_metrics (bool): Save training metrics (default: True).
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
@@ -49,43 +52,42 @@ class TrainConfig:
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
use_profiler (bool): Enable profiling (default: False).
enable_ddp (bool): Enable distributed data parallel (default: False).
dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo").
grad_scaler (bool): Use gradient scaler (default: True).
dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_").
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
dump_logs (bool): Whether to dump logs (default: True).
log_level (str): logging level (default: logging.INFO)
"""

model_name: str = "meta-llama/Llama-3.2-1B"
tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name
run_validation: bool = True
batch_size_training: int = 1
train_batch_size: int = 1
val_batch_size: int = 1
context_length: int = None
gradient_accumulation_steps: int = 4
gradient_checkpointing: bool = False
use_autocast: bool = True
grad_scaler: bool = True
num_epochs: int = 1
max_train_step: int = 0
max_eval_step: int = 0
device: str = "qaic"
device: str = Device.QAIC.value
num_workers_dataloader: int = 1
lr: float = 3e-4
weight_decay: float = 0.0
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
seed: int = 42
use_fp16: bool = True
use_autocast: bool = True
val_batch_size: int = 1
dataset = "samsum_dataset"
task_type = "generation" # "generation" / "seq_classification"
peft_method: str = "lora"
use_peft: bool = True # use parameter efficient fine tuning
from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
output_dir: str = "meta-llama-samsum"
num_freeze_layers: int = 1
one_qaic: bool = False
dataset: str = "alpaca_dataset"
task_mode: str = Task_Mode.GENERATION.value # "generation" / "seq_classification"
use_peft: bool = True # use parameter efficient finetuning
peft_method: str = Peft_Method.LORA.value
peft_config_file: str = None
from_peft_checkpoint: str = None # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
output_dir: str = "training_results"
save_model: bool = True
save_metrics: bool = True # saves training metrics to a json file for later plotting
intermediate_step_save: int = 1000
batching_strategy: str = "packing"
batching_strategy: str = Batching_Strategy.PADDING.value
enable_ddp: bool = False
enable_sorting_for_ddp: bool = True
convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps
convergence_loss: float = (
@@ -98,10 +100,7 @@ class TrainConfig:
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler

# dist-related
enable_ddp: bool = False
dist_backend: str = "cpu:gloo,qaic:qccl,cuda:gloo"

grad_scaler: bool = True
dump_root_dir: str = "meta-llama-samsum-mismatches/step_"
opByOpVerifier: bool = False

dump_logs: bool = True
log_level: str = logging.INFO
16 changes: 10 additions & 6 deletions QEfficient/finetune/data/sampler.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import random
from itertools import islice

import numpy as np
import torch


@@ -22,14 +20,14 @@ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
self.data_source = data_source

def __iter__(self):
ids = np.argsort(self.lengths, kind="mergesort")
ids = list(range(len(self.data_source)))
if self.drop_last:
ids = ids[: len(ids) // self.batch_size * self.batch_size]

batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]

if self.shuffle:
random.shuffle(batches)

@@ -45,11 +43,17 @@ def __len__(self):

class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(
self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0
self,
data_source,
batch_size: int,
num_replicas: int,
rank: int,
shuffle: bool = True,
seed: int = 0,
) -> None:
random.seed(seed)
self.batch_sampler = LengthBasedBatchSampler(
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle
)
self.num_replicas = num_replicas
self.rank = rank
10 changes: 9 additions & 1 deletion QEfficient/finetune/dataset/alpaca_dataset.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
import torch
from torch.utils.data import Dataset

from QEfficient.finetune.utils.logging_utils import logger

PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
@@ -27,7 +29,13 @@

class InstructionDataset(Dataset):
def __init__(self, dataset_config, tokenizer, partition="train", context_length=None):
self.ann = json.load(open(dataset_config.data_path))
try:
self.ann = json.load(open(dataset_config.data_path))
except FileNotFoundError:
logger.raise_error(
"Loading of alpaca dataset failed! Please use (wget -c https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json -P dataset/) to download the alpaca dataset.",
FileNotFoundError,
)
# Use 5% of the dataset for evaluation
eval_length = int(len(self.ann) / 20)
if partition == "train":
32 changes: 20 additions & 12 deletions QEfficient/finetune/dataset/custom_dataset.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
import importlib
from pathlib import Path

from QEfficient.finetune.utils.logging_utils import logger


def load_module_from_py_file(py_file: str) -> object:
"""
@@ -23,27 +25,29 @@ def load_module_from_py_file(py_file: str) -> object:
return module


def get_custom_dataset(dataset_config, tokenizer, split: str):
def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=None):
if ":" in dataset_config.file:
module_path, func_name = dataset_config.file.split(":")
else:
module_path, func_name = dataset_config.file, "get_custom_dataset"

if not module_path.endswith(".py"):
raise ValueError(f"Dataset file {module_path} is not a .py file.")
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)

module_path = Path(module_path)
if not module_path.is_file():
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
logger.raise_error(
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
)

module = load_module_from_py_file(module_path.as_posix())
try:
return getattr(module, func_name)(dataset_config, tokenizer, split)
except AttributeError as e:
print(
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
except AttributeError:
logger.raise_error(
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).",
AttributeError,
)
raise e


def get_data_collator(dataset_processer, dataset_config):
@@ -53,16 +57,20 @@ def get_data_collator(dataset_processer, dataset_config):
module_path, func_name = dataset_config.file, "get_data_collator"

if not module_path.endswith(".py"):
raise ValueError(f"Dataset file {module_path} is not a .py file.")
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)

module_path = Path(module_path)
if not module_path.is_file():
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
logger.raise_error(
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
)

module = load_module_from_py_file(module_path.as_posix())
try:
return getattr(module, func_name)(dataset_processer)
except AttributeError:
print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
print("Using the default data_collator instead.")
logger.log_rank_zero(
f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()})."
)
logger.log_rank_zero("Using the default data_collator instead.")
return None
4 changes: 0 additions & 4 deletions QEfficient/finetune/dataset/dataset_config.py
Original file line number Diff line number Diff line change
@@ -21,14 +21,10 @@
from QEfficient.finetune.dataset.imdb_dataset import (
get_preprocessed_imdb as get_imdb_dataset,
)
from QEfficient.finetune.dataset.samsum_dataset import (
get_preprocessed_samsum as get_samsum_dataset,
)

DATASET_PREPROC = {
"alpaca_dataset": partial(get_alpaca_dataset),
"grammar_dataset": get_grammar_dataset,
"samsum_dataset": get_samsum_dataset,
"gsm8k_dataset": get_gsm8k_dataset,
"custom_dataset": get_custom_dataset,
"imdb_dataset": get_imdb_dataset,
17 changes: 8 additions & 9 deletions QEfficient/finetune/dataset/grammar_dataset.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@
from datasets import load_dataset
from torch.utils.data import Dataset

from QEfficient.finetune.utils.logging_utils import logger


class grammar(Dataset):
def __init__(self, tokenizer, csv_name=None, context_length=None):
@@ -19,11 +21,11 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
delimiter=",",
)
except Exception as e:
print(
"Loading of grammar dataset failed! Please see [here](https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
except FileNotFoundError:
logger.raise_error(
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset.",
FileNotFoundError,
)
raise e

self.context_length = context_length
self.tokenizer = tokenizer
@@ -36,7 +38,7 @@ def convert_to_features(self, example_batch):
# Create prompt and tokenize contexts and questions

if self.print_text:
print("Input Text: ", self.clean_text(example_batch["text"]))
logger.log_rank_zero("Input Text: ", self.clean_text(example_batch["text"]))

input_ = example_batch["input"]
target_ = example_batch["target"]
@@ -71,9 +73,6 @@ def get_dataset(dataset_config, tokenizer, csv_name=None, context_length=None):
"""cover function for handling loading the working dataset"""
"""dataset loading"""
currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
print(f"Loading dataset {currPath}")
csv_name = str(currPath)
print(csv_name)
dataset = grammar(tokenizer=tokenizer, csv_name=csv_name, context_length=context_length)
dataset = grammar(tokenizer=tokenizer, csv_name=str(currPath), context_length=context_length)

return dataset
48 changes: 0 additions & 48 deletions QEfficient/finetune/dataset/samsum_dataset.py

This file was deleted.

65 changes: 21 additions & 44 deletions QEfficient/finetune/eval.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

import os
import random
import warnings

@@ -13,25 +14,19 @@
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.config_utils import (
generate_dataset_config,
get_dataloader_kwargs,
update_config,
)
from utils.dataset_utils import (
get_custom_data_collator,
get_preprocessed_dataset,
)
from utils.config_utils import generate_dataset_config, update_config
from utils.dataset_utils import get_dataloader
from utils.train_utils import evaluation, print_model_size

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.logging_utils import logger

try:
import torch_qaic # noqa: F401

device = "qaic:0"
except ImportError as e:
print(f"Warning: {e}. Moving ahead without these qaic modules.")
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Suppress all warnings
@@ -42,18 +37,24 @@ def main(**kwargs):
# update the configuration for the training process
train_config = TrainConfig()
update_config(train_config, **kwargs)
dataset_config = generate_dataset_config(train_config.dataset)
update_config(dataset_config, **kwargs)

# Set the seeds for reproducibility
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)
np.random.seed(train_config.seed)

# Load the pre-trained model and setup its configuration
# config = AutoConfig.from_pretrained(train_config.model_name)
save_dir = "meta-llama-samsum/trained_weights/step_14000"
# Load the pre-trained model from latest checkpoint
trained_weights_path = os.path.join(train_config.output_dir, "trained_weights")
epoch_max_index = max([int(name.split("_")[-1]) for name in os.listdir(trained_weights_path)])
epochs_path = os.path.join(trained_weights_path, "epoch_" + str(epoch_max_index))
step_max_index = max([int(name.split("_")[-1]) for name in os.listdir(epochs_path)])
save_dir = os.path.join(epochs_path, "step_" + str(step_max_index))

# Load PEFT model on CPU
model_peft = AutoPeftModelForCausalLM.from_pretrained(save_dir)

# Merge LoRA and base model and save
merged_model = model_peft.merge_and_unload()
merged_model.save_pretrained(train_config.output_dir, safe_serialization=True)
@@ -77,44 +78,20 @@ def main(**kwargs):
# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.")
model.resize_token_embeddings(len(tokenizer))

print_model_size(model, train_config)

# Get the dataset utils
dataset_config = generate_dataset_config(train_config, kwargs)
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation
dataset_val = get_preprocessed_dataset(
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
)
print_model_size(model)

eval_dataloader = None
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
if train_config.run_validation:
# TODO: vbaddi enable packing later in entire infra.
# if train_config.batching_strategy == "packing":
# dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)

val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
if custom_data_collator:
val_dl_kwargs["collate_fn"] = custom_data_collator

eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**val_dl_kwargs,
)
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")
if len(eval_dataloader) == 0:
raise ValueError(
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
logger.raise_error(
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
ValueError,
)
else:
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")

model.to(device)
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)
92 changes: 27 additions & 65 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -4,27 +4,22 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import inspect
import json
import os
from dataclasses import asdict
from typing import Any, Dict

import torch.distributed as dist
import torch.utils.data as data_utils
import yaml
from peft import (
AdaptionPromptConfig,
PrefixTuningConfig,
)
from peft import LoraConfig as PeftLoraConfig
from transformers.data import DataCollatorForSeq2Seq

import QEfficient.finetune.configs.dataset_config as datasets
from QEfficient.finetune.configs.peft_config import LoraConfig, PrefixConfig
from QEfficient.finetune.configs.peft_config import LoraConfig
from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
from QEfficient.finetune.utils.helper import Peft_Method
from QEfficient.finetune.utils.logging_utils import logger


def update_config(config, **kwargs):
@@ -50,39 +45,34 @@ def update_config(config, **kwargs):
if hasattr(config, param_name):
setattr(config, param_name, v)
else:
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
logger.raise_error(
f"Config '{config_name}' does not have parameter: '{param_name}'", ValueError
)
else:
config_type = type(config).__name__
# FIXME (Meet): Once logger is available put this in debug level.
print(f"[WARNING]: Unknown parameter '{k}' for config type '{config_type}'")
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")


def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any:
def generate_peft_config(train_config: TrainConfig, **kwargs) -> Any:
"""Generate a PEFT-compatible configuration from a custom config based on peft_method.
Args:
train_config (TrainConfig): Training configuration with peft_method.
custom_config: Custom configuration object (e.g., LoraConfig).
Returns:
Any: A PEFT-specific configuration object (e.g., PeftLoraConfig).
Raises:
RuntimeError: If the peft_method is not supported.
"""
if peft_config_file:
peft_config_data = load_config_file(peft_config_file)
validate_config(peft_config_data, config_type="lora")
if train_config.peft_config_file:
peft_config_data = load_config_file(train_config.peft_config_file)
validate_config(peft_config_data, config_type=Peft_Method.LORA)
peft_config = PeftLoraConfig(**peft_config_data)
else:
config_map = {
"lora": (LoraConfig, PeftLoraConfig),
"prefix": (PrefixConfig, PrefixTuningConfig),
"adaption_prompt": (None, AdaptionPromptConfig),
}

config_map = {Peft_Method.LORA: (LoraConfig, PeftLoraConfig)}
if train_config.peft_method not in config_map:
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
logger.raise_error(f"Peft config not found: {train_config.peft_method}", RuntimeError)

config_cls, peft_config_cls = config_map[train_config.peft_method]
if config_cls is None:
@@ -115,37 +105,7 @@ def generate_dataset_config(dataset_name: str) -> Any:
return dataset_config


def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
if train_config.enable_ddp:
if train_config.enable_sorting_for_ddp:
if train_config.context_length:
raise ValueError(
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
)
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=False,
)
else:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
else:
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
return kwargs


def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None:
def validate_config(config_data: Dict[str, Any], config_type: str = Peft_Method.LORA) -> None:
"""Validate the provided YAML/JSON configuration for required fields and types.
Args:
@@ -160,8 +120,8 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
- Validates required fields for LoraConfig: r, lora_alpha, target_modules.
- Ensures types match expected values (int, float, list, etc.).
"""
if config_type.lower() != "lora":
raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.")
if config_type.lower() != Peft_Method.LORA:
logger.raise_error(f"Unsupported config_type: {config_type}. Only 'lora' is supported.", ValueError)

required_fields = {
"r": int,
@@ -178,26 +138,28 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
# Check for missing required fields
missing_fields = [field for field in required_fields if field not in config_data]
if missing_fields:
raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}")
logger.raise_error(f"Missing required fields in {config_type} config: {missing_fields}", ValueError)

# Validate types of required fields
for field, expected_type in required_fields.items():
if not isinstance(config_data[field], expected_type):
raise ValueError(
logger.raise_error(
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
f"got {type(config_data[field]).__name__}"
f"got {type(config_data[field]).__name__}",
ValueError,
)

# Validate target_modules contains strings
if not all(isinstance(mod, str) for mod in config_data["target_modules"]):
raise ValueError("All elements in 'target_modules' must be strings")
logger.raise_error("All elements in 'target_modules' must be strings", ValueError)

# Validate types of optional fields if present
for field, expected_type in optional_fields.items():
if field in config_data and not isinstance(config_data[field], expected_type):
raise ValueError(
logger.raise_error(
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
f"got {type(config_data[field]).__name__}"
f"got {type(config_data[field]).__name__}",
ValueError,
)


@@ -215,12 +177,12 @@ def load_config_file(config_path: str) -> Dict[str, Any]:
ValueError: If the file format is unsupported.
"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
logger.raise_error(f"Config file not found: {config_path}", FileNotFoundError)

with open(config_path, "r") as f:
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
return yaml.safe_load(f)
elif config_path.endswith(".json"):
return json.load(f)
else:
raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json")
logger.raise_error("Unsupported config file format. Use .yaml, .yml, or .json", ValueError)
84 changes: 77 additions & 7 deletions QEfficient/finetune/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -4,19 +4,22 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import datasets
import torch
import torch.distributed as dist
from transformers.data import DataCollatorForSeq2Seq

# from QEfficient.finetune.data.concatenator import ConcatDataset
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
from QEfficient.finetune.utils.config_utils import get_dataloader_kwargs
from QEfficient.finetune.utils.helper import get_num_ddp_devices
from QEfficient.finetune.utils.logging_utils import logger


def get_preprocessed_dataset(
tokenizer, dataset_config, split: str = "train", context_length: int = None
) -> torch.utils.data.Dataset:
if dataset_config.dataset not in DATASET_PREPROC:
raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
logger.raise_error(f"{dataset_config.dataset} is not (yet) implemented", NotImplementedError)

def get_split():
return dataset_config.train_split if split == "train" else dataset_config.test_split
@@ -31,12 +34,79 @@ def get_custom_data_collator(dataset_processer, dataset_config) -> torch.utils.d
return DATALOADER_COLLATE_FUNC[dataset_config.dataset](dataset_processer, dataset_config)


def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
kwargs = {}
batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size
if train_config.enable_ddp:
if train_config.enable_sorting_for_ddp:
if train_config.context_length:
logger.raise_error(
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding",
ValueError,
)
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=False,
)
else:
kwargs["sampler"] = torch.utils.data.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = False
else:
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = False
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
return kwargs


def padding_dataset(train_config, dataset, batch_size):
if train_config.enable_ddp and train_config.enable_sorting_for_ddp:
if isinstance(dataset, datasets.Dataset):
# Hugging Face Dataset transformation
dataset = dataset.map(lambda x: {"input_length": len(x["input_ids"])})
dataset = dataset.sort("input_length")

else:
dataset = sorted(dataset, key=lambda x: len(x["input_ids"]))

dummy_row = next(iter(dataset))
dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"]))
padding_size = 0
num_replicas = get_num_ddp_devices()
remainder = len(dataset) % (num_replicas * batch_size)
padding_size = (num_replicas * batch_size) - remainder

dummy_data = [dummy_row.copy() for _ in range(padding_size)]
dummy_dataset = datasets.Dataset.from_list(dummy_data)
if isinstance(dataset, datasets.Dataset):
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
else:
combined_dataset = dataset + list(dummy_dataset)
return combined_dataset


def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)

batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size
dataset = padding_dataset(train_config, dataset, batch_size)

dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)

# if split == "train" and train_config.batching_strategy == "packing":
# dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
# FIXME (Meet): Add custom data collator registration from the outside by the user.
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config)

if custom_data_collator:
print("custom_data_collator is used")
dl_kwargs["collate_fn"] = custom_data_collator

logger.log_rank_zero(f"Length of {split} dataset is {len(dataset)}")

# Create data loader
dataloader = torch.utils.data.DataLoader(
79 changes: 79 additions & 0 deletions QEfficient/finetune/utils/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
import os
from contextlib import nullcontext
from enum import Enum

import torch

try:
import torch_qaic.debug as qaic_debug # noqa: F401
except ImportError as e:
print(f"Warning: {e}. Moving ahead without these qaic modules.")


class Batching_Strategy(str, Enum):
PADDING = "padding"
PACKING = "packing"


class Device(str, Enum):
QAIC = "qaic"
CPU = "cpu"
CUDA = "cuda"


class Peft_Method(str, Enum):
LORA = "lora"


class Task_Mode(str, Enum):
GENERATION = "generation"
SEQ_CLASSIFICATION = "seq_classification"


def enum_names(enum_cls):
return [member.value for member in enum_cls]


def is_rank_zero():
return int(os.getenv("LOCAL_RANK", 0)) == 0


def get_num_ddp_devices():
return int(os.getenv("WORLD_SIZE", 1))


def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16):
return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext()


def get_op_verifier_ctx(
use_op_by_op_verifier,
train_device,
dump_dir,
step,
ref_device="cpu",
ref_dtype=torch.float32,
atol=1e-1,
rtol=1e-5,
use_ref_output_on_mismatch=True,
):
if not use_op_by_op_verifier:
return nullcontext()

filter_config = qaic_debug.DispatchFilterConfig.default(train_device)
dump_dir = dump_dir + "/mismatches/step_" + str(step)
return qaic_debug.OpByOpVerifierMode(
ref_device=ref_device,
ref_dtype=ref_dtype,
atol=atol,
rtol=rtol,
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
filter_config=filter_config,
dump_root_dir=dump_dir,
)
54 changes: 54 additions & 0 deletions QEfficient/finetune/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import logging
import os
from datetime import datetime

from QEfficient.finetune.utils.helper import is_rank_zero


class FTLogger:
def __init__(self):
self.logger = logging.getLogger("QEfficient")
if not getattr(self.logger, "_custom_methods_added", False):
self._bind_custom_methods()
self.logger._custom_methods_added = True # Prevent adding handlers/methods twice

def _bind_custom_methods(self):
def raise_error(message, errortype=RuntimeError):
self.logger.error(message)
raise errortype(message)

def log_rank_zero(msg: str, level: int = logging.INFO):
if is_rank_zero():
self.logger.log(level, msg, stacklevel=2)

def prepare_for_logs(output_path, dump_logs=False, level=logging.INFO):
self.logger.setLevel(level)
if dump_logs:
logs_path = os.path.join(output_path, "logs")
if not os.path.exists(logs_path):
os.makedirs(logs_path, exist_ok=True)
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
log_file = os.path.join(logs_path, file_name)

fh = logging.FileHandler(log_file)
fh.setLevel(level)
formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s")
fh.setFormatter(formatter)
self.logger.addHandler(fh)

self.logger.raise_error = raise_error
self.logger.log_rank_zero = log_rank_zero
self.logger.prepare_for_logs = prepare_for_logs

def get_logger(self):
return self.logger


logger = FTLogger().get_logger()
290 changes: 290 additions & 0 deletions QEfficient/finetune/utils/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import argparse
import logging

from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
from QEfficient.finetune.utils.helper import Batching_Strategy, Device, Peft_Method, Task_Mode, enum_names


def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")


def get_finetune_parser():
parser = argparse.ArgumentParser(
description="Finetune command, the model is downloaded from Huggingface, finetuned on Cloud AI 100 and checkpoints are saved."
)
parser.add_argument(
"--model_name",
"--model-name",
required=False,
type=str,
default="meta-llama/Llama-3.2-1B",
help="Name of the pre-trained model to fine-tune",
)
parser.add_argument(
"--tokenizer_name",
"--tokenizer-name",
required=False,
type=str,
default=None,
help="Name of the tokenizer,if not passed as an argument, it uses the value of model_name",
)
parser.add_argument(
"--run_validation",
"--run-validation",
type=str2bool,
nargs="?",
const=True,
default=True,
help="To run validation during training",
)
parser.add_argument(
"--train_batch_size", "--train-batch-size", required=False, type=int, default=1, help="Batch size for training"
)
parser.add_argument(
"--val_batch_size", "--val-batch-size", required=False, type=int, default=1, help="Batch size for validation"
)
parser.add_argument(
"--context_length",
"--context-length",
required=False,
type=int,
default=None,
help="Maximum sequence length for inputs",
)
parser.add_argument(
"--gradient_accumulation_steps",
"--gradient-accumulation-steps",
required=False,
type=int,
default=4,
help="Steps for gradient accumulation",
)
parser.add_argument(
"--gradient_checkpointing",
"--gradient-checkpointing",
action="store_true",
help="Use gradient checkpointing",
)
parser.add_argument(
"--use_autocast",
"--use-autocast",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Use autocast for mixed precision",
)
parser.add_argument(
"--grad_scaler", "--grad-scaler", type=str2bool, nargs="?", const=True, default=True, help="Use gradient scaler"
)
parser.add_argument(
"--num_epochs", "--num-epochs", required=False, type=int, default=1, help="Number of training epochs"
)
parser.add_argument(
"--max_train_step",
"--max-train-step",
required=False,
type=int,
default=0,
help="Maximum training steps, unlimited if 0",
)
parser.add_argument(
"--max_eval_step",
"--max-eval-step",
required=False,
type=int,
default=0,
help="Maximum evaluation steps, unlimited if 0",
)
parser.add_argument(
"--device",
required=False,
type=str,
default=Device.QAIC.value,
choices=enum_names(Device),
help="Device to train on",
)
parser.add_argument(
"--num_workers_dataloader",
"--num-workers-dataloader",
required=False,
type=int,
default=1,
help="Number of workers for dataloader",
)
parser.add_argument("--lr", required=False, type=float, default=3e-4, help="Learning rate ")
parser.add_argument(
"--weight_decay", "--weight-decay", required=False, type=float, default=0.0, help="Weight decay for optimizer"
)
parser.add_argument(
"--gamma",
required=False,
type=float,
default=0.85,
help="Learning rate decay factor, multiplicatively decays the learning rate by gamma after each epoch",
)
parser.add_argument("--seed", required=False, type=int, default=42, help="Random seed for reproducibility")
parser.add_argument(
"--dataset",
required=False,
default="alpaca_dataset",
type=str,
choices=DATASET_PREPROC.keys(),
help="Dataset name to be used for finetuning (default: %(default)s)",
)
parser.add_argument(
"--task_mode",
"--task-mode",
required=False,
type=str,
default=Task_Mode.GENERATION.value,
choices=enum_names(Task_Mode),
help="Task used for finetuning. Use 'generation' for decoder based models and 'seq_classification' for encoder based models.",
)
parser.add_argument(
"--use_peft",
"--use-peft",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Whether to use PEFT(parameter efficient fine tuning)",
)
parser.add_argument(
"--peft_method",
"--peft-method",
required=False,
type=str,
default=Peft_Method.LORA.value,
choices=enum_names(Peft_Method),
help="Parameter efficient finetuning technique to be used. Currently only 'lora' is supported.",
)
parser.add_argument(
"--from_peft_checkpoint",
"--from-peft-checkpoint",
required=False,
type=str,
default="",
help="Path to load PEFT checkpoint and resume the fine-tuning on that checkpoint",
)
parser.add_argument(
"--output_dir",
"--output-dir",
required=False,
type=str,
default="training_results",
help="Directory to save outputs of training",
)
parser.add_argument(
"--save_model",
"--save-model",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Save the final trained model checkpoints",
)
parser.add_argument(
"--save_metrics",
"--save-metrics",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Save training metrics to a json file for later plotting",
)
parser.add_argument(
"--intermediate_step_save",
"--intermediate-step-save",
required=False,
type=int,
default=1000,
help="Steps between intermediate saves of checkpoint",
)
parser.add_argument(
"--batching_strategy",
"--batching-strategy",
required=False,
type=str,
default=Batching_Strategy.PADDING.value,
choices=enum_names(Batching_Strategy),
help="Strategy for making batches of data points. Packing groups data points into batches by minimizing unnecessary empty spaces. Padding adds extra values (often zeros) to batch sequences so they align in size. Currently only padding is supported which is by default.",
)
parser.add_argument(
"--enable_sorting_for_ddp",
"--enable_sorting-for-ddp",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Sort the data points according to sequence length for DDP",
)
parser.add_argument(
"--convergence_counter",
"--convergence-counter",
required=False,
type=int,
default=5,
help="Steps to check convergence, its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps",
)
parser.add_argument(
"--convergence_loss",
"--convergence-loss",
required=False,
type=float,
default=1e-4,
help="Loss threshold for convergence, if loss value is <= convergence_loss for #convergence_counter consecutive steps, fine tuning stops",
)
parser.add_argument(
"--use_profiler",
"--use-profiler",
action="store_true",
help="Enable profiling for the operations during pytorch eager mode execution.",
)
parser.add_argument(
"--enable_ddp",
"--enable-ddp",
action="store_true",
help="Enable distributed data parallel training. This will load the replicas of model on given number of devices and train the model. This should be used using torchrun interface. Please check docs for exact usage.",
)
parser.add_argument(
"--opByOpVerifier",
action="store_true",
help=argparse.SUPPRESS,
# This is for debugging purpose only.
# Enables operation-by-operation verification w.r.t reference device(cpu).
# It is a context manager interface that captures and verifies each operator against reference device.
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at output_dir/mismatches.
)
parser.add_argument(
"--log_level",
"--log-level",
required=False,
type=str,
default=logging.INFO,
help="logging level",
)
parser.add_argument(
"--peft_config_file",
"--peft-config-file",
type=str,
default=None,
help="Path to YAML/JSON file containing PEFT (LoRA) config.",
)

return parser
8 changes: 5 additions & 3 deletions QEfficient/finetune/utils/plot_metrics.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@

import matplotlib.pyplot as plt

from QEfficient.finetune.utils.logging_utils import logger


def plot_metric(data, metric_name, x_label, y_label, title, colors):
plt.figure(figsize=(7, 6))
@@ -67,14 +69,14 @@ def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):

def plot_metrics(file_path):
if not os.path.exists(file_path):
print(f"File {file_path} does not exist.")
logger.raise_error(f"File {file_path} does not exist.", FileNotFoundError)
return

with open(file_path, "r") as f:
try:
data = json.load(f)
except json.JSONDecodeError:
print("Invalid JSON file.")
except json.JSONDecodeError as e:
logger.raise_error("Invalid JSON file.", e)
return

directory = os.path.dirname(file_path)
271 changes: 144 additions & 127 deletions QEfficient/finetune/utils/train_utils.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ def __repr__(self):
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)} sec\
\nDecode is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)} tokens/sec\
\nTotal is= {round(self.perf_metrics.total_perf * self.batch_size, 2)} tokens/sec\
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} tokens/sec"
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} sec"


@dataclass
209 changes: 207 additions & 2 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,10 @@
# -----------------------------------------------------------------------------


from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, HybridChunkedCache

from QEfficient.customop import (
CtxGatherFunc,
@@ -283,3 +283,208 @@ def from_legacy_cache(
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache


class QEffHybridCache(HybridCache):
def __init__(self, config, batch_size, max_cache_len):
super().__init__(config, batch_size, max_cache_len=max_cache_len)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

@classmethod
def from_legacy_cache(
cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2])
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
sliding_window_pattern = cache_kwargs.get("sliding_window_pattern")
is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern))
layer_ctx_len = self.key_cache[layer_idx].shape[2]
kv_position_ids = torch.where(
(~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1)
)

kv_position_ids = torch.where(
is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2),
(position_ids + 1) % layer_ctx_len,
kv_position_ids,
)

valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices)
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out


class QEffHybridChunkedCache(HybridChunkedCache):
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `HybridChunkedCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(
cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "HybridChunkedCache":
"""Converts a cache in the legacy cache format into an equivalent `HybridChunkedCache`. Used for
backward compatibility."""
cache = cls(config, max_batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2])
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states

else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx]))

# Update the position_ids to handle the sliding window
layer_ctx_len = self.key_cache[layer_idx].shape[2]
kv_position_ids = torch.where(
(~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1)
)

kv_position_ids = torch.where(
is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2),
(position_ids + 1) % layer_ctx_len,
kv_position_ids,
)

valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = min(layer_ctx_len, k_out.shape[2])
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

# Rolling indices for sliding window
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices)
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out
6 changes: 6 additions & 0 deletions QEfficient/transformers/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
125 changes: 125 additions & 0 deletions QEfficient/transformers/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import inspect
from typing import Optional

import torch
import torch.nn as nn


def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs mean pooling on the last hidden states of a transformer model.
Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
Returns:
torch.Tensor: The mean pooled last hidden states.
"""
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
return torch.sum(last_hidden_states * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs average pooling on the last hidden states of a transformer model.
Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
Returns:
torch.Tensor: The average pooled last hidden states.
"""
last_hidden = last_hidden_states[0].masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs max pooling on the last hidden states of a transformer model.
Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
Returns:
torch.Tensor: The max pooled last hidden states.
"""
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
last_hidden_states[input_mask_expanded == 0] = -1e9
return torch.max(last_hidden_states, 1)[0]


def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Performs CLS pooling on the last hidden states of a transformer model.
Args:
last_hidden_states (torch.Tensor): The last hidden states of the transformer model.
attention_mask (torch.Tensor): The attention mask used to mask out padding tokens.
Returns:
torch.Tensor: The CLS pooled last hidden states.
"""
return last_hidden_states[:, 0]


POOLING_MAP = {
"mean": mean_pooling,
"avg": average_pool,
"cls": cls_pooling,
"max": max_pooling,
}


class PooledModel(nn.Module):
"""
Adds pooling functionality to embedding model.
"""

def __init__(self, base_model, pooling_fn):
super().__init__()
self.config = base_model.config
self.base_model = base_model
self.pooling_fn = pooling_fn

def forward(
self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
):
output = self.base_model(input_ids, attention_mask, **kwargs)
return self.pooling_fn(output[0], attention_mask)


def validate_user_pooling_function(user_function):
"""
Validate a user-provided pooling function to ensure it meets the required interface.
The function should take two arguments:
- last_hidden_states (torch.Tensor): The last hidden states of the model.
- attention_mask (torch.Tensor): The attention mask of the input sequence.
It should return a torch.Tensor representing the pooled output.
Args:
user_function (callable): The user-provided pooling function.
Raises:
ValueError: If the user-provided function does not meet the required interface.
"""

if not callable(user_function):
raise TypeError("Provided pooling function is not callable.")

sig = inspect.signature(user_function)
required_args = {"last_hidden_states", "attention_mask"}
if not required_args.issubset(sig.parameters.keys()):
raise ValueError(f"Pooling function must accept arguments: {required_args}")
return user_function
4 changes: 4 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -182,6 +182,8 @@
]
)

DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
@@ -283,6 +285,8 @@ def build_model_class_mapping(auto_model_class, qeff_class_name):
}


EXTERNAL_MODEL_CLASS_MAPPING = {"Grok1Config": "QEFFAutoModelForCausalLM"}

MODEL_CLASS_MAPPING = {
**build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM"),
**build_model_class_mapping(mapping.AutoModelForImageTextToText, "QEFFAutoModelForImageTextToText"),
Original file line number Diff line number Diff line change
@@ -85,7 +85,6 @@ def forward(
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
qkv = self.qkv_proj(hidden_states)
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
mp_num = 4
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))

6 changes: 5 additions & 1 deletion QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
@@ -183,7 +183,11 @@ def forward(
):
residual = hidden_states

attention_layernorm_out = self.input_layernorm(hidden_states)
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)

# Self attention.
attn_outputs = self.self_attention(
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/gemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading