From ba1230d545d045623c1ab3e202b7e2b836efe02f Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Thu, 19 Jun 2025 17:25:48 +0100 Subject: [PATCH 1/5] [MLIR] Add ComplexTOROCDL pass This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls. --- flang/lib/Optimizer/CodeGen/CMakeLists.txt | 1 + flang/lib/Optimizer/CodeGen/CodeGen.cpp | 5 +- .../ComplexToROCDL/ComplexToROCDL.h | 19 ++++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 12 +++ mlir/lib/Conversion/CMakeLists.txt | 1 + .../Conversion/ComplexToROCDL/CMakeLists.txt | 18 ++++ .../ComplexToROCDL/ComplexToROCDL.cpp | 95 +++++++++++++++++++ .../ComplexToROCDL/complex-to-rocdl.mlir | 13 +++ 9 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h create mode 100644 mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt create mode 100644 mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp create mode 100644 mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index 980307db315d9..8b4ac18fba527 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen MLIRMathToLLVM MLIRMathToLibm MLIRMathToROCDL + MLIRComplexToROCDL MLIROpenMPToLLVM MLIROpenACCDialect MLIRBuiltinToLLVMIRTranslation diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index a3de3ae9d116a..f721b6232b0fb 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -33,6 +33,7 @@ #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" @@ -4105,8 +4106,10 @@ class FIRToLLVMLowering // GPU library calls, the rest can be converted to LLVM intrinsics, which // is handled in the mathToLLVM conversion. The lowering to libm calls is // not needed since all math operations are handled this way. - if (isAMDGCN) + if (isAMDGCN) { mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); + mathConvertionPM.addPass(mlir::createConvertComplexToROCDL()); + } // Convert math::FPowI operations to inline implementation // only if the exponent's width is greater than 32, otherwise, diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h new file mode 100644 index 0000000000000..ed65be9980408 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h @@ -0,0 +1,19 @@ +#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ +#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Complex to ROCDL +/// calls. +void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit); +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index c9d2a54433736..67e8f5b99b67b 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" +#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index b496ee0114910..8ad2341f93a15 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> { let dependentDialects = ["func::FuncDialect"]; } +//===----------------------------------------------------------------------===// +// ComplexToROCDL +//===----------------------------------------------------------------------===// + +def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> { + let summary = "Convert Complex dialect to ROCDL calls"; + let description = [{ + This pass converts supported Complex ops to calls to the AMD device library. + }]; + let dependentDialects = ["func::FuncDialect"]; +} + //===----------------------------------------------------------------------===// // ComplexToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index e4b4974600577..4ad81553a4fa8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM) add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexCommon) add_subdirectory(ComplexToLibm) +add_subdirectory(ComplexToROCDL) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToSPIRV) add_subdirectory(ComplexToStandard) diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt new file mode 100644 index 0000000000000..54607250083d7 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRComplexToROCDL + ComplexToROCDL.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRComplexDialect + MLIRFuncDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp new file mode 100644 index 0000000000000..cdfe2a6dfe874 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp @@ -0,0 +1,95 @@ +//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" + +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include + +namespace mlir { +#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct FloatTypeResolver { + std::optional operator()(Type type) const { + auto elementType = cast(type); + if (!isa(elementType)) + return {}; + return elementType.getIntOrFloatBitWidth() == 64; + } +}; + +template +struct ScalarOpToROCDLCall : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc, + StringRef doubleFunc, PatternBenefit benefit) + : OpRewritePattern(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc) {} + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final { + auto module = SymbolTable::getNearestSymbolTable(op); + auto isDouble = TypeResolver()(op.getType()); + if (!isDouble.has_value()) + return failure(); + + auto name = *isDouble ? doubleFunc : floatFunc; + + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, name)); + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + auto funcTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = + rewriter.create(rewriter.getUnknownLoc(), name, funcTy); + opFunc.setPrivate(); + } + rewriter.replaceOpWithNewOp(op, name, op.getType(), + op->getOperands()); + return success(); + } + +private: + std::string floatFunc, doubleFunc; +}; +} // namespace + +void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add>( + patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit); +} + +namespace { +struct ConvertComplexToROCDLPass + : public impl::ConvertComplexToROCDLBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertComplexToROCDLPass::runOnOperation() { + auto module = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir new file mode 100644 index 0000000000000..618e9c238378c --- /dev/null +++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s + +// CHECK-DAG: @__ocml_cabs_f32(complex) -> f32 +// CHECK-DAG: @__ocml_cabs_f64(complex) -> f64 + +func.func @abs_caller(%f: complex, %d: complex) -> (f32, f64) { + // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]]) + %rf = complex.abs %f : complex + // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]]) + %rd = complex.abs %d : complex + // CHECK: return %[[RF]], %[[RD]] + return %rf, %rd : f32, f64 +} From a10e50ad76963f65d3b84ff152159755a7efdcfe Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Thu, 19 Jun 2025 17:38:09 +0100 Subject: [PATCH 2/5] Added license to header. --- .../mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h | 8 ++++++++ mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h index ed65be9980408..96d5a352f54c8 100644 --- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h +++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h @@ -1,3 +1,11 @@ +//===-- ComplexToROCDL.h - conversion from Complex to ROCDL calls ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ #define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp index cdfe2a6dfe874..7c7510b5c4e10 100644 --- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp +++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" - #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" From bb88939bc53d9e88973749c63ba21dfa26279258 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Fri, 20 Jun 2025 16:22:07 +0100 Subject: [PATCH 3/5] Address reviewer changes. Add conversion for complex.exp. --- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 14 ++--- .../ComplexToROCDL/ComplexToROCDL.h | 3 +- .../Conversion/ComplexToROCDL/CMakeLists.txt | 3 - .../ComplexToROCDL/ComplexToROCDL.cpp | 57 ++++++++++--------- .../ComplexToROCDL/complex-to-rocdl.mlir | 19 ++++++- 5 files changed, 54 insertions(+), 42 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index f721b6232b0fb..b8c7cba80d863 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -4099,7 +4099,7 @@ class FIRToLLVMLowering // conversions that affect the ModuleOp, e.g. create new // function operations in it. We have to run such conversions // as passes here. - mlir::OpPassManager mathConvertionPM("builtin.module"); + mlir::OpPassManager mathConversionPM("builtin.module"); bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN(); // If compiling for AMD target some math operations must be lowered to AMD @@ -4107,8 +4107,8 @@ class FIRToLLVMLowering // is handled in the mathToLLVM conversion. The lowering to libm calls is // not needed since all math operations are handled this way. if (isAMDGCN) { - mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); - mathConvertionPM.addPass(mlir::createConvertComplexToROCDL()); + mathConversionPM.addPass(mlir::createConvertMathToROCDL()); + mathConversionPM.addPass(mlir::createConvertComplexToROCDL()); } // Convert math::FPowI operations to inline implementation @@ -4116,15 +4116,15 @@ class FIRToLLVMLowering // it will be lowered to LLVM intrinsic operation by a later conversion. mlir::ConvertMathToFuncsOptions mathToFuncsOptions{}; mathToFuncsOptions.minWidthOfFPowIExponent = 33; - mathConvertionPM.addPass( + mathConversionPM.addPass( mlir::createConvertMathToFuncs(mathToFuncsOptions)); - mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass()); + mathConversionPM.addPass(mlir::createConvertComplexToStandardPass()); // Convert Math dialect operations into LLVM dialect operations. // There is no way to prefer MathToLLVM patterns over MathToLibm // patterns (applied below), so we have to run MathToLLVM conversion here. - mathConvertionPM.addNestedPass( + mathConversionPM.addNestedPass( mlir::createConvertMathToLLVMPass()); - if (mlir::failed(runPipeline(mathConvertionPM, mod))) + if (mlir::failed(runPipeline(mathConversionPM, mod))) return signalPassFailure(); std::optional dl = diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h index 96d5a352f54c8..eb785080adab3 100644 --- a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h +++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h @@ -20,8 +20,7 @@ class RewritePatternSet; /// Populate the given list with patterns that convert from Complex to ROCDL /// calls. -void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit); +void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns); } // namespace mlir #endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt index 54607250083d7..133809ac32f0f 100644 --- a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt @@ -7,9 +7,6 @@ add_mlir_conversion_library(MLIRComplexToROCDL DEPENDS MLIRConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC MLIRComplexDialect MLIRFuncDialect diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp index 7c7510b5c4e10..98adb9fb1f607 100644 --- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp +++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" -#include namespace mlir { #define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL @@ -21,36 +20,38 @@ namespace mlir { using namespace mlir; namespace { -struct FloatTypeResolver { - std::optional operator()(Type type) const { - auto elementType = cast(type); - if (!isa(elementType)) - return {}; - return elementType.getIntOrFloatBitWidth() == 64; - } -}; -template -struct ScalarOpToROCDLCall : public OpRewritePattern { +template +// Pattern to convert Complex ops to ROCDL function calls. +struct ComplexOpToROCDLCall : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc, PatternBenefit benefit) + ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc, + StringRef doubleFunc, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), floatFunc(floatFunc), doubleFunc(doubleFunc) {} LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final { - auto module = SymbolTable::getNearestSymbolTable(op); - auto isDouble = TypeResolver()(op.getType()); - if (!isDouble.has_value()) + Operation *symTable = SymbolTable::getNearestSymbolTable(op); + Type resType = op.getType(); + if (auto complexType = dyn_cast(resType)) + resType = complexType.getElementType(); + FloatType floatTy = dyn_cast(resType); + if (!floatTy) return failure(); - auto name = *isDouble ? doubleFunc : floatFunc; + StringRef name; + if (floatTy.isF64()) + name = doubleFunc; + else if (floatTy.isF32()) + name = floatFunc; + else + return failure(); auto opFunc = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(module, name)); + SymbolTable::lookupSymbolIn(symTable, name)); if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + rewriter.setInsertionPointToStart(&symTable->getRegion(0).front()); auto funcTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); opFunc = @@ -67,10 +68,12 @@ struct ScalarOpToROCDLCall : public OpRewritePattern { }; } // namespace -void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add>( - patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit); +void mlir::populateComplexToROCDLConversionPatterns( + RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64"); + patterns.add>( + patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64"); } namespace { @@ -81,14 +84,14 @@ struct ConvertComplexToROCDLPass } // namespace void ConvertComplexToROCDLPass::runOnOperation() { - auto module = getOperation(); + Operation *op = getOperation(); RewritePatternSet patterns(&getContext()); - populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1); + populateComplexToROCDLConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) + target.addIllegalOp(); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir index 618e9c238378c..23631e25e4588 100644 --- a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir +++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir @@ -1,13 +1,26 @@ -// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s +// RUN: mlir-opt %s -convert-complex-to-rocdl | FileCheck %s // CHECK-DAG: @__ocml_cabs_f32(complex) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex) -> f64 +// CHECK-DAG: @__ocml_cexp_f32(complex) -> complex +// CHECK-DAG: @__ocml_cexp_f64(complex) -> complex +//CHECK-LABEL: @abs_caller func.func @abs_caller(%f: complex, %d: complex) -> (f32, f64) { - // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]]) + // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}}) %rf = complex.abs %f : complex - // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]]) + // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}}) %rd = complex.abs %d : complex // CHECK: return %[[RF]], %[[RD]] return %rf, %rd : f32, f64 } + +//CHECK-LABEL: @exp_caller +func.func @exp_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) + %ef = complex.exp %f : complex + // CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}}) + %ed = complex.exp %d : complex + // CHECK: return %[[EF]], %[[ED]] + return %ef, %ed : complex, complex +} From d70bca22bae1d66f126b5a1832153810baf06ab4 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Fri, 20 Jun 2025 18:30:57 +0100 Subject: [PATCH 4/5] Correct alphabetical order for cmake. --- flang/lib/Optimizer/CodeGen/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index 8b4ac18fba527..de8e1c5c3fa3f 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -34,13 +34,13 @@ add_flang_library(FIRCodeGen MLIR_LIBS MLIRComplexToLLVM + MLIRComplexToROCDL MLIRComplexToStandard MLIRGPUDialect MLIRMathToFuncs MLIRMathToLLVM MLIRMathToLibm MLIRMathToROCDL - MLIRComplexToROCDL MLIROpenMPToLLVM MLIROpenACCDialect MLIRBuiltinToLLVMIRTranslation From 0e9aef2910fee55664f7bbfd54ec7bbb1ee383f3 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 24 Jun 2025 16:34:11 +0100 Subject: [PATCH 5/5] Add FloatTy as a template parameter. --- .../ComplexToROCDL/ComplexToROCDL.cpp | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp index 98adb9fb1f607..cfad9f5f6fa19 100644 --- a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp +++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp @@ -21,59 +21,53 @@ using namespace mlir; namespace { -template +template // Pattern to convert Complex ops to ROCDL function calls. struct ComplexOpToROCDLCall : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), floatFunc(floatFunc), - doubleFunc(doubleFunc) {} + ComplexOpToROCDLCall(MLIRContext *context, StringRef funcName, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), funcName(funcName) {} LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final { Operation *symTable = SymbolTable::getNearestSymbolTable(op); Type resType = op.getType(); if (auto complexType = dyn_cast(resType)) resType = complexType.getElementType(); - FloatType floatTy = dyn_cast(resType); - if (!floatTy) - return failure(); - - StringRef name; - if (floatTy.isF64()) - name = doubleFunc; - else if (floatTy.isF32()) - name = floatFunc; - else + if (!isa(resType)) return failure(); auto opFunc = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(symTable, name)); + SymbolTable::lookupSymbolIn(symTable, funcName)); if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&symTable->getRegion(0).front()); auto funcTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = - rewriter.create(rewriter.getUnknownLoc(), name, funcTy); + opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, + funcTy); opFunc.setPrivate(); } - rewriter.replaceOpWithNewOp(op, name, op.getType(), + rewriter.replaceOpWithNewOp(op, funcName, op.getType(), op->getOperands()); return success(); } private: - std::string floatFunc, doubleFunc; + std::string funcName; }; } // namespace void mlir::populateComplexToROCDLConversionPatterns( RewritePatternSet &patterns) { - patterns.add>( - patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64"); - patterns.add>( - patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64"); + patterns.add>( + patterns.getContext(), "__ocml_cabs_f32"); + patterns.add>( + patterns.getContext(), "__ocml_cabs_f64"); + patterns.add>( + patterns.getContext(), "__ocml_cexp_f32"); + patterns.add>( + patterns.getContext(), "__ocml_cexp_f64"); } namespace {