From 8705af0effcc56bf66f0cca64bf9f92e18cf18e7 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Fri, 11 Apr 2025 10:22:33 +0000 Subject: [PATCH] Onboarding Mistral3.1_24B Signed-off-by: Mohit Soni --- QEfficient/transformers/modeling_utils.py | 12 + .../transformers/models/mistral3/__init__.py | 6 + .../models/mistral3/modeling_mistral3.py | 237 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 10 + QEfficient/utils/constants.py | 7 + 5 files changed, 272 insertions(+) create mode 100644 QEfficient/transformers/models/mistral3/__init__.py create mode 100644 QEfficient/transformers/models/mistral3/modeling_mistral3.py diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index ccad5e020..25da9b741 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -57,6 +57,10 @@ MistralModel, MistralRMSNorm, ) +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3ForConditionalGeneration, + Mistral3RMSNorm, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -69,6 +73,7 @@ from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm +from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, @@ -87,6 +92,7 @@ ) from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, @@ -177,6 +183,7 @@ GPTBigCodeForCausalLM.__name__, MllamaForCausalLM.__name__, WhisperForConditionalGeneration.__name__, + Mistral3ForConditionalGeneration.__name__, ] ) @@ -226,6 +233,9 @@ MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, MistralRMSNorm: CustomRMSNormAIC, + # Mistral3 model layers + Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, + Mistral3RMSNorm: CustomRMSNormAIC, # Mixtral model layers MixtralAttention: QEffMixtralAttention, MixtralDecoderLayer: QeffMixtralDecoderLayer, @@ -242,6 +252,8 @@ PhiAttention: QEffPhiAttention, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, + # Pixtral model layers + PixtralRMSNorm: CustomRMSNormAIC, # Falcon model layers FalconAttention: QEffFalconAttention, FalconForCausalLM: QEffFalconForCausalLM, diff --git a/QEfficient/transformers/models/mistral3/__init__.py b/QEfficient/transformers/models/mistral3/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/mistral3/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py new file mode 100644 index 000000000..a4f77f82b --- /dev/null +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -0,0 +1,237 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration + +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config + +BS = 1 +NUM_CHANNEL = 3 +SEQ_LEN = 3072 +CTX_LEN = 4096 + + +class QEFFMistral3EncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.vision_tower + + def forward(self, pixel_values): + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) + image_features = self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.model.config.vision_feature_layer, + image_sizes=image_sizes, + ) + return image_features + + +class QEFFMistral3DecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.config = self.model.config + self.language_model = self.model.language_model + + def forward(self, input_ids, vit_embeds, position_ids, past_key_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + mask = input_ids == self.model.config.image_token_index + indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(mask.shape[0]).view(-1, 1) + image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1] + inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + ) + + return outputs.logits, vit_embeds, outputs.past_key_values + + +class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEFFMistral3EncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEFFMistral3DecoderWrapper(self) + + def forward(self, pixel_values, input_ids, position_ids, past_key_values): + inputs_embeds = self.get_input_embeddings()(input_ids) + # Image features + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.config.vision_feature_layer, + image_sizes=image_sizes, + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + mask = input_ids == self.config.image_token_index + indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(mask.shape[0]).view(-1, 1) + image_features_expanded = image_features.unsqueeze(0)[indices0, indices1] + inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + outputs = self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + ) + return outputs.logits, pixel_values, outputs.past_key_values + + def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + inputs_shapes["vit_embeds"] = ( + constants.MISTRAL3_FEATURE_SIZE, + self.language_model.config.hidden_size, + ) + inputs_shapes["position_ids"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.MISTRAL3_NUM_CHANNELS, + constants.MISTRAL3_HEIGHT, + constants.MISTRAL3_WIDTH, + ) + + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.language_model.config, + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vit_embeds") + inputs = {**vision_inputs, **lang_inputs} + + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: int, + kv_offload: bool = False, + **compiler_options, + ): + prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN + ctx_len = ctx_len if ctx_len else CTX_LEN + height = constants.MISTRAL3_HEIGHT + width = constants.MISTRAL3_WIDTH + + vision = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "height": height, + "width": width, + } + ] + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "height": height, + "width": width, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "height": height, + "width": width, + }, + ] + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + return lang, compiler_options + + def get_onnx_dynamic_axes(self, kv_offload: bool = False): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, + } + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + dynamic_axes = {} + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vit_embeds"] + lang_output_names = ["logits"] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vit_embeds_RetainedState") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + return lang_output_names + return output_names + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")), + ] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index bcedd4a27..5a8331289 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -67,6 +67,10 @@ MistralModel, MistralRMSNorm, ) +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3ForConditionalGeneration, + Mistral3RMSNorm, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -96,6 +100,7 @@ Phi3Model, Phi3RMSNorm, ) +from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, @@ -188,6 +193,7 @@ QEffMistralForCausalLM, QEffMistralModel, ) +from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( QEffMixtralAttention, QeffMixtralDecoderLayer, @@ -255,11 +261,13 @@ class CustomOpsTransform(ModuleMappingTransform): Gemma2RMSNorm: GemmaCustomRMSNormAIC, LlamaRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, + Mistral3RMSNorm: CustomRMSNormAIC, MixtralRMSNorm: CustomRMSNormAIC, Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, + PixtralRMSNorm: CustomRMSNormAIC, } @@ -321,6 +329,8 @@ class KVCacheTransform(ModuleMappingTransform): MistralDecoderLayer: QEffMistralDecoderLayer, MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, + # Mistral3 + Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, # Mixtral MixtralAttention: QEffMixtralAttention, MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index c2663594f..34dd72006 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -74,6 +74,13 @@ def get_models_dir(): INTERN_NUM_CHANNELS = 3 INTERN_IMG_CONTEXT_TOKEN = 151667 +# MISTRAL3 Constants +# Fixing the feature size with reference to mistralai/Mistral-Small-3.1-24B-Instruct-2503 +MISTRAL3_FEATURE_SIZE = 2255 +MISTRAL3_NUM_CHANNELS = 3 +MISTRAL3_HEIGHT = 1540 +MISTRAL3_WIDTH = 1162 + class Constants: # Export Constants.