Skip to content

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Sep 3, 2025

We're currently passing the kernel state object by value, disregarding the typical Julia calling convention, because there's known issues with byval lowering on NVPTX.

For compatibility with back-ends that do not support passing kernel arguments by actual values, provide a pass that's conceptually the inverse of lower_byval, instead rewriting the kernel state object to be passed by reference, and loading from it at the beginning of the kernel.

Possible alternatives include:

  • always passing the object by reference, and relying on lower_byval to avoid the back-end's codegen issues. This is tricky because that function now only operates on the outermost kernel function, and making it so that it lowers byval everywhere (which is necessary to avoid alloca's on cases where the state is passed down, which is almost everywhere) is nontrivial
  • customizing the kernel state lowering to use a reference for the kernel, and a value for child functions

cc @simeonschaub

Copy link
Contributor

github-actions bot commented Sep 3, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/irgen.jl b/src/irgen.jl
index 2d19961..5a05c4a 100644
--- a/src/irgen.jl
+++ b/src/irgen.jl
@@ -271,8 +271,10 @@ end
 # - `name`: the name of the argument
 # - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
 #          is not passed at the LLVM level.
-function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
-                            post_optimization::Bool=false)
+function classify_arguments(
+        @nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
+        post_optimization::Bool = false
+    )
     source_sig = job.source.specTypes
     source_types = [source_sig.parameters...]
 
@@ -286,7 +288,7 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
 
     if post_optimization && kernel_state_type(job) !== Nothing
         args = []
-        push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1))
+        push!(args, (cc = KERNEL_STATE, typ = kernel_state_type(job), name = :kernel_state, idx = 1))
         codegen_i = 2
     else
         args = []
@@ -831,8 +833,10 @@ end
 # the kernel state argument is always passed by value to avoid codegen issues with byval.
 # some back-ends however do not support passing kernel arguments by value, so this pass
 # serves to convert that argument (and is conceptually the inverse of `lower_byval`).
-function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
-                                    f::LLVM.Function)
+function kernel_state_to_reference!(
+        @nospecialize(job::CompilerJob), mod::LLVM.Module,
+        f::LLVM.Function
+    )
     ft = function_type(f)
 
     # check if we even need a kernel state argument
@@ -870,7 +874,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
 
         # emit IR performing the "conversions"
         new_args = LLVM.Value[]
-        @dispose builder=IRBuilder() begin
+        @dispose builder = IRBuilder() begin
             entry = BasicBlock(new_f, "conversion")
             position!(builder, entry)
 
@@ -885,12 +889,14 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
 
             # map the arguments
             value_map = Dict{LLVM.Value, LLVM.Value}(
-                param => new_args[i] for (i,param) in enumerate(parameters(f))
+                param => new_args[i] for (i, param) in enumerate(parameters(f))
             )
             value_map[f] = new_f
 
-            clone_into!(new_f, f; value_map,
-                        changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
+            clone_into!(
+                new_f, f; value_map,
+                changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
+            )
 
             # fall through
             br!(builder, blocks(new_f)[2])
@@ -913,7 +919,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
         LLVM.name!(new_f, fn)
 
         # minimal optimization
-        @dispose pb=NewPMPassBuilder() begin
+        @dispose pb = NewPMPassBuilder() begin
             add!(pb, SimplifyCFGPass())
             run!(pb, new_f, llvm_machine(job.config.target))
         end
@@ -922,8 +928,10 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
     end
 end
 
-function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
-                              entry::LLVM.Function, kernel_intrinsics::Dict)
+function add_input_arguments!(
+        @nospecialize(job::CompilerJob), mod::LLVM.Module,
+        entry::LLVM.Function, kernel_intrinsics::Dict
+    )
     entry_fn = LLVM.name(entry)
 
     # figure out which intrinsics are used and need to be added as arguments
@@ -976,7 +984,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
         for (arg, new_arg) in zip(parameters(f), parameters(new_f))
             LLVM.name!(new_arg, LLVM.name(arg))
         end
-        for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
+        for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[(end - nargs + 1):end])
             LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
         end
 
@@ -994,8 +1002,10 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
         end
 
         value_map[f] = new_f
-        clone_into!(new_f, f; value_map,
-                    changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
+        clone_into!(
+            new_f, f; value_map,
+            changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly
+        )
 
         # we can't remove this function yet, as we might still need to rewrite any called,
         # but remove the IR already
@@ -1011,7 +1021,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
     # update other uses of the old function, modifying call sites to pass the arguments
     function rewrite_uses!(f, new_f)
         # update uses
-        @dispose builder=IRBuilder() begin
+        return @dispose builder = IRBuilder() begin
             for use in uses(f)
                 val = user(use)
                 if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
@@ -1019,9 +1029,11 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
                     # forward the arguments
                     position!(builder, val)
                     new_val = if val isa LLVM.CallInst
-                        call!(builder, function_type(new_f), new_f,
-                              [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
-                              operand_bundles(val))
+                        call!(
+                            builder, function_type(new_f), new_f,
+                            [arguments(val)..., parameters(callee_f)[(end - nargs + 1):end]...],
+                            operand_bundles(val)
+                        )
                     else
                         # TODO: invoke and callbr
                         error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
@@ -1064,7 +1076,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
             val = user(use)
             callee_f = LLVM.parent(LLVM.parent(val))
             if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
-                replace_uses!(val, parameters(callee_f)[end-nargs+i])
+                replace_uses!(val, parameters(callee_f)[end - nargs + i])
             else
                 error("Cannot rewrite unknown use of function: $val")
             end
diff --git a/src/metal.jl b/src/metal.jl
index 200af85..6f2e5c6 100644
--- a/src/metal.jl
+++ b/src/metal.jl
@@ -238,7 +238,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
 
     # find the byref parameters
     byref = BitVector(undef, length(parameters(ft)))
-    args = classify_arguments(job, ft; post_optimization=job.config.optimize)
+    args = classify_arguments(job, ft; post_optimization = job.config.optimize)
     filter!(args) do arg
         arg.cc != GHOST
     end
@@ -563,7 +563,7 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
     arg_infos = Metadata[]
 
     # Iterate through arguments and create metadata for them
-    args = classify_arguments(job, entry_ft; post_optimization=job.config.optimize)
+    args = classify_arguments(job, entry_ft; post_optimization = job.config.optimize)
     i = 1
     for arg in args
         arg.idx ===  nothing && continue
diff --git a/src/spirv.jl b/src/spirv.jl
index 8eea92c..b11210b 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -62,8 +62,10 @@ llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ?
 runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) =
     "spirv-" * String(job.config.target.backend)
 
-function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
-                        entry::LLVM.Function)
+function finish_module!(
+        job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
+        entry::LLVM.Function
+    )
     # update calling convention
     for f in functions(mod)
         # JuliaGPU/GPUCompiler.jl#97
@@ -90,8 +92,10 @@ function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module)
     return errors
 end
 
-function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
-                    entry::LLVM.Function)
+function finish_ir!(
+        job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
+        entry::LLVM.Function
+    )
     # convert the kernel state argument to a byval reference
     if job.config.kernel
         state = kernel_state_type(job)

@simeonschaub
Copy link
Member

simeonschaub commented Sep 10, 2025

This has been working great for OpenCL.jl, is this good to merge? The only thing that still needs to be addressed are the conflicts with #714, I wasn't sure how to handle that done

@simeonschaub simeonschaub force-pushed the tb/kernel_state_reference branch from dce63fc to e9ad136 Compare September 11, 2025 09:13
Copy link

codecov bot commented Sep 12, 2025

Codecov Report

❌ Patch coverage is 58.20106% with 79 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.45%. Comparing base (974755c) to head (e5b0b0b).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
src/irgen.jl 54.08% 73 Missing ⚠️
src/spirv.jl 77.77% 4 Missing ⚠️
src/metal.jl 81.81% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #715      +/-   ##
==========================================
- Coverage   76.12%   73.45%   -2.68%     
==========================================
  Files          24       24              
  Lines        3548     3613      +65     
==========================================
- Hits         2701     2654      -47     
- Misses        847      959     +112     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@simeonschaub simeonschaub force-pushed the tb/kernel_state_reference branch 2 times, most recently from 0a16b76 to 66ca19e Compare September 30, 2025 11:57
maleadt and others added 4 commits October 6, 2025 14:20
We're currently passing the kernel state object by value, disregarding the typical Julia calling convention, because there's known issues with `byval` lowering on NVPTX.

For compatibility with back-ends that do not support passing kernel arguments by actual values, provide a pass that's conceptually the inverse of `lower_byval`, instead rewriting the kernel state object to be passed by reference, and loading from it at the beginning of the kernel.
Allows other backends to pass additional hidden arguments that can be accessed through intrinsics. Required for OpenCL device-side RNG support, where additional shared memory must be passed as arguments to the kernel.
@maleadt maleadt force-pushed the tb/kernel_state_reference branch from 78d9e8a to e5b0b0b Compare October 6, 2025 12:20
@maleadt maleadt merged commit 2d2acf4 into master Oct 6, 2025
34 checks passed
@maleadt maleadt deleted the tb/kernel_state_reference branch October 6, 2025 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants