Skip to content

feat: TensorRT AOT Plugin #3504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions examples/dynamo/aot_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import argparse
from typing import Tuple, Union

import tensorrt as trt
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl

trt_logger = trt.Logger(trt.Logger.VERBOSE)


@triton.jit
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
output = x + 1
tl.store(y_ptr + offsets, output, mask=mask)


@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
def add_one(X: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda

# Create output tensor
Y = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 256

# Grid of programs
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)

# Launch the kernel
add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)

return Y


@torch.library.register_fake("my::add_one")
def _(X: torch.Tensor) -> torch.Tensor:
return X


@trtp.register("my::add_one")
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
return X.like()


@trtp.aot_impl("my::add_one")
def add_plugin_aot_impl(
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
]:
type_str = "fp32" if X.dtype == trt.float32 else "fp16"

block_size = 256
src = triton.compiler.ASTSource(
fn=add_one_kernel,
signature={
"x_ptr": f"*{type_str}",
"n_elements": "i32",
"y_ptr": f"*{type_str}",
"BLOCK_SIZE": "constexpr",
},
constants={
"BLOCK_SIZE": block_size,
},
)

compiled_kernel = triton.compile(src)

N = X.shape_expr.numel()
launch_params = trtp.KernelLaunchParams()

# grid dims
launch_params.grid_x = trtp.cdiv(N, block_size)
# block dims
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
# shared memory
launch_params.shared_mem = compiled_kernel.metadata.shared

extra_args = trtp.SymIntExprs(1)
extra_args[0] = trtp.SymInt32(N)

return (
compiled_kernel.metadata.name,
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)


torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"my::add_one",
supports_dynamic_shapes=False,
requires_output_allocator=False,
use_aot_if_available=True,
)


class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, X: torch.Tensor) -> torch.Tensor:
res = torch.ops.my.add_one.default(X)

return res


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--aot", action="store_true", help="Try to use AOT compilation", default=False
)
args = parser.parse_args()

my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)

assert my_model(X=m)[0][0] == 3.0

with torch_tensorrt.logging.debug():
trt_inputs = [m]
model_trt = torch_tensorrt.compile(
my_model,
inputs=trt_inputs,
debug=True,
min_block_size=1,
)
print("Model compiled successfully!")
print("Running inference with compiled model...")
for i in range(10):
res = model_trt(m)
assert torch.allclose(res, my_model(m)), "Results do not match!"

print("Inference successful!")
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _generate_plugin_converter(
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
use_aot_if_available: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default to true

) -> DynamoConverterImplSignature:
torch_target = getattr(getattr(torch.ops, namespace), op_name)
overload_str = overload if overload else ""
Expand All @@ -41,6 +42,16 @@ def _generate_plugin_converter(
), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter"
torch_schema = torch_target._schemas[overload_str]

use_aot_plugin = use_aot_if_available

if use_aot_if_available:
desc = QDP_REGISTRY[f"{namespace}::{op_name}"]
if desc.aot_impl_func is None:
use_aot_plugin = False
_LOGGER.debug(
f"AOT impl func not found for {namespace}::{op_name}, use JIT plugin instead"
)

def custom_kernel_converter(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -80,7 +91,7 @@ def custom_kernel_converter(
if isinstance(v, torch.fx.immutable_collections.immutable_list):
kwargs[k] = np.array(v)

layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs))
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=use_aot_plugin)
assert layer, f"{namespace}::{name} plugin layer was not able to be created"
_LOGGER.debug(
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
Expand All @@ -107,6 +118,7 @@ def generate_plugin_converter(
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
use_aot_if_available: bool = False,
) -> DynamoConverterImplSignature:
plugin_ns, plugin_name = plugin_id.split("::")
return _generate_plugin_converter(
Expand All @@ -116,4 +128,5 @@ def generate_plugin_converter(
priority=priority,
supports_dynamic_shapes=supports_dynamic_shapes,
requires_output_allocator=requires_output_allocator,
use_aot_if_available=use_aot_if_available,
)
Loading