Skip to content

Commit fe5abf0

Browse files
authored
Convert float exponents in aten.pow to int when possible (#4029)
Pure floating point pow operations no-longer support negative base values (see <llvm/llvm-project#126338>), but many models coming from ONNX use floating point representations of integers as the exponent. This change: 1. matches on constant rank-0 exponents and converts them to scalar constants. 2. matches on constant floating-point scalar exponents and converts them to ints if possible. 3. lowers `Tensor(float)^int` cases to `math.fpowi` Addresses some remaining test failures related to <iree-org/iree#19996>.
1 parent 98b6aee commit fe5abf0

File tree

5 files changed

+97
-2
lines changed

5 files changed

+97
-2
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+2
Original file line numberDiff line numberDiff line change
@@ -5236,6 +5236,7 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
52365236
printDefaultTorchOp(printer, *this, 2, 1);
52375237
}
52385238
}];
5239+
let hasCanonicalizer = 1;
52395240
}
52405241

52415242
def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
@@ -5260,6 +5261,7 @@ def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [
52605261
printDefaultTorchOp(printer, *this, 2, 1);
52615262
}
52625263
}];
5264+
let hasCanonicalizer = 1;
52635265
}
52645266

52655267
def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [

lib/Conversion/TorchToLinalg/Uncategorized.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10121012
pow.emitError("unimplemented: non-floating point dtype");
10131013
return nullptr;
10141014
}
1015+
Value exp = operands[1];
1016+
Type expType = exp.getType();
1017+
if (!expType.isIntOrFloat()) {
1018+
pow.emitError("unimplemented: exp type neither float nor int");
1019+
return nullptr;
1020+
}
1021+
if (isa<mlir::IntegerType>(expType)) {
1022+
return b.create<math::FPowIOp>(loc, payloadArgs[0], exp);
1023+
}
10151024
Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
10161025
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
10171026
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);

lib/Dialect/Torch/IR/TorchOps.cpp

+60
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/Support/Debug.h"
1313

1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/BuiltinOps.h"
1718
#include "mlir/IR/PatternMatch.h"
@@ -2421,6 +2422,65 @@ OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
24212422
return nullptr;
24222423
}
24232424

2425+
//===----------------------------------------------------------------------===//
2426+
// AtenPowTensorScalarOp
2427+
//===----------------------------------------------------------------------===//
2428+
2429+
void AtenPowTensorScalarOp::getCanonicalizationPatterns(
2430+
RewritePatternSet &patterns, MLIRContext *context) {
2431+
// If the exponent is a float representation of an int,
2432+
// convert the exponent to an int
2433+
patterns.add(+[](AtenPowTensorScalarOp op, PatternRewriter &rewriter) {
2434+
auto exp = getAsOpFoldResult(op.getExponent());
2435+
auto baseAttr = dyn_cast<mlir::Attribute>(exp);
2436+
auto floatAttr = dyn_cast_or_null<mlir::FloatAttr>(baseAttr);
2437+
if (!floatAttr)
2438+
return failure();
2439+
double expValue = floatAttr.getValueAsDouble();
2440+
auto truncValue = static_cast<int64_t>(expValue);
2441+
if (expValue != static_cast<double>(truncValue))
2442+
return failure();
2443+
Value IRScalar =
2444+
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), truncValue);
2445+
op->setOperand(1, IRScalar);
2446+
return success();
2447+
});
2448+
}
2449+
2450+
//===----------------------------------------------------------------------===//
2451+
// AtenPowTensorTensorOp
2452+
//===----------------------------------------------------------------------===//
2453+
2454+
void AtenPowTensorTensorOp::getCanonicalizationPatterns(
2455+
RewritePatternSet &patterns, MLIRContext *context) {
2456+
// If the exponent is a single element constant, convert to
2457+
// AtenPowTensorScalar.
2458+
patterns.add(+[](AtenPowTensorTensorOp op, PatternRewriter &rewriter) {
2459+
OpFoldResult exp = getAsOpFoldResult(op.getExponent());
2460+
auto expAttr = dyn_cast<Attribute>(exp);
2461+
auto attr = dyn_cast_or_null<DenseElementsAttr>(expAttr);
2462+
if (!attr || attr.getNumElements() != 1)
2463+
return failure();
2464+
auto elem = *attr.value_begin<Attribute>();
2465+
auto intAttr = dyn_cast<mlir::IntegerAttr>(elem);
2466+
auto floatAttr = dyn_cast<mlir::FloatAttr>(elem);
2467+
if (!intAttr && !floatAttr)
2468+
return failure();
2469+
Value IRScalar;
2470+
if (intAttr)
2471+
IRScalar = rewriter.create<Torch::ConstantIntOp>(
2472+
op.getLoc(), getIntAttrAsSigned(intAttr));
2473+
if (floatAttr) {
2474+
double expValue = floatAttr.getValueAsDouble();
2475+
IRScalar = rewriter.create<Torch::ConstantFloatOp>(op.getLoc(),
2476+
APFloat(expValue));
2477+
}
2478+
rewriter.replaceOpWithNewOp<AtenPowTensorScalarOp>(op, op.getType(),
2479+
op.getSelf(), IRScalar);
2480+
return success();
2481+
});
2482+
}
2483+
24242484
//===----------------------------------------------------------------------===//
24252485
// AtenSelectIntOp
24262486
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,12 @@ def emit_with_mutating_variants(key, **kwargs):
498498
has_canonicalizer=True,
499499
)
500500
emit("aten::gelu : (Tensor, str) -> (Tensor)")
501-
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
502-
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
501+
emit(
502+
"aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True
503+
)
504+
emit(
505+
"aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True
506+
)
503507
emit("aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)")
504508
emit("aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
505509
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")

test/Dialect/Torch/canonicalize.mlir

+20
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,26 @@ func.func @torch.aten.remainder.int() -> !torch.int {
13141314
return %ret : !torch.int
13151315
}
13161316

1317+
// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$canonicalize
1318+
// CHECK: %[[SCALAR_EXP:.*]] = torch.constant.float 3.5
1319+
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[SCALAR_EXP]]
1320+
// CHECK: return %[[POW]]
1321+
func.func @torch.aten.pow.Tensor_Tensor$canonicalize(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
1322+
%exponent = torch.vtensor.literal(dense<3.500000e+00> : tensor<f32>) : !torch.vtensor<[],f32>
1323+
%pow = torch.aten.pow.Tensor_Tensor %arg0, %exponent : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
1324+
return %pow : !torch.vtensor<[?,?],f32>
1325+
}
1326+
1327+
// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$canonicalize
1328+
// CHECK: %[[INT_EXP:.*]] = torch.constant.int 3
1329+
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[INT_EXP]]
1330+
// CHECK: return %[[POW]]
1331+
func.func @torch.aten.pow.Tensor_Scalar$canonicalize(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
1332+
%float_exponent = torch.constant.float 3.0
1333+
%pow = torch.aten.pow.Tensor_Scalar %arg0, %float_exponent : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
1334+
return %pow : !torch.vtensor<[?,?],f32>
1335+
}
1336+
13171337
// CHECK-LABEL: func.func @torch.aten.pow.int_float() -> !torch.float {
13181338
// CHECK: %[[FLOAT_8:.*]] = torch.constant.float 8.000000e+00
13191339
// CHECK: return %[[FLOAT_8]] : !torch.float

0 commit comments

Comments
 (0)