diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 9aec90425f56..dcb6e6763e56 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -4305,6 +4305,308 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( uniqueResults[1], uniqueResults[2]}); return success(); }); + patterns.onOp( + "TfIdfVectorizer", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + llvm::SmallVector ngram_counts; + llvm::SmallVector ngram_indexes; + llvm::SmallVector pool_int64s; + std::string mode; + int64_t min_gram_length; + int64_t max_gram_length; + int64_t max_skip_count; + Value input; + Torch::ValueTensorType resultType; + + if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) || + binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) || + binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) || + binder.customOpNameStringAttr(mode, "mode", "") || + binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) || + binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) || + binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) || + binder.tensorOperand(input) || binder.tensorResultType(resultType)) + return failure(); + + if (mode != "TF") + return rewriter.notifyMatchFailure(binder.op, + "TF mode supported only"); + if (pool_int64s.size() == 0) + return rewriter.notifyMatchFailure( + binder.op, "pool_int64s empty, only integers supported"); + auto inputType = dyn_cast(input.getType()); + auto inputSizes = + dyn_cast(input.getType()).getSizes(); + SmallVector inputShape(inputSizes); + bool is_2d = (inputShape.size() > 1) ? true : false; + if (is_2d && inputShape[0] == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + binder.op, "input batch dimension cannot be dynamic"); + int batch_size = (is_2d) ? inputShape[0] : 1; + + Value none = rewriter.create(binder.getLoc()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + + auto intType = rewriter.getType(); + Value loopConditionTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + Type loopIndexType = intType; + // create a zero tensor for output + SmallVector resultShape(resultType.getSizes()); + int64_t rank = resultShape.size(); + SmallVector zerosShapeValues; + for (int j = 0; j < rank; j++) { + Value dimSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(resultShape[j])); + zerosShapeValues.push_back(dimSize); + } + Value zerosShapeList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + zerosShapeValues); + Value output = rewriter.create( + binder.getLoc(), resultType, zerosShapeList, none, none, none, + none); + + Value batchSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(batch_size)); + auto batchLoop = rewriter.create( + binder.getLoc(), TypeRange({output.getType()}), batchSize, + loopConditionTrue, ValueRange({output})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *batchLoopBody = rewriter.createBlock( + &batchLoop.getRegion(), batchLoop.getRegion().begin(), + TypeRange({loopIndexType, output.getType()}), + {binder.getLoc(), binder.getLoc()}); + Value batchValue = batchLoopBody->getArgument(0); + Value output = batchLoopBody->getArgument(1); + Value outputForBatch = output; + Value inputSequence = input; + if (is_2d) { + // get input sequence from input (ex: [[0,1],[2,3]] -> [[0,1]] -> + // [0,1]) + SmallVector inputSequenceShape; + inputSequenceShape.push_back(1); + inputSequenceShape.push_back(inputShape[1]); + auto inputSequenceType = rewriter.getType( + inputSequenceShape, inputType.getOptionalDtype()); + Value batchPlusOne = rewriter.create( + binder.getLoc(), batchValue, one); + inputSequence = rewriter.create( + binder.getLoc(), inputSequenceType, input, /*dim=*/zero, + batchValue, batchPlusOne, one); + inputSequence = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{inputShape[1]}, + inputType.getOptionalDtype()), + inputSequence, zero); + + SmallVector outputForBatchShape; + outputForBatchShape.push_back(1); + outputForBatchShape.push_back(resultShape[1]); + auto outputForBatchType = rewriter.getType( + outputForBatchShape, resultType.getOptionalDtype()); + outputForBatch = rewriter.create( + binder.getLoc(), outputForBatchType, output, + /*dim=*/zero, batchValue, batchPlusOne, one); + outputForBatch = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{resultShape[1]}, + resultType.getOptionalDtype()), + outputForBatch, zero); + } + // ngram_counts[j] records the starting position of ngrams within the + // pool_int64's of length j+1. The loop below is iterating through the + // different n-gram sizes + // ngram_i keeps track of which ngram we are looking at in the pool. + // The frequency of this ngram will be stored in the output tensor at + // the position ngram_indexes[ngram_i] + int ngram_i = 0; + for (int j = 0; j < (int)ngram_counts.size(); j++) { + int ngram_length = j + 1; + int start_idx = ngram_counts[j]; + int end_idx = (j + 1) < (int)ngram_counts.size() + ? ngram_counts[j + 1] + : pool_int64s.size(); + if (j + 1 < min_gram_length || j + 1 > max_gram_length) { + // progress the ngram counter for the skipped (j+1)grams + ngram_i += (end_idx - start_idx) / ngram_length; + continue; + } + + Value ngramLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(ngram_length)); + for (int start = start_idx; start < end_idx; + start += ngram_length, ngram_i++) { + Value count = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + // for 1-grams, there is no skipping (skip = gap between + // consecutive values in the n-gram pulled from the input + // sequence), so we default to skip_count_bound = 1 in that case + // to avoid repeating the same count multiple times. + int skip_count_bound = + (ngram_length == 1) ? 1 : (max_skip_count + 1); + Value skipCountBound = rewriter.create( + binder.getLoc(), intType, + rewriter.getI64IntegerAttr(skip_count_bound)); + // given a n-gram to search for, and the input sequence to search + // in, we need to count how many times that n-gram appears in the + // input for each skip between 0 and max_skip_count (inclusive). + auto skipLoop = rewriter.create( + binder.getLoc(), TypeRange({count.getType()}), skipCountBound, + loopConditionTrue, ValueRange({count})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *skipLoopBody = rewriter.createBlock( + &skipLoop.getRegion(), skipLoop.getRegion().begin(), + TypeRange({loopIndexType, count.getType()}), + {binder.getLoc(), binder.getLoc()}); + Value skipCount = skipLoopBody->getArgument(0); + Value skipCountPlusOne = rewriter.create( + binder.getLoc(), skipCount, one); + count = skipLoopBody->getArgument(1); + + // max_start_index = + // inputSizes.back() - ((ngram_length - 1) * (skip_count + 1)); + // the index one higher than the last possible start index + // without the input ngram going out of bounds + Value seqLen = rewriter.create( + binder.getLoc(), intType, + rewriter.getI64IntegerAttr(inputSizes.back())); + Value ngramLengthMinusOne = + rewriter.create(binder.getLoc(), + ngramLength, one); + Value ngramSkipLength = rewriter.create( + binder.getLoc(), ngramLengthMinusOne, skipCountPlusOne); + Value maxStartIndex = rewriter.create( + binder.getLoc(), seqLen, ngramSkipLength); + // This loop will extract each n-gram with the given skip_count + // from the input sequence from start input index, and increment + // the count if the n-gram matches the one gotten from the + // pool_int64s + auto countLoop = rewriter.create( + binder.getLoc(), TypeRange({count.getType()}), + maxStartIndex, loopConditionTrue, ValueRange({count})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *countLoopBody = rewriter.createBlock( + &countLoop.getRegion(), countLoop.getRegion().begin(), + TypeRange({loopIndexType, count.getType()}), + {binder.getLoc(), binder.getLoc()}); + + Value startInputIdx = countLoopBody->getArgument(0); + count = countLoopBody->getArgument(1); + + // extract input ngram and compare to pool ngram + Torch::BaseTensorType inputSequenceType = + cast(inputSequence.getType()); + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = + inputSequenceType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), + inputSequenceType.getOptionalDtype()); + Value foundNgram = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + for (int i = 0; i < ngram_length; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + i)); + selectIndex = rewriter.create( + binder.getLoc(), selectIndex, skipCountPlusOne); + selectIndex = rewriter.create( + binder.getLoc(), selectIndex, startInputIdx); + Value inputExtract = + rewriter.create( + binder.getLoc(), selectResultType, inputSequence, + zero, selectIndex); + Value inputNgram_i = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputExtract); + + Value poolNgram_i = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(pool_int64s[start + i])); + Value isEqual = rewriter.create( + binder.getLoc(), inputNgram_i, poolNgram_i); + isEqual = rewriter.create( + binder.getLoc(), isEqual); + foundNgram = rewriter.create( + binder.getLoc(), isEqual, foundNgram); + } + + count = rewriter.create( + binder.getLoc(), count, foundNgram); + rewriter.create( + binder.getLoc(), loopConditionTrue, ValueRange({count})); + } + count = countLoop.getResult(0); + rewriter.create( + binder.getLoc(), loopConditionTrue, ValueRange({count})); + } + count = skipLoop.getResult(0); + // insert count "tf" into output + Value countFloat = rewriter.create( + binder.getLoc(), count); + Value dataList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{countFloat}); + Value cstDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + SmallVector countShape{1}; + auto countType = rewriter.getType( + countShape, resultType.getOptionalDtype()); + Value countTensor = rewriter.create( + binder.getLoc(), countType, dataList, /*dtype=*/cstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + + Value insertStart = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(ngram_indexes[ngram_i])); + Value insertEnd = rewriter.create( + binder.getLoc(), insertStart, one); + outputForBatch = rewriter.create( + binder.getLoc(), outputForBatch.getType(), outputForBatch, + countTensor, + /*dim=*/zero, insertStart, insertEnd, /*step=*/one); + } // start + } + if (is_2d) { + Value batchPlusOne = rewriter.create( + binder.getLoc(), batchValue, one); + outputForBatch = rewriter.create( + binder.getLoc(), + rewriter.getType( + llvm::SmallVector{1, resultShape[1]}, + resultType.getDtype()), + outputForBatch, zero); + output = rewriter.create( + binder.getLoc(), resultType, output, outputForBatch, + /*dim=*/zero, batchValue, batchPlusOne, /*step=*/one); + } else { + output = outputForBatch; + } + rewriter.create( + binder.getLoc(), loopConditionTrue, ValueRange({output})); + } + output = batchLoop.getResult(0); + rewriter.replaceOp(binder.op, output); + return success(); + }); patterns.onOp( "Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 984d32d5361e..3c37cc9c530f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1915,6 +1915,57 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2],si32> } +// ----- + +// CHECK-LABEL : func.func @test_tfidfvectorizer_tf_batch_only_bigrams_skip5 + func.func @test_tfidfvectorizer_tf_batch_onlybigrams_skip5(%arg0: !torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK : %[[output_init:.*]] = torch.aten.zeros %[[x0:.*]], %[[none_0:.*]], %[[none_0]], %[[none_0]], %[[none_0]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,7],f32> + // CHECK : %[[int2_1:.*]] = torch.constant.int 2 + // CHECK : %[[batch_loop:.*]] = torch.prim.Loop %[[int2_1]], %[[true:.*]], init(%[[output_init]]) { + // CHECK : ^bb0(%[[arg1:.*]]: !torch.int, %[[arg2:.*]]: !torch.vtensor<[2,7],f32>): + // CHECK : %[[x3:.*]] = torch.aten.add.int %[[arg1]], %[[int1:.*]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[x4:.*]] = torch.aten.slice.Tensor %arg0, %[[int0:.*]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,6],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,6],si32> + // CHECK : %[[inputbatch:.*]] = torch.aten.squeeze.dim %[[x4]], %[[int0]] : !torch.vtensor<[1,6],si32>, !torch.int -> !torch.vtensor<[6],si32> + // CHECK : %[[x6:.*]] = torch.aten.slice.Tensor %[[arg2]], %[[int0]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,7],f32> + // CHECK : %[[outputbatch:.*]] = torch.aten.squeeze.dim %[[x6]], %[[int0]] : !torch.vtensor<[1,7],f32>, !torch.int -> !torch.vtensor<[7],f32> + // CHECK : %[[int2_2:.*]] = torch.constant.int 2 + // CHECK : %[[int0_3:.*]] = torch.constant.int 0 + // CHECK : %[[max_skip_count:.*]] = torch.constant.int 6 + // CHECK : %[[skip_loop:.*]] = torch.prim.Loop %[[max_skip_count]], %[[true]], init(%[[int0_3]]) { + // CHECK : ^bb0(%[[arg3:.*]]: !torch.int, %[[arg4:.*]]: !torch.int): + // CHECK : %[[x29:.*]] = torch.aten.add.int %[[arg3]], %[[int1]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[int6_12:.*]] = torch.constant.int 6 + // CHECK : %[[x30:.*]] = torch.aten.sub.int %[[int2_2]], %[[int1]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[x31:.*]] = torch.aten.mul.int %[[x30]], %[[x29]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[x32:.*]] = torch.aten.sub.int %[[int6_12]], %[[x31]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[count_loop:.*]] = torch.prim.Loop %[[x32]], %[[true]], init(%[[arg4]]) { + // CHECK : ^bb0(%[[arg5:.*]]: !torch.int, %[[arg6:.*]]: !torch.int): + // CHECK : %[[input_2gram0:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position0:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> + // CHECK : %[[inputval0:.*]] = torch.aten.item %[[input_2gram0]] : !torch.vtensor<[1],si32> -> !torch.int + // CHECK : %[[eq0:.*]] = torch.aten.eq.int %[[inputval0]], %[[first2gram0:.*]] : !torch.int, !torch.int -> !torch.bool + // CHECK : %[[eq0int:.*]] = torch.aten.Int.bool %[[eq0]] : !torch.bool -> !torch.int + // CHECK : %[[alleq0:.*]] = torch.aten.mul.int %[[eq0int]], %[[int1_13:.*]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[input_2gram1:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position1:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> + // CHECK : %[[inputval1:.*]] = torch.aten.item %[[input_2gram1]] : !torch.vtensor<[1],si32> -> !torch.int + // CHECK : %[[eq1:.*]] = torch.aten.eq.int %[[inputval1]], %[[first2gram1:.*]] : !torch.int, !torch.int -> !torch.bool + // CHECK : %[[eq1int:.*]] = torch.aten.Int.bool %[[eq1]] : !torch.bool -> !torch.int + // CHECK : %[[alleq1:.*]] = torch.aten.mul.int %[[eq1int]], %[[alleq0]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[newcount:.*]] = torch.aten.add.int %[[arg6]], %[[alleq1]] : !torch.int, !torch.int -> !torch.int + // CHECK : torch.prim.Loop.condition %[[true]], iter(%[[newcount]] : !torch.int) + // CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int + // CHECK : torch.prim.Loop.condition %[[true]], iter(%[[skip_loop]] : !torch.int) + // CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int + // CHECK : %[[count_insert0:.*]] = torch.aten.slice_scatter %[[outputbatch]], %[[counttensor0:.*]], %[[int0]], %[[ngram_indices0:.*]], %[[ngram_indices0plus1:.*]], %[[int1]] : !torch.vtensor<[7],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[7],f32> + // the skip_loop and count_loops repeat for each ngram in the pool_int64t's, then after the last ngram frequency is counted... + // CHECK : %[[unqueezecounts:.*]] = torch.aten.unsqueeze % [[lastcountinsert:.*]], %[[int0]] : !torch.vtensor<[7],f32>, !torch.int -> !torch.vtensor<[1,7],f32> + // CHECK : %[[count_into_output:.*]] = torch.aten.slice_scatter %[[arg2]], %[[unsqueezecounts]], %[[int0]], %[[arg1]], %[[arg1plus1:.*]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.vtensor<[1,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,7],f32> + // CHECK : torch.prim.Loop.condition %[[true]], iter(%[[count_into_output]] : !torch.vtensor<[2,7],f32>) + // CHECK : } : (!torch.int, !torch.bool, !torch.vtensor<[2,7],f32>) -> !torch.vtensor<[2,7],f32> + // CHECK : return %[[batchloop]] : !torch.vtensor<[2,7],f32> + %0 = torch.operator "onnx.TfIdfVectorizer"(%arg0) {torch.onnx.max_gram_length = 2 : si64, torch.onnx.max_skip_count = 5 : si64, torch.onnx.min_gram_length = 2 : si64, torch.onnx.mode = "TF", torch.onnx.ngram_counts = [0 : si64, 4 : si64], torch.onnx.ngram_indexes = [0 : si64, 1 : si64, 2 : si64, 3 : si64, 4 : si64, 5 : si64, 6 : si64], torch.onnx.pool_int64s = [2 : si64, 3 : si64, 5 : si64, 4 : si64, 5 : si64, 6 : si64, 7 : si64, 8 : si64, 6 : si64, 7 : si64]} : (!torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32> + return %0 : !torch.vtensor<[2,7],f32> + } + // ----- // CHECK-LABEL: func.func @test_range_int16_type