|
| 1 | +# ----------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | +# |
| 6 | +# ----------------------------------------------------------------------------- |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +import torch.utils.checkpoint |
| 11 | +from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration |
| 12 | + |
| 13 | +from QEfficient.utils import constants |
| 14 | +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config |
| 15 | + |
| 16 | +BS = 1 |
| 17 | +NUM_CHANNEL = 3 |
| 18 | +SEQ_LEN = 3072 |
| 19 | +CTX_LEN = 4096 |
| 20 | + |
| 21 | + |
| 22 | +class QEFFMistral3EncoderWrapper(nn.Module): |
| 23 | + def __init__(self, model): |
| 24 | + super().__init__() |
| 25 | + self.model = model |
| 26 | + self.model.vision_model = self.model.vision_tower |
| 27 | + |
| 28 | + def forward(self, pixel_values): |
| 29 | + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) |
| 30 | + image_features = self.model.get_image_features( |
| 31 | + pixel_values=pixel_values, |
| 32 | + vision_feature_layer=self.model.config.vision_feature_layer, |
| 33 | + image_sizes=image_sizes, |
| 34 | + ) |
| 35 | + return image_features |
| 36 | + |
| 37 | + |
| 38 | +class QEFFMistral3DecoderWrapper(nn.Module): |
| 39 | + def __init__(self, model): |
| 40 | + super().__init__() |
| 41 | + self.model = model |
| 42 | + self.config = self.model.config |
| 43 | + self.language_model = self.model.language_model |
| 44 | + |
| 45 | + def forward(self, input_ids, vit_embeds, position_ids, past_key_values): |
| 46 | + inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| 47 | + vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 48 | + mask = input_ids == self.model.config.image_token_index |
| 49 | + indices1 = mask.to(torch.int64).cumsum(1) - 1 |
| 50 | + indices0 = torch.arange(mask.shape[0]).view(-1, 1) |
| 51 | + image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1] |
| 52 | + inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) |
| 53 | + outputs = self.model.language_model( |
| 54 | + inputs_embeds=inputs_embeds, |
| 55 | + position_ids=position_ids, |
| 56 | + past_key_values=past_key_values, |
| 57 | + ) |
| 58 | + |
| 59 | + return outputs.logits, vit_embeds, outputs.past_key_values |
| 60 | + |
| 61 | + |
| 62 | +class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration): |
| 63 | + def get_qeff_vision_encoder(self): |
| 64 | + return QEFFMistral3EncoderWrapper(self) |
| 65 | + |
| 66 | + def get_qeff_language_decoder(self): |
| 67 | + return QEFFMistral3DecoderWrapper(self) |
| 68 | + |
| 69 | + def forward(self, pixel_values, input_ids, position_ids, past_key_values): |
| 70 | + inputs_embeds = self.get_input_embeddings()(input_ids) |
| 71 | + # Image features |
| 72 | + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) |
| 73 | + image_features = self.get_image_features( |
| 74 | + pixel_values=pixel_values, |
| 75 | + vision_feature_layer=self.config.vision_feature_layer, |
| 76 | + image_sizes=image_sizes, |
| 77 | + ) |
| 78 | + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| 79 | + mask = input_ids == self.config.image_token_index |
| 80 | + indices1 = mask.to(torch.int64).cumsum(1) - 1 |
| 81 | + indices0 = torch.arange(mask.shape[0]).view(-1, 1) |
| 82 | + image_features_expanded = image_features.unsqueeze(0)[indices0, indices1] |
| 83 | + inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) |
| 84 | + outputs = self.language_model( |
| 85 | + inputs_embeds=inputs_embeds, |
| 86 | + position_ids=position_ids, |
| 87 | + past_key_values=past_key_values, |
| 88 | + ) |
| 89 | + return outputs.logits, pixel_values, outputs.past_key_values |
| 90 | + |
| 91 | + def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): |
| 92 | + inputs_shapes = {} |
| 93 | + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) |
| 94 | + inputs_shapes["vit_embeds"] = ( |
| 95 | + constants.MISTRAL3_FEATURE_SIZE, |
| 96 | + self.language_model.config.hidden_size, |
| 97 | + ) |
| 98 | + inputs_shapes["position_ids"] = ( |
| 99 | + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, |
| 100 | + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, |
| 101 | + ) |
| 102 | + inputs_shapes["pixel_values"] = ( |
| 103 | + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, |
| 104 | + constants.MISTRAL3_NUM_CHANNELS, |
| 105 | + constants.MISTRAL3_HEIGHT, |
| 106 | + constants.MISTRAL3_WIDTH, |
| 107 | + ) |
| 108 | + |
| 109 | + # Define inputs |
| 110 | + vision_inputs = {} |
| 111 | + lang_inputs = {} |
| 112 | + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) |
| 113 | + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) |
| 114 | + lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32) |
| 115 | + lang_inputs["position_ids"] = ( |
| 116 | + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) |
| 117 | + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) |
| 118 | + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) |
| 119 | + ) |
| 120 | + |
| 121 | + # Add data for KV |
| 122 | + kv_cache_shape = get_padding_shape_from_config( |
| 123 | + config=self.language_model.config, |
| 124 | + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, |
| 125 | + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, |
| 126 | + ) |
| 127 | + |
| 128 | + lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] |
| 129 | + for i in range(self.language_model.config.num_hidden_layers): |
| 130 | + for kv in ["key", "value"]: |
| 131 | + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) |
| 132 | + |
| 133 | + inputs = {} |
| 134 | + if kv_offload: |
| 135 | + inputs["vision"] = vision_inputs |
| 136 | + inputs["lang"] = lang_inputs |
| 137 | + else: |
| 138 | + lang_inputs.pop("vit_embeds") |
| 139 | + inputs = {**vision_inputs, **lang_inputs} |
| 140 | + |
| 141 | + return inputs |
| 142 | + |
| 143 | + def get_specializations( |
| 144 | + self, |
| 145 | + batch_size: int, |
| 146 | + prefill_seq_len: int, |
| 147 | + ctx_len: int, |
| 148 | + img_size: int, |
| 149 | + kv_offload: bool = False, |
| 150 | + **compiler_options, |
| 151 | + ): |
| 152 | + prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN |
| 153 | + ctx_len = ctx_len if ctx_len else CTX_LEN |
| 154 | + height = constants.MISTRAL3_HEIGHT |
| 155 | + width = constants.MISTRAL3_WIDTH |
| 156 | + |
| 157 | + vision = [ |
| 158 | + { |
| 159 | + "batch_size": batch_size, |
| 160 | + "seq_len": prefill_seq_len, |
| 161 | + "ctx_len": ctx_len, |
| 162 | + "height": height, |
| 163 | + "width": width, |
| 164 | + } |
| 165 | + ] |
| 166 | + lang = [ |
| 167 | + { |
| 168 | + "batch_size": batch_size, |
| 169 | + "seq_len": prefill_seq_len, |
| 170 | + "ctx_len": ctx_len, |
| 171 | + "height": height, |
| 172 | + "width": width, |
| 173 | + }, |
| 174 | + { |
| 175 | + "batch_size": batch_size, |
| 176 | + "seq_len": "1", |
| 177 | + "ctx_len": ctx_len, |
| 178 | + "height": height, |
| 179 | + "width": width, |
| 180 | + }, |
| 181 | + ] |
| 182 | + specializations = {} |
| 183 | + |
| 184 | + if kv_offload: |
| 185 | + specializations["vision"] = vision |
| 186 | + specializations["lang"] = lang |
| 187 | + return specializations, compiler_options |
| 188 | + else: |
| 189 | + return lang, compiler_options |
| 190 | + |
| 191 | + def get_onnx_dynamic_axes(self, kv_offload: bool = False): |
| 192 | + # Define dynamic axes |
| 193 | + num_layers = self.config.text_config.num_hidden_layers |
| 194 | + |
| 195 | + vision_dynamic_axes = { |
| 196 | + "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, |
| 197 | + } |
| 198 | + lang_dynamic_axes = { |
| 199 | + "input_ids": {0: "batch_size", 1: "seq_len"}, |
| 200 | + "position_ids": {0: "batch_size", 1: "seq_len"}, |
| 201 | + } |
| 202 | + |
| 203 | + for i in range(num_layers): |
| 204 | + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} |
| 205 | + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} |
| 206 | + |
| 207 | + dynamic_axes = {} |
| 208 | + if kv_offload: |
| 209 | + dynamic_axes["vision"] = vision_dynamic_axes |
| 210 | + dynamic_axes["lang"] = lang_dynamic_axes |
| 211 | + else: |
| 212 | + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} |
| 213 | + return dynamic_axes |
| 214 | + |
| 215 | + def get_output_names(self, kv_offload: bool = False): |
| 216 | + vision_output_names = ["vit_embeds"] |
| 217 | + lang_output_names = ["logits"] |
| 218 | + for i in range(self.language_model.config.num_hidden_layers): |
| 219 | + for kv in ["key", "value"]: |
| 220 | + lang_output_names.append(f"past_{kv}.{i}_RetainedState") |
| 221 | + |
| 222 | + output_names = {} |
| 223 | + if kv_offload: |
| 224 | + lang_output_names.insert(1, "vit_embeds_RetainedState") |
| 225 | + output_names["vision"] = vision_output_names |
| 226 | + output_names["lang"] = lang_output_names |
| 227 | + else: |
| 228 | + lang_output_names.insert(1, "pixel_values_RetainedState") |
| 229 | + return lang_output_names |
| 230 | + return output_names |
| 231 | + |
| 232 | + def get_inputs_info(self): |
| 233 | + return [ |
| 234 | + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), |
| 235 | + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), |
| 236 | + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")), |
| 237 | + ] |
0 commit comments