Skip to content

Allow for nested targets #696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ const __llvm_initialized = Ref(false)
for dyn_job in keys(worklist)
# cached compilation
dyn_entry_fn = get!(jobs, dyn_job) do
config = CompilerConfig(dyn_job.config; toplevel=false)
target = nest_target(dyn_job.config.target, job.config.target)
params = nest_params(dyn_job.config.params, job.config.params)
config = CompilerConfig(dyn_job.config; toplevel=false, target, params)
dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config))
dyn_entry_fn = LLVM.name(dyn_meta.entry)
merge!(compiled, dyn_meta.compiled)
Expand Down
5 changes: 5 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ have_fma(@nospecialize(target::AbstractCompilerTarget), T::Type) = false

dwarf_version(target::AbstractCompilerTarget) = Int32(4) # It seems every target supports v4 bar cuda

# If your target performs nested compilation, this function should reconstruct your target with a new inner target
nest_target(target::AbstractCompilerTarget, parent::AbstractCompilerTarget) = target

## params

export AbstractCompilerParams
Expand All @@ -56,6 +59,8 @@ export AbstractCompilerParams

abstract type AbstractCompilerParams end

nest_params(params::AbstractCompilerParams, parent::AbstractCompilerParams) = params


## config

Expand Down
143 changes: 143 additions & 0 deletions test/helpers/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
module Enzyme

using ..GPUCompiler

struct EnzymeTarget{Target<:AbstractCompilerTarget} <: AbstractCompilerTarget
target::Target
end

function EnzymeTarget(;kwargs...)
EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
end

GPUCompiler.llvm_triple(target::EnzymeTarget) = GPUCompiler.llvm_triple(target.target)
GPUCompiler.llvm_datalayout(target::EnzymeTarget) = GPUCompiler.llvm_datalayout(target.target)
GPUCompiler.llvm_machine(target::EnzymeTarget) = GPUCompiler.llvm_machine(target.target)
GPUCompiler.nest_target(::EnzymeTarget, other::AbstractCompilerTarget) = EnzymeTarget(other)
GPUCompiler.have_fma(target::EnzymeTarget, T::Type) = GPUCompiler.have_fma(target.target, T)
GPUCompiler.dwarf_version(target::EnzymeTarget) = GPUCompiler.dwarf_version(target.target)

abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
struct EnzymeCompilerParams{Params<:AbstractCompilerParams} <: AbstractEnzymeCompilerParams
params::Params
end
struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
end

EnzymeCompilerParams() = EnzymeCompilerParams(PrimalCompilerParams())

GPUCompiler.nest_params(::EnzymeCompilerParams, other::AbstractCompilerParams) = EnzymeCompilerParams(other)

function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeTarget})
config = job.config
primal_target = (job.config.target::EnzymeTarget).target
primal_params = (job.config.params::EnzymeCompilerParams).params

primal_config = CompilerConfig(
primal_target,
primal_params;
toplevel = config.toplevel,
always_inline = config.always_inline,
kernel = false,
libraries = true,
optimize = false,
cleanup = false,
only_entry = false,
validate = false,
# ??? entry_abi
)
primal_job = CompilerJob(job.source, primal_config, job.world)
return GPUCompiler.compile_unhooked(output, primal_job)

# Normally, Enzyme would run here and transform the output of the primal job.
end

import GPUCompiler: deferred_codegen_jobs
import Core.Compiler as CC

function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::Type)
@nospecialize
@assert CC.isType(ft) && CC.isType(tt)
ft = ft.parameters[1]
tt = tt.parameters[1]

stub = Core.GeneratedFunctionStub(identity, Core.svec(:deferred_codegen_id, :ft, :tt), Core.svec())

# look up the method match
method_error = :(throw(MethodError(ft, tt, $world)))
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
sig, #=mt=# nothing, world, min_world, max_world)
match === nothing && return stub(world, source, method_error)

# look up the method and code instance
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any), match.method, match.spec_types, match.sparams)
ci = CC.retrieve_code_info(mi, world)

# prepare a new code info
# TODO: Can we create a new CI instead of copying a "wrong" one?
new_ci = copy(ci)
empty!(new_ci.code)
@static if isdefined(Core, :DebugInfo)
new_ci.debuginfo = Core.DebugInfo(:none)
else
empty!(new_ci.codelocs)
resize!(new_ci.linetable, 1) # see note below
end
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0

# propagate edge metadata
# new_ci.min_world = min_world[]
new_ci.min_world = world
new_ci.max_world = max_world[]
new_ci.edges = Core.MethodInstance[mi]

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]
@static if isdefined(Core, :DebugInfo)
new_ci.nargs = 3
end

# We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
target = EnzymeTarget()
params = EnzymeCompilerParams()
config = CompilerConfig(target, params; kernel=false)
job = CompilerJob(mi, config, world)

id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = job

# return the deferred_codegen_id
push!(new_ci.code, CC.ReturnNode(id))
push!(new_ci.ssaflags, 0x00)
@static if isdefined(Core, :DebugInfo)
else
push!(new_ci.codelocs, 1) # see note below
end
new_ci.ssavaluetypes += 1

# NOTE: we keep the first entry of the original linetable, and use it for location info
# on the call to check_cache. we can't not have a codeloc (using 0 causes
# corruption of the back trace), and reusing the target function's info
# has as advantage that we see the name of the kernel in the backtraces.

return new_ci
end

@eval function deferred_codegen_id(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, deferred_codegen_id_generator))
end

@inline function deferred_codegen(f::Type, tt::Type)
id = deferred_codegen_id(f, tt)
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
end

end
17 changes: 17 additions & 0 deletions test/native.jl
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,20 @@ end
Native.code_llvm(mod.parent, Tuple{}; debuginfo=:none, mod.method_table)
end
end

@testset "Mock Enzyme" begin
function kernel(a)
a[1] = a[1]^2
return
end

function dkernel(a)
ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Vector{Float64}})
ccall(ptr, Cvoid, (Vector{Float64},), a)
return
end

ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo=:none))
@test !occursin("deferred_codegen", ir)
@test occursin("call void @julia_kernel", ir)
end
17 changes: 17 additions & 0 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ end
end
end

@testset "Mock Enzyme" begin
function kernel(a)
unsafe_store!(a, unsafe_load(a)^2)
return
end

function dkernel(a)
ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
ccall(ptr, Cvoid, (Ptr{Float64},), a)
return
end

ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo=:none))
@test !occursin("deferred_codegen", ir)
@test occursin("call void @julia_", ir)
end

end

############################################################################################
Expand Down
4 changes: 4 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,7 @@ end
error("errors")
end
end

@testset "Mock Enzyme" begin
Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}})
end
Loading