Skip to content

Commit 4fdc6d0

Browse files
committed
slice scatter support for dynamic cases
1 parent a8ecd79 commit 4fdc6d0

File tree

2 files changed

+93
-16
lines changed

2 files changed

+93
-16
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,26 @@ def slice_scatter_decomposition(
213213
return src_tensor
214214

215215
# Ensure start, end, and step are all integers
216-
assert isinstance(start, int), "start must be an integer"
217-
assert isinstance(end, int), "end must be an integer"
218-
assert isinstance(step, int), "step must be an integer"
219-
220-
cat_tensors = []
221-
index_tensor_shape = []
222-
for i, src_each_dim in enumerate(list(src_dim)):
223-
if i != dim:
224-
index_tensor_shape.append(src_each_dim)
225-
for index in range(start, end, step):
226-
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64))
227-
index_tensor = torch.stack(cat_tensors, dim)
228-
index_tensor = index_tensor.to(device_input_tensor)
229-
index_tensor_64 = index_tensor.to(torch.int64)
230-
output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor)
231-
return output_tensor
216+
# Ensure start, end, and step are all integers
217+
assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt"
218+
assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt"
219+
assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt"
220+
221+
src_dim = src_tensor.shape
222+
# step == 0 is not a valid torch case
223+
# also src_dim should be equal to slice dimension
224+
225+
if start == 0 and end == dim_size and step == 1:
226+
return src_tensor
227+
228+
indices = torch.arange(
229+
start, end, step, device=device_input_tensor, dtype=torch.int64
230+
)
231+
index_tensor = indices.view(
232+
[-1 if i == dim else 1 for i in range(input_tensor.dim())]
233+
)
234+
index_tensor = index_tensor.expand_as(src_tensor)
235+
return torch.scatter(input_tensor.clone(), dim, index_tensor, src_tensor)
232236

233237

234238
@register_torch_trt_decomposition(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,79 @@ def forward(self, x, src, dim, start, end, step):
812812
f"Slice_scatter TRT outputs don't match with the original model.",
813813
)
814814

815+
def test_lowering_slice_scatter_dynamic_module(self):
816+
class sliceScatter(torch.nn.Module):
817+
def __init__(self, *args, **kwargs) -> None:
818+
super().__init__(*args, **kwargs)
819+
820+
def forward(self, x, src, dim, start=None, end=None, step=1):
821+
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
822+
return y
823+
824+
# Operations expected to be removed in the traced graph after decompositions
825+
expected_ops = {
826+
torch.ops.aten.scatter.src,
827+
}
828+
unexpected_ops = {torch.ops.aten.select_scatter}
829+
830+
a = torch.zeros(8, 8).cuda()
831+
b = torch.ones(8, 2).cuda()
832+
833+
# 0-D tensors for dynamic scalar values
834+
start = torch.tensor(1, dtype=torch.int64).cuda()
835+
end = torch.tensor(6, dtype=torch.int64).cuda()
836+
step = torch.tensor(1, dtype=torch.int64).cuda()
837+
838+
# Mark scalar tensors as dynamic (note: shape = ())
839+
torch._dynamo.mark_dynamic(start, (), min=1, max=3)
840+
torch._dynamo.mark_dynamic(end, (), min=4, max=6)
841+
842+
inputs = (a, b, start, end, None, step)
843+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
844+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
845+
fx_graph,
846+
inputs,
847+
expected_ops=expected_ops,
848+
unexpected_ops=unexpected_ops,
849+
min_block_size=1,
850+
)
851+
852+
self.assertEqual(
853+
len(unexpected_ops_seen),
854+
0,
855+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
856+
)
857+
858+
self.assertEqual(
859+
len(expected_ops_unseen),
860+
0,
861+
f"The following expected ops were not encountered: {expected_ops_unseen}",
862+
)
863+
864+
torch._dynamo.reset()
865+
866+
# Validate that the results between Torch and Torch-TRT are similar
867+
optimized_model = torch_tensorrt.compile(
868+
fx_graph,
869+
"torch_compile",
870+
inputs,
871+
min_block_size=1,
872+
truncate_double=True,
873+
pass_through_build_failures=True,
874+
)
875+
optimized_model_results = optimized_model(*inputs).detach().cpu()
876+
torch_model_results = fx_graph(*inputs).detach().cpu()
877+
878+
max_diff = float(
879+
torch.max(torch.abs(optimized_model_results - torch_model_results))
880+
)
881+
self.assertAlmostEqual(
882+
max_diff,
883+
0,
884+
DECIMALS_OF_AGREEMENT,
885+
f"Slice_scatter TRT outputs don't match with the original model.",
886+
)
887+
815888
def test_lowering_select_scatter_dimZero_module(self):
816889
class selectScatter(torch.nn.Module):
817890
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)