Skip to content

Commit 1b8d7e0

Browse files
giacs-epiczjgarvey
andauthored
[Torch Dialect] Add torch.aten.mul.int_float (required to simplify shape calculation of upsample_nearest2d) (llvm#3764)
As per title. See also [PR](llvm#3750) for `torch.aten.mul.float_int`. --------- Co-authored-by: zjgarvey <[email protected]>
1 parent bdbc64a commit 1b8d7e0

File tree

7 files changed

+73
-6
lines changed

7 files changed

+73
-6
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+25
Original file line numberDiff line numberDiff line change
@@ -15885,6 +15885,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
1588515885
let hasCanonicalizer = 1;
1588615886
}
1588715887

15888+
def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [
15889+
AllowsTypeRefinement,
15890+
HasValueSemantics,
15891+
ReadOnly
15892+
]> {
15893+
let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`";
15894+
let arguments = (ins
15895+
Torch_IntType:$a,
15896+
Torch_FloatType:$b
15897+
);
15898+
let results = (outs
15899+
Torch_FloatType:$result
15900+
);
15901+
let hasCustomAssemblyFormat = 1;
15902+
let extraClassDefinition = [{
15903+
ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
15904+
return parseDefaultTorchOp(parser, result, 2, 1);
15905+
}
15906+
void AtenMulIntFloatOp::print(OpAsmPrinter &printer) {
15907+
printDefaultTorchOp(printer, *this, 2, 1);
15908+
}
15909+
}];
15910+
let hasFolder = 1;
15911+
}
15912+
1588815913
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
1588915914
AllowsTypeRefinement,
1589015915
HasValueSemantics,

lib/Conversion/TorchToArith/TorchToArith.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
7676
Value b = adaptor.getB();
7777
if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value)
7878
b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType());
79+
if (llvm::is_one_of<AtenOp, AtenMulIntFloatOp>::value)
80+
a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType());
7981
rewriter.template replaceOpWithNewOp<BinOp>(op, a, b);
8082
return success();
8183
}
@@ -487,7 +489,7 @@ class ConvertTorchToArith
487489
target.addIllegalOp<AtenNegIntOp>();
488490
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
489491
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
490-
AtenMulIntOp, AtenRemainderIntOp>();
492+
AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp>();
491493
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
492494
typeConverter, context);
493495
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
@@ -498,6 +500,8 @@ class ConvertTorchToArith
498500
typeConverter, context);
499501
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
500502
typeConverter, context);
503+
patterns.add<ConvertAtenBinaryOp<AtenMulIntFloatOp, arith::MulFOp>>(
504+
typeConverter, context);
501505
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
502506
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
503507
typeConverter, context);

lib/Dialect/Torch/IR/TorchOps.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -4219,6 +4219,19 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
42194219
[](double a, double b) -> double { return a * b; });
42204220
}
42214221

4222+
//===----------------------------------------------------------------------===//
4223+
// AtenMulIntFloatOp
4224+
//===----------------------------------------------------------------------===//
4225+
4226+
OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) {
4227+
if (!adaptor.getA() || !adaptor.getB()) {
4228+
return nullptr;
4229+
}
4230+
return atenBinaryFloatOperatorFoldHelper(
4231+
adaptor.getOperands(),
4232+
[](double a, double b) -> double { return a * b; });
4233+
}
4234+
42224235
//===----------------------------------------------------------------------===//
42234236
// AtenSubOp
42244237
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
55075507
" }\n"
55085508
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
55095509
" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
5510-
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
5510+
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
55115511
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
55125512
" %19 = torch.aten.append.t %1, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
55135513
" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
55145514
" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
5515-
" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n"
5515+
" %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n"
55165516
" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n"
55175517
" %24 = torch.aten.append.t %1, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
55185518
" torch.prim.If.yield\n"
@@ -11184,7 +11184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1118411184
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
1118511185
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
1118611186
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
11187-
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
11187+
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
1118811188
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
1118911189
" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
1119011190
" torch.prim.If.yield %19 : !torch.list<int>\n"
@@ -11264,11 +11264,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1126411264
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
1126511265
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
1126611266
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
11267-
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
11267+
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
1126811268
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
1126911269
" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
1127011270
" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
11271-
" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n"
11271+
" %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n"
1127211272
" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n"
1127311273
" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
1127411274
" torch.prim.If.yield %23 : !torch.list<int>\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,7 @@ def emit_with_mutating_variants(key, **kwargs):
11181118
has_folder=True,
11191119
has_canonicalizer=True,
11201120
)
1121+
emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True)
11211122
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
11221123
emit("aten::neg.int : (int) -> (int)", has_folder=True)
11231124
emit("aten::log.int : (int) -> (float)")

test/Conversion/TorchToArith/basic.mlir

+14
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in
236236
return %0 : !torch.int
237237
}
238238

239+
// CHECK-LABEL: func.func @torch.aten.mul.int_float(
240+
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
241+
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
242+
// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
243+
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
244+
// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64
245+
// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64
246+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]]
247+
// CHECK: return %[[OUT]] : !torch.float
248+
func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float {
249+
%0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float
250+
return %0 : !torch.float
251+
}
252+
239253
// CHECK-LABEL: func.func @torch.aten.div.float(
240254
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
241255
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {

test/Dialect/Torch/canonicalize.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,16 @@ func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
12351235
return %ret : !torch.int
12361236
}
12371237

1238+
// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float {
1239+
// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00
1240+
// CHECK: return %[[CST6]] : !torch.float
1241+
func.func @torch.aten.mul.int_float() -> !torch.float {
1242+
%cst2 = torch.constant.int 2
1243+
%cst3 = torch.constant.float 3.0
1244+
%ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float
1245+
return %ret : !torch.float
1246+
}
1247+
12381248
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
12391249
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
12401250
// CHECK: return %[[CST30]] : !torch.float

0 commit comments

Comments
 (0)