PyTorch 2.5, CUDA 12.4, TensorRT 10.3, Python 3.12
Torch-TensorRT 2.5.0 targets PyTorch 2.5, TensorRT 10.3 and CUDA 12.4.
(builds for CUDA 11.8/12.1 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu121)
Deprecation notice
The torchscript frontend will be deprecated in v2.6. Specifically, the following usage will no longer be supported and will issue a deprecation warning at runtime if used:
torch_tensorrt.compile(model, ir="torchscript")
Moving forward, we encourage users to transition to one of the supported options:
torch_tensorrt.compile(model)
torch_tensorrt.compile(model, ir="dynamo")
torch.compile(model, backend="tensorrt")
Torchscript will continued to be supported as a deployment format via post compilation tracing
dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
ts_model = torch.jit.trace(dynamo_model, inputs=[...])
ts_model(...)
Please refer to the README for more information regarding our deprecation policy.
Refit (Beta)
v2.5.0 introduces direct model refitting from PyTorch for your compiled Torch-TensorRT programs. Sometimes the weights need to change through the course of inference and in the past full recompilation was necessary to change out the weights of the model, either through automatic recompilation through torch.compile
or through manual recompilation with torch_tensorrt.compile
. Now using the refit_module_weights
API, compiled modules can be refitted by providing a new PyTorch module (with identical structure) containing the new weights. Compiled modules must be compiled with make_refittable
to use this feature.
# Create and export the updated model
model2 = models.resnet18(pretrained=True).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))
compiled_trt_ep = torch_trt.load("./compiled.ep")
# This returns a new module with updated weights
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
)
There are some ops that are not compatible with refit, such as ops that utilize ILoop layer
. When make_refittable
is enabled, these ops will be forced to run in PyTorch. It should also be known that engines that are refit enabled may be slightly less performant than non-refittable engines as TensorRT cannot tune for the specific weights it will see at execution time.
Refit Caching (Experimental)
Refitting on its own can help to speed up update model swap times by 0.5-2x. However, the speed of refit can be further improved by utilizing refit caching. Refit caching at compile time stores hints for a direct mapping from PyTorch module members to TRT layer names in the metadata of TorchTensorRTModule
. This caching can speed up refit by orders of magnitude. However, it currently has limitations when dealing with layers that have compile time optimization. This feature is still experimental as there may be some ops that are not amenable to refit caching. We still enable using the cache by default when refitting to collect feedback on the edge cases and we provide a output validator which can be used to ensure that refit occurred properly. When verify_outputs
is True if the refit failed, then the refitter will discard the cache and refit from scratch.
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
arg_inputs=inputs,
verify_outputs=True,
)
MutableTorchTensorRTModule (Experimental)
torch.compile
is incredibly useful when it comes to trying to optimize models that may change over time since it can automatically recompile the module when something changes. However, the major limitation of torch.compile
is it cannot be serialized. For users who are looking for similar flexibility but the added ability to serialize and move their work we have introduced the MutableTorchTensorRTModule
. This module wraps a PyTorch module and exposes its members transparently, however it injects listeners on setattr
and overrides the forward function to use TensorRT accelerated subgraphs. This means you can make changes to your module such as applying adapters and the MutableTorchTensorRTModule
will detect the change and mark the function for refit or recompilation based on the change. Similar to torch.compile
this is done in a JIT manner, so the first inference after a change will perform the refit or recompile operation.
from diffusers import DiffusionPipeline
with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refittable": True,
}
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda:0"
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)
pipe.to(device)
# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./without_LoRA_mutable.jpg")
# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="moxin.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./with_LoRA_mutable.jpg")
Engine Caching
In some scenarios, users may compile a module multiple times and each time it takes a long time to build a TensorRT engine in the backend. Engine caching will boost performance by reusing previously compiled TensorRT engines rather than recompiling it every time, thereby avoiding recompilation time. When a cached engine is loaded, it will be refitted with the new module weights.
To make it more efficient, as long as two graph modules have the same structure, even though their weights are not the same, we still consider they are the same, i.e., isomorphic graph modules. Isomorphic graph modules with the same compilation settings will share cached engines.
We implemented DiskEngineCache
so that users can directly use the APIs to control how and where to save/load cached engines on the disk of the local machine. For exmaple,
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
make_refitable=True,
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_dir="/tmp/torch_trt_engine_cache"
engine_cache_size=1 << 30, # 1GB
)
In addition, considering some users want to save to or load engines from other servers, clusters, or cloud, we also provided a base class BaseEngineCache
so that users are able to easily implement their own logic to save and load engines. For example,
class MyEngineCache(BaseEngineCache):
def __init__(
self,
addr: str,
) -> None:
self.addr= addr
def save(
self,
hash: str,
blob: bytes,
prefix: str = "blob",
):
# user's customized function to save engines
write_to(self.addr, name=f"{prefix}_{hash}.bin", content=blob)
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
# user's customized function to load engines
return read_from(self.addr, name=f"{prefix}_{hash}.bin")
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
make_refitable=True,
cache_built_engines=True,
reuse_cached_engines=True,
custom_engine_cache=MyEngineCache("xxxxx"),
)
CUDA Graphs
In v2.5.0 CUDA graph support for in engine kernel launch optimization has been added through a new runtime mode. This mode can be activated from Python using
import torch_tensorrt
my_torchtrt_model = torch_tensorrt.compile(...)
with torch_tensorrt.runtime.enable_cudagraphs():
my_torchtrt_model(inputs)
This mode works by creating CUDAGraphs around individual TensorRT engines which improves their efficiency. It creates graph through a capture phase which is tied to the input shape to the engine. When the input shape changes, this graph is invalidated and the graph is automatically recaptured.
Model Optimizer based Int8 Quantization(PTQ) support for Linux
This version introduces official support for the int8 Quantization via modelopt (https://github.com/NVIDIA/TensorRT-Model-Optimizer) 17.0 for Linux.
Full examples can be found at https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/vgg16_ptq.py
running the vgg16 example for int8 ptq
step1: generate checkpoint file for vgg16:
cd examples/int8/training/vgg16
python main.py --lr 0.01 --batch-size 128 --drop-ratio 0.15 \
--ckpt-dir $(pwd)/vgg16_ckpts --epochs 20 --seed 545
this should produce a ckpt file at examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch20.pth
step2: run int8 ptq for vgg16:
python examples/dynamo/vgg16_fp8_ptq.py --batch-size 128 \
--ckpt=examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch20.pth \
--quantize-type=int8
LLM examples
We now offer dynamic shape support for all converters (covering core ATen operations). Dynamic shapes are widely utilized in leading LLM models, where input sequence lengths may vary. With this release, we showcase full graph compilation for Llama2 and GPT2 models using Torch-TensorRT. For detailed examples, please refer to our documentation.
What's Changed
- Template linux test Workspace in CI by @lanluo-nvidia in #2956
- chore: remove aten.full decomposition by @peri044 in #2954
- empty tensor moving to default device by @apbose in #2948
- fix: Add missing select in Bazel BUILD by @gs-olive in #2966
- chore: fix use_cache flag by @peri044 in #2965
- add dynamic support for floor/logical_not/sign/round/isinf/isnan by @lanluo-nvidia in #2963
- Implemented basic pipeline for Refitting by @cehongwang in #2886
- Add dynamic shape support for bitwise_and/or/xor/not, exp/expm1/recip/log/log2/log10 by @lanluo-nvidia in #2973
- feat: dynamic shape support for atan/asinh/acosh/atanh/atan2/ceil by @keehyuna in #2959
- chore: doc update by @peri044 in #2967
- Doc build pipeline issue fix by @lanluo-nvidia in #2985
- chore: fix bert example description in doc by @ispobock in #2984
- feat: add dynamic support for eq/ne/lt/le by @chohk88 in #2979
- Add support for prelu dynamo converter by @HolyWu in #2972
- refactor: Address some issues with enums and overhaul documentation by @narendasan in #2974
- feat: Cudagraphs integration for Torch-TRT + Non-default Stream Utilization by @gs-olive in #2881
- feat: dynamic shape support for pow/mod/eq operator by @keehyuna in #2982
- chore: dynamic shape support for clamp/min/max/floor_div/logical_and by @keehyuna in #2977
- dynamic shape for slice converter by @apbose in #2901
- add dynamic shape support for scaled_dot_product_attention, logical_or/xor by @lanluo-nvidia in #2975
- feat: dynamic shape support for squeeze ops by @keehyuna in #2994
- Lluo/auto release cherry pick main by @lanluo-nvidia in #2992
- add the sym_not / full operator to support dynamic shape by @lanluo-nvidia in #3013
- fix: fix exported_program import error by @peri044 in #3007
- feat: support dynamic shapes for avg poolNd by @chohk88 in #3010
- [WIP] upgrade to public TRT 10.1.0.27 by @zewenli98 in #2855
- feat: support dynamic shape for aten.linear by @chohk88 in #3011
- scatter_add_decomposition by @apbose in #2740
- Add dynamic support to roll/scaler_tensor by @lanluo-nvidia in #3023
- add dynamic support for embedding_bag/index_select by @lanluo-nvidia in #3032
- chore: doc fix by @peri044 in #3036
- Group norm bug fix by @cehongwang in #3014
- Fix the build name issue: cherry pick from 3041 back to main by @lanluo-nvidia in #3042
- refactor: Upgrade bazel and move to MODULE.bazel by @narendasan in #3012
- Overhaul upsample dynamo converter by @HolyWu in #2790
- chore: dynamic shape support for flip ops by @keehyuna in #3046
- feat: dynamic shape support for adaptive_avg_poolNd (partially) by @chohk88 in #3021
- feat: dynamic shape support for aten.select.int by @chohk88 in #2990
- chore: bug fixes for full and expand by @peri044 in #3019
- chore: re-enable serde tests by @peri044 in #2968
- chore: dynamic shape support for any/sort/trunc ops by @keehyuna in #3026
- chore: bug fix in runtime by @peri044 in #3054
- Add dynamic shape support for cumsum/grid by @lanluo-nvidia in #3051
- chore: fix attn test by @peri044 in #3055
- feat: Lazy engine initialization by @narendasan in #2997
- bugfix: WAR disable BERT TS test by @narendasan in #3057
- Added kwarg support for dynamo.compile by @cehongwang in #2970
- Renamed convert_exported_program_to_serialized_trt_engine by @cehongwang in #3066
- feat: dynamic support for pixel_suffle and pixel_unshuffle by @chohk88 in #3044
- chore: dynamic shape support for pdist ops by @keehyuna in #3068
- feat: Support
aten.dot
dynamo converter by @HolyWu in #3043 - dynamic shape argmax and argmin by @apbose in #3009
- feat: dynamic shape support for pad ops by @chohk88 in #3045
- Run all model tests by @narendasan in #3078
- Added refitting acceleration by @cehongwang in #2983
- fix converter test error by @lanluo-nvidia in #3083
- fix: Fix the CUDAGraphs C++ runtime implementation by @narendasan in #3067
- Fix: Layer norm Torchscript converter by @narendasan in #3062
- Implemented basic Mutable torch tensorrt module pipeline by @cehongwang in #2981
- fix: Adjust reflection pad test cases to prevent runtime errors by @chohk88 in #3088
- Fix docker build issue by @lanluo-nvidia in #3070
- fix the libtorch version mismatch issue by @lanluo-nvidia in #3086
- Fix TypeError in MutableTorchTensorRTModule on Python 3.9 by @HolyWu in #3094
- feat: lowering replace aten.full_like with aten.full by @chohk88 in #3077
- Notebook failure by @apbose in #3048
- Fix tests-py-dynamo-cudagraphs in Windows test by @HolyWu in #3096
- feat: Adding live progress monitoring to the engine building phase by @narendasan in #3087
- chore: fix group_norm test by @keehyuna in #3091
- feat: Add handling for ITensor mean and var in batch_norm by @chohk88 in #3099
- fix tensorrt dependencies issue by @lanluo-nvidia in #3105
- Fix assertEquals AttributeError on Python 3.12 by @HolyWu in #3112
- fix: Relax thresholds of dynamo converters' tests by @zewenli98 in #3061
- fix: TS test_scaled_dot_product_attention by @zewenli98 in #3117
- Added tensor_parallelism examples by @cehongwang in #3047
- chore: Bump TRT version to 10.3.0.26 by @zewenli98 in #3071
- feat: Save target platform as part of TRTEngine Metadata by @narendasan in #3106
- Docs mutable by @narendasan in #3121
- Refit bug fix by @cehongwang in #3097
- add int8 quantization support by @lanluo-nvidia in #3058
- chore: Fixes required for LLM models by @peri044 in #3002
- Fix typo in _compile.py by @juliusgh in #3128
- tile dynamic dim by @apbose in #3085
- Fixing slice scatter and select scatter decomposition by @apbose in #3093
- feat: engine caching by @zewenli98 in #2995
- Dynamic shape index by @apbose in #3039
- feat: Support
aten.gelu
dynamo converter by @HolyWu in #3134 - fix: get_padded_shape_tensors can now handle dynamic pads by @jiwoong-choi in #3123
- bugfix: allow empty tuple for
inputs
orarg_inputs
by @jiwoong-choi in #3122 - docs: Adding words to the refit and engine caching tutorials by @narendasan in #3141
- Fix doc index by @HolyWu in #3130
- chunk converter validator by @apbose in #3120
- fix: remove global logging setup by @seymurkafkas in #3147
- register_jit_hooks: Remove confusing error message by @HolyWu in #3150
- _refit: Properly compare device type by @HolyWu in #3149
- scatter reduce decomposition by @apbose in #3008
- chore: make engine caching opt-in feature by @peri044 in #3152
- fix: distingush engines based on compilation settings in addition to … by @narendasan in #3155
- release 2.5 branch cut by @lanluo-nvidia in #3161
- Lluo/fix merge issue by @lanluo-nvidia in #3162
- feat: cherry pick of Refit fixes by @peri044 in #3166
- cherry pick 3203 from main to release/2.5 by @lanluo-nvidia in #3208
- cherry pick to release/2.5 fix: Fix static arange export 3194 by @lanluo-nvidia in #3207
- cherry pick #3218: extend windows build timeout from 60 min to 120 min by @lanluo-nvidia in #3219
- cherry pick fix global partitioner bug #3195 from main to release/2.5 branch by @lanluo-nvidia in #3209
- cherry pick #3193: add req_full_compilation_arg from main to release/2.5 by @lanluo-nvidia in #3213
- cherry pick #3225 change from main to release/2.5 branch by @lanluo-nvidia in #3228
- cherry pick refit error 3170 from main to release/2.5 branch by @lanluo-nvidia in #3236
- cherry pick #3191 from main to release/2.5 by @lanluo-nvidia in #3237
- cherry pick doc fix #3238 from main to release2.5 by @lanluo-nvidia in #3240
New Contributors
- @ispobock made their first contribution in #2984
- @juliusgh made their first contribution in #3128
- @jiwoong-choi made their first contribution in #3123
- @seymurkafkas made their first contribution in #3147
Full Changelog: v2.4.0...v2.5.0