Skip to content

Allow for nested targets #696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Allow for nested targets #696

wants to merge 3 commits into from

Conversation

vchuravy
Copy link
Member

@vchuravy vchuravy commented May 30, 2025

The crux for Enzyme GPU support is that we use a @generated deferred_codegen implementation in which we do not know what the calling environment is. We might be called from the CPU, CUDA, AMDGPU and so-forth.

GPUCompiler during the CUDA compilation then finds the Enzyme compilation job in the deferred_jobs dictionary,
and then asks Enzyme to codegen the adjoint code. During the code generation of the adjoint code, Enzyme must first codegen the primal/original code and thus must construct a compilation job for CUDA.

Previously we passed parent_job through for Enzyme to be able to perform the mode switch.

Here I propose that instead we support nesting both targets and params such that Enzyme can reuse those correctly instead of guessing.

Open to rename the function and I will add some tests here later. EnzymeAD/Enzyme.jl#2424 is the other side of this change.

x-ref: #668 (comment)

cc: @wsmoses

Copy link

codecov bot commented May 30, 2025

Codecov Report

Attention: Patch coverage is 60.00000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 52.38%. Comparing base (8b8c73f) to head (1662181).

Files with missing lines Patch % Lines
src/interface.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master     #696       +/-   ##
===========================================
- Coverage   71.63%   52.38%   -19.26%     
===========================================
  Files          24       24               
  Lines        3519     3465       -54     
===========================================
- Hits         2521     1815      -706     
- Misses        998     1650      +652     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@vchuravy vchuravy marked this pull request as ready for review June 13, 2025 15:43
@vchuravy vchuravy force-pushed the vc/nested_targets branch from 9f205aa to 650b7f4 Compare June 13, 2025 15:43
Copy link
Contributor

github-actions bot commented Jun 13, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/driver.jl b/src/driver.jl
index f610611..e64791b 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -224,7 +224,7 @@ const __llvm_initialized = Ref(false)
                 dyn_entry_fn = get!(jobs, dyn_job) do
                     target = nest_target(dyn_job.config.target, job.config.target)
                     params = nest_params(dyn_job.config.params, job.config.params)
-                    config = CompilerConfig(dyn_job.config; toplevel=false, target, params)
+                    config = CompilerConfig(dyn_job.config; toplevel = false, target, params)
                     dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config))
                     dyn_entry_fn = LLVM.name(dyn_meta.entry)
                     merge!(compiled, dyn_meta.compiled)
diff --git a/test/helpers/enzyme.jl b/test/helpers/enzyme.jl
index c5af69d..9de0948 100644
--- a/test/helpers/enzyme.jl
+++ b/test/helpers/enzyme.jl
@@ -2,12 +2,12 @@ module Enzyme
 
 using ..GPUCompiler
 
-struct EnzymeTarget{Target<:AbstractCompilerTarget} <: AbstractCompilerTarget
+struct EnzymeTarget{Target <: AbstractCompilerTarget} <: AbstractCompilerTarget
     target::Target
 end
 
-function EnzymeTarget(;kwargs...)
-    EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
+function EnzymeTarget(; kwargs...)
+    return EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
 end
 
 GPUCompiler.llvm_triple(target::EnzymeTarget) = GPUCompiler.llvm_triple(target.target)
@@ -18,7 +18,7 @@ GPUCompiler.have_fma(target::EnzymeTarget, T::Type) = GPUCompiler.have_fma(targe
 GPUCompiler.dwarf_version(target::EnzymeTarget) = GPUCompiler.dwarf_version(target.target)
 
 abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
-struct EnzymeCompilerParams{Params<:AbstractCompilerParams} <: AbstractEnzymeCompilerParams
+struct EnzymeCompilerParams{Params <: AbstractCompilerParams} <: AbstractEnzymeCompilerParams
     params::Params
 end
 struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
@@ -68,14 +68,18 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     sig = Tuple{ft, tt.parameters...}
     min_world = Ref{UInt}(typemin(UInt))
     max_world = Ref{UInt}(typemax(UInt))
-    match = ccall(:jl_gf_invoke_lookup_worlds, Any,
-                  (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
-                  sig, #=mt=# nothing, world, min_world, max_world)
+    match = ccall(
+        :jl_gf_invoke_lookup_worlds, Any,
+        (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
+        sig, #=mt=# nothing, world, min_world, max_world
+    )
     match === nothing && return stub(world, source, method_error)
 
     # look up the method and code instance
-    mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
-               (Any, Any, Any), match.method, match.spec_types, match.sparams)
+    mi = ccall(
+        :jl_specializations_get_linfo, Ref{Core.MethodInstance},
+        (Any, Any, Any), match.method, match.spec_types, match.sparams
+    )
     ci = CC.retrieve_code_info(mi, world)
 
     # prepare a new code info
@@ -83,10 +87,10 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     new_ci = copy(ci)
     empty!(new_ci.code)
     @static if isdefined(Core, :DebugInfo)
-      new_ci.debuginfo = Core.DebugInfo(:none)
+        new_ci.debuginfo = Core.DebugInfo(:none)
     else
-      empty!(new_ci.codelocs)
-      resize!(new_ci.linetable, 1)                # see note below
+        empty!(new_ci.codelocs)
+        resize!(new_ci.linetable, 1)                # see note below
     end
     empty!(new_ci.ssaflags)
     new_ci.ssavaluetypes = 0
@@ -99,13 +103,13 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
 
     # prepare the slots
     new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
-    new_ci.slotflags = UInt8[0x00 for i = 1:3]
+    new_ci.slotflags = UInt8[0x00 for i in 1:3]
     new_ci.nargs = 3
 
     # We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
     target = EnzymeTarget()
     params = EnzymeCompilerParams()
-    config = CompilerConfig(target, params; kernel=false)
+    config = CompilerConfig(target, params; kernel = false)
     job = CompilerJob(mi, config, world)
 
     id = length(deferred_codegen_jobs) + 1
@@ -114,9 +118,9 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     # return the deferred_codegen_id
     push!(new_ci.code, CC.ReturnNode(id))
     push!(new_ci.ssaflags, 0x00)
-        @static if isdefined(Core, :DebugInfo)
+    @static if isdefined(Core, :DebugInfo)
     else
-      push!(new_ci.codelocs, 1)   # see note below
+        push!(new_ci.codelocs, 1)   # see note below
     end
     new_ci.ssavaluetypes += 1
 
@@ -135,7 +139,7 @@ end
 
 @inline function deferred_codegen(f::Type, tt::Type)
     id = deferred_codegen_id(f, tt)
-    ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
+    return ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
 end
 
-end
\ No newline at end of file
+end
diff --git a/test/native.jl b/test/native.jl
index cba496f..6122e15 100644
--- a/test/native.jl
+++ b/test/native.jl
@@ -659,14 +659,14 @@ end
         a[1] = a[1]^2
         return
     end
-    
+
     function dkernel(a)
         ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Vector{Float64}})
         ccall(ptr, Cvoid, (Vector{Float64},), a)
         return
     end
 
-    ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo=:none))
+    ir = sprint(io -> Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo = :none))
     @test !occursin("deferred_codegen", ir)
     @test occursin("call void @julia_kernel", ir)
 end
diff --git a/test/ptx.jl b/test/ptx.jl
index 9e56ee5..c8e3a99 100644
--- a/test/ptx.jl
+++ b/test/ptx.jl
@@ -152,22 +152,22 @@ end
 end
 end
 
-@testset "Mock Enzyme" begin
-    function kernel(a)
-        unsafe_store!(a, unsafe_load(a)^2)
-        return
-    end
-    
-    function dkernel(a)
-        ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
-        ccall(ptr, Cvoid, (Ptr{Float64},), a)
-        return
-    end
+    @testset "Mock Enzyme" begin
+        function kernel(a)
+            unsafe_store!(a, unsafe_load(a)^2)
+            return
+        end
 
-    ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo=:none))
-    @test !occursin("deferred_codegen", ir)
-    @test occursin("call void @julia_", ir)
-end
+        function dkernel(a)
+            ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
+            ccall(ptr, Cvoid, (Ptr{Float64},), a)
+            return
+        end
+
+        ir = sprint(io -> Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo = :none))
+        @test !occursin("deferred_codegen", ir)
+        @test occursin("call void @julia_", ir)
+    end
 
 end
 

@vchuravy vchuravy force-pushed the vc/nested_targets branch from 8a9ab5d to 1662181 Compare June 27, 2025 15:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant