@@ -718,12 +718,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
718718 ]
719719 _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
720720
721- def __init__ (
722- self ,
723- model ,
724- qaic_config : Optional [dict ] = None ,
725- ** kwargs
726- ):
721+ def __init__ (self , model , qaic_config : Optional [dict ] = None , ** kwargs ):
727722 """
728723 Initializes the language decoder component for multimodal models.
729724
@@ -732,7 +727,7 @@ def __init__(
732727 model : nn.Module
733728 The full HuggingFace multimodal model from which the language decoder is extracted.
734729 qaic_config : dict, optional
735- A dictionary for QAIC-specific configurations.
730+ A dictionary for QAIC-specific configurations.
736731 Only the following keys are supported by the text model of the dual QPC multimodal model:
737732 - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
738733 - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -773,7 +768,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
773768 Path to the generated ONNX graph file for the language decoder.
774769 """
775770 if self .model .qaic_config is not None and self .model .qaic_config .get ("include_sampler" , False ):
776- inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (inputs , output_names , dynamic_axes )
771+ inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (
772+ inputs , output_names , dynamic_axes
773+ )
777774 return self ._export (
778775 inputs , output_names , dynamic_axes , export_dir = export_dir , offload_pt_weights = offload_pt_weights
779776 )
@@ -804,7 +801,7 @@ def get_sampling_inputs_and_outputs(
804801 sampling-related parameters.
805802 """
806803 bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
807-
804+
808805 assert "logits" in output_names , "logits must be part of the output names to suport on-device sampling"
809806
810807 logits_index = output_names .index ("logits" )
@@ -856,7 +853,7 @@ def get_sampling_inputs_and_outputs(
856853 example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
857854 dynamic_axes ["min_ps" ] = {0 : "batch_size" }
858855
859- example_inputs ["random_numbers" ] = torch .rand ((bs , 1 ), dtype = torch .float )
856+ example_inputs ["random_numbers" ] = torch .rand ((bs , max_top_k_ids ), dtype = torch .float )
860857 dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
861858
862859 return example_inputs , output_names , dynamic_axes
@@ -2066,7 +2063,7 @@ def from_pretrained(
20662063 pretrained_model_name_or_path : str ,
20672064 kv_offload : Optional [bool ] = None ,
20682065 qaic_config : Optional [dict ] = None ,
2069- ** kwargs
2066+ ** kwargs ,
20702067 ):
20712068 """
20722069 Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path.
@@ -2080,7 +2077,7 @@ def from_pretrained(
20802077 If False, uses the single QPC approach (entire model in one QPC).
20812078 If None, the default behavior of the internal classes is used (typically dual QPC).
20822079 qaic_config : dict, optional
2083- A dictionary for QAIC-specific configurations.
2080+ A dictionary for QAIC-specific configurations.
20842081 Only the following keys are supported by the text model of the dual QPC multimodal model:
20852082 - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
20862083 - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -2116,11 +2113,11 @@ def from_pretrained(
21162113 qaic_config ["pretrained_model_name_or_path" ] = pretrained_model_name_or_path
21172114 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
21182115 return cls (
2119- model ,
2120- kv_offload = kv_offload ,
2121- pretrained_model_name_or_path = pretrained_model_name_or_path ,
2122- qaic_config = qaic_config ,
2123- ** kwargs
2116+ model ,
2117+ kv_offload = kv_offload ,
2118+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
2119+ qaic_config = qaic_config ,
2120+ ** kwargs ,
21242121 )
21252122
21262123
@@ -2327,7 +2324,7 @@ def from_pretrained(
23272324 kv_offload = kv_offload ,
23282325 pretrained_model_name_or_path = pretrained_model_name_or_path ,
23292326 qaic_config = qaic_config ,
2330- ** kwargs
2327+ ** kwargs ,
23312328 )
23322329 return cls (
23332330 model ,
@@ -2519,7 +2516,7 @@ def get_sampling_inputs_and_outputs(
25192516 example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
25202517 dynamic_axes ["min_ps" ] = {0 : "batch_size" }
25212518
2522- example_inputs ["random_numbers" ] = torch .rand ((bs , 1 ), dtype = torch .float )
2519+ example_inputs ["random_numbers" ] = torch .rand ((bs , max_top_k_ids ), dtype = torch .float )
25232520 dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
25242521
25252522 return example_inputs , output_names , dynamic_axes
0 commit comments