Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from torch_tensorrt.dynamo.utils import (
deallocate_module,
get_cpu_memory_usage,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
Expand Down Expand Up @@ -104,6 +105,7 @@ def cross_compile_for_windows(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -178,6 +180,7 @@ def cross_compile_for_windows(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -333,6 +336,7 @@ def cross_compile_for_windows(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"use_distributed_mode_trace": use_distributed_mode_trace,
"cpu_memory_budget": cpu_memory_budget,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -434,6 +438,7 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -680,8 +685,9 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"cpu_memory_budget": cpu_memory_budget,
}

logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
Expand All @@ -695,14 +701,17 @@ def compile(

# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(gm, delete_module=False)
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
Expand Down Expand Up @@ -829,6 +838,7 @@ def preserve_module_specs(
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
skip_fusion=(num_supported_ops == total_ops),
cpu_memory_budget=settings.cpu_memory_budget,
)

except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
Expand Down Expand Up @@ -857,17 +867,27 @@ def preserve_module_specs(
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))

submodule_node_dict = {}
for node in partitioned_module.graph.nodes:
if "_run_on_acc" not in node.name:
for name, node in partitioned_module.named_children():
if "_run_on_acc" not in name:
continue
submodule_node_dict[node.name] = node
submodule_node_dict[name] = node

preserve_module_specs(original_in_spec, original_out_spec, partitioned_module)
# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
# This is done to release CPU memory.
for attr in dir(gm):
if attr.startswith("_frozen_param"):
delattr(gm, attr)

from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS

DYNAMO_CONVERTERS.disallowed_targets = set()

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand Down Expand Up @@ -1056,6 +1076,7 @@ def convert_exported_program_to_serialized_trt_engine(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -1243,7 +1264,7 @@ def convert_exported_program_to_serialized_trt_engine(

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
CPU_MEMORY_BUDGET = -1

if platform.system() == "Linux":
import pwd
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
Expand Down Expand Up @@ -140,6 +141,7 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
cpu_memory_budget: int = CPU_MEMORY_BUDGET

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down
40 changes: 22 additions & 18 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
from torch_tensorrt.dynamo.observer import Observer
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
from torch_tensorrt.dynamo.utils import (
DYNAMIC_DIM,
deallocate_module,
get_cpu_memory_usage,
to_torch_device,
)
from torch_tensorrt.logging import TRT_LOGGER

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError):


class TRTInterpreterResult(NamedTuple):
serialized_engine: bytes
engine: trt.ICudaEngine
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
Expand Down Expand Up @@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None:
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
torch_device = to_torch_device(self.compilation_settings.device)
self.module.to(torch_device)
sd = self.module.state_dict()
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
Expand Down Expand Up @@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None:
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None:
serialized_engine = engine.serialize()
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
# engine = runtime.deserialize_cuda_engine(serialized_engine)

# serialization_config = engine.create_serialization_config()
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
# serialized_engine = engine.serialize_with_config(
Expand Down Expand Up @@ -733,6 +735,9 @@ def run(
return interpreter_result # type: ignore[no-any-return]

self._construct_trt_network_def()
_LOGGER.debug(
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
)

if not self.compilation_settings.immutable_weights:
self._save_weight_mapping()
Expand All @@ -750,16 +755,19 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)
serialized_engine = self.builder.build_serialized_network(

cuda_engine = self.builder.build_engine_with_config(
self.ctx.net, builder_config
)
assert serialized_engine
assert cuda_engine

_LOGGER.debug(
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
)

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self.ctx.clear_cpu_weights_reference_holder()

self._save_timing_cache(
Expand All @@ -772,14 +780,10 @@ def run(
and self.compilation_settings.cache_built_engines
and self.engine_cache is not None
):
self._insert_engine_to_cache(hash_val, serialized_engine)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
self._insert_engine_to_cache(hash_val, cuda_engine)

return TRTInterpreterResult(
engine_str,
cuda_engine,
self._input_names,
self._output_names,
self.weight_name_map,
Expand Down
53 changes: 44 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from __future__ import annotations

import io
import logging
from typing import Any, List, Optional, Sequence
from typing import Any, List, NamedTuple, Optional, Sequence

import torch
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
TRTInterpreter,
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_output_dtypes
from torch_tensorrt.dynamo.utils import (
get_cpu_memory_usage,
get_output_dtypes,
release_memory,
)

logger = logging.getLogger(__name__)


class SerializedInterpreterResult(NamedTuple):
serialized_engine: bytes
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
requires_output_allocator: bool


def infer_module_output_dtypes(
module: torch.fx.GraphModule,
truncate_double: bool = False,
Expand All @@ -29,7 +39,7 @@ def infer_module_output_dtypes(
"""
outputs = [node for node in module.graph.nodes if node.op == "output"]
outputs = outputs[0].args
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
return get_output_dtypes(outputs, truncate_double)


def interpret_module_to_result(
Expand All @@ -39,7 +49,7 @@ def interpret_module_to_result(
arg_inputs: Optional[Sequence[Input]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
engine_cache: Optional[BaseEngineCache] = None,
) -> TRTInterpreterResult:
) -> SerializedInterpreterResult:
"""Interpret an FX module to a TRTInterpreterResult
Args:
module: FX GraphModule to interpret
Expand All @@ -65,7 +75,32 @@ def interpret_module_to_result(
)

interpreter_result = interpreter.run()
return interpreter_result
# Delete the frozen parameters from the module to release CPU memory
del interpreter
for attr in dir(module):
if attr.startswith("_frozen_param"):
delattr(module, attr)
release_memory()
logger.debug(
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
)

serialized_engine = interpreter_result.engine.serialize()
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
serialized_engine = engine_bytes.getvalue()
logger.debug(
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
)
serialized_interpreter_result = SerializedInterpreterResult(
serialized_engine=serialized_engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
weight_name_map=interpreter_result.weight_name_map,
requires_output_allocator=interpreter_result.requires_output_allocator,
)

return serialized_interpreter_result


def convert_module(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/debug/_Debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
"class": "logging.FileHandler",
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
"formatter": "standard",
"mode": "w", # This will clear the previous content
}
config["loggers"][""]["handlers"].append("file")
return config
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def constant_fold(
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
gm,
node,
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
)

erased_params = []
Expand Down
Loading
Loading