@@ -41,133 +41,6 @@ def _raise_timeout(signum, frame):
41
41
raise TimeoutError ()
42
42
43
43
44
- @create_backend
45
- def fx2trt (subgraph , ** kwargs ):
46
- if subgraph .will_tensorrt_barf ():
47
- # TensorRT fails violently with an abort() on this
48
- return None
49
-
50
- from torch_tensorrt .fx .fx2trt import ( # type: ignore[import]
51
- InputTensorSpec ,
52
- TRTInterpreter ,
53
- )
54
- from torch_tensorrt .fx .passes .lower_basic_pass import ( # type: ignore[import]
55
- transform_setitem ,
56
- )
57
- from torch_tensorrt .fx .tools .trt_splitter import ( # type: ignore[import]
58
- TRTSplitter ,
59
- TRTSplitterSetting ,
60
- )
61
- from torch_tensorrt .fx .tracer .acc_tracer import acc_tracer # type: ignore[import]
62
- from torch_tensorrt .fx .trt_module import TRTModule # type: ignore[import]
63
- from torch_tensorrt .fx .utils import LowerPrecision # type: ignore[import]
64
-
65
- try :
66
- model = subgraph .model
67
- inputs = subgraph .example_inputs
68
- # pass rewrite
69
- model = transform_setitem (model , inputs )
70
- acc_model = acc_tracer .trace (model , inputs )
71
- # Split out unsupported ops
72
- splitter_setting = TRTSplitterSetting ()
73
- splitter_setting .use_implicit_batch_dim = False
74
- splitter = TRTSplitter (acc_model , inputs , settings = splitter_setting )
75
- splitter .node_support_preview ()
76
- split_mod = splitter ()
77
- num_piece = 0
78
- for name , _ in split_mod .named_children ():
79
- print (f"graph is split into { name } " )
80
- num_piece += 1
81
-
82
- # if the graph module is split into pieces larger than 8, we consider its perf
83
- # is not good and fall back to non-TRT
84
- if num_piece > 8 :
85
- print (
86
- f"The graph module is split into { num_piece } which is large than the \
87
- threshold=8. Fall back to non-TRT module."
88
- )
89
- return None
90
-
91
- if "fp16_mode" in kwargs and kwargs ["fp16_mode" ]:
92
- precision = LowerPrecision .FP16
93
- else :
94
- precision = LowerPrecision .FP32
95
-
96
- def get_submod_inputs (mod , submod , inputs ):
97
- acc_inputs = None
98
-
99
- def get_input (self , inputs ):
100
- nonlocal acc_inputs
101
- acc_inputs = inputs
102
-
103
- handle = submod .register_forward_pre_hook (get_input )
104
- mod (* inputs )
105
- handle .remove ()
106
- return acc_inputs
107
-
108
- for name , _ in split_mod .named_children ():
109
- if "_run_on_acc" in name :
110
- submod = getattr (split_mod , name )
111
- # print("acc=",submod.code)
112
- # Get submodule inputs for fx2trt
113
- acc_inputs = get_submod_inputs (split_mod , submod , inputs )
114
-
115
- # fx2trt replacement
116
- interp = TRTInterpreter (
117
- submod ,
118
- InputTensorSpec .from_tensors (acc_inputs ),
119
- explicit_batch_dimension = True ,
120
- )
121
- r = interp .run (
122
- max_workspace_size = 20 << 30 ,
123
- lower_precision = precision ,
124
- # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
125
- )
126
- # For profile
127
- # from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
128
- # profile_trt_module("", trt_mod, acc_inputs)
129
- trt_mod = TRTModule (* r )
130
-
131
- setattr (split_mod , name , trt_mod )
132
- else :
133
- submod = getattr (split_mod , name )
134
- # print("gpu=",submod.code)
135
- return subgraph .wrap_returns (split_mod )
136
- except Exception :
137
- log .exception ("FX2TRT conversion error" )
138
- return None
139
-
140
-
141
- @create_backend
142
- def torch2trt (subgraph ):
143
- if subgraph .will_tensorrt_barf ():
144
- # TensorRT fails violently with an abort() on this
145
- return None
146
-
147
- from torch2trt import torch2trt # type: ignore[import]
148
-
149
- inputs = subgraph .example_inputs
150
- trt_mod = torch2trt (
151
- subgraph .model ,
152
- inputs ,
153
- max_batch_size = len (inputs [0 ]),
154
- strict_type_constraints = True ,
155
- )
156
- return subgraph .wrap_returns (trt_mod )
157
-
158
-
159
- @create_backend
160
- def tensorrt (subgraph ):
161
- if subgraph .will_tensorrt_barf ():
162
- # TensorRT fails violently with an abort() on this
163
- return None
164
-
165
- model = fx2trt (subgraph )
166
- if model is None :
167
- model = torch2trt (subgraph )
168
- return model
169
-
170
-
171
44
def tvm_compile (jit_mod , example_inputs , log_file = None , ** kwargs ):
172
45
if jit_mod is None :
173
46
return None
@@ -403,27 +276,3 @@ def ipex(subgraph):
403
276
except Exception :
404
277
log .warning ("JIT trace failed during the 'ipex' optimize process." )
405
278
return model
406
-
407
-
408
- def fx2trt_compiler_fp16 (gm : torch .fx .GraphModule , example_inputs ):
409
- kwargs_fx2trt = {"fp16_mode" : True }
410
- trt_compiled = fx2trt (gm , example_inputs , ** kwargs_fx2trt )
411
- if trt_compiled is not None :
412
- return trt_compiled
413
- else :
414
- print (
415
- "FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
416
- )
417
- return gm .forward
418
-
419
-
420
- def fx2trt_compiler (gm : torch .fx .GraphModule , example_inputs ):
421
- kwargs_fx2trt = {"fp16_mode" : False }
422
- trt_compiled = fx2trt (gm , example_inputs , ** kwargs_fx2trt )
423
- if trt_compiled is not None :
424
- return trt_compiled
425
- else :
426
- print (
427
- "FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
428
- )
429
- return gm .forward
0 commit comments