diff --git a/e2e_testing/torchscript/index_select.py b/e2e_testing/torchscript/index_select.py new file mode 100644 index 00000000000..c4693cc1a42 --- /dev/null +++ b/e2e_testing/torchscript/index_select.py @@ -0,0 +1,145 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + + +class IndexSelectSingleIdxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([1], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 1, indices) + +@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule()) +def IndexSelectSingleIdxModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.tensor([2])) + + +class IndexSelectTwoIdxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([2], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 2, indices) + +@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule()) +def IndexSelectTwoIdxModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4])) + + +class IndexSelectWholeDimensionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([4], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 0, indices) + +@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule()) +def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3])) + + +class IndexSelectWholeTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3], torch.float32, True), + ([3], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 0, indices) + +@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule()) +def IndexSelectWholeTensorModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3), torch.tensor([0, 1, 2])) + + +class IndexSelectDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 2, indices) + +@register_test_case(module_factory=lambda: IndexSelectDynamicModule()) +def IndexSelectDynamicModulebasic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4])) + + +class IndexSelectDynamicInputSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([2], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 2, indices) + +@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule()) +def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2])) + + +class IndexSelectDynamicIndexSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([-1], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, 1, indices) + +@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule()) +def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2])) diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 61b9ded03b1..2a07c76f1ea 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -44,6 +44,7 @@ from . import squeeze from . import slice_like from . import nll_loss +from . import index_select def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 30a21a6f2ca..b5ea1232f3a 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3439,6 +3439,82 @@ class ConvertAtenNumelOp : public OpConversionPattern { }; } // namespace +namespace { +// Let's say we have an input tensor: initialized with some random values of +// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an +// integer argument dim = 1. The size of the output tensor will be [4, 2, 6]. +// The approach is as follows: +// +// for i in range(input.size[0]) +// for j in range(index.size[0]) +// for k in range(input.size[2]) +// indexValue = index[j] +// output[i,j,k] = input[i,indexValue,k] + +class ConvertAtenIndexSelectOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + Value input = adaptor.self(); + Value indices = adaptor.index(); + RankedTensorType inputType = input.getType().cast(); + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Type elementType = resultType.getElementType(); + unsigned inputRank = inputType.getRank(); + + int64_t dimInt; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) + return op->emitError("unimplemented: dim is not constant"); + + SmallVector resultShape = getTensorSizes(rewriter, loc, input); + resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0]; + Value initTensor = + rewriter.create(loc, resultShape, elementType); + + SmallVector resultExpr; + AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt); + SmallVector iteratorTypes; + + for (unsigned i = 0; i < inputRank; i++) { + resultExpr.push_back(rewriter.getAffineDimExpr(i)); + iteratorTypes.push_back(getParallelIteratorTypeName()); + } + + auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); + + Value finalRes = + rewriter + .create( + loc, initTensor.getType(), ValueRange{indices}, initTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value index = rewriter.create( + loc, rewriter.getIndexType(), args[0]); + SmallVector indexTarget; + for (unsigned i = 0; i < inputRank; i++) + indexTarget.push_back(b.create(loc, i)); + indexTarget[dimInt] = index; + Value extractedElement = + b.create(loc, input, indexTarget); + b.create(loc, extractedElement); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, finalRes); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -3539,6 +3615,8 @@ class ConvertTorchToLinalg patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))