Skip to content

Commit 9e390da

Browse files
committed
added back control flag
1 parent c018151 commit 9e390da

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
425426
**kwargs: Any,
426427
) -> torch.fx.GraphModule:
427428
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -666,6 +667,7 @@ def compile(
666667
"enable_weight_streaming": enable_weight_streaming,
667668
"tiling_optimization_level": tiling_optimization_level,
668669
"l2_limit_for_tiling": l2_limit_for_tiling,
670+
"offload_module_to_cpu": offload_module_to_cpu,
669671
}
670672

671673
settings = CompilationSettings(**compilation_options)
@@ -677,10 +679,6 @@ def compile(
677679

678680
gm = exported_program.module()
679681
# Move the weights in the state_dict to CPU
680-
exported_program.module().to("cpu")
681-
logger.info(
682-
"The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
683-
)
684682
logger.debug("Input graph: " + str(gm.graph))
685683

686684
# Apply lowering on the graph module

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = False
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

0 commit comments

Comments
 (0)