Skip to content

Commit e79e08e

Browse files
authored
[Backend] Make sure membar works on warpgroup partitions (triton-lang#6441)
Alias analysis was not propagating shared memory aliases through warp specialize captures into partition regions. This meant that Membar was not actually analyzing shared memory accesses within partitions and inserting barriers. This can sometimes cause kernels to hang if they don't synchronize on mbarrier waits, for example.
1 parent a0cc214 commit e79e08e

File tree

5 files changed

+73
-2
lines changed

5 files changed

+73
-2
lines changed

include/triton/Analysis/Alias.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class SharedMemoryAliasAnalysis
8989
visitOperation(Operation *op,
9090
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
9191
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
92+
93+
void visitNonControlFlowArguments(
94+
Operation *op, const RegionSuccessor &successor,
95+
ArrayRef<dataflow::Lattice<AliasInfo> *> argLattices,
96+
unsigned firstIndex) override;
9297
};
9398

9499
} // namespace mlir

lib/Analysis/Alias.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,30 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation(
5858
return success();
5959
}
6060

61+
void SharedMemoryAliasAnalysis::visitNonControlFlowArguments(
62+
Operation *op, const RegionSuccessor &successor,
63+
ArrayRef<dataflow::Lattice<AliasInfo> *> argLattices, unsigned firstIndex) {
64+
auto wsOp = dyn_cast<triton::gpu::WarpSpecializePartitionsOp>(op);
65+
if (!wsOp) {
66+
setAllToEntryStates(argLattices.take_front(firstIndex));
67+
setAllToEntryStates(argLattices.drop_front(
68+
firstIndex + successor.getSuccessorInputs().size()));
69+
return;
70+
}
71+
72+
// Propagate aliases from the parent operation's operands to the block
73+
// arguments.
74+
assert(!successor.isParent());
75+
ProgramPoint *point = getProgramPointAfter(wsOp);
76+
77+
for (auto [capture, argLattice] :
78+
llvm::zip(wsOp.getParentOp().getExplicitCaptures(), argLattices)) {
79+
propagateIfChanged(
80+
argLattice,
81+
argLattice->join(getLatticeElementFor(point, capture)->getValue()));
82+
}
83+
}
84+
6185
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
6286
// TODO: implement
6387
return AliasResult::MayAlias;

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ class AllocationAnalysis {
332332
solver->load<SharedMemoryAliasAnalysis>();
333333
// Run the analysis rooted at every isolated from above operation, including
334334
// the top-level function but also any nested regions.
335-
operation->walk([&](Operation *op) {
335+
operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
336336
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
337337
failed(solver->initializeAndRun(op))) {
338338
// TODO: return error instead of bailing out..

test/Analysis/test-membar.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,3 +1000,45 @@ module attributes {"ttg.num-warps" = 4 : i32} {
10001000
tt.return
10011001
}
10021002
}
1003+
1004+
// -----
1005+
1006+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
1007+
1008+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
1009+
1010+
// CHECK-LABEL: @membar_alias_through_warp_specialize
1011+
tt.func @membar_alias_through_warp_specialize() {
1012+
%0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1013+
ttg.warp_specialize(%0)
1014+
default {
1015+
ttg.warp_yield
1016+
}
1017+
// CHECK: partition0
1018+
partition0(%arg0: !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>) num_warps(2) {
1019+
%c0 = arith.constant 0 : i32
1020+
%1 = ttg.memdesc_subview %arg0[%c0, %c0] : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1021+
%c = arith.constant dense<0.0> : tensor<16x16xf16>
1022+
// CHECK: local_store
1023+
ttg.local_store %c, %1 : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1024+
// CHECK-NEXT: gpu.barrier
1025+
// CHECK-NEXT: local_store
1026+
ttg.local_store %c, %1 : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1027+
ttg.warp_return
1028+
}
1029+
// CHECK: partition1
1030+
partition1(%arg0: !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>) num_warps(2) {
1031+
%c0 = arith.constant 0 : i32
1032+
%1 = ttg.memdesc_subview %arg0[%c0, %c0] : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1033+
%c = arith.constant dense<0.0> : tensor<16x16xf16>
1034+
// CHECK: local_store
1035+
ttg.local_store %c, %1 : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1036+
// CHECK-NEXT: gpu.barrier
1037+
// CHECK-NEXT: local_store
1038+
ttg.local_store %c, %1 : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>
1039+
ttg.warp_return
1040+
} : (!ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>) -> ()
1041+
tt.return
1042+
}
1043+
1044+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
275275

276276
// Rewrite all warp returns.
277277
partition->walk([&](WarpReturnOp op) {
278-
b.setInsertionPoint(op);
278+
TritonLLVMIRRewriter b(op.getLoc(), op);
279279
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
280280
/*aligned=*/false);
281281
b.replaceOpWithNewOp<LLVM::BrOp>(op, switchLoop);

0 commit comments

Comments
 (0)