-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[CIR] Upstream ComplexRealPtrOp for ComplexType #144235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CIR] Upstream ComplexRealPtrOp for ComplexType #144235
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesFull diff: https://github.com/llvm/llvm-project/pull/144235.diff 8 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index bd36d228578b7..17279f0a9985a 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2385,4 +2385,33 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// ComplexRealPtrOp
+//===----------------------------------------------------------------------===//
+
+def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
+ let summary = "Derive a pointer to the real part of a complex value";
+ let description = [{
+ `cir.complex.real_ptr` operation takes a pointer operand that points to a
+ complex value of type `!cir.complex` and yields a pointer to the real part
+ of the operand.
+
+ Example:
+
+ ```mlir
+ %1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
+ ```
+ }];
+
+ let results = (outs CIR_PtrToIntOrFloatType:$result);
+ let arguments = (ins CIR_PtrToComplexType:$operand);
+
+ let assemblyFormat = [{
+ $operand `:`
+ qualified(type($operand)) `->` qualified(type($result)) attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index bcd516e27cc76..d59ff62248e1f 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -159,6 +159,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
let cppFunctionName = "isAnyIntegerOrFloatingPointType";
}
+//===----------------------------------------------------------------------===//
+// Complex Type predicates
+//===----------------------------------------------------------------------===//
+
+def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
+
//===----------------------------------------------------------------------===//
// Pointer Type predicates
//===----------------------------------------------------------------------===//
@@ -180,6 +186,19 @@ class CIR_PtrToPtrTo<code type, string summary>
: CIR_ConfinedType<CIR_AnyPtrType, [CIR_IsPtrToPtrToPred<type>],
"pointer to pointer to " # summary>;
+// Pointee type constraint bases
+class CIR_PointeePred<Pred pred> : SubstLeaves<"$_self",
+ "::mlir::cast<::cir::PointerType>($_self).getPointee()", pred>;
+
+class CIR_PtrToAnyOf<list<Type> types, string summary = "">
+ : CIR_ConfinedType<CIR_AnyPtrType,
+ [Or<!foreach(type, types, CIR_PointeePred<type.predicate>)>],
+ !if(!empty(summary),
+ "pointer to " # CIR_TypeSummaries<types>.value,
+ summary)>;
+
+class CIR_PtrToType<Type type> : CIR_PtrToAnyOf<[type]>;
+
// Void pointer type constraints
def CIR_VoidPtrType
: CIR_PtrTo<"::cir::VoidType", "void type">,
@@ -192,6 +211,11 @@ def CIR_PtrToVoidPtrType
"$_builder.getType<" # cppType # ">("
"cir::VoidType::get($_builder.getContext())))">;
+// Pointer to type constraints
+def CIR_PtrToIntOrFloatType : CIR_PtrToType<CIR_AnyIntOrFloatType>;
+
+def CIR_PtrToComplexType : CIR_PtrToType<CIR_AnyComplexType>;
+
//===----------------------------------------------------------------------===//
// Vector Type predicates
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
index e38faba83b80c..3f7ea5bccb6d5 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h
+++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
@@ -366,6 +366,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}
+ /// Create a cir.complex.real_ptr operation that derives a pointer to the real
+ /// part of the complex value pointed to by the specified pointer value.
+ mlir::Value createRealPtr(mlir::Location loc, mlir::Value value) {
+ auto srcPtrTy = mlir::cast<cir::PointerType>(value.getType());
+ auto srcComplexTy = mlir::cast<cir::ComplexType>(srcPtrTy.getPointee());
+ return create<cir::ComplexRealPtrOp>(
+ loc, getPointerTo(srcComplexTy.getElementType()), value);
+ }
+
+ Address createRealPtr(mlir::Location loc, Address addr) {
+ return Address{createRealPtr(loc, addr.getPointer()), addr.getAlignment()};
+ }
+
/// Create a cir.ptr_stride operation to get access to an array element.
/// \p idx is the index of the element to access, \p shouldDecay is true if
/// the result should decay to a pointer to the element type.
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 2e43f10be132c..a682586562e04 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -541,8 +541,29 @@ LValue CIRGenFunction::emitUnaryOpLValue(const UnaryOperator *e) {
}
case UO_Real:
case UO_Imag: {
- cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
- return LValue();
+ if (op == UO_Imag) {
+ cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
+ return LValue();
+ }
+
+ LValue lv = emitLValue(e->getSubExpr());
+ assert(lv.isSimple() && "real/imag on non-ordinary l-value");
+
+ // __real is valid on scalars. This is a faster way of testing that.
+ // __imag can only produce an rvalue on scalars.
+ if (e->getOpcode() == UO_Real &&
+ !mlir::isa<cir::ComplexType>(lv.getAddress().getElementType())) {
+ assert(e->getSubExpr()->getType()->isArithmeticType());
+ return lv;
+ }
+
+ QualType exprTy = getContext().getCanonicalType(e->getSubExpr()->getType());
+ QualType elemTy = exprTy->castAs<clang::ComplexType>()->getElementType();
+ mlir::Location loc = getLoc(e->getExprLoc());
+ Address component = builder.createRealPtr(loc, lv.getAddress());
+ LValue elemLV = makeAddrLValue(component, elemTy);
+ elemLV.getQuals().addQualifiers(lv.getQuals());
+ return elemLV;
}
case UO_PreInc:
case UO_PreDec: {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 5578d4f5825a9..99ae4dd59120a 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1775,6 +1775,25 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
return cir::ConstComplexAttr::get(realAttr, imagAttr);
}
+//===----------------------------------------------------------------------===//
+// ComplexRealPtrOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::ComplexRealPtrOp::verify() {
+ mlir::Type resultPointeeTy = getType().getPointee();
+ cir::PointerType operandPtrTy = getOperand().getType();
+ auto operandPointeeTy =
+ mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
+
+ if (resultPointeeTy != operandPointeeTy.getElementType()) {
+ emitOpError()
+ << "cir.complex.real_ptr result type does not match operand type";
+ return failure();
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 6a4e4e4a7df3b..c11992a4bdc61 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1836,7 +1836,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecTernaryOpLowering,
- CIRToLLVMComplexCreateOpLowering
+ CIRToLLVMComplexCreateOpLowering,
+ CIRToLLVMComplexRealPtrOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -2140,6 +2141,23 @@ mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite(
+ cir::ComplexRealPtrOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ cir::PointerType operandTy = op.getOperand().getType();
+ mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
+ mlir::Type elementLLVMTy =
+ getTypeConverter()->convertType(operandTy.getPointee());
+
+ mlir::LLVM::GEPArg gepIndices[2] = {{0}, {0}};
+ mlir::LLVM::GEPNoWrapFlags inboundsNuw =
+ mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
+ rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
+ op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
+ inboundsNuw);
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index a809818063547..caee3e9cd6980 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -418,6 +418,16 @@ class CIRToLLVMComplexCreateOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMComplexRealPtrOpLowering
+ : public mlir::OpConversionPattern<cir::ComplexRealPtrOp> {
+public:
+ using mlir::OpConversionPattern<cir::ComplexRealPtrOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::ComplexRealPtrOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp
index d193b9f32efbc..182673a69be90 100644
--- a/clang/test/CIR/CodeGen/complex.cpp
+++ b/clang/test/CIR/CodeGen/complex.cpp
@@ -176,3 +176,16 @@ void foo7() {
// OGCG: store float %[[TMP_A]], ptr %[[C_REAL_PTR]], align 4
// OGCG: store float 2.000000e+00, ptr %[[C_IMAG_PTR]], align 4
+void foo10() {
+ double _Complex c;
+ double *realPtr = &__real__ c;
+}
+
+// CIR: %[[COMPLEX:.*]] = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
+// CIR: %[[REAL_PTR:.*]] = cir.complex.real_ptr %[[COMPLEX]] : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
+
+// LLVM: %[[COMPLEX:.*]] = alloca { double, double }, i64 1, align 8
+// LLVM: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
+
+// OGCG: %[[COMPLEX:.*]] = alloca { double, double }, align 8
+// OGCG: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns about the CIR representation here. I think we should be aligning our representation of complex operations with the MLIR complex dialect. As such, we want __real__
to be lowerable to complex.re
but the explicit representation of intermediate pointer loads and stores is going to make that difficult.
The test case you have in this PR makes it difficult to reason about what I'm interested in here because it is dealing so explicitly with pointers. I'd like to consider something more like this:
double f(double _Complex a, double _Complex b) {
return __real__ a + __real__ b;
}
The incubator will generate the following CIR for this case:
cir.func dso_local @_Z1fCdS_(%arg0: !cir.complex<!cir.double>, %arg1: !cir.complex<!cir.double> -> !cir.double extra(#fn_attr) {
%0 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["a", init] {alignment = 8 : i64}
%1 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["b", init] {alignment = 8 : i64}
%2 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["__retval"] {alignment = 8 : i64}
cir.store %arg0, %0 : !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>
cir.store %arg1, %1 : !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>
%3 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
%4 = cir.load align(8) %3 : !cir.ptr<!cir.double>, !cir.double
%5 = cir.complex.real_ptr %1 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
%6 = cir.load align(8) %5 : !cir.ptr<!cir.double>, !cir.double
%7 = cir.binop(add, %4, %6) : !cir.double
cir.store %7, %2 : !cir.double, !cir.ptr<!cir.double>
%8 = cir.load %2 : !cir.ptr<!cir.double>, !cir.double
cir.return %8 : !cir.double
}
That's not what I'd like to see. What I'd like to see is something more like this:
cir.func dso_local @_Z1fCdS_(%arg0: !cir.complex<!cir.double>, %arg1: !cir.complex<!cir.double> -> !cir.double extra(#fn_attr) {
%0 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["a", init] {alignment = 8 : i64}
%1 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["b", init] {alignment = 8 : i64}
%2 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["__retval"] {alignment = 8 : i64}
cir.store %arg0, %0 : !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>
cir.store %arg1, %1 : !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>
%3 = cir.load align(8) %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.complex<!cir.double>
%4 = cir.complex.real %3 : !cir.complex<!cir.double> -> !cir.double
%5 = cir.load align(8) %1 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.complex<!cir.double>
%6 = cir.complex.real %5 : !cir.complex<!cir.double> -> !cir.double
%7 = cir.binop(add, %4, %6) : !cir.double
cir.store %7, %2 : !cir.double, !cir.ptr<!cir.double>
%8 = cir.load %2 : !cir.ptr<!cir.double>, !cir.double
cir.return %8 : !cir.double
}
The idea is that the __real__
operation acts directly on the complex value rather than manipulating its memory representation. We'll still need to go to the memory representation when it gets lowered to LLVM IR, but I would like to keep it higher level until then.
For your test case:
void foo10() {
double _Complex c;
double *realPtr = &__real__ c;
}
the expression &__real__ c
is represented in the AST like this:
`-UnaryOperator <col:21, col:31> 'double *' prefix '&' cannot overflow
`-UnaryOperator <col:22, col:31> 'double' lvalue prefix '__real' cannot overflow
`-DeclRefExpr <col:31> '_Complex double' lvalue Var 0x10ed3bd8 'c' '_Complex double'
Hopefully, we could represent the __real__
operator with a cir.complex.real
operation as I've proposed and then use the &
operator to do something to get a pointer. I don't know exactly how that would look.
This is, of course, something we'd want to change in the incubator first to take advantage of the more comprehensive testing available there.
Example: | ||
|
||
```mlir | ||
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reformat this for 80 columns.
I thought |
+1, I think this is a reasonable approach! |
Update `__real__` operation to use ComplexRealOp and act directly on the complex value. Ref: llvm/llvm-project#144235 (review)
Update `__imag__` operation to use ComplexRealOp and act directly on the complex value. Ref: llvm/llvm-project#144235 (review)
This change adds support for ComplexRealPtrOp for ComplexType
#141365