Skip to content

Add Gemma3 #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/gemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
779 changes: 779 additions & 0 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py

Large diffs are not rendered by default.

29 changes: 17 additions & 12 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
@@ -789,7 +789,9 @@ def kv_offload_generate(
}

vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16")
vision_start = perf_counter()
vision_outputs = vision_session.run(vision_inputs)
vision_end = perf_counter()

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
lang_inputs["position_ids"] = np.where(
@@ -798,22 +800,28 @@ def kv_offload_generate(

vision_session.deactivate()
lang_session.activate()

lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"]
lang_session.set_buffers(vision_outputs)

prefill_start = perf_counter()
# Run prefill
chunk_inputs = lang_inputs.copy()
chunk_inputs["index"] = np.array([[0]])
for i in range(num_chunks):
chunk_inputs = lang_inputs.copy()
chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
chunk_inputs["position_ids"] = lang_inputs["position_ids"][
:, i * prefill_seq_len : (i + 1) * prefill_seq_len
]
outputs = lang_session.run(chunk_inputs)
chunk_inputs["index"] = outputs["index_output"]
Copy link

@quic-xiyushi quic-xiyushi May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what are chunk_inputs["index"] and outputs["index_output"]?
Also, with the new approach, is batching supported?


prefill_time = perf_counter() - prefill_start
prefill_time = perf_counter() - prefill_start + vision_end - vision_start
# Skip inputs/outputs again
lang_session.skip_buffers(
[x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")]
[
x
for x in lang_session.input_names + lang_session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

# Get first token
@@ -844,7 +852,7 @@ def kv_offload_generate(
streamer.end()

decode_perf = (num_token - 1) / (decode_end - decode_start)
total_time = decode_end - prefill_start
total_time = decode_end - decode_start + prefill_time
total_perf = num_token / total_time

return CloudAI100ExecInfoNew(
@@ -1025,11 +1033,8 @@ def cloud_ai_100_generate(
qpc_session = QAICInferenceSession(
self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False
)

batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path)

pad_token_id = 1

# Skip inputs/outputs
qpc_session.skip_buffers(
[
@@ -1085,16 +1090,16 @@ def cloud_ai_100_generate(
inputs["pixel_values"] = inputs["pixel_values"].astype("float16")

inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)

inputs["index"] = np.array([[0]])
qpc_session.activate()

chunk_inputs = inputs.copy()
# Run prefill

for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
outputs = qpc_session.run(chunk_inputs)
chunk_inputs["index"] = outputs["index_output"]

prefill_time = perf_counter() - prefill_start
# Get first token
23 changes: 23 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,14 @@
Gemma2Model,
Gemma2RMSNorm,
)
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3DecoderLayer,
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3RMSNorm,
Gemma3TextModel,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
@@ -157,6 +165,14 @@
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from QEfficient.transformers.models.gemma3.modeling_gemma3 import (
QEffGemma3Attention,
QEffGemma3CustomRMSNormAIC,
QEffGemma3DecoderLayer,
QEffGemma3ForCausalLMModel,
QEffGemma3ForConditionalGeneration,
QEffGemma3TextModel,
)
from QEfficient.transformers.models.gpt2.modeling_gpt2 import (
QEffGPT2Attention,
QEffGPT2Block,
@@ -284,6 +300,7 @@ class CustomOpsTransform(ModuleMappingTransform):
MllamaTextRMSNorm: CustomRMSNormAIC,
GraniteRMSNorm: CustomRMSNormAIC,
GraniteMoeRMSNorm: CustomRMSNormAIC,
Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC,
}


@@ -328,6 +345,12 @@ class KVCacheTransform(ModuleMappingTransform):
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
Gemma2Model: QEffGemma2Model,
Gemma2ForCausalLM: QEffGemma2ForCausalLM,
# Gemma3
Gemma3Attention: QEffGemma3Attention,
Gemma3DecoderLayer: QEffGemma3DecoderLayer,
Gemma3TextModel: QEffGemma3TextModel,
Gemma3ForCausalLM: QEffGemma3ForCausalLMModel,
Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration,
# Granite
GraniteModel: QEffGraniteModel,
GraniteForCausalLM: QEffGraniteForCausalLM,
Loading