@@ -76,11 +76,6 @@ struct GPULaunchKernelConversion
76
76
mlir::LogicalResult
77
77
matchAndRewrite (mlir::gpu::LaunchFuncOp op, OpAdaptor adaptor,
78
78
mlir::ConversionPatternRewriter &rewriter) const override {
79
-
80
- if (op.hasClusterSize ()) {
81
- return mlir::failure ();
82
- }
83
-
84
79
mlir::Location loc = op.getLoc ();
85
80
auto *ctx = rewriter.getContext ();
86
81
mlir::ModuleOp mod = op->getParentOfType <mlir::ModuleOp>();
@@ -107,37 +102,65 @@ struct GPULaunchKernelConversion
107
102
rewriter.create <LLVM::AddressOfOp>(loc, ptrTy, kernel.getName ());
108
103
}
109
104
110
- auto funcOp = mod.lookupSymbol <mlir::LLVM::LLVMFuncOp>(
111
- RTNAME_STRING (CUFLaunchKernel));
112
-
113
105
auto llvmIntPtrType = mlir::IntegerType::get (
114
106
ctx, this ->getTypeConverter ()->getPointerBitwidth (0 ));
115
107
auto voidTy = mlir::LLVM::LLVMVoidType::get (ctx);
116
- auto funcTy = mlir::LLVM::LLVMFunctionType::get (
117
- voidTy,
118
- {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
119
- llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
120
- /* isVarArg=*/ false );
121
-
122
- auto cufLaunchKernel = mlir::SymbolRefAttr::get (
123
- mod.getContext (), RTNAME_STRING (CUFLaunchKernel));
124
- if (!funcOp) {
125
- mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
126
- rewriter.setInsertionPointToStart (mod.getBody ());
127
- auto launchKernelFuncOp = rewriter.create <mlir::LLVM::LLVMFuncOp>(
128
- loc, RTNAME_STRING (CUFLaunchKernel), funcTy);
129
- launchKernelFuncOp.setVisibility (mlir::SymbolTable::Visibility::Private);
130
- }
131
108
132
109
mlir::Value nullPtr = rewriter.create <LLVM::ZeroOp>(loc, ptrTy);
133
110
134
- rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
135
- op, funcTy, cufLaunchKernel,
136
- mlir::ValueRange{kernelPtr, adaptor.getGridSizeX (),
137
- adaptor.getGridSizeY (), adaptor.getGridSizeZ (),
138
- adaptor.getBlockSizeX (), adaptor.getBlockSizeY (),
139
- adaptor.getBlockSizeZ (), dynamicMemorySize, kernelArgs,
140
- nullPtr});
111
+ if (op.hasClusterSize ()) {
112
+ auto funcOp = mod.lookupSymbol <mlir::LLVM::LLVMFuncOp>(
113
+ RTNAME_STRING (CUFLaunchClusterKernel));
114
+ auto funcTy = mlir::LLVM::LLVMFunctionType::get (
115
+ voidTy,
116
+ {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
117
+ llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
118
+ llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
119
+ /* isVarArg=*/ false );
120
+ auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get (
121
+ mod.getContext (), RTNAME_STRING (CUFLaunchClusterKernel));
122
+ if (!funcOp) {
123
+ mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
124
+ rewriter.setInsertionPointToStart (mod.getBody ());
125
+ auto launchKernelFuncOp = rewriter.create <mlir::LLVM::LLVMFuncOp>(
126
+ loc, RTNAME_STRING (CUFLaunchClusterKernel), funcTy);
127
+ launchKernelFuncOp.setVisibility (
128
+ mlir::SymbolTable::Visibility::Private);
129
+ }
130
+ rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
131
+ op, funcTy, cufLaunchClusterKernel,
132
+ mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX (),
133
+ adaptor.getClusterSizeY (), adaptor.getClusterSizeZ (),
134
+ adaptor.getGridSizeX (), adaptor.getGridSizeY (),
135
+ adaptor.getGridSizeZ (), adaptor.getBlockSizeX (),
136
+ adaptor.getBlockSizeY (), adaptor.getBlockSizeZ (),
137
+ dynamicMemorySize, kernelArgs, nullPtr});
138
+ } else {
139
+ auto funcOp = mod.lookupSymbol <mlir::LLVM::LLVMFuncOp>(
140
+ RTNAME_STRING (CUFLaunchKernel));
141
+ auto funcTy = mlir::LLVM::LLVMFunctionType::get (
142
+ voidTy,
143
+ {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
144
+ llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
145
+ /* isVarArg=*/ false );
146
+ auto cufLaunchKernel = mlir::SymbolRefAttr::get (
147
+ mod.getContext (), RTNAME_STRING (CUFLaunchKernel));
148
+ if (!funcOp) {
149
+ mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
150
+ rewriter.setInsertionPointToStart (mod.getBody ());
151
+ auto launchKernelFuncOp = rewriter.create <mlir::LLVM::LLVMFuncOp>(
152
+ loc, RTNAME_STRING (CUFLaunchKernel), funcTy);
153
+ launchKernelFuncOp.setVisibility (
154
+ mlir::SymbolTable::Visibility::Private);
155
+ }
156
+ rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
157
+ op, funcTy, cufLaunchKernel,
158
+ mlir::ValueRange{kernelPtr, adaptor.getGridSizeX (),
159
+ adaptor.getGridSizeY (), adaptor.getGridSizeZ (),
160
+ adaptor.getBlockSizeX (), adaptor.getBlockSizeY (),
161
+ adaptor.getBlockSizeZ (), dynamicMemorySize,
162
+ kernelArgs, nullPtr});
163
+ }
141
164
142
165
return mlir::success ();
143
166
}
0 commit comments