Skip to content

Commit 1b8fd58

Browse files
committed
:Using torch.export workflow since compile is showing error in tensor guard
1 parent 4fdc6d0 commit 1b8fd58

File tree

2 files changed

+24
-72
lines changed

2 files changed

+24
-72
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,30 +201,21 @@ def slice_scatter_decomposition(
201201
start = get_positive_dim(start, input_tensor.shape[dim])
202202
if end is None: # Ensure end is int
203203
end = dim_size
204-
end = get_positive_dim(end, input_tensor.shape[dim])
204+
end = (
205+
get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end
206+
)
205207
if step is None:
206208
step = 1
207209

208-
src_dim = src_tensor.shape
209210
# step == 0 is not a valid torch case
210-
# also src_dim should be equal to slice dimension
211-
212211
if start == 0 and end == dim_size and step == 1:
213212
return src_tensor
214213

215-
# Ensure start, end, and step are all integers
216214
# Ensure start, end, and step are all integers
217215
assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt"
218216
assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt"
219217
assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt"
220218

221-
src_dim = src_tensor.shape
222-
# step == 0 is not a valid torch case
223-
# also src_dim should be equal to slice dimension
224-
225-
if start == 0 and end == dim_size and step == 1:
226-
return src_tensor
227-
228219
indices = torch.arange(
229220
start, end, step, device=device_input_tensor, dtype=torch.int64
230221
)

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 21 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -817,72 +817,33 @@ class sliceScatter(torch.nn.Module):
817817
def __init__(self, *args, **kwargs) -> None:
818818
super().__init__(*args, **kwargs)
819819

820-
def forward(self, x, src, dim, start=None, end=None, step=1):
821-
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
820+
def forward(self, x, src):
821+
y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1)
822822
return y
823823

824-
# Operations expected to be removed in the traced graph after decompositions
825-
expected_ops = {
826-
torch.ops.aten.scatter.src,
827-
}
828-
unexpected_ops = {torch.ops.aten.select_scatter}
829-
830-
a = torch.zeros(8, 8).cuda()
831-
b = torch.ones(8, 2).cuda()
832-
833-
# 0-D tensors for dynamic scalar values
834-
start = torch.tensor(1, dtype=torch.int64).cuda()
835-
end = torch.tensor(6, dtype=torch.int64).cuda()
836-
step = torch.tensor(1, dtype=torch.int64).cuda()
837-
838-
# Mark scalar tensors as dynamic (note: shape = ())
839-
torch._dynamo.mark_dynamic(start, (), min=1, max=3)
840-
torch._dynamo.mark_dynamic(end, (), min=4, max=6)
841-
842-
inputs = (a, b, start, end, None, step)
843824
fx_graph = torch.fx.symbolic_trace(sliceScatter())
844-
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
845-
fx_graph,
846-
inputs,
847-
expected_ops=expected_ops,
848-
unexpected_ops=unexpected_ops,
849-
min_block_size=1,
850-
)
851825

852-
self.assertEqual(
853-
len(unexpected_ops_seen),
854-
0,
855-
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
856-
)
857-
858-
self.assertEqual(
859-
len(expected_ops_unseen),
860-
0,
861-
f"The following expected ops were not encountered: {expected_ops_unseen}",
826+
dim1 = torch.export.Dim("dim1", min=8, max=10)
827+
dynamic_shapes = {
828+
"x": [torch.export.Dim.STATIC, dim1],
829+
"src": [torch.export.Dim.STATIC, None],
830+
}
831+
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
832+
exported_program = torch.export.export(
833+
sliceScatter(), tuple(inputs), dynamic_shapes=dynamic_shapes
862834
)
863-
835+
fx_graph = exported_program.module()
836+
inputs = [
837+
torch_tensorrt.Input(
838+
min_shape=[8, 8], opt_shape=[8, 10], max_shape=[8, 10]
839+
),
840+
torch_tensorrt.Input(min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]),
841+
]
864842
torch._dynamo.reset()
865-
866-
# Validate that the results between Torch and Torch-TRT are similar
867-
optimized_model = torch_tensorrt.compile(
868-
fx_graph,
869-
"torch_compile",
870-
inputs,
871-
min_block_size=1,
872-
truncate_double=True,
873-
pass_through_build_failures=True,
874-
)
875-
optimized_model_results = optimized_model(*inputs).detach().cpu()
876-
torch_model_results = fx_graph(*inputs).detach().cpu()
877-
878-
max_diff = float(
879-
torch.max(torch.abs(optimized_model_results - torch_model_results))
880-
)
881-
self.assertAlmostEqual(
882-
max_diff,
883-
0,
884-
DECIMALS_OF_AGREEMENT,
885-
f"Slice_scatter TRT outputs don't match with the original model.",
843+
trt_model = torch_tensorrt.dynamo.compile(exported_program, inputs)
844+
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
845+
torch.testing.assert_close(
846+
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
886847
)
887848

888849
def test_lowering_select_scatter_dimZero_module(self):

0 commit comments

Comments
 (0)