Skip to content

Commit e68913c

Browse files
committed
Onboarding Mistral3.1_24B
Signed-off-by: Mohit Soni <[email protected]>
1 parent c5a5c17 commit e68913c

File tree

4 files changed

+265
-0
lines changed

4 files changed

+265
-0
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
MistralModel,
5959
MistralRMSNorm,
6060
)
61+
from transformers.models.mistral3.modeling_mistral3 import (
62+
Mistral3ForConditionalGeneration,
63+
Mistral3RMSNorm,
64+
)
6165
from transformers.models.mixtral.modeling_mixtral import (
6266
MixtralAttention,
6367
MixtralDecoderLayer,
@@ -70,6 +74,7 @@
7074
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
7175
from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel
7276
from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
77+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
7378
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm
7479
from transformers.models.starcoder2.modeling_starcoder2 import (
7580
Starcoder2Attention,
@@ -88,6 +93,7 @@
8893
)
8994

9095
from QEfficient.customop import CustomRMSNormAIC
96+
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration
9197

9298
# Placeholder for all non-transformer models
9399
from .models.codegen.modeling_codegen import (
@@ -179,6 +185,7 @@
179185
GPTBigCodeForCausalLM.__name__,
180186
MllamaForCausalLM.__name__,
181187
WhisperForConditionalGeneration.__name__,
188+
Mistral3ForConditionalGeneration.__name__,
182189
]
183190
)
184191

@@ -230,6 +237,9 @@
230237
MistralModel: QEffMistralModel,
231238
MistralForCausalLM: QEffMistralForCausalLM,
232239
MistralRMSNorm: CustomRMSNormAIC,
240+
# Mistral3 model layers
241+
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
242+
Mistral3RMSNorm: CustomRMSNormAIC,
233243
# Mixtral model layers
234244
MixtralAttention: QEffMixtralAttention,
235245
MixtralDecoderLayer: QeffMixtralDecoderLayer,
@@ -246,6 +256,8 @@
246256
PhiAttention: QEffPhiAttention,
247257
PhiModel: QEffPhiModel,
248258
PhiForCausalLM: QEffPhiForCausalLM,
259+
# Pixtral model layers
260+
PixtralRMSNorm: CustomRMSNormAIC,
249261
# Falcon model layers
250262
FalconAttention: QEffFalconAttention,
251263
FalconForCausalLM: QEffFalconForCausalLM,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.utils.checkpoint
11+
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
12+
13+
from QEfficient.utils import constants
14+
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
15+
16+
BS = 1
17+
NUM_CHANNEL = 3
18+
SEQ_LEN = 3072
19+
CTX_LEN = 4096
20+
21+
22+
class QEFFMistral3EncoderWrapper(nn.Module):
23+
def __init__(self, model):
24+
super().__init__()
25+
self.model = model
26+
self.model.vision_model = self.model.vision_tower
27+
28+
def forward(self, pixel_values):
29+
image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]])
30+
image_features = self.model.get_image_features(
31+
pixel_values=pixel_values,
32+
vision_feature_layer=self.model.config.vision_feature_layer,
33+
image_sizes=image_sizes,
34+
)
35+
return image_features
36+
37+
38+
class QEFFMistral3DecoderWrapper(nn.Module):
39+
def __init__(self, model):
40+
super().__init__()
41+
self.model = model
42+
self.config = self.model.config
43+
self.language_model = self.model.language_model
44+
45+
def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
46+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
47+
vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
48+
mask = input_ids == self.model.config.image_token_index
49+
indices1 = mask.to(torch.int64).cumsum(1) - 1
50+
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
51+
image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1]
52+
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
53+
outputs = self.model.language_model(
54+
inputs_embeds=inputs_embeds,
55+
position_ids=position_ids,
56+
past_key_values=past_key_values,
57+
)
58+
59+
return outputs.logits, vit_embeds, outputs.past_key_values
60+
61+
62+
class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration):
63+
def get_qeff_vision_encoder(self):
64+
return QEFFMistral3EncoderWrapper(self)
65+
66+
def get_qeff_language_decoder(self):
67+
return QEFFMistral3DecoderWrapper(self)
68+
69+
def forward(self, pixel_values, input_ids, position_ids, past_key_values):
70+
inputs_embeds = self.get_input_embeddings()(input_ids)
71+
# Image features
72+
image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]])
73+
image_features = self.get_image_features(
74+
pixel_values=pixel_values,
75+
vision_feature_layer=self.config.vision_feature_layer,
76+
image_sizes=image_sizes,
77+
)
78+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
79+
mask = input_ids == self.config.image_token_index
80+
indices1 = mask.to(torch.int64).cumsum(1) - 1
81+
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
82+
image_features_expanded = image_features.unsqueeze(0)[indices0, indices1]
83+
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
84+
outputs = self.language_model(
85+
inputs_embeds=inputs_embeds,
86+
position_ids=position_ids,
87+
past_key_values=past_key_values,
88+
)
89+
return outputs.logits, pixel_values, outputs.past_key_values
90+
91+
def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
92+
inputs_shapes = {}
93+
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
94+
inputs_shapes["vit_embeds"] = (
95+
constants.MISTRAL3_FEATURE_SIZE,
96+
self.language_model.config.hidden_size,
97+
)
98+
inputs_shapes["position_ids"] = (
99+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
100+
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
101+
)
102+
inputs_shapes["pixel_values"] = (
103+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
104+
constants.MISTRAL3_NUM_CHANNELS,
105+
constants.MISTRAL3_HEIGHT,
106+
constants.MISTRAL3_WIDTH,
107+
)
108+
109+
# Define inputs
110+
vision_inputs = {}
111+
lang_inputs = {}
112+
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
113+
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
114+
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
115+
lang_inputs["position_ids"] = (
116+
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
117+
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
118+
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
119+
)
120+
121+
# Add data for KV
122+
kv_cache_shape = get_padding_shape_from_config(
123+
config=self.language_model.config,
124+
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
125+
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
126+
)
127+
128+
lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
129+
for i in range(self.language_model.config.num_hidden_layers):
130+
for kv in ["key", "value"]:
131+
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
132+
133+
inputs = {}
134+
if kv_offload:
135+
inputs["vision"] = vision_inputs
136+
inputs["lang"] = lang_inputs
137+
else:
138+
lang_inputs.pop("vit_embeds")
139+
inputs = {**vision_inputs, **lang_inputs}
140+
141+
return inputs
142+
143+
def get_specializations(
144+
self,
145+
batch_size: int,
146+
prefill_seq_len: int,
147+
ctx_len: int,
148+
img_size: int,
149+
kv_offload: bool = False,
150+
**compiler_options,
151+
):
152+
prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN
153+
ctx_len = ctx_len if ctx_len else CTX_LEN
154+
height = constants.MISTRAL3_HEIGHT
155+
width = constants.MISTRAL3_WIDTH
156+
157+
vision = [
158+
{
159+
"batch_size": batch_size,
160+
"seq_len": prefill_seq_len,
161+
"ctx_len": ctx_len,
162+
"height": height,
163+
"width": width,
164+
}
165+
]
166+
lang = [
167+
{
168+
"batch_size": batch_size,
169+
"seq_len": prefill_seq_len,
170+
"ctx_len": ctx_len,
171+
"height": height,
172+
"width": width,
173+
},
174+
{
175+
"batch_size": batch_size,
176+
"seq_len": "1",
177+
"ctx_len": ctx_len,
178+
"height": height,
179+
"width": width,
180+
},
181+
]
182+
specializations = {}
183+
184+
if kv_offload:
185+
specializations["vision"] = vision
186+
specializations["lang"] = lang
187+
return specializations, compiler_options
188+
else:
189+
return lang, compiler_options
190+
191+
def get_onnx_dynamic_axes(self, kv_offload: bool = False):
192+
# Define dynamic axes
193+
num_layers = self.config.text_config.num_hidden_layers
194+
195+
vision_dynamic_axes = {
196+
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
197+
}
198+
lang_dynamic_axes = {
199+
"input_ids": {0: "batch_size", 1: "seq_len"},
200+
"position_ids": {0: "batch_size", 1: "seq_len"},
201+
}
202+
203+
for i in range(num_layers):
204+
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
205+
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
206+
207+
dynamic_axes = {}
208+
if kv_offload:
209+
dynamic_axes["vision"] = vision_dynamic_axes
210+
dynamic_axes["lang"] = lang_dynamic_axes
211+
else:
212+
dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes}
213+
return dynamic_axes
214+
215+
def get_output_names(self, kv_offload: bool = False):
216+
vision_output_names = ["vit_embeds"]
217+
lang_output_names = ["logits"]
218+
for i in range(self.language_model.config.num_hidden_layers):
219+
for kv in ["key", "value"]:
220+
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
221+
222+
output_names = {}
223+
if kv_offload:
224+
lang_output_names.insert(1, "vit_embeds_RetainedState")
225+
output_names["vision"] = vision_output_names
226+
output_names["lang"] = lang_output_names
227+
else:
228+
lang_output_names.insert(1, "pixel_values_RetainedState")
229+
return lang_output_names
230+
return output_names
231+
232+
def get_inputs_info(self):
233+
return [
234+
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
235+
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
236+
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")),
237+
]

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@
100100
MistralModel,
101101
MistralRMSNorm,
102102
)
103+
from transformers.models.mistral3.modeling_mistral3 import (
104+
Mistral3ForConditionalGeneration,
105+
Mistral3RMSNorm,
106+
)
103107
from transformers.models.mixtral.modeling_mixtral import (
104108
MixtralAttention,
105109
MixtralDecoderLayer,
@@ -129,6 +133,7 @@
129133
Phi3Model,
130134
Phi3RMSNorm,
131135
)
136+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
132137
from transformers.models.qwen2.modeling_qwen2 import (
133138
Qwen2Attention,
134139
Qwen2DecoderLayer,
@@ -260,6 +265,7 @@
260265
QEffMistralForCausalLM,
261266
QEffMistralModel,
262267
)
268+
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration
263269
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
264270
QEffMixtralAttention,
265271
QeffMixtralDecoderLayer,
@@ -332,13 +338,15 @@ class CustomOpsTransform(ModuleMappingTransform):
332338
LlamaRMSNorm: CustomRMSNormAIC,
333339
Llama4TextRMSNorm: CustomRMSNormAIC,
334340
MistralRMSNorm: CustomRMSNormAIC,
341+
Mistral3RMSNorm: CustomRMSNormAIC,
335342
MixtralRMSNorm: CustomRMSNormAIC,
336343
Phi3RMSNorm: CustomRMSNormAIC,
337344
Qwen2RMSNorm: CustomRMSNormAIC,
338345
MllamaTextRMSNorm: CustomRMSNormAIC,
339346
GraniteRMSNorm: CustomRMSNormAIC,
340347
GraniteMoeRMSNorm: CustomRMSNormAIC,
341348
Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC,
349+
PixtralRMSNorm: CustomRMSNormAIC,
342350
}
343351

344352

@@ -426,6 +434,8 @@ class KVCacheTransform(ModuleMappingTransform):
426434
MistralDecoderLayer: QEffMistralDecoderLayer,
427435
MistralModel: QEffMistralModel,
428436
MistralForCausalLM: QEffMistralForCausalLM,
437+
# Mistral3
438+
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
429439
# Mixtral
430440
MixtralAttention: QEffMixtralAttention,
431441
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,

0 commit comments

Comments
 (0)