@@ -231,14 +231,18 @@ def __init__(self, tensor):
231231 self .shape = get_onnx_tensor_shape (self .tensor )
232232 self .dtype = get_onnx_tensor_dtype (self .tensor )
233233 self .nbytes = misc .volume (self .shape ) * get_itemsize (self .dtype )
234+ self ._cached_values = None # Initialize the cache
234235
235236 def load (self ):
236237 """
237- Load a numpy array from the underlying tensor values.
238+ Load a numpy array from the underlying tensor values, using cache .
238239
239240 Returns:
240241 np.array: A numpy array containing the values of the tensor.
241242 """
243+ if self ._cached_values is not None :
244+ return self ._cached_values # Return cached data if available
245+
242246 import onnx
243247 import onnx .numpy_helper
244248 from onnx_graphsurgeon .importers .onnx_importer import (
@@ -254,7 +258,8 @@ def load(self):
254258 f"If this is not what you intended, please avoid accessing the values of this constant tensor."
255259 )
256260
257- return np .array (onnx .numpy_helper .to_array (self .tensor ))
261+ self ._cached_values = np .array (onnx .numpy_helper .to_array (self .tensor ))
262+ return self ._cached_values
258263
259264 def __str__ (self ):
260265 return "LazyValues (shape={:}, dtype={:})" .format (self .shape , self .dtype )
@@ -268,13 +273,20 @@ class SparseValues(LazyValues):
268273 A special object that represents constant tensor values that is sparse
269274 """
270275
276+ def __init__ (self , tensor ):
277+ super ().__init__ (tensor )
278+ self ._cached_values = None # Initialize the cache
279+
271280 def load (self ):
272281 """
273- Load a numpy array from the sparse structure.
282+ Load a numpy array from the sparse structure, using cache .
274283
275284 Returns:
276285 np.array: A numpy array containing the values of the tensor.
277286 """
287+ if self ._cached_values is not None :
288+ return self ._cached_values # Return cached data if available
289+
278290 import onnx
279291 import onnx .numpy_helper
280292 from onnx_graphsurgeon .importers .onnx_importer import (
@@ -316,7 +328,8 @@ def load(self):
316328 f"Unsupported index data dims { self .tensor .indices .dims } in { self .tensor .values .name } "
317329 )
318330
319- return values
331+ self ._cached_values = values
332+ return self ._cached_values
320333
321334 def __str__ (self ):
322335 return "SparseValues (shape={:}, dtype={:})" .format (self .shape , self .dtype )
0 commit comments