diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0dc4654db0..130d693b60 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -693,6 +693,7 @@ def compile( # Move the weights in the state_dict to CPU if offload_module_to_cpu: + deallocate_module(gm, delete_module=False) deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..a329f692a1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} @@ -592,13 +591,11 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None: + serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # serialization_config = engine.create_serialization_config() # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) # serialized_engine = engine.serialize_with_config( @@ -750,16 +747,15 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) - assert serialized_engine + assert cuda_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( @@ -772,14 +768,10 @@ def run( and self.compilation_settings.cache_built_engines and self.engine_cache is not None ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + self._insert_engine_to_cache(hash_val, cuda_engine) return TRTInterpreterResult( - engine_str, + cuda_engine, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..f0519eb263 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -104,7 +104,7 @@ def convert_module( ) return rt_cls( - serialized_engine=interpreter_result.serialized_engine, + cuda_engine=interpreter_result.engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name=name, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..9b821df906 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..5a935e5c79 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -123,6 +123,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, + cuda_engine: trt.ICudaEngine = None, serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, @@ -182,7 +183,19 @@ def __init__( # Unused currently - to be used by Dynamic Shape support implementation self.memory_pool = None - self.serialized_engine = serialized_engine + if cuda_engine: + assert isinstance( + cuda_engine, trt.ICudaEngine + ), "Cuda engine must be a trt.ICudaEngine object" + self.engine = cuda_engine + elif serialized_engine: + assert isinstance( + serialized_engine, bytes + ), "Serialized engine must be a bytes object" + self.engine = serialized_engine + else: + raise ValueError("Serialized engine or cuda engine must be provided") + self.input_names = ( input_binding_names if input_binding_names is not None else [] ) @@ -204,7 +217,6 @@ def __init__( else False ) self.settings = settings - self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() self.runtime_states = TorchTRTRuntimeStates( @@ -219,7 +231,7 @@ def __init__( self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - if self.serialized_engine is not None and not self.settings.lazy_engine_init: + if self.engine and not self.settings.lazy_engine_init: self.setup_engine() def get_streamable_device_memory_budget(self) -> Any: @@ -260,13 +272,22 @@ def set_default_device_memory_budget(self) -> int: return self._set_device_memory_budget(budget_bytes) def setup_engine(self) -> None: + + if isinstance(self.engine, trt.ICudaEngine): + pass + elif isinstance(self.engine, bytes): + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.engine) + else: + raise ValueError( + "Expected engine as trt.ICudaEngine or serialized engine as bytes" + ) + assert ( self.target_platform == Platform.current_platform() ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" self.initialized = True - runtime = trt.Runtime(TRT_LOGGER) - self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) if self.settings.enable_weight_streaming: self.set_default_device_memory_budget() self.context = self.engine.create_execution_context() @@ -302,7 +323,7 @@ def _check_initialized(self) -> None: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: - state_dict[prefix + "engine"] = self.serialized_engine + state_dict[prefix + "engine"] = self.engine state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names state_dict[prefix + "platform"] = self.target_platform @@ -317,7 +338,7 @@ def _load_from_state_dict( unexpected_keys: Any, error_msgs: Any, ) -> None: - self.serialized_engine = state_dict[prefix + "engine"] + self.engine = state_dict[prefix + "engine"] self.input_names = state_dict[prefix + "input_names"] self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..be5d60ff58 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -2,10 +2,12 @@ import base64 import copy +import io import logging import pickle from typing import Any, List, Optional, Tuple, Union +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform @@ -76,6 +78,7 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] def __init__( self, + cuda_engine: Optional[trt.ICudaEngine | bytes] = None, serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, @@ -123,8 +126,22 @@ def __init__( """ super(TorchTensorRTModule, self).__init__() - if not isinstance(serialized_engine, bytearray): - ValueError("Expected serialized engine as bytearray") + if serialized_engine: + assert isinstance( + serialized_engine, bytes + ), "Serialized engine must be a bytes object" + self.serialized_engine = serialized_engine + + elif cuda_engine: + assert isinstance( + cuda_engine, trt.ICudaEngine + ), "Cuda engine must be a trt.ICudaEngine object" + serialized_engine = cuda_engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) # type: ignore + self.serialized_engine = engine_bytes.getvalue() + else: + raise ValueError("Serialized engine or cuda engine must be provided") self.input_binding_names = ( input_binding_names if input_binding_names is not None else [] @@ -136,12 +153,11 @@ def __init__( self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map - self.serialized_engine = serialized_engine self.engine = None self.requires_output_allocator = requires_output_allocator if ( - serialized_engine + self.serialized_engine and not self.settings.lazy_engine_init and not self.settings.enable_cross_compile_for_windows ):