diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 8037858151..c5c191bb7b 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -201,34 +201,29 @@ def slice_scatter_decomposition( start = get_positive_dim(start, input_tensor.shape[dim]) if end is None: # Ensure end is int end = dim_size - end = get_positive_dim(end, input_tensor.shape[dim]) + end = ( + get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end + ) if step is None: step = 1 - src_dim = src_tensor.shape # step == 0 is not a valid torch case - # also src_dim should be equal to slice dimension - if start == 0 and end == dim_size and step == 1: return src_tensor # Ensure start, end, and step are all integers - assert isinstance(start, int), "start must be an integer" - assert isinstance(end, int), "end must be an integer" - assert isinstance(step, int), "step must be an integer" - - cat_tensors = [] - index_tensor_shape = [] - for i, src_each_dim in enumerate(list(src_dim)): - if i != dim: - index_tensor_shape.append(src_each_dim) - for index in range(start, end, step): - cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) - index_tensor = torch.stack(cat_tensors, dim) - index_tensor = index_tensor.to(device_input_tensor) - index_tensor_64 = index_tensor.to(torch.int64) - output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) - return output_tensor + assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt" + assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt" + assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt" + + indices = torch.arange( + start, end, step, device=device_input_tensor, dtype=torch.int64 + ) + index_tensor = indices.view( + [-1 if i == dim else 1 for i in range(input_tensor.dim())] + ) + index_tensor = index_tensor.expand_as(src_tensor) + return torch.scatter(input_tensor, dim, index_tensor, src_tensor) @register_torch_trt_decomposition( diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index b63e0f3bf7..e7c7b33672 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -812,6 +812,40 @@ def forward(self, x, src, dim, start, end, step): f"Slice_scatter TRT outputs don't match with the original model.", ) + def test_lowering_slice_scatter_dynamic_module(self): + class sliceScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src): + y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1) + return y + + fx_graph = torch.fx.symbolic_trace(sliceScatter()) + + dim1 = torch.export.Dim("dim1", min=8, max=10) + dynamic_shapes = { + "x": [torch.export.Dim.STATIC, dim1], + "src": [torch.export.Dim.STATIC, None], + } + inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda()) + exported_program = torch.export.export( + sliceScatter(), tuple(inputs), dynamic_shapes=dynamic_shapes + ) + fx_graph = exported_program.module() + inputs = [ + torch_tensorrt.Input( + min_shape=[8, 8], opt_shape=[8, 10], max_shape=[8, 10] + ), + torch_tensorrt.Input(min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]), + ] + torch._dynamo.reset() + trt_model = torch_tensorrt.dynamo.compile(exported_program, inputs) + inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda()) + torch.testing.assert_close( + trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL + ) + def test_lowering_select_scatter_dimZero_module(self): class selectScatter(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: