1
+ import inspect
1
2
import logging
2
3
from copy import deepcopy
3
4
from enum import Enum , auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
41
42
return self ._state
42
43
43
44
45
+ class DynamicShapeOutOfRangeException (Exception ):
46
+ pass
47
+
48
+
44
49
class MutableTorchTensorRTModule (object ):
45
50
"""
46
51
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
65
70
Union [torch .dtype , dtype ]
66
71
] = _defaults .ENABLED_PRECISIONS ,
67
72
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
68
- immutable_weights : bool = _defaults . IMMUTABLE_WEIGHTS ,
73
+ immutable_weights : bool = False ,
69
74
debug : bool = _defaults .DEBUG ,
70
75
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
71
76
workspace_size : int = _defaults .WORKSPACE_SIZE ,
@@ -189,6 +194,9 @@ def __init__(
189
194
"hardware_compatible" : hardware_compatible ,
190
195
"timing_cache_path" : timing_cache_path ,
191
196
}
197
+ self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198
+ self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199
+ self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
192
200
193
201
self .settings = CompilationSettings (** compilation_options )
194
202
self .run_info : Optional [tuple [Any , ...]] = None
@@ -203,6 +211,26 @@ def __init__(
203
211
)
204
212
self .init_finished = True
205
213
214
+ def set_dynamic_shape_hint (
215
+ self ,
216
+ args_dynamic_shape : tuple [dict [Any , Any ]],
217
+ kwargs_dynamic_shape : dict [str , Any ],
218
+ ) -> None :
219
+ assert isinstance (
220
+ args_dynamic_shape , tuple
221
+ ), "args dynamic shape has to be a tuple"
222
+ assert isinstance (
223
+ kwargs_dynamic_shape , dict
224
+ ), "args dynamic shape has to be a dictionary"
225
+ self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226
+ self .arg_dynamic_shapes = args_dynamic_shape
227
+ self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228
+ signature = list (
229
+ inspect .signature (self .original_model .forward ).parameters .keys ()
230
+ )
231
+ for i , arg in enumerate (self .arg_dynamic_shapes ):
232
+ self .total_dynamic_shape [signature [i ]] = arg
233
+
206
234
def store_state_dict_metadata (self ) -> None :
207
235
for k , v in self .original_model .state_dict ().items ():
208
236
self .state_dict_metadata [k ] = v .shape
@@ -295,6 +323,7 @@ def compile(self) -> None:
295
323
self .original_model ,
296
324
self .arg_inputs ,
297
325
kwargs = self .kwarg_inputs ,
326
+ dynamic_shapes = self .total_dynamic_shape ,
298
327
)
299
328
self .gm = dynamo_compile (
300
329
self .exp_program ,
@@ -306,14 +335,26 @@ def compile(self) -> None:
306
335
torch .cuda .empty_cache ()
307
336
308
337
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
309
- if (
310
- not self .arg_inputs
311
- or not MutableTorchTensorRTModule .check_inputs_equal (self .arg_inputs , args )
312
- or not MutableTorchTensorRTModule .check_inputs_equal (
313
- self .kwarg_inputs , kwargs
314
- )
315
- ):
338
+ try :
339
+ if (
340
+ not self .arg_inputs
341
+ or not MutableTorchTensorRTModule .check_inputs_equal (
342
+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
343
+ )
344
+ or not MutableTorchTensorRTModule .check_inputs_equal (
345
+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
346
+ )
347
+ ):
348
+ logger .info ("Input change detected." )
349
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
350
+ self .store_inputs (args , kwargs )
351
+ except DynamicShapeOutOfRangeException as e :
316
352
logger .info ("Input change detected." )
353
+ logger .warning (e )
354
+ logger .warning ("Recompiling the engine with static shape" )
355
+ self .arg_dynamic_shapes = None
356
+ self .kwarg_dynamic_shapes = None
357
+ self .total_dynamic_shape = None
317
358
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
318
359
self .store_inputs (args , kwargs )
319
360
@@ -436,33 +477,66 @@ def __setattr__(self, name: str, value: Any) -> None:
436
477
def check_inputs_equal (
437
478
input1 : Any ,
438
479
input2 : Any ,
480
+ dynamic_shapes : Any = None ,
439
481
) -> bool :
440
- # TODO: Add support for dynamic shape
482
+
441
483
if isinstance (input1 , (tuple , list )):
442
484
if len (input1 ) != len (input2 ):
443
485
return False
444
- for a , b in zip (input1 , input2 ):
486
+ for ( i , a ) , b in zip (enumerate ( input1 ) , input2 ):
445
487
if type (a ) != type (b ):
446
488
return False
447
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
448
- return False
449
- elif isinstance (a , bool ) and a != b :
489
+ if isinstance (a , bool ) and a != b :
450
490
return False
491
+ elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
492
+ if dynamic_shapes is None :
493
+ return False
494
+ else :
495
+ tensor_dynamic_shape = dynamic_shapes [i ]
496
+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
497
+ a , b , tensor_dynamic_shape
498
+ ):
499
+ return False
451
500
452
501
elif isinstance (input1 , dict ):
453
502
if input1 .keys () != input2 .keys ():
454
503
return False
455
- for a , b in zip (input1 .values (), input2 .values ()):
456
- if type (a ) != type (b ):
504
+ for ( ka , va ), vb in zip (input1 .items (), input2 .values ()):
505
+ if type (va ) != type (vb ):
457
506
return False
458
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
459
- return False
460
- elif isinstance (a , bool ) and a != b :
507
+ if isinstance (va , bool ) and va != vb :
461
508
return False
509
+ elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
510
+ if dynamic_shapes is None :
511
+ return False
512
+ else :
513
+ tensor_dynamic_shape = dynamic_shapes [ka ]
514
+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
515
+ va , vb , tensor_dynamic_shape
516
+ ):
517
+ return False
462
518
elif isinstance (
463
- a , (list , tuple , dict )
464
- ) and not MutableTorchTensorRTModule .check_inputs_equal (a , b ):
519
+ va , (list , tuple , dict )
520
+ ) and not MutableTorchTensorRTModule .check_inputs_equal (
521
+ va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
522
+ ):
523
+ return False
524
+ return True
525
+
526
+ @staticmethod
527
+ def check_tensor_shapes_with_dynamic_shapes (
528
+ t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
529
+ ) -> bool :
530
+ for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
531
+ if axis_0 != axis_1 :
532
+ if i not in dynamic_shape :
465
533
return False
534
+ dyn = dynamic_shape [i ]
535
+ if axis_1 > dyn .max or axis_1 < dyn .min :
536
+ raise DynamicShapeOutOfRangeException (
537
+ f"The input size ({ axis_1 } ) of dimension ({ i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]!"
538
+ )
539
+
466
540
return True
467
541
468
542
@staticmethod
0 commit comments