Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def save(
inputs: Optional[Sequence[torch.Tensor]] = None,
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
retrace: bool = False,
retrace: bool = True,
pickle_protocol: int = 2,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -661,7 +661,7 @@ def save(
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
)
elif module_type == _ModuleType.ts:
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
if not all(output_format == f for f in ["exported_program", "aot_inductor"]):
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
)
Expand Down
107 changes: 91 additions & 16 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
from torch_tensorrt.dynamo.lowering import (
clean_up_graph_after_modifications,
get_decompositions,
post_lowering,
pre_export_lowering,
Expand Down Expand Up @@ -94,6 +95,8 @@ def construct_refit_mapping_from_weight_name_map(
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
# Add more constant folding converters here
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
# Batch Norm Layer
params = {}
Expand All @@ -106,12 +109,12 @@ def construct_refit_mapping_from_weight_name_map(
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
)

elif sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)

engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
to_torch_device(settings.device)
)
Expand Down Expand Up @@ -272,12 +275,66 @@ def refit_module_weights(
compiled_submodules_map[name] = submodule

else:
# Handle torch modules
compiled_submodules_map = {}
guard_fn_modules = []
for name, submodule in compiled_module.named_children():
if not isinstance(
submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
if (
not isinstance(
submodule,
(
PythonTorchTensorRTModule,
TorchTensorRTModule,
torch.nn.modules.module.Module,
),
)
or "_run_on_gpu" in name
):
continue
settings = submodule.settings

# When we re-export the graph module, torch.export._unlift.GuardsFn modules are being added as submodules.
if isinstance(submodule, torch.export._unlift.GuardsFn):
guard_fn_modules.append(name)
continue
# Obtain the settings

compiled_submodules = [
(name.replace("_engine", ""), engine)
for name, engine in submodule.__dict__.items()
if "engine" in name
]

settings = None
try:
# If the gm is not inlined or transformed by retracing, the settings is stored in the submodule
settings = submodule.settings
except AttributeError:

encoded_metadata = [
engine for name, engine in compiled_submodules if name == "engine"
][0].__getstate__()[0][SERIALIZED_METADATA_IDX]
assert (
encoded_metadata != ""
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version"
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)[
"settings"
]

compiled_submodules_map[name] = submodule

# Delete the guard fn modules to avoid the guard fn modules being refitted
# First, remove nodes in the graph that reference the guard function modules
for node in list(compiled_module.graph.nodes):
if node.op == "call_module" and node.target in guard_fn_modules:
compiled_module.graph.erase_node(node)

# Now delete the submodules themselves
for guard_fn_module_name in guard_fn_modules:
# delattr(compiled_module, guard_fn_module_name)
compiled_module.delete_submodule(guard_fn_module_name)

# Clean up the graph
clean_up_graph_after_modifications(compiled_module)

assert settings is not None

Expand Down Expand Up @@ -411,11 +468,29 @@ def refit_module_weights(
)
else:
compiled_submodule = getattr(compiled_module, name)
if "_run_on_acc" not in name:
compiled_submodule.load_state_dict(new_submodule.state_dict())
continue

weight_name_map = None
if use_weight_map_cache:
try:
weight_name_map = compiled_submodule.weight_name_map
except AttributeError:
if isinstance(compiled_submodule, torch.nn.Module):
# Torch retrace module
assert (
not settings.use_python_runtime
), "Refitting a torch retraced module is only supported with use_python_runtime=False"
encoded_metadata = [
engine
for name, engine in compiled_submodules
if name == "engine"
][0].__getstate__()[0][SERIALIZED_METADATA_IDX]
weight_name_map = TorchTensorRTModule.decode_metadata(
encoded_metadata
)["weight_name_map"]

if not isinstance(
compiled_submodule, torch.fx.graph_module.GraphModule
):
Expand All @@ -427,21 +502,16 @@ def refit_module_weights(
logger.warning(
"This engine does not have a weight map cache. Rebuilding the weight map"
)
if isinstance(compiled_submodule, PythonTorchTensorRTModule):

# Rexporting the TRT compiled graph module and loading it back doesn't preserve the instance type and registers
# the compiled submodule as torch.nn.Module. So we use settings.use_python_runtime to determine the instance type.
if settings.use_python_runtime:
engine = compiled_submodule.engine
elif isinstance(compiled_submodule, TorchTensorRTModule):
else:
engine_info = compiled_submodule.engine.__getstate__()[0]
engine = get_engine_from_encoded_engine(
engine_info[ENGINE_IDX], runtime
)
elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule):
# This is graph break resulted by unsupported ops
compiled_submodule.load_state_dict(new_submodule.state_dict())
continue
else:
raise AssertionError(
"The type of graph module is not supported for refitting."
)
except AttributeError:
raise AssertionError(
"The type of graph module is not supported for refitting or two compiled modules do not match."
Expand Down Expand Up @@ -500,7 +570,12 @@ def refit_module_weights(
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
setattr(compiled_module, f"{name}_engine", refitted_engine)

elif isinstance(compiled_submodule, torch.nn.Module):
# Torch retrace module
new_engine_info = list(engine_info)
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
compiled_submodule.engine = refitted_engine
del engine
gc.collect()
torch.cuda.empty_cache()
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._aten_lowering_pass import *
from .pass_utils import clean_up_graph_after_modifications
from .remove_sym_nodes import remove_sym_nodes
from .repair_input_aliasing import repair_input_aliasing
10 changes: 5 additions & 5 deletions tests/py/dynamo/models/test_export_kwargs_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x, b=5, c=None, d=None):

# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()

Expand Down Expand Up @@ -138,7 +138,7 @@ def forward(self, x, b=5, c=None, d=None):

# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()

Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(self, x, b=5, c=None, d=None):

# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()

Expand Down Expand Up @@ -299,7 +299,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
)
# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()

Expand Down Expand Up @@ -389,7 +389,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
)
# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()

Expand Down
22 changes: 11 additions & 11 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
# Check Pyt and TRT exported program outputs
Expand Down Expand Up @@ -111,7 +111,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
# Check Pyt and TRT exported program outputs
Expand Down Expand Up @@ -170,7 +170,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
# Check Pyt and TRT exported program outputs
Expand Down Expand Up @@ -232,7 +232,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_resnet18(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir):
msg="Model should be offloaded to CPU",
)
model.cuda()
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)
deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = deser_trt_module(input)
outputs_trt = trt_module(input)
Expand Down Expand Up @@ -463,7 +463,7 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down Expand Up @@ -525,7 +525,7 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
model.cuda()
torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down Expand Up @@ -584,7 +584,7 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

torchtrt.save(trt_module, trt_ep_path)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
Expand Down
10 changes: 6 additions & 4 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def test_refit_one_engine_inline_runtime_with_weightmap():
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))
exp_program = torch.export.export(model, tuple(inputs), strict=False)
exp_program2 = torch.export.export(model2, tuple(inputs), strict=False)

trt_gm = torchtrt.dynamo.compile(
exp_program,
Expand All @@ -551,8 +551,9 @@ def test_refit_one_engine_inline_runtime_with_weightmap():
min_block_size=min_block_size,
immutable_weights=False,
)
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs, retrace=True)
trt_gm = torch.export.load(trt_ep_path)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
Expand All @@ -565,6 +566,7 @@ def test_refit_one_engine_inline_runtime_with_weightmap():
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)

for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
Expand Down Expand Up @@ -906,7 +908,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
min_block_size=min_block_size,
immutable_weights=False,
)
torchtrt.save(trt_gm, trt_ep_path)
torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
Expand Down
4 changes: 3 additions & 1 deletion tests/py/dynamo/runtime/test_002_lazy_engine_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def test_lazy_engine_init_cpp_serialization(self):
trt_mod = torchtrt.compile(model, **compile_spec)

with tempfile.TemporaryDirectory() as tmpdir:
torch_tensorrt.save(trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"))
torch_tensorrt.save(
trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"), arg_inputs=(input,)
)
new_trt_mod = torch.export.load(os.path.join(tmpdir, "tmp_trt_mod.ep"))

loaded_trt_mod = new_trt_mod.module()
Expand Down
Loading