Skip to content

Commit

Permalink
Merge pull request #1495 from FluxML/bc/gradtuple
Browse files Browse the repository at this point in the history
Un-collapse nothings in `gradient`
  • Loading branch information
ToucheSir authored Jan 17, 2024
2 parents 5b61724 + 070466d commit 54f1e80
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Requires = "1.1"
SpecialFunctions = "1.6, 2"
Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.4"
ZygoteRules = "0.2.5"
julia = "1.6"

[extras]
Expand Down
66 changes: 60 additions & 6 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)

"""
pullback(f, args...)
pullback(f, ::Params)
Returns the value of the function `f` and a back-propagator function,
which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or gradient.
```julia
y, back = pullback(f, args...)
∇ = back(seed)
```
`back` must be called with a start value `seed` matching the output of `f(args...)`.
If `f(args...)` returns a number, `seed` should be a number.
If `f(args...)` returns an array, `seed` should be an equally-sized array.
See also [`withgradient`](@ref) to obtain the value and gradients in one call,
and [`gradient`](@ref) for obtaining just the gradients.
```jldoctest; setup=:(using Zygote)
julia> y, back = pullback(*, 2.0, 3.0, 5.0);
julia> y
30.0
julia> back(1.0)
(15.0, 10.0, 6.0)
julia> back(2.0)
(30.0, 20.0, 12.0)
julia> y, back = pullback(x -> [x, x], 1.0);
julia> y
2-element Vector{Float64}:
1.0
1.0
julia> back([1.0, 1.0])
(2.0,)
julia> back([2.0, nothing])
(2.0,)
```
"""
@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
y, back = _pullback(cx, f, args...)
Expand Down Expand Up @@ -67,11 +113,16 @@ sensitivity(y::Complex) = error("Output is complex, so the gradient is not defin
sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")

# Preserves output as tuple when gradients are collapsed
_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)

"""
gradient(f, args...)
Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or the gradient.
If no gradient is defined, `∂f/∂x` will be `nothing`.
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
Expand All @@ -95,7 +146,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
return _project_all(args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand All @@ -109,7 +160,7 @@ end
withgradient(f, ::Params)
Returns both the value of the function and the [`gradient`](@ref),
as a named tuple.
as a named tuple.
```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
Expand Down Expand Up @@ -161,7 +212,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
results = _project_all(args, grad)
(val=y, grad=results)
end

Expand Down Expand Up @@ -304,7 +355,7 @@ end
Grads(...)
Dictionary-like container returned when taking gradients with
respect to implicit parameters. For an array `W`, appearing
respect to implicit parameters. For an array `W`, appearing
within `Params([W, A, B...])`, the gradient is `g[W]`.
"""
struct Grads
Expand All @@ -321,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads}

# Dictionary interface.
# Don't use the IdDict directly since it may contain some spurious pairs.
Base.haskey(gs::Grads, x) = x gs.params
Base.haskey(gs::Grads, x) = x gs.params
Base.keys(gs::Grads) = gs.params
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)

Expand Down Expand Up @@ -381,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)

function materialize!(gs1::Grads, gs2::Grads)
issetequal(gs1.params, gs2.params) ||
issetequal(gs1.params, gs2.params) ||
throw(ArgumentError("Expected Grads objects with the same Params."))
for p in gs1.params
gs1[p] = gs2[p]
Expand Down Expand Up @@ -421,6 +472,9 @@ function pullback(f, ps::Params)
end
end

# No conversion required here
_project_all(_, dx::Grads) = dx

# Code Reflection

function code_ir(f, T)
Expand Down
4 changes: 2 additions & 2 deletions test/lib/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
@test gradient(floor, 1) === (0.0,)
@test gradient(ceil, 1) === (0.0,)
@test gradient(round, 1) === (0.0,)
@test gradient(hash, 1) === nothing
@test gradient(div, 1, 2) === nothing
@test gradient(hash, 1) === (nothing,)
@test gradient(div, 1, 2) === (nothing, nothing)
end

@testset "basics" begin
Expand Down
2 changes: 1 addition & 1 deletion test/structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ end
end

m, b = Zygote._pullback(Zygote.Context(), nameof, M)
@test b(m) == (nothing, nothing)
@test b(m) === nothing
end

0 comments on commit 54f1e80

Please sign in to comment.