1
+ import inspect
1
2
import logging
2
3
from copy import deepcopy
3
4
from enum import Enum , auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
41
42
return self ._state
42
43
43
44
45
+ class DynamicShapeOutOfRangeException (Exception ):
46
+ pass
47
+
48
+
44
49
class MutableTorchTensorRTModule (object ):
45
50
"""
46
51
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
65
70
Union [torch .dtype , dtype ]
66
71
] = _defaults .ENABLED_PRECISIONS ,
67
72
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
68
- immutable_weights : bool = _defaults . IMMUTABLE_WEIGHTS ,
73
+ immutable_weights : bool = False ,
69
74
debug : bool = _defaults .DEBUG ,
70
75
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
71
76
workspace_size : int = _defaults .WORKSPACE_SIZE ,
@@ -189,11 +194,13 @@ def __init__(
189
194
"hardware_compatible" : hardware_compatible ,
190
195
"timing_cache_path" : timing_cache_path ,
191
196
}
197
+ self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198
+ self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
192
199
193
200
self .settings = CompilationSettings (** compilation_options )
194
201
self .run_info : Optional [tuple [Any , ...]] = None
195
202
self .state_dict_metadata : dict [str , torch .Size ] = {}
196
- self .store_state_dict_metadata ()
203
+ self ._store_state_dict_metadata ()
197
204
198
205
cls = self .__class__
199
206
self .__class__ = type (
@@ -203,7 +210,66 @@ def __init__(
203
210
)
204
211
self .init_finished = True
205
212
206
- def store_state_dict_metadata (self ) -> None :
213
+ def set_expected_dynamic_shape_range (
214
+ self ,
215
+ args_dynamic_shape : tuple [dict [Any , Any ]],
216
+ kwargs_dynamic_shape : dict [str , Any ],
217
+ ) -> None :
218
+ """
219
+ Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function
220
+ and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
221
+ If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
222
+ Note that you should exclude keyword arguments with value None as those will be filtered out.
223
+
224
+ Example:
225
+ def forward(a, b, c=0, d=0):
226
+ pass
227
+
228
+ seq_len = torch.export.Dim("seq_len", min=1, max=10)
229
+ args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape
230
+ kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape
231
+ set_expected_dynamic_shape_range(args_dynamic_shape, kwargs_dynamic_shape)
232
+ # Later when you call the function
233
+ forward(*(a, b), **{c:..., d:...})
234
+
235
+ Reference: https://pytorch.org/docs/stable/export.html#expressing-dynamism
236
+ Arguments:
237
+ args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs,
238
+ kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs
239
+ """
240
+ assert isinstance (
241
+ args_dynamic_shape , tuple
242
+ ), f"args dynamic shape has to be a tuple, but got { type (args_dynamic_shape )} "
243
+ assert isinstance (
244
+ kwargs_dynamic_shape , dict
245
+ ), f"args dynamic shape has to be a dictionary, but got { type (kwargs_dynamic_shape )} "
246
+ self .kwarg_dynamic_shapes = kwargs_dynamic_shape
247
+ self .arg_dynamic_shapes = args_dynamic_shape
248
+
249
+ # Clear cached inputs
250
+ self .arg_inputs = tuple ()
251
+ self .kwarg_inputs = {}
252
+
253
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
254
+
255
+ def _get_total_dynamic_shapes (self ) -> dict [str , Any ] | None :
256
+ if not self .arg_dynamic_shapes and not self .kwarg_dynamic_shapes :
257
+ return None
258
+ total_dynamic_shape = {}
259
+ if self .arg_dynamic_shapes :
260
+ signature = list (
261
+ inspect .signature (self .original_model .forward ).parameters .keys ()
262
+ )
263
+ for i , arg in enumerate (self .arg_dynamic_shapes ):
264
+ total_dynamic_shape [signature [i ]] = arg
265
+
266
+ if self .kwarg_dynamic_shapes :
267
+ for kwargs , kwargs_dynamic_shape in self .kwarg_dynamic_shapes .items ():
268
+ total_dynamic_shape [kwargs ] = kwargs_dynamic_shape
269
+
270
+ return total_dynamic_shape
271
+
272
+ def _store_state_dict_metadata (self ) -> None :
207
273
for k , v in self .original_model .state_dict ().items ():
208
274
self .state_dict_metadata [k ] = v .shape
209
275
@@ -295,6 +361,7 @@ def compile(self) -> None:
295
361
self .original_model ,
296
362
self .arg_inputs ,
297
363
kwargs = self .kwarg_inputs ,
364
+ dynamic_shapes = self ._get_total_dynamic_shapes (),
298
365
)
299
366
self .gm = dynamo_compile (
300
367
self .exp_program ,
@@ -306,39 +373,89 @@ def compile(self) -> None:
306
373
torch .cuda .empty_cache ()
307
374
308
375
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
309
- if (
310
- not self .arg_inputs
311
- or not MutableTorchTensorRTModule .check_inputs_equal (self .arg_inputs , args )
312
- or not MutableTorchTensorRTModule .check_inputs_equal (
313
- self .kwarg_inputs , kwargs
314
- )
315
- ):
376
+
377
+ if not self .arg_inputs :
378
+ logger .info ("First time compilation initiated. This may take some time." )
379
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
380
+ self ._store_inputs (args , kwargs )
381
+ if self .arg_dynamic_shapes or self .kwarg_dynamic_shapes :
382
+ if not self ._validates_dynamic_hints ():
383
+ logger .warning (
384
+ "Invalid dynamic shape hint. Compiling module for the provided input shapes (static)"
385
+ )
386
+ self .arg_dynamic_shapes = None
387
+ self .kwarg_dynamic_shapes = None
388
+ return
389
+
390
+ # If input does not equal or does not fall into dynamic shape range, recompile the engine
391
+ try :
392
+ if not MutableTorchTensorRTModule ._check_inputs_shape (
393
+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
394
+ ) or not MutableTorchTensorRTModule ._check_inputs_shape (
395
+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
396
+ ):
397
+ logger .info ("Input change detected." )
398
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
399
+ self ._store_inputs (args , kwargs )
400
+ except DynamicShapeOutOfRangeException as e :
316
401
logger .info ("Input change detected." )
402
+ logger .warning (e )
403
+ logger .warning (
404
+ "Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)"
405
+ )
406
+ self .arg_dynamic_shapes = None
407
+ self .kwarg_dynamic_shapes = None
317
408
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
318
- self .store_inputs (args , kwargs )
409
+ self ._store_inputs (args , kwargs )
410
+
411
+ def _validates_dynamic_hints (self ) -> bool :
412
+ if self .arg_dynamic_shapes is None :
413
+ if self .arg_inputs :
414
+ logger .warning ("arg_dynamic_shape is not provided!" )
415
+ else :
416
+ if len (self .arg_dynamic_shapes ) != len (self .arg_inputs ):
417
+ logger .warning (
418
+ f"Warning: The length of arg_inputs is { len (self .arg_inputs )} but the length of arg_dynamic_shape is { len (self .arg_dynamic_shapes )} !"
419
+ )
420
+ return False
421
+
422
+ if self .kwarg_dynamic_shapes is None :
423
+ if self .kwarg_inputs :
424
+ logger .warning ("kwarg_dynamic_shape is not provided!" )
425
+ else :
426
+ if self .kwarg_dynamic_shapes .keys () != self .kwarg_inputs .keys ():
427
+ logger .warning (
428
+ f"kwarg_inputs has { list (self .kwarg_inputs .keys ())} but kwarg_dynamic_shape has { list (self .kwarg_dynamic_shapes .keys ())} ! You may need to exclude keyword arguments with value None."
429
+ )
430
+ return False
319
431
320
- def store_inputs (self , arg_inputs : Any , kwarg_inputs : Any ) -> None :
432
+ return True
433
+
434
+ def _store_inputs (self , arg_inputs : Any , kwarg_inputs : Any ) -> None :
321
435
self .arg_inputs = arg_inputs
322
436
self .kwarg_inputs = kwarg_inputs
323
437
324
438
@staticmethod
325
- def process_kwarg_inputs (inputs : Any ) -> Any :
439
+ def _process_kwarg_inputs (inputs : Any ) -> Any :
326
440
# Process kwarg inputs to be acceptable for Torch-TensorRT
327
441
if isinstance (inputs , dict ):
328
442
# None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
329
443
return {
330
- k : MutableTorchTensorRTModule .process_kwarg_inputs (v )
444
+ k : MutableTorchTensorRTModule ._process_kwarg_inputs (v )
331
445
for k , v in inputs .items ()
332
- if (v is not None and not isinstance ( v , bool ) )
446
+ if (v is not None )
333
447
}
334
- elif isinstance (inputs , torch .Tensor ):
448
+ elif isinstance (inputs , ( torch .Tensor , bool ) ):
335
449
return inputs
336
450
elif isinstance (inputs , (int , float , np .ndarray )):
337
451
return torch .tensor (inputs )
338
452
elif isinstance (inputs , (list , tuple )):
339
453
if None not in inputs :
340
454
return type (inputs )(
341
- [MutableTorchTensorRTModule .process_kwarg_inputs (v ) for v in inputs ]
455
+ [
456
+ MutableTorchTensorRTModule ._process_kwarg_inputs (v )
457
+ for v in inputs
458
+ ]
342
459
)
343
460
344
461
raise ValueError (
@@ -348,7 +465,7 @@ def process_kwarg_inputs(inputs: Any) -> Any:
348
465
349
466
def forward (self , * args : Any , ** kwargs : Any ) -> Any :
350
467
# Step 1: Check whether the input shape has changed
351
- kwargs = MutableTorchTensorRTModule .process_kwarg_inputs (kwargs )
468
+ kwargs = MutableTorchTensorRTModule ._process_kwarg_inputs (kwargs )
352
469
self ._validate_inputs (* args , ** kwargs )
353
470
354
471
# Step 2: If the flag is unknown, it could be a recompile or refit.
@@ -360,7 +477,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
360
477
if self .refit_state .get_state () == RefitFlag .NEEDS_RECOMPILE :
361
478
logger .info ("(Re)Compiling the engine..." )
362
479
self .compile ()
363
- self .store_state_dict_metadata ()
480
+ self ._store_state_dict_metadata ()
364
481
self .refit_state .set_state (RefitFlag .LIVE )
365
482
366
483
elif self .refit_state .get_state () == RefitFlag .NEEDS_REFIT :
@@ -371,7 +488,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
371
488
logger .error (e )
372
489
logger .error ("Model refit failed. Recompiling the graph module." )
373
490
self .compile ()
374
- self .store_state_dict_metadata ()
491
+ self ._store_state_dict_metadata ()
375
492
self .refit_state .set_state (RefitFlag .LIVE )
376
493
377
494
result = self .gm (* args , ** kwargs )
@@ -381,7 +498,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
381
498
382
499
def to (self , device : str ) -> None :
383
500
logger .warning ("Original PyTorch model is moved. CPU offload may failed." )
384
- self .orignial_model .to (device )
501
+ self .original_model .to (device )
385
502
386
503
def __deepcopy__ (self , memo : Any ) -> Any :
387
504
cls = self .__class__
@@ -433,38 +550,80 @@ def __setattr__(self, name: str, value: Any) -> None:
433
550
object .__setattr__ (self , name , value )
434
551
435
552
@staticmethod
436
- def check_inputs_equal (
553
+ def _check_inputs_shape (
437
554
input1 : Any ,
438
555
input2 : Any ,
556
+ dynamic_shapes : Any = None ,
439
557
) -> bool :
440
- # TODO: Add support for dynamic shape
558
+
441
559
if isinstance (input1 , (tuple , list )):
442
560
if len (input1 ) != len (input2 ):
443
561
return False
444
- for a , b in zip (input1 , input2 ):
562
+ for ( i , a ) , b in zip (enumerate ( input1 ) , input2 ):
445
563
if type (a ) != type (b ):
446
564
return False
447
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
448
- return False
449
- elif isinstance (a , bool ) and a != b :
565
+ if isinstance (a , bool ) and a != b :
450
566
return False
567
+ elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
568
+ if dynamic_shapes is None :
569
+ logger .warning (
570
+ "Dynamic shape is not properly set but the input shape is changed!"
571
+ )
572
+ return False
573
+ else :
574
+ tensor_dynamic_shape = dynamic_shapes [i ]
575
+ if not MutableTorchTensorRTModule ._check_tensor_shapes_with_dynamic_shapes (
576
+ a , b , tensor_dynamic_shape
577
+ ):
578
+ return False
451
579
452
580
elif isinstance (input1 , dict ):
453
581
if input1 .keys () != input2 .keys ():
454
582
return False
455
- for a , b in zip (input1 .values (), input2 .values ()):
456
- if type (a ) != type (b ):
457
- return False
458
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
583
+ for (ka , va ), vb in zip (input1 .items (), input2 .values ()):
584
+ if type (va ) != type (vb ):
459
585
return False
460
- elif isinstance (a , bool ) and a != b :
586
+ if isinstance (va , bool ) and va != vb :
461
587
return False
588
+ elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
589
+ if dynamic_shapes is None :
590
+ logger .warning (
591
+ "Dynamic shape is not properly set but the input shape is changed!"
592
+ )
593
+ return False
594
+ else :
595
+ tensor_dynamic_shape = dynamic_shapes [ka ]
596
+ if not MutableTorchTensorRTModule ._check_tensor_shapes_with_dynamic_shapes (
597
+ va , vb , tensor_dynamic_shape
598
+ ):
599
+ return False
462
600
elif isinstance (
463
- a , (list , tuple , dict )
464
- ) and not MutableTorchTensorRTModule .check_inputs_equal (a , b ):
601
+ va , (list , tuple , dict )
602
+ ) and not MutableTorchTensorRTModule ._check_inputs_shape (
603
+ va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
604
+ ):
465
605
return False
466
606
return True
467
607
608
+ @staticmethod
609
+ def _check_tensor_shapes_with_dynamic_shapes (
610
+ t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
611
+ ) -> bool :
612
+ for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
613
+ if axis_0 != axis_1 :
614
+ if i not in dynamic_shape :
615
+ logger .warning (
616
+ "Dynamic shape does not include the axis on which input changes!"
617
+ )
618
+ return False
619
+ dyn = dynamic_shape [i ]
620
+ if axis_1 > dyn .max or axis_1 < dyn .min :
621
+ raise DynamicShapeOutOfRangeException (
622
+ f"The input size ({ axis_1 } ) of dimension ({ i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]!"
623
+ )
624
+
625
+ return True
626
+
468
627
@staticmethod
469
628
def save (module : Any , path : str ) -> None :
470
629
# Cast the object back to MutableTorchTensorRTModule to save
0 commit comments