Skip to content

Commit 203b2ca

Browse files
janselpytorchmergebot
authored andcommitted
Remove fx2trt/torch2trt backends (pytorch#93822)
These backends have been broken for some time. I tried to get them running again, but as far as I can tell they are not maintained. Installing torch_tensorrt downgrades PyTorch to 1.12. If I manually bypass that downgrade, I get import errors from inside fx2trt. Fixes that re-add these are welcome, but it might make sense to move these wrappers to the torch_tensorrt repo once PyTorch 2.0 support is added. Pull Request resolved: pytorch#93822 Approved by: https://github.com/frank-wei
1 parent 5d709af commit 203b2ca

File tree

3 files changed

+5
-170
lines changed

3 files changed

+5
-170
lines changed

benchmarks/dynamo/common.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,25 +2007,6 @@ def run(runner, args, original_dir=None):
20072007
optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython)
20082008
experiment = speedup_experiment
20092009
output_filename = "speedup_dynamo_ts.csv"
2010-
elif args.speedup_fx2trt:
2011-
optimize_ctx = torch._dynamo.optimize(
2012-
backends.fx2trt_compiler, nopython=args.nopython
2013-
)
2014-
experiment = speedup_experiment_fx2trt
2015-
output_filename = "speedups_fx2trt.csv"
2016-
runner.skip_models.update(runner.failing_fx2trt_models)
2017-
args.float32 = True
2018-
args.float16 = False
2019-
args.cosine = True
2020-
elif args.speedup_fx2trt_fp16:
2021-
optimize_ctx = torch._dynamo.optimize(
2022-
backends.fx2trt_compiler_fp16, nopython=args.nopython
2023-
)
2024-
experiment = speedup_experiment_fx2trt
2025-
output_filename = "speedups_fx2trt_fp16.csv"
2026-
args.float32 = False
2027-
args.float16 = True
2028-
args.cosine = True
20292010
elif args.prims_nvfuser:
20302011
optimize_ctx = torch._dynamo.optimize("prims_nvfuser", nopython=args.nopython)
20312012
experiment = speedup_experiment

torch/_dynamo/backends/onnxrt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,8 @@ def _call(*initial_args):
107107
return outputs
108108

109109
return _call
110+
111+
112+
@register_backend
113+
def tensorrt(gm, example_inputs):
114+
return onnxrt(gm, example_inputs, provider="TensorrtExecutionProvider")

torch/_dynamo/optimizations/backends.py

Lines changed: 0 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -41,133 +41,6 @@ def _raise_timeout(signum, frame):
4141
raise TimeoutError()
4242

4343

44-
@create_backend
45-
def fx2trt(subgraph, **kwargs):
46-
if subgraph.will_tensorrt_barf():
47-
# TensorRT fails violently with an abort() on this
48-
return None
49-
50-
from torch_tensorrt.fx.fx2trt import ( # type: ignore[import]
51-
InputTensorSpec,
52-
TRTInterpreter,
53-
)
54-
from torch_tensorrt.fx.passes.lower_basic_pass import ( # type: ignore[import]
55-
transform_setitem,
56-
)
57-
from torch_tensorrt.fx.tools.trt_splitter import ( # type: ignore[import]
58-
TRTSplitter,
59-
TRTSplitterSetting,
60-
)
61-
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer # type: ignore[import]
62-
from torch_tensorrt.fx.trt_module import TRTModule # type: ignore[import]
63-
from torch_tensorrt.fx.utils import LowerPrecision # type: ignore[import]
64-
65-
try:
66-
model = subgraph.model
67-
inputs = subgraph.example_inputs
68-
# pass rewrite
69-
model = transform_setitem(model, inputs)
70-
acc_model = acc_tracer.trace(model, inputs)
71-
# Split out unsupported ops
72-
splitter_setting = TRTSplitterSetting()
73-
splitter_setting.use_implicit_batch_dim = False
74-
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
75-
splitter.node_support_preview()
76-
split_mod = splitter()
77-
num_piece = 0
78-
for name, _ in split_mod.named_children():
79-
print(f"graph is split into {name}")
80-
num_piece += 1
81-
82-
# if the graph module is split into pieces larger than 8, we consider its perf
83-
# is not good and fall back to non-TRT
84-
if num_piece > 8:
85-
print(
86-
f"The graph module is split into {num_piece} which is large than the \
87-
threshold=8. Fall back to non-TRT module."
88-
)
89-
return None
90-
91-
if "fp16_mode" in kwargs and kwargs["fp16_mode"]:
92-
precision = LowerPrecision.FP16
93-
else:
94-
precision = LowerPrecision.FP32
95-
96-
def get_submod_inputs(mod, submod, inputs):
97-
acc_inputs = None
98-
99-
def get_input(self, inputs):
100-
nonlocal acc_inputs
101-
acc_inputs = inputs
102-
103-
handle = submod.register_forward_pre_hook(get_input)
104-
mod(*inputs)
105-
handle.remove()
106-
return acc_inputs
107-
108-
for name, _ in split_mod.named_children():
109-
if "_run_on_acc" in name:
110-
submod = getattr(split_mod, name)
111-
# print("acc=",submod.code)
112-
# Get submodule inputs for fx2trt
113-
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
114-
115-
# fx2trt replacement
116-
interp = TRTInterpreter(
117-
submod,
118-
InputTensorSpec.from_tensors(acc_inputs),
119-
explicit_batch_dimension=True,
120-
)
121-
r = interp.run(
122-
max_workspace_size=20 << 30,
123-
lower_precision=precision,
124-
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
125-
)
126-
# For profile
127-
# from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
128-
# profile_trt_module("", trt_mod, acc_inputs)
129-
trt_mod = TRTModule(*r)
130-
131-
setattr(split_mod, name, trt_mod)
132-
else:
133-
submod = getattr(split_mod, name)
134-
# print("gpu=",submod.code)
135-
return subgraph.wrap_returns(split_mod)
136-
except Exception:
137-
log.exception("FX2TRT conversion error")
138-
return None
139-
140-
141-
@create_backend
142-
def torch2trt(subgraph):
143-
if subgraph.will_tensorrt_barf():
144-
# TensorRT fails violently with an abort() on this
145-
return None
146-
147-
from torch2trt import torch2trt # type: ignore[import]
148-
149-
inputs = subgraph.example_inputs
150-
trt_mod = torch2trt(
151-
subgraph.model,
152-
inputs,
153-
max_batch_size=len(inputs[0]),
154-
strict_type_constraints=True,
155-
)
156-
return subgraph.wrap_returns(trt_mod)
157-
158-
159-
@create_backend
160-
def tensorrt(subgraph):
161-
if subgraph.will_tensorrt_barf():
162-
# TensorRT fails violently with an abort() on this
163-
return None
164-
165-
model = fx2trt(subgraph)
166-
if model is None:
167-
model = torch2trt(subgraph)
168-
return model
169-
170-
17144
def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
17245
if jit_mod is None:
17346
return None
@@ -403,27 +276,3 @@ def ipex(subgraph):
403276
except Exception:
404277
log.warning("JIT trace failed during the 'ipex' optimize process.")
405278
return model
406-
407-
408-
def fx2trt_compiler_fp16(gm: torch.fx.GraphModule, example_inputs):
409-
kwargs_fx2trt = {"fp16_mode": True}
410-
trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
411-
if trt_compiled is not None:
412-
return trt_compiled
413-
else:
414-
print(
415-
"FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
416-
)
417-
return gm.forward
418-
419-
420-
def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
421-
kwargs_fx2trt = {"fp16_mode": False}
422-
trt_compiled = fx2trt(gm, example_inputs, **kwargs_fx2trt)
423-
if trt_compiled is not None:
424-
return trt_compiled
425-
else:
426-
print(
427-
"FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
428-
)
429-
return gm.forward

0 commit comments

Comments
 (0)