@@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
4093
4093
return success ();
4094
4094
}
4095
4095
4096
+ Value wrapNegativeIndices (Value index, int maxIndex, Operation *op,
4097
+ ConversionPatternRewriter &rewriter) {
4098
+
4099
+ auto zeroValue = tosa::getConstTensor<int32_t >(rewriter, op, 0 , {}).value ();
4100
+ auto maxIndexValue =
4101
+ tosa::getConstTensor<int32_t >(rewriter, op, maxIndex, {}).value ();
4102
+
4103
+ auto indexType = dyn_cast<RankedTensorType>(index .getType ());
4104
+
4105
+ auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
4106
+ rewriter, op->getLoc (), indexType, maxIndexValue, index );
4107
+ auto boolType = indexType.clone (rewriter.getIntegerType (1 ));
4108
+ auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4109
+ rewriter, op->getLoc (), boolType, zeroValue, index );
4110
+ return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc (),
4111
+ indexType, isNegativeIndices,
4112
+ wrappedIndicesOp, index );
4113
+ }
4114
+
4096
4115
template <>
4097
4116
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
4098
4117
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
@@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
4124
4143
4125
4144
auto outType = getTypeConverter ()->convertType (op.getType ());
4126
4145
4146
+ Operation *indicesTf;
4147
+
4127
4148
// Support for multiple indexes
4128
4149
if (indexTensors.size () > 1 ) {
4129
4150
// t[i, i]
@@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
4157
4178
index );
4158
4179
}
4159
4180
4181
+ index = wrapNegativeIndices (index , inputTensorType.getShape ()[i], op,
4182
+ rewriter);
4160
4183
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
4161
4184
SmallVector<int64_t > indiceShapeOneDim;
4162
4185
for (auto shape : indexShape) {
@@ -4299,57 +4322,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
4299
4322
auto indicesShapeConcat = indexesShape[0 ];
4300
4323
uint64_t lastDim = indexesRank[0 ];
4301
4324
indicesShapeConcat.push_back (indicesTfConcatTensors.size ());
4302
- auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
4325
+ indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
4303
4326
rewriter, op->getLoc (),
4304
4327
GetTypeFromTensorShape (indicesShapeConcat, rewriter.getIntegerType (32 )),
4305
4328
indicesTfConcatTensors, lastDim);
4306
4329
4307
- if (!indicesTf) {
4308
- return rewriter.notifyMatchFailure (
4309
- op, " Convert TorchIndex To TfIndices fail." );
4310
- }
4311
- // do the tf gathernp algorithm with tf style indices as input.
4312
- auto result = tosa::convertGatherNdOp (rewriter, op, outType, input,
4313
- indicesTf.getResult ());
4330
+ } else {
4314
4331
4315
- if (!result) {
4316
- return rewriter.notifyMatchFailure (
4317
- op, " Convert GatherNdOp fail for index tensor." );
4332
+ // Single index
4333
+ auto index = indexTensors[0 ];
4334
+ auto indexType = dyn_cast<RankedTensorType>(index .getType ());
4335
+ auto indexShape = indexType.getShape ();
4336
+ // index i64 to i32 for tosa compatible
4337
+ if (indexType.getElementType () != rewriter.getIntegerType (32 )) {
4338
+ index = rewriter.create <tosa::CastOp>(
4339
+ op->getLoc (),
4340
+ RankedTensorType::get (indexShape, rewriter.getIntegerType (32 )),
4341
+ index );
4318
4342
}
4319
- rewriter.replaceOp (op, {result.value ()});
4320
4343
4321
- return success ();
4322
- }
4344
+ index =
4345
+ wrapNegativeIndices ( index , inputTensorType. getShape ()[ 0 ], op, rewriter);
4323
4346
4324
- // Support for multiple index
4325
- auto index = indexTensors[0 ];
4326
- auto indexType = dyn_cast<RankedTensorType>(index .getType ());
4327
- auto indexShape = indexType.getShape ();
4328
- // index i64 to i32 for tosa compatible
4329
- if (indexType.getElementType () != rewriter.getIntegerType (32 )) {
4330
- index = rewriter.create <tosa::CastOp>(
4331
- op->getLoc (),
4332
- RankedTensorType::get (indexShape, rewriter.getIntegerType (32 )), index );
4333
- }
4334
-
4335
- // Expand last dim of index to tf indices [2,3] -> [2,3,1]
4336
- SmallVector<int64_t > indicesShape;
4337
- for (auto shape : indexShape) {
4338
- indicesShape.push_back (shape);
4347
+ // Expand last dim of index to tf indices [2,3] -> [2,3,1]
4348
+ SmallVector<int64_t > indicesShape;
4349
+ for (auto shape : indexShape) {
4350
+ indicesShape.push_back (shape);
4351
+ }
4352
+ indicesShape.push_back (1 );
4353
+ indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4354
+ rewriter, op->getLoc (),
4355
+ RankedTensorType::get (indicesShape, rewriter.getIntegerType (32 )), index ,
4356
+ rewriter.getDenseI64ArrayAttr (indicesShape));
4339
4357
}
4340
- indicesShape.push_back (1 );
4341
- auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4342
- rewriter, op->getLoc (),
4343
- RankedTensorType::get (indicesShape, rewriter.getIntegerType (32 )), index ,
4344
- rewriter.getDenseI64ArrayAttr (indicesShape));
4345
4358
4346
4359
if (!indicesTf) {
4347
4360
return rewriter.notifyMatchFailure (op,
4348
4361
" Convert TorchIndex To TfIndices fail." );
4349
4362
}
4350
4363
// do the tf gathernp algorithm with tf style indices as input.
4351
4364
auto result = tosa::convertGatherNdOp (rewriter, op, outType, input,
4352
- indicesTf. getResult ());
4365
+ indicesTf-> getResult (0 ));
4353
4366
4354
4367
if (!result) {
4355
4368
return rewriter.notifyMatchFailure (
0 commit comments