Skip to content

[CIR] Upstream ComplexRealPtrOp for ComplexType #144235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2385,4 +2385,33 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// ComplexRealPtrOp
//===----------------------------------------------------------------------===//

def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
let summary = "Derive a pointer to the real part of a complex value";
let description = [{
`cir.complex.real_ptr` operation takes a pointer operand that points to a
complex value of type `!cir.complex` and yields a pointer to the real part
of the operand.

Example:

```mlir
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformat this for 80 columns.

```
}];

let results = (outs CIR_PtrToIntOrFloatType:$result);
let arguments = (ins CIR_PtrToComplexType:$operand);

let assemblyFormat = [{
$operand `:`
qualified(type($operand)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
24 changes: 24 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
let cppFunctionName = "isAnyIntegerOrFloatingPointType";
}

//===----------------------------------------------------------------------===//
// Complex Type predicates
//===----------------------------------------------------------------------===//

def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;

//===----------------------------------------------------------------------===//
// Pointer Type predicates
//===----------------------------------------------------------------------===//
Expand All @@ -180,6 +186,19 @@ class CIR_PtrToPtrTo<code type, string summary>
: CIR_ConfinedType<CIR_AnyPtrType, [CIR_IsPtrToPtrToPred<type>],
"pointer to pointer to " # summary>;

// Pointee type constraint bases
class CIR_PointeePred<Pred pred> : SubstLeaves<"$_self",
"::mlir::cast<::cir::PointerType>($_self).getPointee()", pred>;

class CIR_PtrToAnyOf<list<Type> types, string summary = "">
: CIR_ConfinedType<CIR_AnyPtrType,
[Or<!foreach(type, types, CIR_PointeePred<type.predicate>)>],
!if(!empty(summary),
"pointer to " # CIR_TypeSummaries<types>.value,
summary)>;

class CIR_PtrToType<Type type> : CIR_PtrToAnyOf<[type]>;

// Void pointer type constraints
def CIR_VoidPtrType
: CIR_PtrTo<"::cir::VoidType", "void type">,
Expand All @@ -192,6 +211,11 @@ def CIR_PtrToVoidPtrType
"$_builder.getType<" # cppType # ">("
"cir::VoidType::get($_builder.getContext())))">;

// Pointer to type constraints
def CIR_PtrToIntOrFloatType : CIR_PtrToType<CIR_AnyIntOrFloatType>;

def CIR_PtrToComplexType : CIR_PtrToType<CIR_AnyComplexType>;

//===----------------------------------------------------------------------===//
// Vector Type predicates
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}

/// Create a cir.complex.real_ptr operation that derives a pointer to the real
/// part of the complex value pointed to by the specified pointer value.
mlir::Value createRealPtr(mlir::Location loc, mlir::Value value) {
auto srcPtrTy = mlir::cast<cir::PointerType>(value.getType());
auto srcComplexTy = mlir::cast<cir::ComplexType>(srcPtrTy.getPointee());
return create<cir::ComplexRealPtrOp>(
loc, getPointerTo(srcComplexTy.getElementType()), value);
}

Address createRealPtr(mlir::Location loc, Address addr) {
return Address{createRealPtr(loc, addr.getPointer()), addr.getAlignment()};
}

/// Create a cir.ptr_stride operation to get access to an array element.
/// \p idx is the index of the element to access, \p shouldDecay is true if
/// the result should decay to a pointer to the element type.
Expand Down
25 changes: 23 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,29 @@ LValue CIRGenFunction::emitUnaryOpLValue(const UnaryOperator *e) {
}
case UO_Real:
case UO_Imag: {
cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
return LValue();
if (op == UO_Imag) {
cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
return LValue();
}

LValue lv = emitLValue(e->getSubExpr());
assert(lv.isSimple() && "real/imag on non-ordinary l-value");

// __real is valid on scalars. This is a faster way of testing that.
// __imag can only produce an rvalue on scalars.
if (e->getOpcode() == UO_Real &&
!mlir::isa<cir::ComplexType>(lv.getAddress().getElementType())) {
assert(e->getSubExpr()->getType()->isArithmeticType());
return lv;
}

QualType exprTy = getContext().getCanonicalType(e->getSubExpr()->getType());
QualType elemTy = exprTy->castAs<clang::ComplexType>()->getElementType();
mlir::Location loc = getLoc(e->getExprLoc());
Address component = builder.createRealPtr(loc, lv.getAddress());
LValue elemLV = makeAddrLValue(component, elemTy);
elemLV.getQuals().addQualifiers(lv.getQuals());
return elemLV;
}
case UO_PreInc:
case UO_PreDec: {
Expand Down
19 changes: 19 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,25 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
return cir::ConstComplexAttr::get(realAttr, imagAttr);
}

//===----------------------------------------------------------------------===//
// ComplexRealPtrOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexRealPtrOp::verify() {
mlir::Type resultPointeeTy = getType().getPointee();
cir::PointerType operandPtrTy = getOperand().getType();
auto operandPointeeTy =
mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());

if (resultPointeeTy != operandPointeeTy.getElementType()) {
emitOpError()
<< "cir.complex.real_ptr result type does not match operand type";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 19 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecTernaryOpLowering,
CIRToLLVMComplexCreateOpLowering
CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexRealPtrOpLowering
// clang-format on
>(converter, patterns.getContext());

Expand Down Expand Up @@ -2140,6 +2141,23 @@ mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite(
cir::ComplexRealPtrOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
cir::PointerType operandTy = op.getOperand().getType();
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
mlir::Type elementLLVMTy =
getTypeConverter()->convertType(operandTy.getPointee());

mlir::LLVM::GEPArg gepIndices[2] = {{0}, {0}};
mlir::LLVM::GEPNoWrapFlags inboundsNuw =
mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
inboundsNuw);
return mlir::success();
}

std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,16 @@ class CIRToLLVMComplexCreateOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMComplexRealPtrOpLowering
: public mlir::OpConversionPattern<cir::ComplexRealPtrOp> {
public:
using mlir::OpConversionPattern<cir::ComplexRealPtrOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::ComplexRealPtrOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

} // namespace direct
} // namespace cir

Expand Down
13 changes: 13 additions & 0 deletions clang/test/CIR/CodeGen/complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,16 @@ void foo7() {
// OGCG: store float %[[TMP_A]], ptr %[[C_REAL_PTR]], align 4
// OGCG: store float 2.000000e+00, ptr %[[C_IMAG_PTR]], align 4

void foo10() {
double _Complex c;
double *realPtr = &__real__ c;
}

// CIR: %[[COMPLEX:.*]] = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
// CIR: %[[REAL_PTR:.*]] = cir.complex.real_ptr %[[COMPLEX]] : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>

// LLVM: %[[COMPLEX:.*]] = alloca { double, double }, i64 1, align 8
// LLVM: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0

// OGCG: %[[COMPLEX:.*]] = alloca { double, double }, align 8
// OGCG: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
Loading