Skip to content

Commit

Permalink
[LINALG] Added support for conversion from float to complex. (#3595)
Browse files Browse the repository at this point in the history
  • Loading branch information
BaneTrifa authored Aug 7, 2024
1 parent b48e55c commit 2d6bfb2
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
}

if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {

// Complex to complex.
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
auto dtypeElemType = dtypeComplex.getElementType();

Expand All @@ -364,6 +366,27 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,

return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
}

// Float to complex type.
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(scalarType)) {
auto complexElementType =
cast<mlir::FloatType>(dtypeComplex.getElementType());
Value realVal;
Value imgVal =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(complexElementType));

if (complexElementType.getWidth() > dtypeFloat.getWidth()) {
realVal = b.create<arith::ExtFOp>(loc, complexElementType, scalar);
} else if (complexElementType.getWidth() < dtypeFloat.getWidth()) {
realVal = b.create<arith::TruncFOp>(loc, complexElementType, scalar);
;
} else {
realVal = scalar;
}

return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
}

mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,8 @@
"TensorToFloatZeroRank_basic",
"TensorToIntZeroRank_basic",
"TensorsConcatModule_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex64FloatModule_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
Expand Down Expand Up @@ -2598,6 +2600,8 @@
"SubFloatModule_basic",
"SubIntModule_basic",
"TanhBackward_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex64FloatModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"TensorToFloatZeroRank_basic",
Expand Down
62 changes: 62 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,68 @@ def TensorsConcatModule_basic(module, tu: TestUtils):
# ==============================================================================


class TensorsConcatComplex64FloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.complex64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float16, True),
]
)
def forward(self, a, b, c, d):
return torch.cat([a, b, c, d], 1)


@register_test_case(module_factory=lambda: TensorsConcatComplex64FloatModule())
def TensorsConcatComplex64FloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 1, 4, low=1, high=10).to(torch.complex64),
tu.rand(2, 3, 4, low=1, high=10).to(torch.float64),
tu.rand(2, 3, 4, low=1, high=10).to(torch.float32),
tu.rand(2, 3, 4, low=1, high=10).to(torch.float16),
)


# ==============================================================================


class TensorsConcatComplex128FloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.complex128, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float16, True),
]
)
def forward(self, a, b, c, d):
return torch.cat([a, b, c, d], 1)


@register_test_case(module_factory=lambda: TensorsConcatComplex128FloatModule())
def TensorsConcatComplex128FloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128),
tu.rand(2, 3, 4, low=1, high=10).to(torch.float64),
tu.rand(2, 3, 4, low=1, high=10).to(torch.float32),
tu.rand(2, 3, 4, low=1, high=10).to(torch.float16),
)


# ==============================================================================


class TensorsConcatNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 2d6bfb2

Please sign in to comment.