-
Notifications
You must be signed in to change notification settings - Fork 57
Convert the kernel state back to a reference when needed. #715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/irgen.jl b/src/irgen.jl
index 2d19961..5a05c4a 100644
--- a/src/irgen.jl
+++ b/src/irgen.jl
@@ -271,8 +271,10 @@ end
# - `name`: the name of the argument
# - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
# is not passed at the LLVM level.
-function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
- post_optimization::Bool=false)
+function classify_arguments(
+ @nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
+ post_optimization::Bool = false
+ )
source_sig = job.source.specTypes
source_types = [source_sig.parameters...]
@@ -286,7 +288,7 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
if post_optimization && kernel_state_type(job) !== Nothing
args = []
- push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1))
+ push!(args, (cc = KERNEL_STATE, typ = kernel_state_type(job), name = :kernel_state, idx = 1))
codegen_i = 2
else
args = []
@@ -831,8 +833,10 @@ end
# the kernel state argument is always passed by value to avoid codegen issues with byval.
# some back-ends however do not support passing kernel arguments by value, so this pass
# serves to convert that argument (and is conceptually the inverse of `lower_byval`).
-function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
- f::LLVM.Function)
+function kernel_state_to_reference!(
+ @nospecialize(job::CompilerJob), mod::LLVM.Module,
+ f::LLVM.Function
+ )
ft = function_type(f)
# check if we even need a kernel state argument
@@ -870,7 +874,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
# emit IR performing the "conversions"
new_args = LLVM.Value[]
- @dispose builder=IRBuilder() begin
+ @dispose builder = IRBuilder() begin
entry = BasicBlock(new_f, "conversion")
position!(builder, entry)
@@ -885,12 +889,14 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}(
- param => new_args[i] for (i,param) in enumerate(parameters(f))
+ param => new_args[i] for (i, param) in enumerate(parameters(f))
)
value_map[f] = new_f
- clone_into!(new_f, f; value_map,
- changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
+ clone_into!(
+ new_f, f; value_map,
+ changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
+ )
# fall through
br!(builder, blocks(new_f)[2])
@@ -913,7 +919,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
LLVM.name!(new_f, fn)
# minimal optimization
- @dispose pb=NewPMPassBuilder() begin
+ @dispose pb = NewPMPassBuilder() begin
add!(pb, SimplifyCFGPass())
run!(pb, new_f, llvm_machine(job.config.target))
end
@@ -922,8 +928,10 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
end
end
-function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
- entry::LLVM.Function, kernel_intrinsics::Dict)
+function add_input_arguments!(
+ @nospecialize(job::CompilerJob), mod::LLVM.Module,
+ entry::LLVM.Function, kernel_intrinsics::Dict
+ )
entry_fn = LLVM.name(entry)
# figure out which intrinsics are used and need to be added as arguments
@@ -976,7 +984,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_arg, LLVM.name(arg))
end
- for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
+ for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[(end - nargs + 1):end])
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
end
@@ -994,8 +1002,10 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
end
value_map[f] = new_f
- clone_into!(new_f, f; value_map,
- changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
+ clone_into!(
+ new_f, f; value_map,
+ changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly
+ )
# we can't remove this function yet, as we might still need to rewrite any called,
# but remove the IR already
@@ -1011,7 +1021,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
# update other uses of the old function, modifying call sites to pass the arguments
function rewrite_uses!(f, new_f)
# update uses
- @dispose builder=IRBuilder() begin
+ return @dispose builder = IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
@@ -1019,9 +1029,11 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
# forward the arguments
position!(builder, val)
new_val = if val isa LLVM.CallInst
- call!(builder, function_type(new_f), new_f,
- [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
- operand_bundles(val))
+ call!(
+ builder, function_type(new_f), new_f,
+ [arguments(val)..., parameters(callee_f)[(end - nargs + 1):end]...],
+ operand_bundles(val)
+ )
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
@@ -1064,7 +1076,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
val = user(use)
callee_f = LLVM.parent(LLVM.parent(val))
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
- replace_uses!(val, parameters(callee_f)[end-nargs+i])
+ replace_uses!(val, parameters(callee_f)[end - nargs + i])
else
error("Cannot rewrite unknown use of function: $val")
end
diff --git a/src/metal.jl b/src/metal.jl
index 200af85..6f2e5c6 100644
--- a/src/metal.jl
+++ b/src/metal.jl
@@ -238,7 +238,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
# find the byref parameters
byref = BitVector(undef, length(parameters(ft)))
- args = classify_arguments(job, ft; post_optimization=job.config.optimize)
+ args = classify_arguments(job, ft; post_optimization = job.config.optimize)
filter!(args) do arg
arg.cc != GHOST
end
@@ -563,7 +563,7 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
arg_infos = Metadata[]
# Iterate through arguments and create metadata for them
- args = classify_arguments(job, entry_ft; post_optimization=job.config.optimize)
+ args = classify_arguments(job, entry_ft; post_optimization = job.config.optimize)
i = 1
for arg in args
arg.idx === nothing && continue
diff --git a/src/spirv.jl b/src/spirv.jl
index 8eea92c..b11210b 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -62,8 +62,10 @@ llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ?
runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) =
"spirv-" * String(job.config.target.backend)
-function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
- entry::LLVM.Function)
+function finish_module!(
+ job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
+ entry::LLVM.Function
+ )
# update calling convention
for f in functions(mod)
# JuliaGPU/GPUCompiler.jl#97
@@ -90,8 +92,10 @@ function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module)
return errors
end
-function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
- entry::LLVM.Function)
+function finish_ir!(
+ job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
+ entry::LLVM.Function
+ )
# convert the kernel state argument to a byval reference
if job.config.kernel
state = kernel_state_type(job) |
This has been working great for OpenCL.jl, is this good to merge? |
dce63fc
to
e9ad136
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #715 +/- ##
==========================================
- Coverage 76.12% 73.45% -2.68%
==========================================
Files 24 24
Lines 3548 3613 +65
==========================================
- Hits 2701 2654 -47
- Misses 847 959 +112 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
0a16b76
to
66ca19e
Compare
We're currently passing the kernel state object by value, disregarding the typical Julia calling convention, because there's known issues with `byval` lowering on NVPTX. For compatibility with back-ends that do not support passing kernel arguments by actual values, provide a pass that's conceptually the inverse of `lower_byval`, instead rewriting the kernel state object to be passed by reference, and loading from it at the beginning of the kernel.
Allows other backends to pass additional hidden arguments that can be accessed through intrinsics. Required for OpenCL device-side RNG support, where additional shared memory must be passed as arguments to the kernel.
78d9e8a
to
e5b0b0b
Compare
We're currently passing the kernel state object by value, disregarding the typical Julia calling convention, because there's known issues with
byval
lowering on NVPTX.For compatibility with back-ends that do not support passing kernel arguments by actual values, provide a pass that's conceptually the inverse of
lower_byval
, instead rewriting the kernel state object to be passed by reference, and loading from it at the beginning of the kernel.Possible alternatives include:
lower_byval
to avoid the back-end's codegen issues. This is tricky because that function now only operates on the outermost kernel function, and making it so that it lowers byval everywhere (which is necessary to avoidalloca
's on cases where the state is passed down, which is almost everywhere) is nontrivialcc @simeonschaub