1212import numpy as np
1313from onnx import ModelProto , external_data_helper , numpy_helper
1414
15- from QEfficient .utils .constants import ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL
15+ from QEfficient .utils .constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL
1616
1717logger = logging .getLogger (__name__ )
1818
@@ -22,6 +22,8 @@ class OnnxTransform:
2222 OnnxTransform is the base class for graph modifications on exported onnx.
2323 """
2424
25+ _external_data_loaded_cache = {} # Dict[int, bool]
26+
2527 def __init__ (self ):
2628 raise TypeError ("Transform classes are not to be instantiated. Directly use the `apply` method." )
2729
@@ -45,12 +47,54 @@ def _check_external_data_loaded(cls, model: ModelProto) -> bool:
4547 :param model: The ONNX model to check
4648 :returns: True if external data is already loaded, False otherwise
4749 """
50+ # Use object ID as key instead of the object itself
51+ model_id = id (model )
52+ # Return cached result if available
53+ if model_id in cls ._external_data_loaded_cache :
54+ return cls ._external_data_loaded_cache [model_id ]
55+
56+ # Load the model if not already loaded
4857 for tensor in external_data_helper ._get_all_tensors (model ):
4958 # Check if tensor has external data but no raw data loaded
5059 if len (tensor .external_data ) > 0 and not tensor .HasField ("raw_data" ):
60+ cls ._external_data_loaded_cache [model_id ] = False
5161 return False
62+
63+ cls ._external_data_loaded_cache [model_id ] = True
5264 return True
5365
66+ @classmethod
67+ def _load_external_data (cls , model : ModelProto , onnx_base_dir : Optional [str ] = None ):
68+ """
69+ Performs a bulk load of external data if it's not already loaded.
70+ Updates the cache upon successful load.
71+ """
72+ model_id = id (model )
73+ if not cls ._check_external_data_loaded (model ):
74+ logger .info ("External data not loaded. Performing bulk load." )
75+ external_data_helper .load_external_data_for_model (model , onnx_base_dir )
76+ cls ._external_data_loaded_cache [model_id ] = True
77+ else :
78+ logger .info ("External data already loaded (or cached). Skipping bulk load." )
79+
80+
81+ @classmethod
82+ def _cleanup_external_data_and_cache (cls , model : ModelProto ):
83+ """
84+ Combines clearing external data from the model and its cache entry.
85+ """
86+ # Remove the loaded raw data from tensors
87+ for tensor in external_data_helper ._get_all_tensors (model ):
88+ if tensor .HasField ("raw_data" ):
89+ tensor .ClearField ("raw_data" )
90+
91+ # Clear the cache entry for this model using its ID
92+ model_id = id (model )
93+ if model_id in cls ._external_data_loaded_cache :
94+ del cls ._external_data_loaded_cache [model_id ]
95+
96+ logger .info ("External data and cache cleaned up." )
97+
5498 @classmethod
5599 def _cleanup_memory (cls ):
56100 """
@@ -69,36 +113,42 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar
69113 """
70114 :param onnx_base_dir: Base directory to load tensors
71115 """
72- finfo = np .finfo (np .float16 )
73- fp16_max = finfo .max
74- fp16_min = finfo .min
75- transformed = False
116+ try :
117+ # --- FIX: Ensure external data is loaded efficiently BEFORE processing ---
118+ cls ._load_external_data (model , onnx_base_dir )
76119
77- processed_count = 0
78- for tensor in external_data_helper ._get_all_tensors (model ):
79- nptensor = numpy_helper .to_array (tensor , onnx_base_dir )
80- if nptensor .dtype == np .float32 and (np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min )):
81- neg_inf_mask = np .isinf (nptensor ) & (nptensor < 0 )
82- clipped_tensor = np .clip (nptensor , fp16_min , fp16_max )
120+ finfo = np .finfo (np .float16 )
121+ fp16_max = finfo .max
122+ fp16_min = finfo .min
123+ transformed = False
124+
125+ processed_count = 0
126+ for tensor in external_data_helper ._get_all_tensors (model ):
127+ nptensor = numpy_helper .to_array (tensor ) # Removed onnx_base_dir as data is already loaded
128+ if nptensor .dtype == np .float32 and (np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min )):
129+ neg_inf_mask = np .isinf (nptensor ) & (nptensor < 0 )
130+ clipped_tensor = np .clip (nptensor , fp16_min , fp16_max )
83131
84- # Restore -inf values
85- if neg_inf_mask .any ():
86- clipped_tensor = np .where (neg_inf_mask , np .float32 ("-inf" ), clipped_tensor )
132+ # Restore -inf values
133+ if neg_inf_mask .any ():
134+ clipped_tensor = np .where (neg_inf_mask , np .float32 ("-inf" ), clipped_tensor )
87135
88- new_tensor = numpy_helper .from_array (clipped_tensor , tensor .name )
89- tensor .CopyFrom (new_tensor )
90- transformed = True
136+ new_tensor = numpy_helper .from_array (clipped_tensor , tensor .name )
137+ tensor .CopyFrom (new_tensor )
138+ transformed = True
91139
92- del neg_inf_mask , clipped_tensor , new_tensor
140+ del neg_inf_mask , clipped_tensor , new_tensor
93141
94- del nptensor
95- processed_count += 1
142+ del nptensor
143+ processed_count += 1
96144
97- if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0 :
98- cls ._cleanup_memory ()
145+ if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
146+ cls ._cleanup_memory ()
99147
100- cls ._cleanup_memory ()
101- return model , transformed
148+ return model , transformed
149+ finally :
150+ # Ensure cleanup happens even if an exception occurs
151+ cls ._cleanup_memory ()
102152
103153
104154class SplitTensorsTransform (OnnxTransform ):
@@ -123,32 +173,30 @@ def apply(
123173 :param file_chunk_size: Chunk size to split external files into.
124174 :param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
125175 """
126- file_num = 0
127- current_file_size = 0
128- transformed = False
129-
130- # Check if external data is already loaded to avoid redundant loading
131- external_data_already_loaded = cls ._check_external_data_loaded (model )
132-
133- if not external_data_already_loaded :
134- external_data_helper .load_external_data_for_model (model , onnx_base_dir )
135- else :
136- logger .info ("External data already loaded, skipping redundant load operation" )
137-
138- processed_count = 0
139- for tensor in external_data_helper ._get_all_tensors (model ):
140- if tensor .HasField ("raw_data" ) and ((tsize := len (tensor .raw_data )) > size_threshold ):
141- transformed = True
142- current_file_size += tsize
143- if current_file_size > file_chunk_size :
144- file_num += 1
145- current_file_size = tsize
146- external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
147-
148- processed_count += 1
149- if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0 :
150- cls ._cleanup_memory ()
151-
152- cls ._cleanup_memory ()
153-
154- return model , transformed
176+ try :
177+ file_num = 0
178+ current_file_size = 0
179+ transformed = False
180+
181+ # --- Adjustment: The initial check and load will now use the new bulk loader ---
182+ # This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself.
183+ cls ._load_external_data (model , onnx_base_dir )
184+
185+ processed_count = 0
186+ for tensor in external_data_helper ._get_all_tensors (model ):
187+ if tensor .HasField ("raw_data" ) and ((tsize := len (tensor .raw_data )) > size_threshold ):
188+ transformed = True
189+ current_file_size += tsize
190+ if current_file_size > file_chunk_size :
191+ file_num += 1
192+ current_file_size = tsize
193+ external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
194+
195+ processed_count += 1
196+ if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
197+ cls ._cleanup_memory ()
198+
199+ return model , transformed
200+ finally :
201+ # Ensure cleanup happens even if an exception occurs
202+ cls ._cleanup_memory ()
0 commit comments