diff --git a/src/interface.jl b/src/interface.jl index 0ad0c9f6..157f1977 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -267,7 +267,7 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented") kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing # Does the target need to pass kernel arguments by value? -needs_byval(@nospecialize(job::CompilerJob)) = true +pass_by_value(@nospecialize(job::CompilerJob)) = true # whether pointer is a valid call target valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false diff --git a/src/irgen.jl b/src/irgen.jl index 15e22ef7..2d19961a 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob)) # minimal required optimization @tracepoint "rewrite" begin - if job.config.kernel && needs_byval(job) + if job.config.kernel && pass_by_value(job) # pass all bitstypes by value; by default Julia passes aggregates by reference # (this improves performance, and is mandated by certain back-ends like SPIR-V). args = classify_arguments(job, function_type(entry)) @@ -256,10 +256,11 @@ end ## kernel promotion @enum ArgumentCC begin - BITS_VALUE # bitstype, passed as value - BITS_REF # bitstype, passed as pointer - MUT_REF # jl_value_t*, or the anonymous equivalent - GHOST # not passed + BITS_VALUE # bitstype, passed as value + BITS_REF # bitstype, passed as pointer + MUT_REF # jl_value_t*, or the anonymous equivalent + GHOST # not passed + KERNEL_STATE # the kernel state argument end # Determine the calling convention of a the arguments of a Julia function, given the @@ -270,7 +271,8 @@ end # - `name`: the name of the argument # - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument # is not passed at the LLVM level. -function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType) +function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType; + post_optimization::Bool=false) source_sig = job.source.specTypes source_types = [source_sig.parameters...] @@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu codegen_types = parameters(codegen_ft) - args = [] - codegen_i = 1 - for (source_i, (source_typ, source_name)) in enumerate(zip(source_types, source_argnames)) + if post_optimization && kernel_state_type(job) !== Nothing + args = [] + push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1)) + codegen_i = 2 + else + args = [] + codegen_i = 1 + end + for (source_typ, source_name) in zip(source_types, source_argnames) if isghosttype(source_typ) || Core.Compiler.isconstType(source_typ) push!(args, (cc=GHOST, typ=source_typ, name=source_name, idx=nothing)) continue @@ -817,3 +825,256 @@ function kernel_state_value(state) call_function(llvm_f, state) end end + +# convert kernel state argument from pass-by-value to pass-by-reference +# +# the kernel state argument is always passed by value to avoid codegen issues with byval. +# some back-ends however do not support passing kernel arguments by value, so this pass +# serves to convert that argument (and is conceptually the inverse of `lower_byval`). +function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, + f::LLVM.Function) + ft = function_type(f) + + # check if we even need a kernel state argument + state = kernel_state_type(job) + if state === Nothing + return f + end + + T_state = convert(LLVMType, state) + + # find the kernel state parameter (should be the first argument) + if isempty(parameters(ft)) || value_type(parameters(f)[1]) != T_state + return f + end + + @tracepoint "kernel state to reference" begin + # generate the new function type & definition + new_types = LLVM.LLVMType[] + # convert the first parameter (kernel state) to a pointer + push!(new_types, LLVM.PointerType(T_state)) + # keep all other parameters as-is + for i in 2:length(parameters(ft)) + push!(new_types, parameters(ft)[i]) + end + + new_ft = LLVM.FunctionType(return_type(ft), new_types) + new_f = LLVM.Function(mod, "", new_ft) + linkage!(new_f, linkage(f)) + + # name the parameters + LLVM.name!(parameters(new_f)[1], "state_ptr") + for (i, (arg, new_arg)) in enumerate(zip(parameters(f)[2:end], parameters(new_f)[2:end])) + LLVM.name!(new_arg, LLVM.name(arg)) + end + + # emit IR performing the "conversions" + new_args = LLVM.Value[] + @dispose builder=IRBuilder() begin + entry = BasicBlock(new_f, "conversion") + position!(builder, entry) + + # load the kernel state value from the pointer + state_val = load!(builder, T_state, parameters(new_f)[1], "state") + push!(new_args, state_val) + + # all other arguments are passed through directly + for i in 2:length(parameters(new_f)) + push!(new_args, parameters(new_f)[i]) + end + + # map the arguments + value_map = Dict{LLVM.Value, LLVM.Value}( + param => new_args[i] for (i,param) in enumerate(parameters(f)) + ) + value_map[f] = new_f + + clone_into!(new_f, f; value_map, + changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + + # fall through + br!(builder, blocks(new_f)[2]) + end + + # set the attributes for the state pointer parameter + attrs = parameter_attributes(new_f, 1) + # the pointer itself cannot be captured since we immediately load from it + push!(attrs, EnumAttribute("nocapture", 0)) + # each kernel state is separate + push!(attrs, EnumAttribute("noalias", 0)) + # the state is read-only + push!(attrs, EnumAttribute("readonly", 0)) + + # remove the old function + fn = LLVM.name(f) + @assert isempty(uses(f)) + replace_metadata_uses!(f, new_f) + erase!(f) + LLVM.name!(new_f, fn) + + # minimal optimization + @dispose pb=NewPMPassBuilder() begin + add!(pb, SimplifyCFGPass()) + run!(pb, new_f, llvm_machine(job.config.target)) + end + + return new_f + end +end + +function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, + entry::LLVM.Function, kernel_intrinsics::Dict) + entry_fn = LLVM.name(entry) + + # figure out which intrinsics are used and need to be added as arguments + used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn + haskey(functions(mod), intr_fn) + end |> collect + nargs = length(used_intrinsics) + + # determine which functions need these arguments + worklist = Set{LLVM.Function}([entry]) + for intr_fn in used_intrinsics + push!(worklist, functions(mod)[intr_fn]) + end + worklist_length = 0 + while worklist_length != length(worklist) + # iteratively discover functions that use an intrinsic or any function calling it + worklist_length = length(worklist) + additions = LLVM.Function[] + for f in worklist, use in uses(f) + inst = user(use)::Instruction + bb = LLVM.parent(inst) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) + end + for f in additions + push!(worklist, f) + end + end + for intr_fn in used_intrinsics + delete!(worklist, functions(mod)[intr_fn]) + end + + # add the arguments + # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt + workmap = Dict{LLVM.Function, LLVM.Function}() + for f in worklist + fn = LLVM.name(f) + ft = function_type(f) + LLVM.name!(f, fn * ".orig") + # create a new function + new_param_types = LLVMType[parameters(ft)...] + + for intr_fn in used_intrinsics + llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ) + push!(new_param_types, llvm_typ) + end + new_ft = LLVM.FunctionType(return_type(ft), new_param_types) + new_f = LLVM.Function(mod, fn, new_ft) + linkage!(new_f, linkage(f)) + for (arg, new_arg) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_arg, LLVM.name(arg)) + end + for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end]) + LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name) + end + + workmap[f] = new_f + end + + # clone and rewrite the function bodies. + # we don't need to rewrite much as the arguments are added last. + for (f, new_f) in workmap + # map the arguments + value_map = Dict{LLVM.Value, LLVM.Value}() + for (param, new_param) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_param, LLVM.name(param)) + value_map[param] = new_param + end + + value_map[f] = new_f + clone_into!(new_f, f; value_map, + changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly) + + # we can't remove this function yet, as we might still need to rewrite any called, + # but remove the IR already + empty!(f) + end + + # drop unused constants that may be referring to the old functions + # XXX: can we do this differently? + for f in worklist + prune_constexpr_uses!(f) + end + + # update other uses of the old function, modifying call sites to pass the arguments + function rewrite_uses!(f, new_f) + # update uses + @dispose builder=IRBuilder() begin + for use in uses(f) + val = user(use) + if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst + callee_f = LLVM.parent(LLVM.parent(val)) + # forward the arguments + position!(builder, val) + new_val = if val isa LLVM.CallInst + call!(builder, function_type(new_f), new_f, + [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], + operand_bundles(val)) + else + # TODO: invoke and callbr + error("Rewrite of $(typeof(val))-based calls is not implemented: $val") + end + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast + # XXX: why isn't this caught by the value materializer above? + target = operands(val)[1] + @assert target == f + new_val = LLVM.const_bitcast(new_f, value_type(val)) + rewrite_uses!(val, new_val) + # we can't simply replace this constant expression, as it may be used + # as a call, taking arguments (so we need to rewrite it to pass the input arguments) + + # drop the old constant if it is unused + # XXX: can we do this differently? + if isempty(uses(val)) + LLVM.unsafe_destroy!(val) + end + else + error("Cannot rewrite unknown use of function: $val") + end + end + end + end + for (f, new_f) in workmap + rewrite_uses!(f, new_f) + @assert isempty(uses(f)) + erase!(f) + end + + # replace uses of the intrinsics with references to the input arguments + for (i, intr_fn) in enumerate(used_intrinsics) + intr = functions(mod)[intr_fn] + for use in uses(intr) + val = user(use) + callee_f = LLVM.parent(LLVM.parent(val)) + if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst + replace_uses!(val, parameters(callee_f)[end-nargs+i]) + else + error("Cannot rewrite unknown use of function: $val") + end + + @assert isempty(uses(val)) + erase!(val) + end + @assert isempty(uses(intr)) + erase!(intr) + end + + return functions(mod)[entry_fn] +end diff --git a/src/metal.jl b/src/metal.jl index 4f303196..200af854 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -35,7 +35,7 @@ llvm_datalayout(target::MetalCompilerTarget) = "-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024"* "-n8:16:32" -needs_byval(job::CompilerJob{MetalCompilerTarget}) = false +pass_by_value(job::CompilerJob{MetalCompilerTarget}) = false ## job @@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo # update calling conventions if job.config.kernel entry = pass_by_reference!(job, mod, entry) - - add_input_arguments!(job, mod, entry) - entry = LLVM.functions(mod)[entry_fn] + entry = add_input_arguments!(job, mod, entry, kernel_intrinsics) end # emit the AIR and Metal version numbers as constants in the module. this makes it @@ -160,6 +158,11 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L entry::LLVM.Function) entry_fn = LLVM.name(entry) + # convert the kernel state argument to a reference + if job.config.kernel && kernel_state_type(job) !== Nothing + entry = kernel_state_to_reference!(job, mod, entry) + end + # add kernel metadata if job.config.kernel entry = add_parameter_address_spaces!(job, mod, entry) @@ -235,12 +238,12 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV # find the byref parameters byref = BitVector(undef, length(parameters(ft))) - args = classify_arguments(job, ft) + args = classify_arguments(job, ft; post_optimization=job.config.optimize) filter!(args) do arg arg.cc != GHOST end for arg in args - byref[arg.idx] = (arg.cc == BITS_REF) + byref[arg.idx] = (arg.cc == BITS_REF || arg.cc == KERNEL_STATE) end function remapType(src) @@ -318,6 +321,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV # remove the old function fn = LLVM.name(f) + prune_constexpr_uses!(f) @assert isempty(uses(f)) replace_metadata_uses!(f, new_f) erase!(f) @@ -418,7 +422,7 @@ end # value-to-reference conversion # -# Metal doesn't support passing valuse, so we need to convert those to references instead +# Metal doesn't support passing values, so we need to convert those to references instead function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function) ft = function_type(f) @@ -547,164 +551,6 @@ function argument_type_name(typ) end end -function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - entry::LLVM.Function) - entry_fn = LLVM.name(entry) - - # figure out which intrinsics are used and need to be added as arguments - used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn - haskey(functions(mod), intr_fn) - end |> collect - nargs = length(used_intrinsics) - - # determine which functions need these arguments - worklist = Set{LLVM.Function}([entry]) - for intr_fn in used_intrinsics - push!(worklist, functions(mod)[intr_fn]) - end - worklist_length = 0 - while worklist_length != length(worklist) - # iteratively discover functions that use an intrinsic or any function calling it - worklist_length = length(worklist) - additions = LLVM.Function[] - for f in worklist, use in uses(f) - inst = user(use)::Instruction - bb = LLVM.parent(inst) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) - end - for f in additions - push!(worklist, f) - end - end - for intr_fn in used_intrinsics - delete!(worklist, functions(mod)[intr_fn]) - end - - # add the arguments - # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt - workmap = Dict{LLVM.Function, LLVM.Function}() - for f in worklist - fn = LLVM.name(f) - ft = function_type(f) - LLVM.name!(f, fn * ".orig") - # create a new function - new_param_types = LLVMType[parameters(ft)...] - - for intr_fn in used_intrinsics - llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ) - push!(new_param_types, llvm_typ) - end - new_ft = LLVM.FunctionType(return_type(ft), new_param_types) - new_f = LLVM.Function(mod, fn, new_ft) - linkage!(new_f, linkage(f)) - for (arg, new_arg) in zip(parameters(f), parameters(new_f)) - LLVM.name!(new_arg, LLVM.name(arg)) - end - for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end]) - LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name) - end - - workmap[f] = new_f - end - - # clone and rewrite the function bodies. - # we don't need to rewrite much as the arguments are added last. - for (f, new_f) in workmap - # map the arguments - value_map = Dict{LLVM.Value, LLVM.Value}() - for (param, new_param) in zip(parameters(f), parameters(new_f)) - LLVM.name!(new_param, LLVM.name(param)) - value_map[param] = new_param - end - - value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly) - - # we can't remove this function yet, as we might still need to rewrite any called, - # but remove the IR already - empty!(f) - end - - # drop unused constants that may be referring to the old functions - # XXX: can we do this differently? - for f in worklist - prune_constexpr_uses!(f) - end - - # update other uses of the old function, modifying call sites to pass the arguments - function rewrite_uses!(f, new_f) - # update uses - @dispose builder=IRBuilder() begin - for use in uses(f) - val = user(use) - if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst - callee_f = LLVM.parent(LLVM.parent(val)) - # forward the arguments - position!(builder, val) - new_val = if val isa LLVM.CallInst - call!(builder, function_type(new_f), new_f, - [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], - operand_bundles(val)) - else - # TODO: invoke and callbr - error("Rewrite of $(typeof(val))-based calls is not implemented: $val") - end - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast - # XXX: why isn't this caught by the value materializer above? - target = operands(val)[1] - @assert target == f - new_val = LLVM.const_bitcast(new_f, value_type(val)) - rewrite_uses!(val, new_val) - # we can't simply replace this constant expression, as it may be used - # as a call, taking arguments (so we need to rewrite it to pass the input arguments) - - # drop the old constant if it is unused - # XXX: can we do this differently? - if isempty(uses(val)) - LLVM.unsafe_destroy!(val) - end - else - error("Cannot rewrite unknown use of function: $val") - end - end - end - end - for (f, new_f) in workmap - rewrite_uses!(f, new_f) - @assert isempty(uses(f)) - erase!(f) - end - - # replace uses of the intrinsics with references to the input arguments - for (i, intr_fn) in enumerate(used_intrinsics) - intr = functions(mod)[intr_fn] - for use in uses(intr) - val = user(use) - callee_f = LLVM.parent(LLVM.parent(val)) - if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst - replace_uses!(val, parameters(callee_f)[end-nargs+i]) - else - error("Cannot rewrite unknown use of function: $val") - end - - @assert isempty(uses(val)) - erase!(val) - end - @assert isempty(uses(intr)) - erase!(intr) - end - - return -end - - # argument metadata generation # # module metadata is used to identify buffers that are passed as kernel arguments. @@ -717,11 +563,15 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul arg_infos = Metadata[] # Iterate through arguments and create metadata for them - args = classify_arguments(job, entry_ft) + args = classify_arguments(job, entry_ft; post_optimization=job.config.optimize) i = 1 for arg in args arg.idx === nothing && continue - @assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType + if job.config.optimize + @assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType + else + parameters(entry_ft)[arg.idx] isa LLVM.PointerType || continue + end # NOTE: we emit the bare minimum of argument metadata to support # bindless argument encoding. Actually using the argument encoder diff --git a/src/spirv.jl b/src/spirv.jl index 21d59a93..8eea92c6 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -62,7 +62,8 @@ llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ? runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) = "spirv-" * String(job.config.target.backend) -function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function) +function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, + entry::LLVM.Function) # update calling convention for f in functions(mod) # JuliaGPU/GPUCompiler.jl#97 @@ -72,6 +73,37 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, callconv!(entry, LLVM.API.LLVMSPIRKERNELCallConv) end + return entry +end + +function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module) + errors = IRError[] + + # support for half and double depends on the target + if !job.config.target.supports_fp16 + append!(errors, check_ir_values(mod, LLVM.HalfType())) + end + if !job.config.target.supports_fp64 + append!(errors, check_ir_values(mod, LLVM.DoubleType())) + end + + return errors +end + +function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, + entry::LLVM.Function) + # convert the kernel state argument to a byval reference + if job.config.kernel + state = kernel_state_type(job) + if state !== Nothing + entry = kernel_state_to_reference!(job, mod, entry) + + T_state = convert(LLVMType, state) + attr = TypeAttribute("byval", T_state) + push!(parameter_attributes(entry, 1), attr) + end + end + # HACK: Intel's compute runtime doesn't properly support SPIR-V's byval attribute. # they do support struct byval, for OpenCL, so wrap byval parameters in a struct. if job.config.kernel @@ -91,20 +123,6 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, return entry end -function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module) - errors = IRError[] - - # support for half and double depends on the target - if !job.config.target.supports_fp16 - append!(errors, check_ir_values(mod, LLVM.HalfType())) - end - if !job.config.target.supports_fp64 - append!(errors, check_ir_values(mod, LLVM.DoubleType())) - end - - return errors -end - @unlocked function mcgen(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, format=LLVM.API.LLVMAssemblyFile) # The SPIRV Tools don't handle Julia's debug info, rejecting DW_LANG_Julia... @@ -343,6 +361,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F # remove the old function # NOTE: if we ever have legitimate uses of the old function, create a shim instead fn = LLVM.name(f) + prune_constexpr_uses!(f) @assert isempty(uses(f)) replace_metadata_uses!(f, new_f) erase!(f)