5050from torch_tensorrt .dynamo .debug ._DebuggerConfig import DebuggerConfig
5151from torch_tensorrt .dynamo .debug ._supports_debugger import cls_supports_debugger
5252from torch_tensorrt .dynamo .observer import Observer
53- from torch_tensorrt .dynamo .utils import DYNAMIC_DIM , deallocate_module , to_torch_device
53+ from torch_tensorrt .dynamo .utils import (
54+ DYNAMIC_DIM ,
55+ deallocate_module ,
56+ get_cpu_memory_usage ,
57+ to_torch_device ,
58+ )
5459from torch_tensorrt .logging import TRT_LOGGER
5560
5661_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError):
6570
6671
6772class TRTInterpreterResult (NamedTuple ):
68- serialized_engine : bytes
73+ engine : trt . ICudaEngine
6974 input_names : Sequence [str ]
7075 output_names : Sequence [str ]
7176 weight_name_map : Optional [dict [Any , Any ]]
@@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None:
512517 _LOGGER .info ("Building weight name mapping..." )
513518 # Stage 1: Name mapping
514519 torch_device = to_torch_device (self .compilation_settings .device )
515- self .module .to (torch_device )
516- sd = self .module .state_dict ()
520+ sd = {k : v .to (torch_device ) for k , v in self .module .state_dict ().items ()}
517521 weight_name_map : dict [str , Any ] = {}
518522 weight_refit_map = self .ctx .weight_refit_map
519523 constant_mapping = {k : v for k , v in weight_refit_map .items () if v .size == 1 }
@@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None:
592596 torch .cuda .empty_cache ()
593597
594598 @needs_refit # type: ignore[misc]
595- def _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) -> None :
599+ def _insert_engine_to_cache (self , hash_val : str , engine : trt .ICudaEngine ) -> None :
600+ serialized_engine = engine .serialize ()
596601 # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
597602 # if not self.compilation_settings.strip_engine_weights:
598603 # # set EXCLUDE_WEIGHTS flag to strip weights
599- # runtime = trt.Runtime(TRT_LOGGER)
600- # engine = runtime.deserialize_cuda_engine(serialized_engine)
601-
602604 # serialization_config = engine.create_serialization_config()
603605 # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
604606 # serialized_engine = engine.serialize_with_config(
@@ -733,6 +735,9 @@ def run(
733735 return interpreter_result # type: ignore[no-any-return]
734736
735737 self ._construct_trt_network_def ()
738+ _LOGGER .debug (
739+ f"CPU memory usage after network construction: { get_cpu_memory_usage ()} MB"
740+ )
736741
737742 if not self .compilation_settings .immutable_weights :
738743 self ._save_weight_mapping ()
@@ -750,16 +755,19 @@ def run(
750755 self ._create_timing_cache (
751756 builder_config , self .compilation_settings .timing_cache_path
752757 )
753- serialized_engine = self .builder .build_serialized_network (
758+
759+ cuda_engine = self .builder .build_engine_with_config (
754760 self .ctx .net , builder_config
755761 )
756- assert serialized_engine
762+ assert cuda_engine
763+
764+ _LOGGER .debug (
765+ f"CPU memory usage after engine building: { get_cpu_memory_usage ()} MB"
766+ )
757767
758768 _LOGGER .info (
759769 f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
760770 )
761- _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
762-
763771 self .ctx .clear_cpu_weights_reference_holder ()
764772
765773 self ._save_timing_cache (
@@ -772,14 +780,10 @@ def run(
772780 and self .compilation_settings .cache_built_engines
773781 and self .engine_cache is not None
774782 ):
775- self ._insert_engine_to_cache (hash_val , serialized_engine )
776-
777- with io .BytesIO () as engine_bytes :
778- engine_bytes .write (serialized_engine )
779- engine_str = engine_bytes .getvalue ()
783+ self ._insert_engine_to_cache (hash_val , cuda_engine )
780784
781785 return TRTInterpreterResult (
782- engine_str ,
786+ cuda_engine ,
783787 self ._input_names ,
784788 self ._output_names ,
785789 self .weight_name_map ,
0 commit comments