Skip to content

slice scatter support for dynamic cases #3513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -201,34 +201,29 @@ def slice_scatter_decomposition(
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None: # Ensure end is int
end = dim_size
end = get_positive_dim(end, input_tensor.shape[dim])
end = (
get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end
)
if step is None:
step = 1

src_dim = src_tensor.shape
# step == 0 is not a valid torch case
# also src_dim should be equal to slice dimension

if start == 0 and end == dim_size and step == 1:
return src_tensor

# Ensure start, end, and step are all integers
assert isinstance(start, int), "start must be an integer"
assert isinstance(end, int), "end must be an integer"
assert isinstance(step, int), "step must be an integer"

cat_tensors = []
index_tensor_shape = []
for i, src_each_dim in enumerate(list(src_dim)):
if i != dim:
index_tensor_shape.append(src_each_dim)
for index in range(start, end, step):
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64))
index_tensor = torch.stack(cat_tensors, dim)
index_tensor = index_tensor.to(device_input_tensor)
index_tensor_64 = index_tensor.to(torch.int64)
output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor)
return output_tensor
assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt"
assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt"
assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt"

indices = torch.arange(
start, end, step, device=device_input_tensor, dtype=torch.int64
)
index_tensor = indices.view(
[-1 if i == dim else 1 for i in range(input_tensor.dim())]
)
index_tensor = index_tensor.expand_as(src_tensor)
return torch.scatter(input_tensor, dim, index_tensor, src_tensor)


@register_torch_trt_decomposition(
34 changes: 34 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
@@ -812,6 +812,40 @@ def forward(self, x, src, dim, start, end, step):
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dynamic_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src):
y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1)
return y

fx_graph = torch.fx.symbolic_trace(sliceScatter())

dim1 = torch.export.Dim("dim1", min=8, max=10)
dynamic_shapes = {
"x": [torch.export.Dim.STATIC, dim1],
"src": [torch.export.Dim.STATIC, None],
}
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
exported_program = torch.export.export(
sliceScatter(), tuple(inputs), dynamic_shapes=dynamic_shapes
)
fx_graph = exported_program.module()
inputs = [
torch_tensorrt.Input(
min_shape=[8, 8], opt_shape=[8, 10], max_shape=[8, 10]
),
torch_tensorrt.Input(min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]),
]
torch._dynamo.reset()
trt_model = torch_tensorrt.dynamo.compile(exported_program, inputs)
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
torch.testing.assert_close(
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
)

def test_lowering_select_scatter_dimZero_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None: