Skip to content

Commit 2b01f8b

Browse files
authored
[Tosa] : Add support for negative indices in index.tensor and index.Tensor_hacked_twin for TorchToTosa lowering. (llvm#3790)
1. Negative indices for tensor indexing is handled by wrapping around the index values by checking their values at run time. Without the fix, there was a runtime error. 2. Added a lit test to lock down the behavior. 3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown the behavior with e2e test as well. "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."
1 parent 54d9e24 commit 2b01f8b

File tree

3 files changed

+81
-38
lines changed

3 files changed

+81
-38
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+48-35
Original file line numberDiff line numberDiff line change
@@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
40934093
return success();
40944094
}
40954095

4096+
Value wrapNegativeIndices(Value index, int maxIndex, Operation *op,
4097+
ConversionPatternRewriter &rewriter) {
4098+
4099+
auto zeroValue = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
4100+
auto maxIndexValue =
4101+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4102+
4103+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4104+
4105+
auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
4106+
rewriter, op->getLoc(), indexType, maxIndexValue, index);
4107+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4108+
auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4109+
rewriter, op->getLoc(), boolType, zeroValue, index);
4110+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4111+
indexType, isNegativeIndices,
4112+
wrappedIndicesOp, index);
4113+
}
4114+
40964115
template <>
40974116
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
40984117
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
@@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
41244143

41254144
auto outType = getTypeConverter()->convertType(op.getType());
41264145

4146+
Operation *indicesTf;
4147+
41274148
// Support for multiple indexes
41284149
if (indexTensors.size() > 1) {
41294150
// t[i, i]
@@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
41574178
index);
41584179
}
41594180

4181+
index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op,
4182+
rewriter);
41604183
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
41614184
SmallVector<int64_t> indiceShapeOneDim;
41624185
for (auto shape : indexShape) {
@@ -4299,57 +4322,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
42994322
auto indicesShapeConcat = indexesShape[0];
43004323
uint64_t lastDim = indexesRank[0];
43014324
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
4302-
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
4325+
indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
43034326
rewriter, op->getLoc(),
43044327
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
43054328
indicesTfConcatTensors, lastDim);
43064329

4307-
if (!indicesTf) {
4308-
return rewriter.notifyMatchFailure(
4309-
op, "Convert TorchIndex To TfIndices fail.");
4310-
}
4311-
// do the tf gathernp algorithm with tf style indices as input.
4312-
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
4313-
indicesTf.getResult());
4330+
} else {
43144331

4315-
if (!result) {
4316-
return rewriter.notifyMatchFailure(
4317-
op, "Convert GatherNdOp fail for index tensor.");
4332+
// Single index
4333+
auto index = indexTensors[0];
4334+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4335+
auto indexShape = indexType.getShape();
4336+
// index i64 to i32 for tosa compatible
4337+
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
4338+
index = rewriter.create<tosa::CastOp>(
4339+
op->getLoc(),
4340+
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
4341+
index);
43184342
}
4319-
rewriter.replaceOp(op, {result.value()});
43204343

4321-
return success();
4322-
}
4344+
index =
4345+
wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter);
43234346

4324-
// Support for multiple index
4325-
auto index = indexTensors[0];
4326-
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4327-
auto indexShape = indexType.getShape();
4328-
// index i64 to i32 for tosa compatible
4329-
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
4330-
index = rewriter.create<tosa::CastOp>(
4331-
op->getLoc(),
4332-
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
4333-
}
4334-
4335-
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
4336-
SmallVector<int64_t> indicesShape;
4337-
for (auto shape : indexShape) {
4338-
indicesShape.push_back(shape);
4347+
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
4348+
SmallVector<int64_t> indicesShape;
4349+
for (auto shape : indexShape) {
4350+
indicesShape.push_back(shape);
4351+
}
4352+
indicesShape.push_back(1);
4353+
indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4354+
rewriter, op->getLoc(),
4355+
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
4356+
rewriter.getDenseI64ArrayAttr(indicesShape));
43394357
}
4340-
indicesShape.push_back(1);
4341-
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4342-
rewriter, op->getLoc(),
4343-
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
4344-
rewriter.getDenseI64ArrayAttr(indicesShape));
43454358

43464359
if (!indicesTf) {
43474360
return rewriter.notifyMatchFailure(op,
43484361
"Convert TorchIndex To TfIndices fail.");
43494362
}
43504363
// do the tf gathernp algorithm with tf style indices as input.
43514364
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
4352-
indicesTf.getResult());
4365+
indicesTf->getResult(0));
43534366

43544367
if (!result) {
43554368
return rewriter.notifyMatchFailure(

projects/pt1/e2e_testing/xfail_sets.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1698,15 +1698,13 @@
16981698
"ArangeStartOutModule_basic",
16991699
"ScatterSrcStaticModule_basic",
17001700
# Runtime op verification: Out of bounds access
1701-
"IndexTensorNegativeIndexModule_basic",
17021701
"ReduceAllDimEmpty_basic",
17031702
}
17041703

17051704
FX_IMPORTER_TOSA_CRASHING_SET = {
17061705
"ScatterSrcModule_basic",
17071706
"ScatterSrcStaticModule_basic",
17081707
"HBC_basic",
1709-
"IndexTensorNegativeIndexModule_basic",
17101708
"InterpolateDynamicModule_scales_recompute_bilinear",
17111709
"InterpolateDynamicModule_sizes_bilinear",
17121710
"InterpolateDynamicModule_sizes_nearest",
@@ -2162,6 +2160,7 @@
21622160
"HardswishRandomModule_basic",
21632161
"HardtanhBackward_basic",
21642162
"IndexTensorMultiIndexStaticModule_basic",
2163+
"IndexTensorNegativeIndexModule_basic",
21652164
"IndexTensorStaticModule_basic",
21662165
"IscloseStaticModuleTrue_basic",
21672166
"IscloseStaticModule_basic",
@@ -3635,7 +3634,6 @@
36353634
"IndexPutImpl3DFloatNonAccumulateModule_basic",
36363635
"IndexPutImplIndexWithNoneModule_basic",
36373636
"IndexSelectRank0IdxModule_basic",
3638-
"IndexTensorNegativeIndexModule_basic",
36393637
"InterpolateDynamicModule_sizes_bilinear",
36403638
"InterpolateDynamicModule_sizes_nearest",
36413639
"InterpolateStaticModule_scales_bilinear_align_corners",

test/Conversion/TorchToTosa/basic.mlir

+32
Original file line numberDiff line numberDiff line change
@@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t
21312131
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
21322132
return %0 : !torch.vtensor<[2,3,4,4],f32>
21332133
}
2134+
2135+
// -----
2136+
2137+
// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
2138+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>,
2139+
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
2140+
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
2141+
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
2142+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor<i64>
2143+
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<i64>) -> tensor<i32>
2144+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
2145+
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
2146+
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
2147+
// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
2148+
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
2149+
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
2150+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
2151+
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
2152+
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
2153+
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
2154+
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
2155+
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
2156+
// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
2157+
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
2158+
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
2159+
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64>
2160+
2161+
func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
2162+
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
2163+
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
2164+
return %1 : !torch.vtensor<[4,2],si64>
2165+
}

0 commit comments

Comments
 (0)