Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] Fix onnx-to-torch lowering for flatten shape #2834

Merged
merged 5 commits into from
Feb 5, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto message = llvm::formatv("unimplemented support for the given "
"dtype conversion (onnx 'type' = {0})",
dtypeIntOnnx);
llvm::errs() << message << "\n";
auto y = rewriter.notifyMatchFailure(binder.op, message);

return y;
Expand Down Expand Up @@ -1412,16 +1411,24 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();

auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
llvm::SmallVector<int64_t> shape(operandTy.getSizes());
int64_t rank = shape.size();

// If axis is negative, count from the right instead of left
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
if (axis < 0)
axis = rank + axis;

Value collapsedRight;
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
newling marked this conversation as resolved.
Show resolved Hide resolved
binder.op->getContext());
// We collapse in the dimensions to the right of the axis.
for (int i = axis + 1; i < rank; ++i) {
shape[axis] *= shape[i];
}

shape.resize(axis + 1, 1);

auto baseType = rewriter.getType<Torch::ValueTensorType>(
shape, operandTy.getDtype());
Value collapsedRight;
if (axis >= rank) {
// If the right range is empty, add a dim of size 1 to the
// right side of the shape:
Expand Down Expand Up @@ -1455,10 +1462,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

// Otherwise, collapse the left range into a single dimension:
// torch._prims.collapse(cr, 0, axis - 1)
Value axisLess1Const = rewriter.create<Torch::ConstantIntOp>(
newling marked this conversation as resolved.
Show resolved Hide resolved
binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1));
rewriter.replaceOpWithNewOp<Torch::PrimsCollapseOp>(
binder.op, resultType, collapsedRight, zero, axisLess1Const);
rewriter.replaceOp(binder.op, collapsedRight);
return success();
});
patterns.onOp("Floor", 13,
Expand Down
Loading