Skip to content

[MLIR] Add ComplexTOROCDL pass #144926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented Jun 19, 2025

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.
@TIFitis TIFitis requested review from krzysz00, jsjodin and Copilot June 19, 2025 16:37
@llvmbot llvmbot added mlir flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Jun 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: Akash Banerjee (TIFitis)

Changes

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.


Full diff: https://github.com/llvm/llvm-project/pull/144926.diff

9 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-1)
  • (added) mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h (+19)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+12)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt (+18)
  • (added) mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp (+95)
  • (added) mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir (+13)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIRMathToROCDL
+  MLIRComplexToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
-    if (isAMDGCN)
+    if (isAMDGCN) {
       mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+    }
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
   let dependentDialects = ["func::FuncDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL calls";
+  let description = [{
+    This pass converts supported Complex ops to calls to the AMD device library.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+  ComplexToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+  std::optional<bool> operator()(Type type) const {
+    auto elementType = cast<FloatType>(type);
+    if (!isa<Float32Type, Float64Type>(elementType))
+      return {};
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+                      StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc) {}
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+    auto module = SymbolTable::getNearestSymbolTable(op);
+    auto isDouble = TypeResolver()(op.getType());
+    if (!isDouble.has_value())
+      return failure();
+
+    auto name = *isDouble ? doubleFunc : floatFunc;
+
+    auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+        SymbolTable::lookupSymbolIn(module, name));
+    if (!opFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+      auto funcTy = FunctionType::get(
+          rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+      opFunc =
+          rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+      opFunc.setPrivate();
+    }
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+                                              op->getOperands());
+    return success();
+  }
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+    : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::AbsOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+  %rf = complex.abs %f : complex<f32>
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+  %rd = complex.abs %d : complex<f64>
+  // CHECK: return %[[RF]], %[[RD]]
+  return %rf, %rd : f32, f64
+}

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new MLIR pass to convert complex.abs operations into corresponding ROCDL library calls (i.e. __ocml_cabs_f32/__ocml_cabs_f64), along with the necessary test and build system support.

  • Added a test to verify proper lowering in mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
  • Implemented the conversion pass in mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp and integrated it in the MLIR and Flang build systems
  • Updated CMake configurations and pass declarations to support the new pass

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir Adds a test verifying the conversion of complex.abs to ROCDL calls
mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp Implements the conversion pass from Complex dialect to ROCDL calls
mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt Configures build for the new conversion library
mlir/lib/Conversion/CMakeLists.txt Registers the ComplexToROCDL subdirectory
mlir/include/mlir/Conversion/Passes.td Adds the pass definitions and documentation for ComplexToROCDL
mlir/include/mlir/Conversion/Passes.h Updates header inclusions to support the new pass
mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h Declares the conversion population function
flang/lib/Optimizer/CodeGen/CodeGen.cpp Integrates the new conversion pass into the code generation pipeline for AMDGCN targets
flang/lib/Optimizer/CodeGen/CMakeLists.txt Adds MLIRComplexToROCDL as a dependency for code generation

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir

Author: Akash Banerjee (TIFitis)

Changes

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.


Full diff: https://github.com/llvm/llvm-project/pull/144926.diff

9 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-1)
  • (added) mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h (+19)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+12)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt (+18)
  • (added) mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp (+95)
  • (added) mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir (+13)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIRMathToROCDL
+  MLIRComplexToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
-    if (isAMDGCN)
+    if (isAMDGCN) {
       mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+    }
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
   let dependentDialects = ["func::FuncDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL calls";
+  let description = [{
+    This pass converts supported Complex ops to calls to the AMD device library.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+  ComplexToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+  std::optional<bool> operator()(Type type) const {
+    auto elementType = cast<FloatType>(type);
+    if (!isa<Float32Type, Float64Type>(elementType))
+      return {};
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+                      StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc) {}
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+    auto module = SymbolTable::getNearestSymbolTable(op);
+    auto isDouble = TypeResolver()(op.getType());
+    if (!isDouble.has_value())
+      return failure();
+
+    auto name = *isDouble ? doubleFunc : floatFunc;
+
+    auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+        SymbolTable::lookupSymbolIn(module, name));
+    if (!opFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+      auto funcTy = FunctionType::get(
+          rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+      opFunc =
+          rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+      opFunc.setPrivate();
+    }
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+                                              op->getOperands());
+    return success();
+  }
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+    : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::AbsOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+  %rf = complex.abs %f : complex<f32>
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+  %rd = complex.abs %d : complex<f64>
+  // CHECK: return %[[RF]], %[[RD]]
+  return %rf, %rd : f32, f64
+}

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the pass should clearly indicate what it does: convert operations to function calls. The current name suggests that it converts operations to operations from the ROCDL dialect.

using namespace mlir;

namespace {
struct FloatTypeResolver {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document top-level entities.

@ftynse
Copy link
Member

ftynse commented Jun 19, 2025

Also note that the NVGPU pipeline is doing something similar and it makes sense to align with that and reuse utilities like those from mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will note that MathToROCDL is already the "patterns for calling functions from the ROCm device libraries" pass, so ComplexToROCDL is at least consistent with existing practice.


void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a pattern parametrized by complex:: op and element type that takes one argument - the function name to rewrite to.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry but I did't understand this. Both f64 and f32 are FloatType so we can't use template to differentiate them right?

@ftynse
Copy link
Member

ftynse commented Jun 20, 2025

I will note that MathToROCDL is already the "patterns for calling functions from the ROCm device libraries" pass, so ComplexToROCDL is at least consistent with existing practice.

At which point, we could merge this into that and rename it to something more obvious.

@TIFitis
Copy link
Member Author

TIFitis commented Jun 20, 2025

I will note that MathToROCDL is already the "patterns for calling functions from the ROCm device libraries" pass, so ComplexToROCDL is at least consistent with existing practice.

At which point, we could merge this into that and rename it to something more obvious.

Hi @ftynse,
Could you please clarify what you’d like me to do? Are you suggesting that ComplexToROCDL be renamed to something like ComplexToROCDLCall? Or are you proposing that we merge ComplexToROCDL and MathToROCDL into a single pass? If the latter, do you have a preferred name in mind for the merged pass?

Thanks!

@TIFitis
Copy link
Member Author

TIFitis commented Jun 20, 2025

Thanks @ftynse and @krzysz00 for the review, I've addressed the changes requested.

I haven't changed the pass name yet, I'll do so if we decide to keep the pass separate from MathToROCDL (does ComplexToROCDLCalls sound alright?). On that note, regarding merging this pass to MathToROCDL or making use of mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h, I hit the assert below. I'm not entirely sure why that assert is needed but for now it doesn't seem compatible with the Complex op. Let me know what are your thoughts.

Thanks.

assert((op->getResultTypes().front() == op->getOperand(0).getType() ||

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants