Skip to content

Commit

Permalink
[MLIR][TORCH] Add E2E support for aten.index_select op
Browse files Browse the repository at this point in the history
This commit adds lowering of `aten.index_select` op.

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored and Prashant Kumar committed Dec 9, 2021
1 parent 0a0a1b4 commit 8130354
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 0 deletions.
145 changes: 145 additions & 0 deletions e2e_testing/torchscript/index_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import torch

from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

# ==============================================================================


class IndexSelectSingleIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([1], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 1, indices)

@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))


class IndexSelectTwoIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([2], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 2, indices)

@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))


class IndexSelectWholeDimensionModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([4], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 0, indices)

@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))


class IndexSelectWholeTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3], torch.float32, True),
([3], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 0, indices)

@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))


class IndexSelectDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 2, indices)

@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))


class IndexSelectDynamicInputSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([2], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 2, indices)

@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))


class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([-1], torch.int64, True),
])

def forward(self, input, indices):
return torch.index_select(input, 1, indices)

@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))
1 change: 1 addition & 0 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from . import squeeze
from . import slice_like
from . import nll_loss
from . import index_select

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Expand Down
78 changes: 78 additions & 0 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3439,6 +3439,82 @@ class ConvertAtenNumelOp : public OpConversionPattern<AtenNumelOp> {
};
} // namespace

namespace {
// Let's say we have an input tensor: initialized with some random values of
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
// integer argument dim = 1. The size of the output tensor will be [4, 2, 6].
// The approach is as follows:
//
// for i in range(input.size[0])
// for j in range(index.size[0])
// for k in range(input.size[2])
// indexValue = index[j]
// output[i,j,k] = input[i,indexValue,k]

class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op.getLoc();
Value input = adaptor.self();
Value indices = adaptor.index();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
unsigned inputRank = inputType.getRank();

int64_t dimInt;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt)))
return op->emitError("unimplemented: dim is not constant");

SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);

SmallVector<AffineExpr> resultExpr;
AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt);
SmallVector<StringRef> iteratorTypes;

for (unsigned i = 0; i < inputRank; i++) {
resultExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName());
}

auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});

Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), ValueRange{indices}, initTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
SmallVector<Value> indexTarget;
for (unsigned i = 0; i < inputRank; i++)
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));
indexTarget[dimInt] = index;
Value extractedElement =
b.create<tensor::ExtractOp>(loc, input, indexTarget);
b.create<linalg::YieldOp>(loc, extractedElement);
})
.getResult(0);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};
} // namespace

// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -3539,6 +3615,8 @@ class ConvertTorchToLinalg
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenIndexSelectOp>();
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down

0 comments on commit 8130354

Please sign in to comment.