Skip to content

Commit

Permalink
[ONNX] add support for tfidfvectorizer (#3553)
Browse files Browse the repository at this point in the history
1-d/2-d input and output
implemented based on the description and example test cases in
https://github.com/onnx/onnx/blob/main/docs/Operators.md#TfIdfVectorizer
and some notes from

https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_tfidf_vectorizer.py#L128

---------

Co-authored-by: zjgarvey <[email protected]>
  • Loading branch information
aldesilv and zjgarvey authored Aug 12, 2024
1 parent d3695a9 commit a4ba02e
Show file tree
Hide file tree
Showing 2 changed files with 353 additions and 0 deletions.
302 changes: 302 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4305,6 +4305,308 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
uniqueResults[1], uniqueResults[2]});
return success();
});
patterns.onOp(
"TfIdfVectorizer", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<int64_t> ngram_counts;
llvm::SmallVector<int64_t> ngram_indexes;
llvm::SmallVector<int64_t> pool_int64s;
std::string mode;
int64_t min_gram_length;
int64_t max_gram_length;
int64_t max_skip_count;
Value input;
Torch::ValueTensorType resultType;

if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) ||
binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) ||
binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) ||
binder.customOpNameStringAttr(mode, "mode", "") ||
binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) ||
binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) ||
binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) ||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure();

if (mode != "TF")
return rewriter.notifyMatchFailure(binder.op,
"TF mode supported only");
if (pool_int64s.size() == 0)
return rewriter.notifyMatchFailure(
binder.op, "pool_int64s empty, only integers supported");
auto inputType = dyn_cast<Torch::ValueTensorType>(input.getType());
auto inputSizes =
dyn_cast<Torch::ValueTensorType>(input.getType()).getSizes();
SmallVector<int64_t> inputShape(inputSizes);
bool is_2d = (inputShape.size() > 1) ? true : false;
if (is_2d && inputShape[0] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
binder.op, "input batch dimension cannot be dynamic");
int batch_size = (is_2d) ? inputShape[0] : 1;

Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(false));

auto intType = rewriter.getType<Torch::IntType>();
Value loopConditionTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
Type loopIndexType = intType;
// create a zero tensor for output
SmallVector<int64_t> resultShape(resultType.getSizes());
int64_t rank = resultShape.size();
SmallVector<Value> zerosShapeValues;
for (int j = 0; j < rank; j++) {
Value dimSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(resultShape[j]));
zerosShapeValues.push_back(dimSize);
}
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
Value output = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), resultType, zerosShapeList, none, none, none,
none);

Value batchSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(batch_size));
auto batchLoop = rewriter.create<Torch::PrimLoopOp>(
binder.getLoc(), TypeRange({output.getType()}), batchSize,
loopConditionTrue, ValueRange({output}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *batchLoopBody = rewriter.createBlock(
&batchLoop.getRegion(), batchLoop.getRegion().begin(),
TypeRange({loopIndexType, output.getType()}),
{binder.getLoc(), binder.getLoc()});
Value batchValue = batchLoopBody->getArgument(0);
Value output = batchLoopBody->getArgument(1);
Value outputForBatch = output;
Value inputSequence = input;
if (is_2d) {
// get input sequence from input (ex: [[0,1],[2,3]] -> [[0,1]] ->
// [0,1])
SmallVector<int64_t> inputSequenceShape;
inputSequenceShape.push_back(1);
inputSequenceShape.push_back(inputShape[1]);
auto inputSequenceType = rewriter.getType<Torch::ValueTensorType>(
inputSequenceShape, inputType.getOptionalDtype());
Value batchPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), batchValue, one);
inputSequence = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), inputSequenceType, input, /*dim=*/zero,
batchValue, batchPlusOne, one);
inputSequence = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{inputShape[1]},
inputType.getOptionalDtype()),
inputSequence, zero);

SmallVector<int64_t> outputForBatchShape;
outputForBatchShape.push_back(1);
outputForBatchShape.push_back(resultShape[1]);
auto outputForBatchType = rewriter.getType<Torch::ValueTensorType>(
outputForBatchShape, resultType.getOptionalDtype());
outputForBatch = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), outputForBatchType, output,
/*dim=*/zero, batchValue, batchPlusOne, one);
outputForBatch = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{resultShape[1]},
resultType.getOptionalDtype()),
outputForBatch, zero);
}
// ngram_counts[j] records the starting position of ngrams within the
// pool_int64's of length j+1. The loop below is iterating through the
// different n-gram sizes
// ngram_i keeps track of which ngram we are looking at in the pool.
// The frequency of this ngram will be stored in the output tensor at
// the position ngram_indexes[ngram_i]
int ngram_i = 0;
for (int j = 0; j < (int)ngram_counts.size(); j++) {
int ngram_length = j + 1;
int start_idx = ngram_counts[j];
int end_idx = (j + 1) < (int)ngram_counts.size()
? ngram_counts[j + 1]
: pool_int64s.size();
if (j + 1 < min_gram_length || j + 1 > max_gram_length) {
// progress the ngram counter for the skipped (j+1)grams
ngram_i += (end_idx - start_idx) / ngram_length;
continue;
}

Value ngramLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(ngram_length));
for (int start = start_idx; start < end_idx;
start += ngram_length, ngram_i++) {
Value count = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
// for 1-grams, there is no skipping (skip = gap between
// consecutive values in the n-gram pulled from the input
// sequence), so we default to skip_count_bound = 1 in that case
// to avoid repeating the same count multiple times.
int skip_count_bound =
(ngram_length == 1) ? 1 : (max_skip_count + 1);
Value skipCountBound = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), intType,
rewriter.getI64IntegerAttr(skip_count_bound));
// given a n-gram to search for, and the input sequence to search
// in, we need to count how many times that n-gram appears in the
// input for each skip between 0 and max_skip_count (inclusive).
auto skipLoop = rewriter.create<Torch::PrimLoopOp>(
binder.getLoc(), TypeRange({count.getType()}), skipCountBound,
loopConditionTrue, ValueRange({count}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *skipLoopBody = rewriter.createBlock(
&skipLoop.getRegion(), skipLoop.getRegion().begin(),
TypeRange({loopIndexType, count.getType()}),
{binder.getLoc(), binder.getLoc()});
Value skipCount = skipLoopBody->getArgument(0);
Value skipCountPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), skipCount, one);
count = skipLoopBody->getArgument(1);

// max_start_index =
// inputSizes.back() - ((ngram_length - 1) * (skip_count + 1));
// the index one higher than the last possible start index
// without the input ngram going out of bounds
Value seqLen = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), intType,
rewriter.getI64IntegerAttr(inputSizes.back()));
Value ngramLengthMinusOne =
rewriter.create<Torch::AtenSubIntOp>(binder.getLoc(),
ngramLength, one);
Value ngramSkipLength = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), ngramLengthMinusOne, skipCountPlusOne);
Value maxStartIndex = rewriter.create<Torch::AtenSubIntOp>(
binder.getLoc(), seqLen, ngramSkipLength);
// This loop will extract each n-gram with the given skip_count
// from the input sequence from start input index, and increment
// the count if the n-gram matches the one gotten from the
// pool_int64s
auto countLoop = rewriter.create<Torch::PrimLoopOp>(
binder.getLoc(), TypeRange({count.getType()}),
maxStartIndex, loopConditionTrue, ValueRange({count}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *countLoopBody = rewriter.createBlock(
&countLoop.getRegion(), countLoop.getRegion().begin(),
TypeRange({loopIndexType, count.getType()}),
{binder.getLoc(), binder.getLoc()});

Value startInputIdx = countLoopBody->getArgument(0);
count = countLoopBody->getArgument(1);

// extract input ngram and compare to pool ngram
Torch::BaseTensorType inputSequenceType =
cast<Torch::BaseTensorType>(inputSequence.getType());
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType =
inputSequenceType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes),
inputSequenceType.getOptionalDtype());
Value foundNgram = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
for (int i = 0; i < ngram_length; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
i));
selectIndex = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), selectIndex, skipCountPlusOne);
selectIndex = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), selectIndex, startInputIdx);
Value inputExtract =
rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, inputSequence,
zero, selectIndex);
Value inputNgram_i = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
inputExtract);

Value poolNgram_i = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(pool_int64s[start + i]));
Value isEqual = rewriter.create<Torch::AtenEqIntOp>(
binder.getLoc(), inputNgram_i, poolNgram_i);
isEqual = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), isEqual);
foundNgram = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isEqual, foundNgram);
}

count = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), count, foundNgram);
rewriter.create<Torch::PrimLoopConditionOp>(
binder.getLoc(), loopConditionTrue, ValueRange({count}));
}
count = countLoop.getResult(0);
rewriter.create<Torch::PrimLoopConditionOp>(
binder.getLoc(), loopConditionTrue, ValueRange({count}));
}
count = skipLoop.getResult(0);
// insert count "tf" into output
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
binder.getLoc(), count);
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{countFloat});
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
SmallVector<int64_t> countShape{1};
auto countType = rewriter.getType<Torch::ValueTensorType>(
countShape, resultType.getOptionalDtype());
Value countTensor = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(), countType, dataList, /*dtype=*/cstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);

Value insertStart = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(ngram_indexes[ngram_i]));
Value insertEnd = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), insertStart, one);
outputForBatch = rewriter.create<Torch::AtenSliceScatterOp>(
binder.getLoc(), outputForBatch.getType(), outputForBatch,
countTensor,
/*dim=*/zero, insertStart, insertEnd, /*step=*/one);
} // start
}
if (is_2d) {
Value batchPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), batchValue, one);
outputForBatch = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(),
rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1, resultShape[1]},
resultType.getDtype()),
outputForBatch, zero);
output = rewriter.create<Torch::AtenSliceScatterOp>(
binder.getLoc(), resultType, output, outputForBatch,
/*dim=*/zero, batchValue, batchPlusOne, /*step=*/one);
} else {
output = outputForBatch;
}
rewriter.create<Torch::PrimLoopConditionOp>(
binder.getLoc(), loopConditionTrue, ValueRange({output}));
}
output = batchLoop.getResult(0);
rewriter.replaceOp(binder.op, output);
return success();
});
patterns.onOp(
"Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Expand Down
Loading

0 comments on commit a4ba02e

Please sign in to comment.