Skip to content

Commit 07e4643

Browse files
remove torchvision dependency from build, optional for test (#3598)
1 parent 6041890 commit 07e4643

20 files changed

+317
-88
lines changed

.github/scripts/install-torch-tensorrt.sh

100644100755
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#set -exou pipefail
22
set -x
33

4-
TORCH_TORCHVISION=$(grep "^torch" ${PWD}/py/requirements.txt)
4+
TORCH=$(grep "^torch>" ${PWD}/py/requirements.txt)
5+
TORCHVISION=$(grep "^torchvision" ${PWD}/py/requirements.txt)
56
INDEX_URL=https://download.pytorch.org/whl/${CHANNEL}/${CU_VERSION}
67
PLATFORM=$(python -c "import sys; print(sys.platform)")
78

@@ -12,8 +13,10 @@ if [[ $(uname -m) == "aarch64" ]]; then
1213
fi
1314

1415
# Install all the dependencies required for Torch-TensorRT
15-
pip install --pre ${TORCH_TORCHVISION} --index-url ${INDEX_URL}
1616
pip install --pre -r ${PWD}/tests/py/requirements.txt
17+
pip install --force-reinstall --pre ${TORCH} --index-url ${INDEX_URL}
18+
pip install --force-reinstall --pre ${TORCHVISION} --index-url ${INDEX_URL}
19+
1720

1821
# Install Torch-TensorRT
1922
if [[ ${PLATFORM} == win32 ]]; then

packaging/pre_build_script.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.26.0/bazelis
4242
pip uninstall -y torch torchvision
4343

4444
if [[ ${IS_JETPACK} == true ]]; then
45-
# install torch 2.7 torchvision 0.22.0 for jp6.2
46-
pip install torch==2.7.0 torchvision==0.22.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
45+
# install torch 2.7 for jp6.2
46+
pip install torch==2.7.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
4747
else
48-
TORCH_TORCHVISION=$(grep "^torch" py/requirements.txt)
48+
TORCH=$(grep "^torch>" py/requirements.txt)
4949
INDEX_URL=https://download.pytorch.org/whl/${CHANNEL}/${CU_VERSION}
5050

5151
# Install all the dependencies required for Torch-TensorRT
52-
pip install --force-reinstall --pre ${TORCH_TORCHVISION} --index-url ${INDEX_URL}
52+
pip install --force-reinstall --pre ${TORCH} --index-url ${INDEX_URL}
5353
fi
5454

5555
export TORCH_BUILD_NUMBER=$(python -c "import torch, urllib.parse as ul; print(ul.quote_plus(torch.__version__))")

packaging/pre_build_script_windows.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ if [[ "${CU_VERSION::4}" < "cu12" ]]; then
2525
pyproject.toml
2626
fi
2727

28-
TORCH_TORCHVISION=$(grep "^torch" py/requirements.txt)
28+
TORCH=$(grep "^torch>" py/requirements.txt)
2929
INDEX_URL=https://download.pytorch.org/whl/${CHANNEL}/${CU_VERSION}
3030

3131
# Install all the dependencies required for Torch-TensorRT
3232
pip uninstall -y torch torchvision
33-
pip install --force-reinstall --pre ${TORCH_TORCHVISION} --index-url ${INDEX_URL}
33+
pip install --force-reinstall --pre ${TORCH} --index-url ${INDEX_URL}
3434

3535
export CUDA_HOME="$(echo ${CUDA_PATH} | sed -e 's#\\#\/#g')"
3636
export TORCH_INSTALL_PATH="$(python -c "import torch, os; print(os.path.dirname(torch.__file__))" | sed -e 's#\\#\/#g')"

tests/modules/hub.py

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import importlib
12
import json
23
import os
34

45
import custom_models as cm
5-
import timm
66
import torch
7-
import torchvision.models as models
7+
8+
if importlib.util.find_spec("torchvision"):
9+
import timm
10+
import torchvision.models as models
811

912
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
1013

@@ -19,53 +22,7 @@
1922
# Downloads all model files again if manifest file is not present
2023
MANIFEST_FILE = "model_manifest.json"
2124

22-
models = {
23-
"alexnet": {"model": models.alexnet(pretrained=True), "path": "both"},
24-
"vgg16": {"model": models.vgg16(pretrained=True), "path": "both"},
25-
"squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"},
26-
"densenet": {"model": models.densenet161(pretrained=True), "path": "both"},
27-
"inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"},
28-
"shufflenet": {"model": models.shufflenet_v2_x1_0(pretrained=True), "path": "both"},
29-
"mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"},
30-
"resnext50_32x4d": {
31-
"model": models.resnext50_32x4d(pretrained=True),
32-
"path": "both",
33-
},
34-
"wideresnet50_2": {
35-
"model": models.wide_resnet50_2(pretrained=True),
36-
"path": "both",
37-
},
38-
"mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"},
39-
"resnet18": {
40-
"model": torch.hub.load("pytorch/vision:v0.9.0", "resnet18", pretrained=True),
41-
"path": "both",
42-
},
43-
"resnet50": {
44-
"model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True),
45-
"path": "both",
46-
},
47-
"efficientnet_b0": {
48-
"model": timm.create_model("efficientnet_b0", pretrained=True),
49-
"path": "script",
50-
},
51-
"vit": {
52-
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
53-
"path": "script",
54-
},
55-
"pooling": {"model": cm.Pool(), "path": "trace"},
56-
"module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"},
57-
"loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"},
58-
"loop_fallback_no_eval": {"model": cm.LoopFallbackNoEval(), "path": "script"},
59-
"conditional": {"model": cm.FallbackIf(), "path": "script"},
60-
"inplace_op_if": {"model": cm.FallbackInplaceOPIf(), "path": "script"},
61-
"standard_tensor_input": {"model": cm.StandardTensorInput(), "path": "script"},
62-
"tuple_input": {"model": cm.TupleInput(), "path": "script"},
63-
"list_input": {"model": cm.ListInput(), "path": "script"},
64-
"tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"},
65-
"list_input_output": {"model": cm.ListInputOutput(), "path": "script"},
66-
"list_input_tuple_output": {"model": cm.ListInputTupleOutput(), "path": "script"},
67-
# "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
68-
}
25+
models = {}
6926

7027

7128
def get(n, m, manifest):
@@ -120,6 +77,68 @@ def download_models(version_matches, manifest):
12077

12178

12279
def main():
80+
if not importlib.util.find_spec("torchvision"):
81+
print(f"torchvision is not installed, skip models download")
82+
return
83+
84+
models = {
85+
"alexnet": {"model": models.alexnet(pretrained=True), "path": "both"},
86+
"vgg16": {"model": models.vgg16(pretrained=True), "path": "both"},
87+
"squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"},
88+
"densenet": {"model": models.densenet161(pretrained=True), "path": "both"},
89+
"inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"},
90+
"shufflenet": {
91+
"model": models.shufflenet_v2_x1_0(pretrained=True),
92+
"path": "both",
93+
},
94+
"mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"},
95+
"resnext50_32x4d": {
96+
"model": models.resnext50_32x4d(pretrained=True),
97+
"path": "both",
98+
},
99+
"wideresnet50_2": {
100+
"model": models.wide_resnet50_2(pretrained=True),
101+
"path": "both",
102+
},
103+
"mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"},
104+
"resnet18": {
105+
"model": torch.hub.load(
106+
"pytorch/vision:v0.9.0", "resnet18", pretrained=True
107+
),
108+
"path": "both",
109+
},
110+
"resnet50": {
111+
"model": torch.hub.load(
112+
"pytorch/vision:v0.9.0", "resnet50", pretrained=True
113+
),
114+
"path": "both",
115+
},
116+
"efficientnet_b0": {
117+
"model": timm.create_model("efficientnet_b0", pretrained=True),
118+
"path": "script",
119+
},
120+
"vit": {
121+
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
122+
"path": "script",
123+
},
124+
"pooling": {"model": cm.Pool(), "path": "trace"},
125+
"module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"},
126+
"loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"},
127+
"loop_fallback_no_eval": {"model": cm.LoopFallbackNoEval(), "path": "script"},
128+
"conditional": {"model": cm.FallbackIf(), "path": "script"},
129+
"inplace_op_if": {"model": cm.FallbackInplaceOPIf(), "path": "script"},
130+
"standard_tensor_input": {"model": cm.StandardTensorInput(), "path": "script"},
131+
"tuple_input": {"model": cm.TupleInput(), "path": "script"},
132+
"list_input": {"model": cm.ListInput(), "path": "script"},
133+
"tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"},
134+
"list_input_output": {"model": cm.ListInputOutput(), "path": "script"},
135+
"list_input_tuple_output": {
136+
"model": cm.ListInputTupleOutput(),
137+
"path": "script",
138+
},
139+
# "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
140+
}
141+
123142
manifest = None
124143
version_matches = False
125144
manifest_exists = False

tests/py/core/test_classes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
import unittest
33
from typing import Dict
44

5+
import tensorrt as trt
56
import torch
67
import torch_tensorrt
78
import torch_tensorrt as torchtrt
8-
import torchvision.models as models
99
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule
1010

11-
import tensorrt as trt
12-
1311

1412
class TestDevice(unittest.TestCase):
1513
def test_from_string_constructor(self):

tests/py/dynamo/models/test_dyn_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# type: ignore
2-
2+
import importlib
33
import unittest
44

55
import pytest
6-
import timm
76
import torch
87
import torch_tensorrt as torchtrt
98
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
@@ -175,6 +174,9 @@ def forward(self, x):
175174
)
176175

177176

177+
@unittest.skipIf(
178+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
179+
)
178180
@pytest.mark.unit
179181
def test_resnet_dynamic(ir):
180182
"""

tests/py/dynamo/models/test_engine_cache.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# type: ignore
2+
import importlib
23
import os
34
import shutil
45
import unittest
@@ -7,7 +8,6 @@
78
import pytest
89
import torch
910
import torch_tensorrt as torch_trt
10-
import torchvision.models as models
1111
from torch.testing._internal.common_utils import TestCase
1212
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
1313
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
@@ -16,6 +16,9 @@
1616

1717
assertions = unittest.TestCase()
1818

19+
if importlib.util.find_spec("torchvision"):
20+
import torchvision.models as models
21+
1922

2023
class MyEngineCache(BaseEngineCache):
2124
def __init__(
@@ -57,6 +60,9 @@ def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
5760

5861

5962
class TestHashFunction(TestCase):
63+
@unittest.skipIf(
64+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
65+
)
6066
def test_reexport_is_equal(self):
6167
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
6268
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
@@ -94,6 +100,9 @@ def test_reexport_is_equal(self):
94100

95101
self.assertEqual(hash1, hash2)
96102

103+
@unittest.skipIf(
104+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
105+
)
97106
def test_input_shape_change_is_not_equal(self):
98107
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
99108
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
@@ -131,6 +140,9 @@ def test_input_shape_change_is_not_equal(self):
131140

132141
self.assertNotEqual(hash1, hash2)
133142

143+
@unittest.skipIf(
144+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
145+
)
134146
def test_engine_settings_is_not_equal(self):
135147
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
136148
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
@@ -177,6 +189,9 @@ def test_engine_settings_is_not_equal(self):
177189

178190
class TestEngineCache(TestCase):
179191
@pytest.mark.xfail
192+
@unittest.skipIf(
193+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
194+
)
180195
def test_dynamo_compile_with_default_disk_engine_cache(self):
181196
model = models.resnet18(pretrained=True).eval().to("cuda")
182197
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
@@ -254,6 +269,9 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
254269
not torch_trt.ENABLED_FEATURES.refit,
255270
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
256271
)
272+
@unittest.skipIf(
273+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
274+
)
257275
def test_dynamo_compile_with_custom_engine_cache(self):
258276
model = models.resnet18(pretrained=True).eval().to("cuda")
259277

@@ -322,6 +340,9 @@ def test_dynamo_compile_with_custom_engine_cache(self):
322340
not torch_trt.ENABLED_FEATURES.refit,
323341
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
324342
)
343+
@unittest.skipIf(
344+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
345+
)
325346
def test_dynamo_compile_change_input_shape(self):
326347
"""Runs compilation 3 times, the cache should miss each time"""
327348
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -358,6 +379,9 @@ def test_dynamo_compile_change_input_shape(self):
358379
not torch_trt.ENABLED_FEATURES.refit,
359380
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
360381
)
382+
@unittest.skipIf(
383+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
384+
)
361385
@pytest.mark.xfail
362386
def test_torch_compile_with_default_disk_engine_cache(self):
363387
# Custom Engine Cache
@@ -430,6 +454,9 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
430454
msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms",
431455
)
432456

457+
@unittest.skipIf(
458+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
459+
)
433460
def test_torch_compile_with_custom_engine_cache(self):
434461
# Custom Engine Cache
435462
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -501,6 +528,9 @@ def test_torch_compile_with_custom_engine_cache(self):
501528
not torch_trt.ENABLED_FEATURES.refit,
502529
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
503530
)
531+
@unittest.skipIf(
532+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
533+
)
504534
def test_torch_trt_compile_change_input_shape(self):
505535
# Custom Engine Cache
506536
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -631,6 +661,9 @@ def forward(self, c, d):
631661
not torch_trt.ENABLED_FEATURES.refit,
632662
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
633663
)
664+
@unittest.skipIf(
665+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
666+
)
634667
def test_caching_small_model(self):
635668
from torch_tensorrt.dynamo._refit import refit_module_weights
636669

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import unittest
55

66
import pytest
7-
import timm
87
import torch
98
import torch.nn.functional as F
109
import torch_tensorrt as torchtrt
11-
import torchvision.models as models
1210
from torch import nn
1311
from torch_tensorrt.dynamo._compiler import (
1412
convert_exported_program_to_serialized_trt_engine,

0 commit comments

Comments
 (0)