Skip to content

Fix deferred_codegen registration #711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/GPUCompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ function __init__()
global compile_cache = dir

Tracy.@register_tracepoints()

# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"`
@dispose jljit=JuliaOJIT() begin
jd = JITDylib(jljit)

address = LLVM.API.LLVMOrcJITTargetAddress(
reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))))
flags = LLVM.API.LLVMJITSymbolFlags(
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
name = mangle(jljit, "deferred_codegen")
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
map = if LLVM.version() >= v"15"
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
else
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
end

mu = LLVM.absolute_symbols(Ref(map))
LLVM.define(jd, mu)
addr = lookup(jljit, "deferred_codegen")
@assert addr != C_NULL "Failed to register deferred_codegen"
end
end

end # module
3 changes: 2 additions & 1 deletion src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
# Julia 1.11 and co broke @ccallable so we have to do this manually in __init__
function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid}
ptr
end

Expand Down
2 changes: 2 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,6 @@ end

@testset "Mock Enzyme" begin
Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}})
# Check that we can call this function from the CPU, to support deferred codegen for Enzyme.
@test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3
end
Loading