Skip to content

Early fusion dual qpc modification for index-based embedding interleaving. #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7c4ed22
Add Llama4 Multi-Modal Support
vbaddi Apr 29, 2025
602253e
nit: modeling changes
vbaddi Apr 29, 2025
9f183f5
Adding Vision Part and Chunking (#383)
mohiso22 Apr 30, 2025
ecd4b8c
nit: update moe implementation and add sample export/compile script
vbaddi May 2, 2025
eef428e
nit: fix linter for example script
vbaddi May 2, 2025
a149c76
nit: update pytorch transforms to map Llama4TextExperts
vbaddi May 2, 2025
8df2303
nit: update modeling with new freq apply computation and sample mm ex…
vbaddi May 4, 2025
1a16539
nit: update llama4 mm example script
vbaddi May 4, 2025
db4fc4e
nit: update modeling to avoid >2GiB issue in Onnx, rope max-position
vbaddi May 6, 2025
e8a36f3
Added pytorch transform for the split_gate_up_weights and removed exa…
quic-amitraj May 8, 2025
b6b6e3d
Ruff Check and format
quic-amitraj May 8, 2025
240bc32
Minor fixes-1
quic-amitraj May 8, 2025
517f3fd
Added logger for new transform
quic-amitraj May 8, 2025
941b272
fixed Llama4 MOE accuracy bug
ochougul May 18, 2025
307b655
Updating index method in Wrappers (#410)
mohiso22 May 19, 2025
7cccc33
nit: add position_ids to attn_scales instead of cache_position in use…
vbaddi May 20, 2025
21439a9
Minor Fixes (#421)
mohiso22 May 21, 2025
53d3314
Updating Specialization and modeling auto files
mohiso22 May 21, 2025
ac31c9e
Fix for Multi Image Chunking
mohiso22 May 23, 2025
e877f9f
Adding SingleQPC
mohiso22 Jun 1, 2025
77afbbc
Rebase and Minor Fixes
mohiso22 Jun 9, 2025
502c30b
Addressed Comments
mohiso22 Jun 9, 2025
5965426
Modified Early Fusin VLMs to enable index based interleaving of embed…
quic-dhirajku Jun 10, 2025
96bc0dc
Renamed vision_idx to image_idx
quic-dhirajku Jun 10, 2025
c6634bf
Merge branch 'main' into early_fusion_dual_qpc
quic-rishinr Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.base.pytorch_transforms import PytorchTransform, append_tranform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants, dump_qconfig
Expand All @@ -46,6 +46,7 @@ class QEFFBaseModel(ABC):
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

@append_tranform
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model
Expand Down
70 changes: 70 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from torch import nn

from QEfficient.utils.logging_utils import logger


class PytorchTransform:
"""
Expand Down Expand Up @@ -110,3 +112,71 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = True

return model, transformed


class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()
for layer_idx in range(num_layers):
# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if delete_fused_key:
del sd[fused_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp
return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = ["Llama4ForConditionalGeneration", "Llama4TextModel"]


def append_tranform(func):
def wrapper(*args, **kwargs):
model_class = args[1].model.__class__.__name__ if hasattr(args[1], "model") else args[1].__class__.__name__
if model_class in VLM_SPLIT_GATE_UP_WEIGHTS:
args[0]._pytorch_transforms.append(SplitGateUpWeightsTransform)
return func(*args, **kwargs)

return wrapper
31 changes: 21 additions & 10 deletions QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -31,21 +31,23 @@ def __init__(self, model):
self.config = self.model.language_model.config
self.language_model = self.model.language_model

def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
)
return outputs.logits, vision_embeds, outputs.past_key_values
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return outputs.logits, vision_embeds, image_idx, outputs.past_key_values


class QEffInternVLModel(nn.Module):
Expand Down Expand Up @@ -81,13 +83,14 @@ def get_specializations(
logger.warning("Setting img_size to be 448, as it was neither passed nor found in vision_config")
if img_size != constants.INTERN_IMG_SIZE and kv_offload:
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")

per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2
vision_size = int(num_patches * per_patch_embed_size)
vision = [
{
"batch_size": batch_size,
"num_patches": num_patches,
"img_size": img_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
}
]
lang = [
Expand All @@ -97,13 +100,15 @@ def get_specializations(
"ctx_len": ctx_len,
"num_patches": num_patches,
"img_size": img_size,
"vision_size": vision_size,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"num_patches": num_patches,
"img_size": img_size,
"vision_size": vision_size,
},
]

Expand All @@ -122,7 +127,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vision_embeds"] = {0: "num_patches"}
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "vision_size"}
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
Expand All @@ -148,10 +153,12 @@ def get_output_names(self, kv_offload: bool = False):
output_names = {}
if kv_offload:
lang_output_names.insert(1, "vision_embeds_RetainedState")
lang_output_names.insert(2, "image_idx_output")
output_names["vision"] = vision_output_names
output_names["lang"] = lang_output_names
else:
lang_output_names.insert(1, "pixel_values_RetainedState")
lang_output_names.insert(2, "image_idx_output")
return lang_output_names
return output_names

Expand All @@ -176,8 +183,8 @@ def get_dummy_inputs(self, kv_offload: bool = False):
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vision_embeds"] = (
constants.INTERN_NUM_PATCHES,
constants.INTERN_FEATURE_SIZE,
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
computed_feature_size,
self.language_model.config.hidden_size,
)
inputs_shapes["position_ids"] = (
Expand All @@ -202,6 +209,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64)

# Add data for KV
kv_cache_shape = get_padding_shape_from_config(
Expand All @@ -225,22 +233,25 @@ def get_dummy_inputs(self, kv_offload: bool = False):

return inputs

def forward(self, input_ids, pixel_values, position_ids, past_key_values):
def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values):
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vision_embeds = self.extract_feature(pixel_values)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
)
return outputs.logits, pixel_values, outputs.past_key_values
next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx)
return outputs.logits, pixel_values, image_idx, outputs.past_key_values

def get_inputs_info(self):
return [
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llama4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading