Skip to content

Commit

Permalink
Add scalar type promotion for mul and div (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Dec 3, 2021
1 parent c9c9b68 commit b0cb49c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 15 deletions.
1 change: 1 addition & 0 deletions e2e_testing/torchscript/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils):

# ==============================================================================


class DropoutModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
76 changes: 75 additions & 1 deletion e2e_testing/torchscript/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================


class ElementwiseMulScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -378,7 +380,52 @@ def forward(self, x):
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))




class ElementwiseMulTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float64, True),
])
def forward(self, a, b):
return torch.mul(a, b)


@register_test_case(
module_factory=lambda: ElementwiseMulTensorFloatModule())
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(4),
tu.rand(4).type(torch.float64))

class ElementwiseMulTensorIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.int64, True),
])
def forward(self, a, b):
return torch.mul(a, b)


@register_test_case(
module_factory=lambda: ElementwiseMulTensorIntModule())
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
module.forward(
torch.randint(10, [4]).type(torch.int32),
torch.randint(10, [4]))


# ==============================================================================
class ElementwiseLogModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -553,7 +600,32 @@ def forward(self, x):
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))


class ElementwiseDivTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float64, True),
])
def forward(self, a, b):
return torch.div(a, b)


@register_test_case(
module_factory=lambda: ElementwiseDivTensorFloatModule())
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(4),
tu.rand(4).type(torch.float64))


# ==============================================================================


class ElementwiseAndIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -573,3 +645,5 @@ def forward(self, x, y):
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
torch.randint(-10, 10, (3, 4)))


2 changes: 2 additions & 0 deletions e2e_testing/torchscript/type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,5 @@ def forward(self, a, b):
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand())


32 changes: 18 additions & 14 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
}
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
if (!mul.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
mul.emitError("unimplemented: non-floating point dtype");
return nullptr;
AtenMulTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(mul.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
return b.create<arith::MulFOp>(loc, payloadArgs[0], payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
if (!div.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(div.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>())
div.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
return b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1]);
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::DivFOp>(loc, lhs, rhs);
}
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType()
Expand Down

0 comments on commit b0cb49c

Please sign in to comment.