diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index 980307db315d9..de8e1c5c3fa3f 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -34,6 +34,7 @@ add_flang_library(FIRCodeGen MLIR_LIBS MLIRComplexToLLVM + MLIRComplexToROCDL MLIRComplexToStandard MLIRGPUDialect MLIRMathToFuncs diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index a3de3ae9d116a..b8c7cba80d863 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -33,6 +33,7 @@ #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" @@ -4098,30 +4099,32 @@ class FIRToLLVMLowering // conversions that affect the ModuleOp, e.g. create new // function operations in it. We have to run such conversions // as passes here. - mlir::OpPassManager mathConvertionPM("builtin.module"); + mlir::OpPassManager mathConversionPM("builtin.module"); bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN(); // If compiling for AMD target some math operations must be lowered to AMD // GPU library calls, the rest can be converted to LLVM intrinsics, which // is handled in the mathToLLVM conversion. The lowering to libm calls is // not needed since all math operations are handled this way. - if (isAMDGCN) - mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); + if (isAMDGCN) { + mathConversionPM.addPass(mlir::createConvertMathToROCDL()); + mathConversionPM.addPass(mlir::createConvertComplexToROCDL()); + } // Convert math::FPowI operations to inline implementation // only if the exponent's width is greater than 32, otherwise, // it will be lowered to LLVM intrinsic operation by a later conversion. mlir::ConvertMathToFuncsOptions mathToFuncsOptions{}; mathToFuncsOptions.minWidthOfFPowIExponent = 33; - mathConvertionPM.addPass( + mathConversionPM.addPass( mlir::createConvertMathToFuncs(mathToFuncsOptions)); - mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass()); + mathConversionPM.addPass(mlir::createConvertComplexToStandardPass()); // Convert Math dialect operations into LLVM dialect operations. // There is no way to prefer MathToLLVM patterns over MathToLibm // patterns (applied below), so we have to run MathToLLVM conversion here. - mathConvertionPM.addNestedPass( + mathConversionPM.addNestedPass( mlir::createConvertMathToLLVMPass()); - if (mlir::failed(runPipeline(mathConvertionPM, mod))) + if (mlir::failed(runPipeline(mathConversionPM, mod))) return signalPassFailure(); std::optional dl = diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h new file mode 100644 index 0000000000000..eb785080adab3 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h @@ -0,0 +1,26 @@ +//===-- ComplexToROCDL.h - conversion from Complex to ROCDL calls ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ +#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Complex to ROCDL +/// calls. +void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index c9d2a54433736..67e8f5b99b67b 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" +#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index b496ee0114910..8ad2341f93a15 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> { let dependentDialects = ["func::FuncDialect"]; } +//===----------------------------------------------------------------------===// +// ComplexToROCDL +//===----------------------------------------------------------------------===// + +def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> { + let summary = "Convert Complex dialect to ROCDL calls"; + let description = [{ + This pass converts supported Complex ops to calls to the AMD device library. + }]; + let dependentDialects = ["func::FuncDialect"]; +} + //===----------------------------------------------------------------------===// // ComplexToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index e4b4974600577..4ad81553a4fa8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM) add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexCommon) add_subdirectory(ComplexToLibm) +add_subdirectory(ComplexToROCDL) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToSPIRV) add_subdirectory(ComplexToStandard) diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt new file mode 100644 index 0000000000000..133809ac32f0f --- /dev/null +++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRComplexToROCDL + ComplexToROCDL.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRComplexDialect + MLIRFuncDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp new file mode 100644 index 0000000000000..cfad9f5f6fa19 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp @@ -0,0 +1,91 @@ +//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +template +// Pattern to convert Complex ops to ROCDL function calls. +struct ComplexOpToROCDLCall : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ComplexOpToROCDLCall(MLIRContext *context, StringRef funcName, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), funcName(funcName) {} + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final { + Operation *symTable = SymbolTable::getNearestSymbolTable(op); + Type resType = op.getType(); + if (auto complexType = dyn_cast(resType)) + resType = complexType.getElementType(); + if (!isa(resType)) + return failure(); + + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symTable, funcName)); + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&symTable->getRegion(0).front()); + auto funcTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, + funcTy); + opFunc.setPrivate(); + } + rewriter.replaceOpWithNewOp(op, funcName, op.getType(), + op->getOperands()); + return success(); + } + +private: + std::string funcName; +}; +} // namespace + +void mlir::populateComplexToROCDLConversionPatterns( + RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext(), "__ocml_cabs_f32"); + patterns.add>( + patterns.getContext(), "__ocml_cabs_f64"); + patterns.add>( + patterns.getContext(), "__ocml_cexp_f32"); + patterns.add>( + patterns.getContext(), "__ocml_cexp_f64"); +} + +namespace { +struct ConvertComplexToROCDLPass + : public impl::ConvertComplexToROCDLBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertComplexToROCDLPass::runOnOperation() { + Operation *op = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateComplexToROCDLConversionPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir new file mode 100644 index 0000000000000..23631e25e4588 --- /dev/null +++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -convert-complex-to-rocdl | FileCheck %s + +// CHECK-DAG: @__ocml_cabs_f32(complex) -> f32 +// CHECK-DAG: @__ocml_cabs_f64(complex) -> f64 +// CHECK-DAG: @__ocml_cexp_f32(complex) -> complex +// CHECK-DAG: @__ocml_cexp_f64(complex) -> complex + +//CHECK-LABEL: @abs_caller +func.func @abs_caller(%f: complex, %d: complex) -> (f32, f64) { + // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}}) + %rf = complex.abs %f : complex + // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}}) + %rd = complex.abs %d : complex + // CHECK: return %[[RF]], %[[RD]] + return %rf, %rd : f32, f64 +} + +//CHECK-LABEL: @exp_caller +func.func @exp_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) + %ef = complex.exp %f : complex + // CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}}) + %ed = complex.exp %d : complex + // CHECK: return %[[EF]], %[[ED]] + return %ef, %ed : complex, complex +}