Skip to content

Commit f4219f7

Browse files
cehongwangChengzhe Xuperi044
authoredFeb 28, 2025··
Mutable module improvement (#3394)
Co-authored-by: Chengzhe Xu <[email protected]> Co-authored-by: Dheeraj Peri <[email protected]>
1 parent b33f393 commit f4219f7

File tree

6 files changed

+515
-66
lines changed

6 files changed

+515
-66
lines changed
 

‎examples/dynamo/README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Model Zoo
2121
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
2222
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
2323
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
24-
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
24+
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)

‎examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 144 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
1212
1313
In this tutorial, we are going to walk through
14-
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15-
2. Save a Mutable Torch TensorRT Module
16-
3. Integration with Huggingface pipeline in LoRA use case
14+
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15+
2. Save a Mutable Torch TensorRT Module
16+
3. Integration with Huggingface pipeline in LoRA use case
17+
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1718
"""
1819

20+
# %%
1921
import numpy as np
2022
import torch
2123
import torch_tensorrt as torch_trt
@@ -63,16 +65,14 @@
6365
# Saving Mutable Torch TensorRT Module
6466
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6567

66-
# Currently, saving is only enabled for C++ runtime, not python runtime.
68+
# Currently, saving is only enabled when "use_python_runtime" = False in settings
6769
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
6870
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
6971

7072
# %%
7173
# Stable Diffusion with Huggingface
7274
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7375

74-
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75-
7676
from diffusers import DiffusionPipeline
7777

7878
with torch.no_grad():
@@ -83,33 +83,161 @@
8383
"immutable_weights": False,
8484
}
8585

86-
model_id = "runwayml/stable-diffusion-v1-5"
86+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
8787
device = "cuda:0"
8888

89-
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
90-
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
89+
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
90+
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
9191

92-
pipe = DiffusionPipeline.from_pretrained(
93-
model_id, revision="fp16", torch_dtype=torch.float16
94-
)
92+
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
9593
pipe.to(device)
9694

9795
# The only extra line you need
9896
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
99-
100-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
97+
BATCH = torch.export.Dim("BATCH", min=2, max=24)
98+
_HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
99+
_WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
100+
HEIGHT = 4 * _HEIGHT
101+
WIDTH = 4 * _WIDTH
102+
args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
103+
kwargs_dynamic_shapes = {
104+
"encoder_hidden_states": {0: BATCH},
105+
"added_cond_kwargs": {
106+
"text_embeds": {0: BATCH},
107+
"time_ids": {0: BATCH},
108+
},
109+
}
110+
pipe.unet.set_expected_dynamic_shape_range(
111+
args_dynamic_shapes, kwargs_dynamic_shapes
112+
)
113+
image = pipe(
114+
prompt,
115+
negative_prompt=negative,
116+
num_inference_steps=30,
117+
height=1024,
118+
width=768,
119+
num_images_per_prompt=2,
120+
).images[0]
101121
image.save("./without_LoRA_mutable.jpg")
102122

103123
# Standard Huggingface LoRA loading procedure
104124
pipe.load_lora_weights(
105125
"stablediffusionapi/load_lora_embeddings",
106-
weight_name="moxin.safetensors",
126+
weight_name="all-disney-princess-xl-lo.safetensors",
107127
adapter_name="lora1",
108128
)
109129
pipe.set_adapters(["lora1"], adapter_weights=[1])
110130
pipe.fuse_lora()
111131
pipe.unload_lora_weights()
112132

113133
# Refit triggered
114-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
134+
image = pipe(
135+
prompt,
136+
negative_prompt=negative,
137+
num_inference_steps=30,
138+
height=1024,
139+
width=1024,
140+
num_images_per_prompt=1,
141+
).images[0]
115142
image.save("./with_LoRA_mutable.jpg")
143+
144+
145+
# %%
146+
# Use Mutable Torch TensorRT module with dynamic shape
147+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148+
# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function
149+
# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
150+
# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
151+
# Note that you should exclude keyword arguments with value None as those will be filtered out.
152+
153+
154+
class Model(torch.nn.Module):
155+
def __init__(self):
156+
super().__init__()
157+
158+
def forward(self, a, b, c={}):
159+
x = torch.matmul(a, b)
160+
x = torch.matmul(c["a"], c["b"].T)
161+
print(c["b"][0])
162+
x = 2 * c["b"]
163+
return x
164+
165+
166+
device = "cuda:0"
167+
model = Model().eval().to(device)
168+
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
169+
kwargs = {
170+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
171+
}
172+
dim_0 = torch.export.Dim("dim", min=1, max=50)
173+
dim_1 = torch.export.Dim("dim", min=1, max=50)
174+
dim_2 = torch.export.Dim("dim2", min=1, max=50)
175+
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
176+
kwarg_dynamic_shapes = {
177+
"c": {
178+
"a": {},
179+
"b": {0: dim_2},
180+
}, # a's shape does not change so we give it an empty dict
181+
}
182+
# Export the model first with custom dynamic shape constraints
183+
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
184+
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
185+
# Compile
186+
model(*inputs, **kwargs)
187+
# Change input shape
188+
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
189+
kwargs_2 = {
190+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
191+
}
192+
# Run without recompiling
193+
model(*inputs_2, **kwargs_2)
194+
195+
# %%
196+
# Use Mutable Torch TensorRT module with persistent cache
197+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
198+
# Leveraging engine caching, we are able to shortcut the engine compilation and save much time.
199+
import os
200+
201+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
202+
203+
model = models.resnet18(pretrained=True).eval().to("cuda")
204+
205+
times = []
206+
start = torch.cuda.Event(enable_timing=True)
207+
end = torch.cuda.Event(enable_timing=True)
208+
209+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
210+
model = torch_trt.MutableTorchTensorRTModule(
211+
model,
212+
use_python_runtime=True,
213+
enabled_precisions={torch.float},
214+
debug=True,
215+
min_block_size=1,
216+
immutable_weights=False,
217+
cache_built_engines=True,
218+
reuse_cached_engines=True,
219+
engine_cache_size=1 << 30, # 1GB
220+
)
221+
222+
223+
def remove_timing_cache(path=TIMING_CACHE_PATH):
224+
if os.path.exists(path):
225+
os.remove(path)
226+
227+
228+
remove_timing_cache()
229+
230+
for i in range(4):
231+
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
232+
233+
start.record()
234+
model(*inputs) # Recompile
235+
end.record()
236+
torch.cuda.synchronize()
237+
times.append(start.elapsed_time(end))
238+
239+
print("----------------dynamo_compile----------------")
240+
print("Without engine caching, used:", times[0], "ms")
241+
print("With engine caching used:", times[1], "ms")
242+
print("With engine caching used:", times[2], "ms")
243+
print("With engine caching used:", times[3], "ms")

‎py/torch_tensorrt/dynamo/_refit.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,12 @@ def refit_module_weights(
395395
try:
396396
weight_name_map = compiled_submodule.weight_name_map
397397
except AttributeError:
398-
logger.warning(
399-
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
400-
)
398+
if not isinstance(
399+
compiled_submodule, torch.fx.graph_module.GraphModule
400+
):
401+
logger.warning(
402+
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
403+
)
401404
if not weight_name_map:
402405
use_weight_map_cache = False
403406
logger.warning(

‎py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None:
375375

376376
@staticmethod
377377
def find_weight(
378-
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
378+
weight_name: str,
379+
np_map: dict[str, Any],
380+
state_dict: dict[str, Any],
381+
device: torch.device,
379382
) -> str:
380383
"""
381384
We need to build map from engine weight name to state_dict weight name.
@@ -385,19 +388,21 @@ def find_weight(
385388
np_map: the map from weight name to np values in INetworkDefinition
386389
state_dict: state of the graph module
387390
"""
388-
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
391+
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
389392
for sd_w_name, sd_weight in state_dict.items():
390-
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
393+
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
391394
del state_dict[sd_w_name]
392395
return sd_w_name
393396
return ""
394397

395398
@staticmethod
396399
def check_weight_equal(
397-
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
400+
sd_weight: torch.tensor,
401+
network_weight: Union[torch.Tensor, np.ndarray],
402+
device: torch.device,
398403
) -> Any:
399404
if not isinstance(network_weight, torch.Tensor):
400-
network_weight = torch.from_numpy(network_weight).cuda()
405+
network_weight = torch.from_numpy(network_weight).to(device)
401406
try:
402407
return sd_weight.shape == network_weight.shape and torch.all(
403408
torch.abs(sd_weight - network_weight) < 0.01
@@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None:
530535
# There is no direct connection in batch_norm layer. So skip it
531536
pass
532537
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
533-
sd[sd_weight_name], np_map[engine_weight_name]
538+
sd[sd_weight_name], np_map[engine_weight_name], torch_device
534539
):
535540
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
536-
engine_weight_name, np_map, sd
541+
engine_weight_name, np_map, sd, torch_device
537542
)
538543
if (
539544
weight_name_map[engine_weight_name] != ""

‎py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 193 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import logging
23
from copy import deepcopy
34
from enum import Enum, auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
4142
return self._state
4243

4344

45+
class DynamicShapeOutOfRangeException(Exception):
46+
pass
47+
48+
4449
class MutableTorchTensorRTModule(object):
4550
"""
4651
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
6570
Union[torch.dtype, dtype]
6671
] = _defaults.ENABLED_PRECISIONS,
6772
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
68-
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
73+
immutable_weights: bool = False,
6974
debug: bool = _defaults.DEBUG,
7075
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
7176
workspace_size: int = _defaults.WORKSPACE_SIZE,
@@ -189,11 +194,13 @@ def __init__(
189194
"hardware_compatible": hardware_compatible,
190195
"timing_cache_path": timing_cache_path,
191196
}
197+
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
198+
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
192199

193200
self.settings = CompilationSettings(**compilation_options)
194201
self.run_info: Optional[tuple[Any, ...]] = None
195202
self.state_dict_metadata: dict[str, torch.Size] = {}
196-
self.store_state_dict_metadata()
203+
self._store_state_dict_metadata()
197204

198205
cls = self.__class__
199206
self.__class__ = type(
@@ -203,7 +210,66 @@ def __init__(
203210
)
204211
self.init_finished = True
205212

206-
def store_state_dict_metadata(self) -> None:
213+
def set_expected_dynamic_shape_range(
214+
self,
215+
args_dynamic_shape: tuple[dict[Any, Any]],
216+
kwargs_dynamic_shape: dict[str, Any],
217+
) -> None:
218+
"""
219+
Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function
220+
and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
221+
If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
222+
Note that you should exclude keyword arguments with value None as those will be filtered out.
223+
224+
Example:
225+
def forward(a, b, c=0, d=0):
226+
pass
227+
228+
seq_len = torch.export.Dim("seq_len", min=1, max=10)
229+
args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape
230+
kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape
231+
set_expected_dynamic_shape_range(args_dynamic_shape, kwargs_dynamic_shape)
232+
# Later when you call the function
233+
forward(*(a, b), **{c:..., d:...})
234+
235+
Reference: https://pytorch.org/docs/stable/export.html#expressing-dynamism
236+
Arguments:
237+
args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs,
238+
kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs
239+
"""
240+
assert isinstance(
241+
args_dynamic_shape, tuple
242+
), f"args dynamic shape has to be a tuple, but got {type(args_dynamic_shape)}"
243+
assert isinstance(
244+
kwargs_dynamic_shape, dict
245+
), f"args dynamic shape has to be a dictionary, but got {type(kwargs_dynamic_shape)}"
246+
self.kwarg_dynamic_shapes = kwargs_dynamic_shape
247+
self.arg_dynamic_shapes = args_dynamic_shape
248+
249+
# Clear cached inputs
250+
self.arg_inputs = tuple()
251+
self.kwarg_inputs = {}
252+
253+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
254+
255+
def _get_total_dynamic_shapes(self) -> dict[str, Any] | None:
256+
if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes:
257+
return None
258+
total_dynamic_shape = {}
259+
if self.arg_dynamic_shapes:
260+
signature = list(
261+
inspect.signature(self.original_model.forward).parameters.keys()
262+
)
263+
for i, arg in enumerate(self.arg_dynamic_shapes):
264+
total_dynamic_shape[signature[i]] = arg
265+
266+
if self.kwarg_dynamic_shapes:
267+
for kwargs, kwargs_dynamic_shape in self.kwarg_dynamic_shapes.items():
268+
total_dynamic_shape[kwargs] = kwargs_dynamic_shape
269+
270+
return total_dynamic_shape
271+
272+
def _store_state_dict_metadata(self) -> None:
207273
for k, v in self.original_model.state_dict().items():
208274
self.state_dict_metadata[k] = v.shape
209275

@@ -295,6 +361,7 @@ def compile(self) -> None:
295361
self.original_model,
296362
self.arg_inputs,
297363
kwargs=self.kwarg_inputs,
364+
dynamic_shapes=self._get_total_dynamic_shapes(),
298365
)
299366
self.gm = dynamo_compile(
300367
self.exp_program,
@@ -306,39 +373,89 @@ def compile(self) -> None:
306373
torch.cuda.empty_cache()
307374

308375
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
309-
if (
310-
not self.arg_inputs
311-
or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args)
312-
or not MutableTorchTensorRTModule.check_inputs_equal(
313-
self.kwarg_inputs, kwargs
314-
)
315-
):
376+
377+
if not self.arg_inputs:
378+
logger.info("First time compilation initiated. This may take some time.")
379+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
380+
self._store_inputs(args, kwargs)
381+
if self.arg_dynamic_shapes or self.kwarg_dynamic_shapes:
382+
if not self._validates_dynamic_hints():
383+
logger.warning(
384+
"Invalid dynamic shape hint. Compiling module for the provided input shapes (static)"
385+
)
386+
self.arg_dynamic_shapes = None
387+
self.kwarg_dynamic_shapes = None
388+
return
389+
390+
# If input does not equal or does not fall into dynamic shape range, recompile the engine
391+
try:
392+
if not MutableTorchTensorRTModule._check_inputs_shape(
393+
self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes
394+
) or not MutableTorchTensorRTModule._check_inputs_shape(
395+
self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes
396+
):
397+
logger.info("Input change detected.")
398+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
399+
self._store_inputs(args, kwargs)
400+
except DynamicShapeOutOfRangeException as e:
316401
logger.info("Input change detected.")
402+
logger.warning(e)
403+
logger.warning(
404+
"Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)"
405+
)
406+
self.arg_dynamic_shapes = None
407+
self.kwarg_dynamic_shapes = None
317408
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
318-
self.store_inputs(args, kwargs)
409+
self._store_inputs(args, kwargs)
410+
411+
def _validates_dynamic_hints(self) -> bool:
412+
if self.arg_dynamic_shapes is None:
413+
if self.arg_inputs:
414+
logger.warning("arg_dynamic_shape is not provided!")
415+
else:
416+
if len(self.arg_dynamic_shapes) != len(self.arg_inputs):
417+
logger.warning(
418+
f"Warning: The length of arg_inputs is {len(self.arg_inputs)} but the length of arg_dynamic_shape is {len(self.arg_dynamic_shapes)}!"
419+
)
420+
return False
421+
422+
if self.kwarg_dynamic_shapes is None:
423+
if self.kwarg_inputs:
424+
logger.warning("kwarg_dynamic_shape is not provided!")
425+
else:
426+
if self.kwarg_dynamic_shapes.keys() != self.kwarg_inputs.keys():
427+
logger.warning(
428+
f"kwarg_inputs has {list(self.kwarg_inputs.keys())} but kwarg_dynamic_shape has {list(self.kwarg_dynamic_shapes.keys())}! You may need to exclude keyword arguments with value None."
429+
)
430+
return False
319431

320-
def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
432+
return True
433+
434+
def _store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
321435
self.arg_inputs = arg_inputs
322436
self.kwarg_inputs = kwarg_inputs
323437

324438
@staticmethod
325-
def process_kwarg_inputs(inputs: Any) -> Any:
439+
def _process_kwarg_inputs(inputs: Any) -> Any:
326440
# Process kwarg inputs to be acceptable for Torch-TensorRT
327441
if isinstance(inputs, dict):
328442
# None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
329443
return {
330-
k: MutableTorchTensorRTModule.process_kwarg_inputs(v)
444+
k: MutableTorchTensorRTModule._process_kwarg_inputs(v)
331445
for k, v in inputs.items()
332-
if (v is not None and not isinstance(v, bool))
446+
if (v is not None)
333447
}
334-
elif isinstance(inputs, torch.Tensor):
448+
elif isinstance(inputs, (torch.Tensor, bool)):
335449
return inputs
336450
elif isinstance(inputs, (int, float, np.ndarray)):
337451
return torch.tensor(inputs)
338452
elif isinstance(inputs, (list, tuple)):
339453
if None not in inputs:
340454
return type(inputs)(
341-
[MutableTorchTensorRTModule.process_kwarg_inputs(v) for v in inputs]
455+
[
456+
MutableTorchTensorRTModule._process_kwarg_inputs(v)
457+
for v in inputs
458+
]
342459
)
343460

344461
raise ValueError(
@@ -348,7 +465,7 @@ def process_kwarg_inputs(inputs: Any) -> Any:
348465

349466
def forward(self, *args: Any, **kwargs: Any) -> Any:
350467
# Step 1: Check whether the input shape has changed
351-
kwargs = MutableTorchTensorRTModule.process_kwarg_inputs(kwargs)
468+
kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs)
352469
self._validate_inputs(*args, **kwargs)
353470

354471
# Step 2: If the flag is unknown, it could be a recompile or refit.
@@ -360,7 +477,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
360477
if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE:
361478
logger.info("(Re)Compiling the engine...")
362479
self.compile()
363-
self.store_state_dict_metadata()
480+
self._store_state_dict_metadata()
364481
self.refit_state.set_state(RefitFlag.LIVE)
365482

366483
elif self.refit_state.get_state() == RefitFlag.NEEDS_REFIT:
@@ -371,7 +488,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
371488
logger.error(e)
372489
logger.error("Model refit failed. Recompiling the graph module.")
373490
self.compile()
374-
self.store_state_dict_metadata()
491+
self._store_state_dict_metadata()
375492
self.refit_state.set_state(RefitFlag.LIVE)
376493

377494
result = self.gm(*args, **kwargs)
@@ -381,7 +498,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
381498

382499
def to(self, device: str) -> None:
383500
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
384-
self.orignial_model.to(device)
501+
self.original_model.to(device)
385502

386503
def __deepcopy__(self, memo: Any) -> Any:
387504
cls = self.__class__
@@ -433,38 +550,80 @@ def __setattr__(self, name: str, value: Any) -> None:
433550
object.__setattr__(self, name, value)
434551

435552
@staticmethod
436-
def check_inputs_equal(
553+
def _check_inputs_shape(
437554
input1: Any,
438555
input2: Any,
556+
dynamic_shapes: Any = None,
439557
) -> bool:
440-
# TODO: Add support for dynamic shape
558+
441559
if isinstance(input1, (tuple, list)):
442560
if len(input1) != len(input2):
443561
return False
444-
for a, b in zip(input1, input2):
562+
for (i, a), b in zip(enumerate(input1), input2):
445563
if type(a) != type(b):
446564
return False
447-
if isinstance(a, torch.Tensor) and a.shape != b.shape:
448-
return False
449-
elif isinstance(a, bool) and a != b:
565+
if isinstance(a, bool) and a != b:
450566
return False
567+
elif isinstance(a, torch.Tensor) and a.shape != b.shape:
568+
if dynamic_shapes is None:
569+
logger.warning(
570+
"Dynamic shape is not properly set but the input shape is changed!"
571+
)
572+
return False
573+
else:
574+
tensor_dynamic_shape = dynamic_shapes[i]
575+
if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes(
576+
a, b, tensor_dynamic_shape
577+
):
578+
return False
451579

452580
elif isinstance(input1, dict):
453581
if input1.keys() != input2.keys():
454582
return False
455-
for a, b in zip(input1.values(), input2.values()):
456-
if type(a) != type(b):
457-
return False
458-
if isinstance(a, torch.Tensor) and a.shape != b.shape:
583+
for (ka, va), vb in zip(input1.items(), input2.values()):
584+
if type(va) != type(vb):
459585
return False
460-
elif isinstance(a, bool) and a != b:
586+
if isinstance(va, bool) and va != vb:
461587
return False
588+
elif isinstance(va, torch.Tensor) and va.shape != vb.shape:
589+
if dynamic_shapes is None:
590+
logger.warning(
591+
"Dynamic shape is not properly set but the input shape is changed!"
592+
)
593+
return False
594+
else:
595+
tensor_dynamic_shape = dynamic_shapes[ka]
596+
if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes(
597+
va, vb, tensor_dynamic_shape
598+
):
599+
return False
462600
elif isinstance(
463-
a, (list, tuple, dict)
464-
) and not MutableTorchTensorRTModule.check_inputs_equal(a, b):
601+
va, (list, tuple, dict)
602+
) and not MutableTorchTensorRTModule._check_inputs_shape(
603+
va, vb, dynamic_shapes[ka] if dynamic_shapes else None
604+
):
465605
return False
466606
return True
467607

608+
@staticmethod
609+
def _check_tensor_shapes_with_dynamic_shapes(
610+
t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any]
611+
) -> bool:
612+
for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape):
613+
if axis_0 != axis_1:
614+
if i not in dynamic_shape:
615+
logger.warning(
616+
"Dynamic shape does not include the axis on which input changes!"
617+
)
618+
return False
619+
dyn = dynamic_shape[i]
620+
if axis_1 > dyn.max or axis_1 < dyn.min:
621+
raise DynamicShapeOutOfRangeException(
622+
f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!"
623+
)
624+
625+
return True
626+
468627
@staticmethod
469628
def save(module: Any, path: str) -> None:
470629
# Cast the object back to MutableTorchTensorRTModule to save

‎tests/py/dynamo/runtime/test_mutable_torchtrt_module.py

Lines changed: 159 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,160 @@ def test_check_output_equal():
3535
msg=f"test_check_output_equal is not correct.",
3636
)
3737

38+
torch.manual_seed(1)
39+
c = {
40+
"a": torch.rand(10, 30),
41+
"b": [torch.rand(10, 30), torch.rand(5, 5)],
42+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
43+
}
44+
assertions.assertFalse(
45+
check_output_equal(a, c),
46+
msg=f"test_check_output_equal is not correct.",
47+
)
48+
49+
50+
@pytest.mark.unit
51+
def test_check_input_shape_dynamic():
52+
torch.manual_seed(0)
53+
a = {
54+
"a": torch.rand(10, 3),
55+
"b": [torch.rand(10, 30), torch.rand(5, 5)],
56+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
57+
}
58+
torch.manual_seed(0)
59+
b = {
60+
"a": torch.rand(10, 30),
61+
"b": [torch.rand(10, 30), torch.rand(5, 5)],
62+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
63+
}
64+
65+
dim = torch.export.Dim("dim", min=1, max=50)
66+
dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}}
67+
assertions.assertFalse(
68+
torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b),
69+
msg=f"test_check_input_shape_dynamic is not correct.",
70+
)
71+
assertions.assertTrue(
72+
torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b, dynamic_shape),
73+
msg=f"test_check_input_shape_dynamic is not correct.",
74+
)
75+
76+
77+
@pytest.mark.unit
78+
def test_model_complex_dynamic_shape():
79+
device = "cuda:0"
80+
81+
class Model(torch.nn.Module):
82+
def __init__(self):
83+
super().__init__()
84+
85+
def forward(self, a, b, c=None):
86+
x = torch.matmul(a, b)
87+
x = torch.matmul(c["a"], c["b"][0].T)
88+
x = 2 * c["b"][1]
89+
return x
90+
91+
model = Model().eval().to(device)
92+
inputs = [torch.rand(10, 3).to(device)]
93+
kwargs = {
94+
"b": torch.rand(3, 30).to(device),
95+
"c": {
96+
"a": torch.rand(10, 30).to(device),
97+
"b": [torch.rand(10, 30).to(device), torch.rand(5, 3).to(device)],
98+
},
99+
}
100+
101+
dim = torch.export.Dim("dim", min=1, max=50)
102+
dim2 = torch.export.Dim("dim2", min=1, max=50)
103+
args_dynamic_shapes = ({1: dim},)
104+
kwarg_dynamic_shapes = {
105+
"b": {0: dim},
106+
"c": {"a": {}, "b": [{}, {1: dim2}]},
107+
}
108+
# Export the model first with custom dynamic shape constraints
109+
trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
110+
trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
111+
# Run inference
112+
trt_gm(*inputs, **kwargs)
113+
114+
inputs_2 = [torch.rand(10, 9).to(device)]
115+
kwargs_2 = {
116+
"b": torch.rand(9, 30).to(device),
117+
"c": {
118+
"a": torch.rand(10, 30).to(device),
119+
"b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)],
120+
},
121+
}
122+
123+
kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_2)
124+
trt_gm._validate_inputs(*inputs_2, **kwargs_2)
125+
assertions.assertTrue(
126+
trt_gm.refit_state.get_state() == RefitFlag.LIVE,
127+
msg=f"Dynamic shape support of inputs_2 is not correct.",
128+
)
129+
trt_gm(*inputs_2, **kwargs_2)
130+
131+
# Change does not align with Dynamic Shape Hint
132+
inputs_3 = [torch.rand(7, 9).to(device)]
133+
kwargs_3 = {
134+
"b": torch.rand(9, 30).to(device),
135+
"c": {
136+
"a": torch.rand(10, 30).to(device),
137+
"b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)],
138+
},
139+
}
140+
141+
kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_3)
142+
trt_gm._validate_inputs(*inputs_3, **kwargs_3)
143+
assertions.assertTrue(
144+
trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE,
145+
msg=f"Dynamic shape support of inputs_3 is not correct.",
146+
)
147+
trt_gm(*inputs_3, **kwargs_3)
148+
149+
# # Stored input is changed (inputs first dimension is 7)
150+
inputs_4 = [torch.rand(7, 20).to(device)]
151+
kwargs_4 = {
152+
"b": torch.rand(20, 30).to(device),
153+
"c": {
154+
"a": torch.rand(10, 30).to(device),
155+
"b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)],
156+
},
157+
}
158+
159+
kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_4)
160+
trt_gm._validate_inputs(*inputs_4, **kwargs_4)
161+
assertions.assertTrue(
162+
trt_gm.refit_state.get_state() == RefitFlag.LIVE,
163+
msg=f"Dynamic shape support of inputs_4 is not correct.",
164+
)
165+
trt_gm(*inputs_4, **kwargs_4)
166+
167+
# # Change outside of the dynamic range limit
168+
inputs_5 = [torch.rand(7, 900).to(device)]
169+
kwargs_5 = {
170+
"b": torch.rand(900, 30).to(device),
171+
"c": {
172+
"a": torch.rand(10, 30).to(device),
173+
"b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)],
174+
},
175+
}
176+
177+
kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_5)
178+
trt_gm._validate_inputs(*inputs_5, **kwargs_5)
179+
assertions.assertTrue(
180+
trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE,
181+
msg=f"Dynamic shape support of inputs_5 is not correct.",
182+
)
183+
assertions.assertTrue(
184+
trt_gm.arg_dynamic_shapes == None,
185+
msg=f"Dynamic shape support of inputs_5 is not correct.",
186+
)
187+
assertions.assertTrue(
188+
trt_gm.kwarg_dynamic_shapes == None,
189+
msg=f"Dynamic shape support of inputs_5 is not correct.",
190+
)
191+
38192

39193
@unittest.skipIf(
40194
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
@@ -188,7 +342,7 @@ def test_resnet18_modify_attribute_no_refit():
188342
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
189343
assertions.assertTrue(
190344
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
191-
msg=f"The output of refitted Mutable Module is not correct.",
345+
msg=f"The output of original and refitted Mutable Module is not the same.",
192346
)
193347

194348
# # Clean up model env
@@ -255,7 +409,7 @@ def forward(self, x, b=5, c=None, d=None):
255409
)
256410
assertions.assertTrue(
257411
check_output_equal(expected_outputs, refitted_outputs),
258-
msg=f"The output of saved and reloaded Mutable Module is not correct.",
412+
msg=f"The output of original and refitted Mutable Module is not the same.",
259413
)
260414

261415
# Clean up model env
@@ -318,7 +472,7 @@ def set_weights(self):
318472
expected_outputs, refitted_outputs = model(*args), mutable_module(*args)
319473
assertions.assertTrue(
320474
check_output_equal(expected_outputs, refitted_outputs),
321-
msg=f"The output of saved and reloaded Mutable Module is not correct.",
475+
msg=f"The output of original and refitted Mutable Module is not the same.",
322476
)
323477

324478
# Clean up model env
@@ -381,7 +535,7 @@ def set_layer(self):
381535
expected_outputs, refitted_outputs = model(*args), mutable_module(*args)
382536
assertions.assertTrue(
383537
check_output_equal(expected_outputs, refitted_outputs),
384-
msg=f"The output of saved and reloaded Mutable Module is not correct.",
538+
msg=f"The output of original and refitted Mutable Module is not the same.",
385539
)
386540

387541
# Clean up model env
@@ -451,7 +605,7 @@ def forward(self, x, b=5, c=None, d=None):
451605
)
452606
assertions.assertTrue(
453607
check_output_equal(expected_outputs, refitted_outputs),
454-
msg=f"The output of saved and reloaded Mutable Module is not correct.",
608+
msg=f"The output of original and refitted Mutable Module is not the same.",
455609
)
456610

457611
# Clean up model env

0 commit comments

Comments
 (0)
Please sign in to comment.