Skip to content

Commit

Permalink
Convert float exponents in aten.pow to int when possible (#4029)
Browse files Browse the repository at this point in the history
Pure floating point pow operations no-longer support negative base
values (see <llvm/llvm-project#126338>), but
many models coming from ONNX use floating point representations of
integers as the exponent.

This change:

1. matches on constant rank-0 exponents and converts them to scalar
constants.
2. matches on constant floating-point scalar exponents and converts them
to ints if possible.
3. lowers `Tensor(float)^int` cases to `math.fpowi` 

Addresses some remaining test failures related to
<iree-org/iree#19996>.
  • Loading branch information
zjgarvey authored Feb 19, 2025
1 parent 98b6aee commit fe5abf0
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5236,6 +5236,7 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
Expand All @@ -5260,6 +5261,7 @@ def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [
Expand Down
9 changes: 9 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value exp = operands[1];
Type expType = exp.getType();
if (!expType.isIntOrFloat()) {
pow.emitError("unimplemented: exp type neither float nor int");
return nullptr;
}
if (isa<mlir::IntegerType>(expType)) {
return b.create<math::FPowIOp>(loc, payloadArgs[0], exp);
}
Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
Expand Down
60 changes: 60 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -2421,6 +2422,65 @@ OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenPowTensorScalarOp
//===----------------------------------------------------------------------===//

void AtenPowTensorScalarOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// If the exponent is a float representation of an int,
// convert the exponent to an int
patterns.add(+[](AtenPowTensorScalarOp op, PatternRewriter &rewriter) {
auto exp = getAsOpFoldResult(op.getExponent());
auto baseAttr = dyn_cast<mlir::Attribute>(exp);
auto floatAttr = dyn_cast_or_null<mlir::FloatAttr>(baseAttr);
if (!floatAttr)
return failure();
double expValue = floatAttr.getValueAsDouble();
auto truncValue = static_cast<int64_t>(expValue);
if (expValue != static_cast<double>(truncValue))
return failure();
Value IRScalar =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), truncValue);
op->setOperand(1, IRScalar);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenPowTensorTensorOp
//===----------------------------------------------------------------------===//

void AtenPowTensorTensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// If the exponent is a single element constant, convert to
// AtenPowTensorScalar.
patterns.add(+[](AtenPowTensorTensorOp op, PatternRewriter &rewriter) {
OpFoldResult exp = getAsOpFoldResult(op.getExponent());
auto expAttr = dyn_cast<Attribute>(exp);
auto attr = dyn_cast_or_null<DenseElementsAttr>(expAttr);
if (!attr || attr.getNumElements() != 1)
return failure();
auto elem = *attr.value_begin<Attribute>();
auto intAttr = dyn_cast<mlir::IntegerAttr>(elem);
auto floatAttr = dyn_cast<mlir::FloatAttr>(elem);
if (!intAttr && !floatAttr)
return failure();
Value IRScalar;
if (intAttr)
IRScalar = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), getIntAttrAsSigned(intAttr));
if (floatAttr) {
double expValue = floatAttr.getValueAsDouble();
IRScalar = rewriter.create<Torch::ConstantFloatOp>(op.getLoc(),
APFloat(expValue));
}
rewriter.replaceOpWithNewOp<AtenPowTensorScalarOp>(op, op.getType(),
op.getSelf(), IRScalar);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenSelectIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,12 @@ def emit_with_mutating_variants(key, **kwargs):
has_canonicalizer=True,
)
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True
)
emit(
"aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True
)
emit("aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)")
emit("aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
Expand Down
20 changes: 20 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,26 @@ func.func @torch.aten.remainder.int() -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$canonicalize
// CHECK: %[[SCALAR_EXP:.*]] = torch.constant.float 3.5
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[SCALAR_EXP]]
// CHECK: return %[[POW]]
func.func @torch.aten.pow.Tensor_Tensor$canonicalize(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%exponent = torch.vtensor.literal(dense<3.500000e+00> : tensor<f32>) : !torch.vtensor<[],f32>
%pow = torch.aten.pow.Tensor_Tensor %arg0, %exponent : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
return %pow : !torch.vtensor<[?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$canonicalize
// CHECK: %[[INT_EXP:.*]] = torch.constant.int 3
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[INT_EXP]]
// CHECK: return %[[POW]]
func.func @torch.aten.pow.Tensor_Scalar$canonicalize(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%float_exponent = torch.constant.float 3.0
%pow = torch.aten.pow.Tensor_Scalar %arg0, %float_exponent : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %pow : !torch.vtensor<[?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.pow.int_float() -> !torch.float {
// CHECK: %[[FLOAT_8:.*]] = torch.constant.float 8.000000e+00
// CHECK: return %[[FLOAT_8]] : !torch.float
Expand Down

0 comments on commit fe5abf0

Please sign in to comment.