Skip to content

[mlir][amdgpu] Add rocdl.s.waitcnt wrapper #149670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 22, 2025
Merged

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Jul 19, 2025

The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly (similar to the asm format) and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering.

@llvmbot
Copy link
Member

llvmbot commented Jul 19, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: Ivan Butygin (Hardcode84)

Changes

The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly (similar to the asm format) and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering. Only gfx9 support added as part of this commit.


Full diff: https://github.com/llvm/llvm-project/pull/149670.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+20)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+49-3)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir (+20)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 80959ffbaf426..cecb936e18ae3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -717,6 +717,26 @@ def AMDGPU_SchedBarrierOp :
   }];
 }
 
+def AMDGPU_WaitcntOp :
+  AMDGPU_Op<"waitcnt">,
+  Arguments<(ins
+      OptionalAttr<I32Attr>:$vmcnt,
+      OptionalAttr<I32Attr>:$expcnt,
+      OptionalAttr<I32Attr>:$lgkmcnt
+    )>
+  {
+  let summary = "Wrapper on ROCDL SWaitcntOp";
+  let description = [{
+    Covenience wrapper on `rocdl.s.waitcnt`. Hides the architecture specific
+    bitpacking from user. Missing values will be assumed maximum values supported
+    by the architecture. Large values will also be clamped to the maximum
+    supported values.
+  }];
+  let assemblyFormat = [{
+    (`vmcnt` `(` $vmcnt^ `)` )? (`expcnt` `(` $expcnt^ `)` )? (`lgkmcnt` `(` $lgkmcnt^ `)`)? attr-dict
+  }];
+}
+
 def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
     "The possible permutations of the lanes storing B available in an MFMA",
     [
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ef35ee208f002..af588d5b70a45 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -419,6 +419,52 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
   }
 };
 
+// TODO: AMDGPU backend already have all this bitpacking logic, we should move
+// it to some common place.
+static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
+                                         unsigned expcnt, unsigned lgkmcnt) {
+  if (chipset.majorVersion == 9) {
+    vmcnt = std::min(63u, vmcnt);
+    expcnt = std::min(7u, expcnt);
+    lgkmcnt = std::min(15u, lgkmcnt);
+    unsigned lowBits = vmcnt & 0xF;
+    unsigned highBits = (vmcnt >> 4) << 14;
+    unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
+    return lowBits | highBits | otherCnts;
+  }
+  return failure();
+}
+
+struct WaitcntOpLowering : public ConvertOpToLLVMPattern<WaitcntOp> {
+  WaitcntOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<WaitcntOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(WaitcntOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto getVal = [](Attribute attr) -> unsigned {
+      if (attr)
+        return cast<IntegerAttr>(attr).getInt();
+
+      // This value will be clamped to the maximum value for the chipset.
+      return 1024 * 1024;
+    };
+    unsigned vmcnt = getVal(adaptor.getVmcntAttr());
+    unsigned expcnt = getVal(adaptor.getExpcntAttr());
+    unsigned lgkmcnt = getVal(adaptor.getLgkmcntAttr());
+
+    FailureOr<unsigned> waitcnt =
+        encodeWaitcnt(chipset, vmcnt, expcnt, lgkmcnt);
+    if (failed(waitcnt))
+      return op.emitOpError("unsupported chipset");
+
+    rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
+    return success();
+  }
+};
+
 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
   LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -1825,9 +1871,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicUminOp>,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
-           AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
-           MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
-           ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
+           AMDGPUDPPLowering, WaitcntOpLowering, LDSBarrierOpLowering,
+           SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
            PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
            TransposeLoadOpLowering>(converter, chipset);
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir b/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir
new file mode 100644
index 0000000000000..9c785670198ae
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
+// TODO: Add more chipsets support
+
+
+// CHECK-LABEL: func @waitcnt
+func.func @waitcnt() {
+  // GFX9: rocdl.s.waitcnt 53119
+  amdgpu.waitcnt
+
+  // GFX9: rocdl.s.waitcnt 3952
+  amdgpu.waitcnt vmcnt(0)
+
+  // GFX9: rocdl.s.waitcnt 53007
+  amdgpu.waitcnt expcnt(0)
+
+  // GFX9: rocdl.s.waitcnt 49279
+  amdgpu.waitcnt lgkmcnt(0)
+
+  return
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 5559ac8f1a5c3..b126b23cb8156 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -504,3 +504,16 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
   amdgpu.gather_to_lds %mem1[%idx1],        %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>,    memref<32x32xf16, #gpu.address_space<workgroup>>
   func.return
 }
+
+// CHECK-LABEL: func @waitcnt
+func.func @waitcnt() {
+  // CHECK: amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3)
+  // CHECK: amdgpu.waitcnt vmcnt(1)
+  // CHECK: amdgpu.waitcnt expcnt(2)
+  // CHECK: amdgpu.waitcnt lgkmcnt(3)
+  amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3)
+  amdgpu.waitcnt vmcnt(1)
+  amdgpu.waitcnt expcnt(2)
+  amdgpu.waitcnt lgkmcnt(3)
+  func.return
+}

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but let's wait for an approval from @krzysz00

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo assembly format nit

supported values.
}];
let assemblyFormat = [{
(`vmcnt` `(` $vmcnt^ `)` )? (`expcnt` `(` $expcnt^ `)` )? (`lgkmcnt` `(` $lgkmcnt^ `)`)? attr-dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we oilist this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I wasn't even aware of this feature.

lgkmcnt = std::min(63u, lgkmcnt);
return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
}
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gfx12 has an entirely different system, yes? Should we support that here, or somewhere else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it uses different instructions and I not very familiar with. So, the high level questions: can amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3) format be semantically converted to these new instructions? Should we take more generic name for amdgpu wrapper, not referencing specific instruction? CC @kuhar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, the gfx12 instructions are a more fine-grained version of pre-gfx12 ones, so that we can actually use the gfx12 categories and then combine them as needed for older architectures

But that feels like it could be future work.

(Though if you want to do it now, I might call it amdgpu.memory_counter_wait and give it loads, ds, store, sample, exp, and so on, following gfx12 and then packing into the relevant fields of s_waitcnt as needed

The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering.
Only gfx9 bitpacking support added as part of this commit.
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
@Hardcode84
Copy link
Contributor Author

@kuhar @krzysz00 I've updated op to be more generic, PTAL again

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor note, otherwise LGTM

unsigned store = getVal(adaptor.getStoreAttr());
unsigned exp = getVal(adaptor.getExpAttr());

unsigned vmcnt = std::min(load, store);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue for load + store here, since that'll be the total amount of outstanding VMEM operations

Copy link
Contributor Author

@Hardcode84 Hardcode84 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will work, e.g. if user specify only load(0) without store, waitcnt will get the maximum value and won't wait for anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. Yeah, I see. In which case, what you've done is fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Though I'd go with "if we've specified loads and stores, it's the sum", otherwise it's loads or stores

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It kind of broken either way )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think that since the vmcount is the count of outstanding loads and stores ... we do want the sum of outstanding loads and stores (unless the user didn't specify a number of loads or stores)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense

@krzysz00
Copy link
Contributor

LGTM

@Hardcode84 Hardcode84 merged commit 4977100 into llvm:main Jul 22, 2025
9 checks passed
@Hardcode84 Hardcode84 deleted the waitcnt branch July 22, 2025 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants