Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented")
kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing

# Does the target need to pass kernel arguments by value?
needs_byval(@nospecialize(job::CompilerJob)) = true
pass_by_value(@nospecialize(job::CompilerJob)) = true

# whether pointer is a valid call target
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false
Expand Down
279 changes: 270 additions & 9 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob))

# minimal required optimization
@tracepoint "rewrite" begin
if job.config.kernel && needs_byval(job)
if job.config.kernel && pass_by_value(job)
# pass all bitstypes by value; by default Julia passes aggregates by reference
# (this improves performance, and is mandated by certain back-ends like SPIR-V).
args = classify_arguments(job, function_type(entry))
Expand Down Expand Up @@ -256,10 +256,11 @@ end
## kernel promotion

@enum ArgumentCC begin
BITS_VALUE # bitstype, passed as value
BITS_REF # bitstype, passed as pointer
MUT_REF # jl_value_t*, or the anonymous equivalent
GHOST # not passed
BITS_VALUE # bitstype, passed as value
BITS_REF # bitstype, passed as pointer
MUT_REF # jl_value_t*, or the anonymous equivalent
GHOST # not passed
KERNEL_STATE # the kernel state argument
end

# Determine the calling convention of a the arguments of a Julia function, given the
Expand All @@ -270,7 +271,8 @@ end
# - `name`: the name of the argument
# - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
# is not passed at the LLVM level.
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType)
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
post_optimization::Bool=false)
source_sig = job.source.specTypes
source_types = [source_sig.parameters...]

Expand All @@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu

codegen_types = parameters(codegen_ft)

args = []
codegen_i = 1
for (source_i, (source_typ, source_name)) in enumerate(zip(source_types, source_argnames))
if post_optimization && kernel_state_type(job) !== Nothing
args = []
push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1))
codegen_i = 2
else
args = []
codegen_i = 1
end
for (source_typ, source_name) in zip(source_types, source_argnames)
if isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
push!(args, (cc=GHOST, typ=source_typ, name=source_name, idx=nothing))
continue
Expand Down Expand Up @@ -817,3 +825,256 @@ function kernel_state_value(state)
call_function(llvm_f, state)
end
end

# convert kernel state argument from pass-by-value to pass-by-reference
#
# the kernel state argument is always passed by value to avoid codegen issues with byval.
# some back-ends however do not support passing kernel arguments by value, so this pass
# serves to convert that argument (and is conceptually the inverse of `lower_byval`).
function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
f::LLVM.Function)
ft = function_type(f)

# check if we even need a kernel state argument
state = kernel_state_type(job)
if state === Nothing
return f
end

T_state = convert(LLVMType, state)

# find the kernel state parameter (should be the first argument)
if isempty(parameters(ft)) || value_type(parameters(f)[1]) != T_state
return f
end

@tracepoint "kernel state to reference" begin
# generate the new function type & definition
new_types = LLVM.LLVMType[]
# convert the first parameter (kernel state) to a pointer
push!(new_types, LLVM.PointerType(T_state))
# keep all other parameters as-is
for i in 2:length(parameters(ft))
push!(new_types, parameters(ft)[i])
end

new_ft = LLVM.FunctionType(return_type(ft), new_types)
new_f = LLVM.Function(mod, "", new_ft)
linkage!(new_f, linkage(f))

# name the parameters
LLVM.name!(parameters(new_f)[1], "state_ptr")
for (i, (arg, new_arg)) in enumerate(zip(parameters(f)[2:end], parameters(new_f)[2:end]))
LLVM.name!(new_arg, LLVM.name(arg))
end

# emit IR performing the "conversions"
new_args = LLVM.Value[]
@dispose builder=IRBuilder() begin
entry = BasicBlock(new_f, "conversion")
position!(builder, entry)

# load the kernel state value from the pointer
state_val = load!(builder, T_state, parameters(new_f)[1], "state")
push!(new_args, state_val)

# all other arguments are passed through directly
for i in 2:length(parameters(new_f))
push!(new_args, parameters(new_f)[i])
end

# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}(
param => new_args[i] for (i,param) in enumerate(parameters(f))
)
value_map[f] = new_f

clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)

# fall through
br!(builder, blocks(new_f)[2])
end

# set the attributes for the state pointer parameter
attrs = parameter_attributes(new_f, 1)
# the pointer itself cannot be captured since we immediately load from it
push!(attrs, EnumAttribute("nocapture", 0))
# each kernel state is separate
push!(attrs, EnumAttribute("noalias", 0))
# the state is read-only
push!(attrs, EnumAttribute("readonly", 0))

# remove the old function
fn = LLVM.name(f)
@assert isempty(uses(f))
replace_metadata_uses!(f, new_f)
erase!(f)
LLVM.name!(new_f, fn)

# minimal optimization
@dispose pb=NewPMPassBuilder() begin
add!(pb, SimplifyCFGPass())
run!(pb, new_f, llvm_machine(job.config.target))
end

return new_f
end
end

function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function, kernel_intrinsics::Dict)
entry_fn = LLVM.name(entry)

# figure out which intrinsics are used and need to be added as arguments
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
haskey(functions(mod), intr_fn)
end |> collect
nargs = length(used_intrinsics)

# determine which functions need these arguments
worklist = Set{LLVM.Function}([entry])
for intr_fn in used_intrinsics
push!(worklist, functions(mod)[intr_fn])
end
worklist_length = 0
while worklist_length != length(worklist)
# iteratively discover functions that use an intrinsic or any function calling it
worklist_length = length(worklist)
additions = LLVM.Function[]
for f in worklist, use in uses(f)
inst = user(use)::Instruction
bb = LLVM.parent(inst)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
end
for f in additions
push!(worklist, f)
end
end
for intr_fn in used_intrinsics
delete!(worklist, functions(mod)[intr_fn])
end

# add the arguments
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
workmap = Dict{LLVM.Function, LLVM.Function}()
for f in worklist
fn = LLVM.name(f)
ft = function_type(f)
LLVM.name!(f, fn * ".orig")
# create a new function
new_param_types = LLVMType[parameters(ft)...]

for intr_fn in used_intrinsics
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
push!(new_param_types, llvm_typ)
end
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
new_f = LLVM.Function(mod, fn, new_ft)
linkage!(new_f, linkage(f))
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_arg, LLVM.name(arg))
end
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
end

workmap[f] = new_f
end

# clone and rewrite the function bodies.
# we don't need to rewrite much as the arguments are added last.
for (f, new_f) in workmap
# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}()
for (param, new_param) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_param, LLVM.name(param))
value_map[param] = new_param
end

value_map[f] = new_f
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)

# we can't remove this function yet, as we might still need to rewrite any called,
# but remove the IR already
empty!(f)
end

# drop unused constants that may be referring to the old functions
# XXX: can we do this differently?
for f in worklist
prune_constexpr_uses!(f)
end

# update other uses of the old function, modifying call sites to pass the arguments
function rewrite_uses!(f, new_f)
# update uses
@dispose builder=IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
callee_f = LLVM.parent(LLVM.parent(val))
# forward the arguments
position!(builder, val)
new_val = if val isa LLVM.CallInst
call!(builder, function_type(new_f), new_f,
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
# XXX: why isn't this caught by the value materializer above?
target = operands(val)[1]
@assert target == f
new_val = LLVM.const_bitcast(new_f, value_type(val))
rewrite_uses!(val, new_val)
# we can't simply replace this constant expression, as it may be used
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)

# drop the old constant if it is unused
# XXX: can we do this differently?
if isempty(uses(val))
LLVM.unsafe_destroy!(val)
end
else
error("Cannot rewrite unknown use of function: $val")
end
end
end
end
for (f, new_f) in workmap
rewrite_uses!(f, new_f)
@assert isempty(uses(f))
erase!(f)
end

# replace uses of the intrinsics with references to the input arguments
for (i, intr_fn) in enumerate(used_intrinsics)
intr = functions(mod)[intr_fn]
for use in uses(intr)
val = user(use)
callee_f = LLVM.parent(LLVM.parent(val))
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
replace_uses!(val, parameters(callee_f)[end-nargs+i])
else
error("Cannot rewrite unknown use of function: $val")
end

@assert isempty(uses(val))
erase!(val)
end
@assert isempty(uses(intr))
erase!(intr)
end

return functions(mod)[entry_fn]
end
Loading