Skip to content

Commit 8705af0

Browse files
committed
Onboarding Mistral3.1_24B
Signed-off-by: Mohit Soni <[email protected]>
1 parent 6199051 commit 8705af0

File tree

5 files changed

+272
-0
lines changed

5 files changed

+272
-0
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
MistralModel,
5858
MistralRMSNorm,
5959
)
60+
from transformers.models.mistral3.modeling_mistral3 import (
61+
Mistral3ForConditionalGeneration,
62+
Mistral3RMSNorm,
63+
)
6064
from transformers.models.mixtral.modeling_mixtral import (
6165
MixtralAttention,
6266
MixtralDecoderLayer,
@@ -69,6 +73,7 @@
6973
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
7074
from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel
7175
from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
76+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
7277
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm
7378
from transformers.models.starcoder2.modeling_starcoder2 import (
7479
Starcoder2Attention,
@@ -87,6 +92,7 @@
8792
)
8893

8994
from QEfficient.customop import CustomRMSNormAIC
95+
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration
9096

9197
from .models.codegen.modeling_codegen import (
9298
QEffCodeGenAttention,
@@ -177,6 +183,7 @@
177183
GPTBigCodeForCausalLM.__name__,
178184
MllamaForCausalLM.__name__,
179185
WhisperForConditionalGeneration.__name__,
186+
Mistral3ForConditionalGeneration.__name__,
180187
]
181188
)
182189

@@ -226,6 +233,9 @@
226233
MistralModel: QEffMistralModel,
227234
MistralForCausalLM: QEffMistralForCausalLM,
228235
MistralRMSNorm: CustomRMSNormAIC,
236+
# Mistral3 model layers
237+
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
238+
Mistral3RMSNorm: CustomRMSNormAIC,
229239
# Mixtral model layers
230240
MixtralAttention: QEffMixtralAttention,
231241
MixtralDecoderLayer: QeffMixtralDecoderLayer,
@@ -242,6 +252,8 @@
242252
PhiAttention: QEffPhiAttention,
243253
PhiModel: QEffPhiModel,
244254
PhiForCausalLM: QEffPhiForCausalLM,
255+
# Pixtral model layers
256+
PixtralRMSNorm: CustomRMSNormAIC,
245257
# Falcon model layers
246258
FalconAttention: QEffFalconAttention,
247259
FalconForCausalLM: QEffFalconForCausalLM,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.utils.checkpoint
11+
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
12+
13+
from QEfficient.utils import constants
14+
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
15+
16+
BS = 1
17+
NUM_CHANNEL = 3
18+
SEQ_LEN = 3072
19+
CTX_LEN = 4096
20+
21+
22+
class QEFFMistral3EncoderWrapper(nn.Module):
23+
def __init__(self, model):
24+
super().__init__()
25+
self.model = model
26+
self.model.vision_model = self.model.vision_tower
27+
28+
def forward(self, pixel_values):
29+
image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]])
30+
image_features = self.model.get_image_features(
31+
pixel_values=pixel_values,
32+
vision_feature_layer=self.model.config.vision_feature_layer,
33+
image_sizes=image_sizes,
34+
)
35+
return image_features
36+
37+
38+
class QEFFMistral3DecoderWrapper(nn.Module):
39+
def __init__(self, model):
40+
super().__init__()
41+
self.model = model
42+
self.config = self.model.config
43+
self.language_model = self.model.language_model
44+
45+
def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
46+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
47+
vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
48+
mask = input_ids == self.model.config.image_token_index
49+
indices1 = mask.to(torch.int64).cumsum(1) - 1
50+
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
51+
image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1]
52+
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
53+
outputs = self.model.language_model(
54+
inputs_embeds=inputs_embeds,
55+
position_ids=position_ids,
56+
past_key_values=past_key_values,
57+
)
58+
59+
return outputs.logits, vit_embeds, outputs.past_key_values
60+
61+
62+
class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration):
63+
def get_qeff_vision_encoder(self):
64+
return QEFFMistral3EncoderWrapper(self)
65+
66+
def get_qeff_language_decoder(self):
67+
return QEFFMistral3DecoderWrapper(self)
68+
69+
def forward(self, pixel_values, input_ids, position_ids, past_key_values):
70+
inputs_embeds = self.get_input_embeddings()(input_ids)
71+
# Image features
72+
image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]])
73+
image_features = self.get_image_features(
74+
pixel_values=pixel_values,
75+
vision_feature_layer=self.config.vision_feature_layer,
76+
image_sizes=image_sizes,
77+
)
78+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
79+
mask = input_ids == self.config.image_token_index
80+
indices1 = mask.to(torch.int64).cumsum(1) - 1
81+
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
82+
image_features_expanded = image_features.unsqueeze(0)[indices0, indices1]
83+
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
84+
outputs = self.language_model(
85+
inputs_embeds=inputs_embeds,
86+
position_ids=position_ids,
87+
past_key_values=past_key_values,
88+
)
89+
return outputs.logits, pixel_values, outputs.past_key_values
90+
91+
def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
92+
inputs_shapes = {}
93+
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
94+
inputs_shapes["vit_embeds"] = (
95+
constants.MISTRAL3_FEATURE_SIZE,
96+
self.language_model.config.hidden_size,
97+
)
98+
inputs_shapes["position_ids"] = (
99+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
100+
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
101+
)
102+
inputs_shapes["pixel_values"] = (
103+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
104+
constants.MISTRAL3_NUM_CHANNELS,
105+
constants.MISTRAL3_HEIGHT,
106+
constants.MISTRAL3_WIDTH,
107+
)
108+
109+
# Define inputs
110+
vision_inputs = {}
111+
lang_inputs = {}
112+
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
113+
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
114+
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
115+
lang_inputs["position_ids"] = (
116+
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
117+
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
118+
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
119+
)
120+
121+
# Add data for KV
122+
kv_cache_shape = get_padding_shape_from_config(
123+
config=self.language_model.config,
124+
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
125+
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
126+
)
127+
128+
lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
129+
for i in range(self.language_model.config.num_hidden_layers):
130+
for kv in ["key", "value"]:
131+
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
132+
133+
inputs = {}
134+
if kv_offload:
135+
inputs["vision"] = vision_inputs
136+
inputs["lang"] = lang_inputs
137+
else:
138+
lang_inputs.pop("vit_embeds")
139+
inputs = {**vision_inputs, **lang_inputs}
140+
141+
return inputs
142+
143+
def get_specializations(
144+
self,
145+
batch_size: int,
146+
prefill_seq_len: int,
147+
ctx_len: int,
148+
img_size: int,
149+
kv_offload: bool = False,
150+
**compiler_options,
151+
):
152+
prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN
153+
ctx_len = ctx_len if ctx_len else CTX_LEN
154+
height = constants.MISTRAL3_HEIGHT
155+
width = constants.MISTRAL3_WIDTH
156+
157+
vision = [
158+
{
159+
"batch_size": batch_size,
160+
"seq_len": prefill_seq_len,
161+
"ctx_len": ctx_len,
162+
"height": height,
163+
"width": width,
164+
}
165+
]
166+
lang = [
167+
{
168+
"batch_size": batch_size,
169+
"seq_len": prefill_seq_len,
170+
"ctx_len": ctx_len,
171+
"height": height,
172+
"width": width,
173+
},
174+
{
175+
"batch_size": batch_size,
176+
"seq_len": "1",
177+
"ctx_len": ctx_len,
178+
"height": height,
179+
"width": width,
180+
},
181+
]
182+
specializations = {}
183+
184+
if kv_offload:
185+
specializations["vision"] = vision
186+
specializations["lang"] = lang
187+
return specializations, compiler_options
188+
else:
189+
return lang, compiler_options
190+
191+
def get_onnx_dynamic_axes(self, kv_offload: bool = False):
192+
# Define dynamic axes
193+
num_layers = self.config.text_config.num_hidden_layers
194+
195+
vision_dynamic_axes = {
196+
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
197+
}
198+
lang_dynamic_axes = {
199+
"input_ids": {0: "batch_size", 1: "seq_len"},
200+
"position_ids": {0: "batch_size", 1: "seq_len"},
201+
}
202+
203+
for i in range(num_layers):
204+
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
205+
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
206+
207+
dynamic_axes = {}
208+
if kv_offload:
209+
dynamic_axes["vision"] = vision_dynamic_axes
210+
dynamic_axes["lang"] = lang_dynamic_axes
211+
else:
212+
dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes}
213+
return dynamic_axes
214+
215+
def get_output_names(self, kv_offload: bool = False):
216+
vision_output_names = ["vit_embeds"]
217+
lang_output_names = ["logits"]
218+
for i in range(self.language_model.config.num_hidden_layers):
219+
for kv in ["key", "value"]:
220+
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
221+
222+
output_names = {}
223+
if kv_offload:
224+
lang_output_names.insert(1, "vit_embeds_RetainedState")
225+
output_names["vision"] = vision_output_names
226+
output_names["lang"] = lang_output_names
227+
else:
228+
lang_output_names.insert(1, "pixel_values_RetainedState")
229+
return lang_output_names
230+
return output_names
231+
232+
def get_inputs_info(self):
233+
return [
234+
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
235+
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
236+
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")),
237+
]

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@
6767
MistralModel,
6868
MistralRMSNorm,
6969
)
70+
from transformers.models.mistral3.modeling_mistral3 import (
71+
Mistral3ForConditionalGeneration,
72+
Mistral3RMSNorm,
73+
)
7074
from transformers.models.mixtral.modeling_mixtral import (
7175
MixtralAttention,
7276
MixtralDecoderLayer,
@@ -96,6 +100,7 @@
96100
Phi3Model,
97101
Phi3RMSNorm,
98102
)
103+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
99104
from transformers.models.qwen2.modeling_qwen2 import (
100105
Qwen2Attention,
101106
Qwen2DecoderLayer,
@@ -188,6 +193,7 @@
188193
QEffMistralForCausalLM,
189194
QEffMistralModel,
190195
)
196+
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration
191197
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
192198
QEffMixtralAttention,
193199
QeffMixtralDecoderLayer,
@@ -255,11 +261,13 @@ class CustomOpsTransform(ModuleMappingTransform):
255261
Gemma2RMSNorm: GemmaCustomRMSNormAIC,
256262
LlamaRMSNorm: CustomRMSNormAIC,
257263
MistralRMSNorm: CustomRMSNormAIC,
264+
Mistral3RMSNorm: CustomRMSNormAIC,
258265
MixtralRMSNorm: CustomRMSNormAIC,
259266
Phi3RMSNorm: CustomRMSNormAIC,
260267
Qwen2RMSNorm: CustomRMSNormAIC,
261268
MllamaTextRMSNorm: CustomRMSNormAIC,
262269
GraniteRMSNorm: CustomRMSNormAIC,
270+
PixtralRMSNorm: CustomRMSNormAIC,
263271
}
264272

265273

@@ -321,6 +329,8 @@ class KVCacheTransform(ModuleMappingTransform):
321329
MistralDecoderLayer: QEffMistralDecoderLayer,
322330
MistralModel: QEffMistralModel,
323331
MistralForCausalLM: QEffMistralForCausalLM,
332+
# Mistral3
333+
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
324334
# Mixtral
325335
MixtralAttention: QEffMixtralAttention,
326336
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,

QEfficient/utils/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def get_models_dir():
7474
INTERN_NUM_CHANNELS = 3
7575
INTERN_IMG_CONTEXT_TOKEN = 151667
7676

77+
# MISTRAL3 Constants
78+
# Fixing the feature size with reference to mistralai/Mistral-Small-3.1-24B-Instruct-2503
79+
MISTRAL3_FEATURE_SIZE = 2255
80+
MISTRAL3_NUM_CHANNELS = 3
81+
MISTRAL3_HEIGHT = 1540
82+
MISTRAL3_WIDTH = 1162
83+
7784

7885
class Constants:
7986
# Export Constants.

0 commit comments

Comments
 (0)