|
12 | 12 | #include "llvm/Support/Debug.h"
|
13 | 13 |
|
14 | 14 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
| 15 | +#include "mlir/Dialect/Utils/StaticValueUtils.h" |
15 | 16 | #include "mlir/IR/Builders.h"
|
16 | 17 | #include "mlir/IR/BuiltinOps.h"
|
17 | 18 | #include "mlir/IR/PatternMatch.h"
|
@@ -2421,6 +2422,65 @@ OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
|
2421 | 2422 | return nullptr;
|
2422 | 2423 | }
|
2423 | 2424 |
|
| 2425 | +//===----------------------------------------------------------------------===// |
| 2426 | +// AtenPowTensorScalarOp |
| 2427 | +//===----------------------------------------------------------------------===// |
| 2428 | + |
| 2429 | +void AtenPowTensorScalarOp::getCanonicalizationPatterns( |
| 2430 | + RewritePatternSet &patterns, MLIRContext *context) { |
| 2431 | + // If the exponent is a float representation of an int, |
| 2432 | + // convert the exponent to an int |
| 2433 | + patterns.add(+[](AtenPowTensorScalarOp op, PatternRewriter &rewriter) { |
| 2434 | + auto exp = getAsOpFoldResult(op.getExponent()); |
| 2435 | + auto baseAttr = dyn_cast<mlir::Attribute>(exp); |
| 2436 | + auto floatAttr = dyn_cast_or_null<mlir::FloatAttr>(baseAttr); |
| 2437 | + if (!floatAttr) |
| 2438 | + return failure(); |
| 2439 | + double expValue = floatAttr.getValueAsDouble(); |
| 2440 | + auto truncValue = static_cast<int64_t>(expValue); |
| 2441 | + if (expValue != static_cast<double>(truncValue)) |
| 2442 | + return failure(); |
| 2443 | + Value IRScalar = |
| 2444 | + rewriter.create<Torch::ConstantIntOp>(op.getLoc(), truncValue); |
| 2445 | + op->setOperand(1, IRScalar); |
| 2446 | + return success(); |
| 2447 | + }); |
| 2448 | +} |
| 2449 | + |
| 2450 | +//===----------------------------------------------------------------------===// |
| 2451 | +// AtenPowTensorTensorOp |
| 2452 | +//===----------------------------------------------------------------------===// |
| 2453 | + |
| 2454 | +void AtenPowTensorTensorOp::getCanonicalizationPatterns( |
| 2455 | + RewritePatternSet &patterns, MLIRContext *context) { |
| 2456 | + // If the exponent is a single element constant, convert to |
| 2457 | + // AtenPowTensorScalar. |
| 2458 | + patterns.add(+[](AtenPowTensorTensorOp op, PatternRewriter &rewriter) { |
| 2459 | + OpFoldResult exp = getAsOpFoldResult(op.getExponent()); |
| 2460 | + auto expAttr = dyn_cast<Attribute>(exp); |
| 2461 | + auto attr = dyn_cast_or_null<DenseElementsAttr>(expAttr); |
| 2462 | + if (!attr || attr.getNumElements() != 1) |
| 2463 | + return failure(); |
| 2464 | + auto elem = *attr.value_begin<Attribute>(); |
| 2465 | + auto intAttr = dyn_cast<mlir::IntegerAttr>(elem); |
| 2466 | + auto floatAttr = dyn_cast<mlir::FloatAttr>(elem); |
| 2467 | + if (!intAttr && !floatAttr) |
| 2468 | + return failure(); |
| 2469 | + Value IRScalar; |
| 2470 | + if (intAttr) |
| 2471 | + IRScalar = rewriter.create<Torch::ConstantIntOp>( |
| 2472 | + op.getLoc(), getIntAttrAsSigned(intAttr)); |
| 2473 | + if (floatAttr) { |
| 2474 | + double expValue = floatAttr.getValueAsDouble(); |
| 2475 | + IRScalar = rewriter.create<Torch::ConstantFloatOp>(op.getLoc(), |
| 2476 | + APFloat(expValue)); |
| 2477 | + } |
| 2478 | + rewriter.replaceOpWithNewOp<AtenPowTensorScalarOp>(op, op.getType(), |
| 2479 | + op.getSelf(), IRScalar); |
| 2480 | + return success(); |
| 2481 | + }); |
| 2482 | +} |
| 2483 | + |
2424 | 2484 | //===----------------------------------------------------------------------===//
|
2425 | 2485 | // AtenSelectIntOp
|
2426 | 2486 | //===----------------------------------------------------------------------===//
|
|
0 commit comments