Skip to content

Cross compile guard #3486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py/requirements.txt
Original file line number Diff line number Diff line change
@@ -5,4 +5,5 @@ pybind11==2.6.2
torch>=2.8.0.dev,<2.9.0
torchvision>=0.22.0.dev,<0.23.0
--extra-index-url https://pypi.ngc.nvidia.com
pyyaml
pyyaml
dllist
3 changes: 2 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
@@ -301,6 +301,7 @@ def compile(
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


@needs_cross_compile
def cross_compile_for_windows(
module: torch.nn.Module,
file_path: str,
30 changes: 28 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,10 @@
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar

from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt._utils import (
check_cross_compile_trt_win_lib,
sanitized_torch_version,
)

from packaging import version

@@ -15,6 +18,7 @@
"dynamo_frontend",
"fx_frontend",
"refit",
"windows_cross_compile",
],
)

@@ -38,9 +42,15 @@
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
_FX_FE_AVAIL = True
_REFIT_AVAIL = True
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib()

ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
_TS_FE_AVAIL,
_TORCHTRT_RT_AVAIL,
_DYNAMO_FE_AVAIL,
_FX_FE_AVAIL,
_REFIT_AVAIL,
_WINDOWS_CROSS_COMPILE,
)


@@ -80,6 +90,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return wrapper


def needs_cross_compile(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.windows_cross_compile:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"Windows cross compilation feature is not available"
)

return not_implemented(*args, **kwargs)

return wrapper


T = TypeVar("T")


15 changes: 15 additions & 0 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
import sys
import platform

import torch

@@ -9,3 +11,16 @@ def sanitized_torch_version() -> Any:
if ".nv" not in torch.__version__
else torch.__version__.split(".nv")[0]
)


def check_cross_compile_trt_win_lib() -> bool:
# cross compile feature is only available on linux
# build engine on linux and run on windows
import dllist

if sys.platform.startswith("linux"):
loaded_libs = dllist.dllist()
target_lib = "libnvinfer_builder_resource_win.so.*"
if target_lib in loaded_libs:
return True
return False
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults, partitioning
from torch_tensorrt.dynamo._DryRunTracker import (
@@ -50,6 +51,7 @@
logger = logging.getLogger(__name__)


@needs_cross_compile
def cross_compile_for_windows(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Sequence[Any]]] = None,
@@ -1223,6 +1225,7 @@ def convert_exported_program_to_serialized_trt_engine(
return serialized_engine


@needs_cross_compile
def save_cross_compiled_exported_program(
gm: torch.fx.GraphModule,
file_path: str,
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ requires = [
"pybind11==2.6.2",
"numpy",
"sympy",
"dllist",
]
build-backend = "setuptools.build_meta"

@@ -63,6 +64,7 @@ dependencies = [
"packaging>=23",
"numpy",
"typing-extensions>=4.7.0",
"dllist",
]

dynamic = ["version"]
13 changes: 13 additions & 0 deletions tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt.dynamo.utils import get_model_device
from torch_tensorrt._utils import check_cross_compile_trt_win_lib

from ..testing_utilities import DECIMALS_OF_AGREEMENT

@@ -17,6 +18,10 @@ class TestCrossCompileSaveForWindows(TestCase):
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@unittest.skipIf(
not (check_cross_compile_trt_win_lib()),
"TRT windows lib for cross compile not found",
)
@pytest.mark.unit
def test_cross_compile_for_windows(self):
class Add(torch.nn.Module):
@@ -41,6 +46,10 @@ def forward(self, a, b):
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@unittest.skipIf(
not (check_cross_compile_trt_win_lib()),
"TRT windows lib for cross compile not found",
)
@pytest.mark.unit
def test_dynamo_cross_compile_for_windows(self):
class Add(torch.nn.Module):
@@ -69,6 +78,10 @@ def forward(self, a, b):
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@unittest.skipIf(
not (check_cross_compile_trt_win_lib()),
"TRT windows lib for cross compile not found",
)
@pytest.mark.unit
def test_dynamo_cross_compile_for_windows_cpu_offload(self):
class Add(torch.nn.Module):
Loading