|
| 1 | +using Flux, ChainRulesCore |
| 2 | +using LinearAlgebra: mul! |
| 3 | +# using FastBroadcast: @.. |
| 4 | +using Strided |
| 5 | + |
| 6 | +const NoT = NoTangent() |
| 7 | + |
| 8 | +""" |
| 9 | + PreLayer(Dense(2 => 3, relu)) |
| 10 | +
|
| 11 | +Stores, along with the layer, pre-allocated space for its output, |
| 12 | +and all gradient components. Only works on layers it understands. |
| 13 | +""" |
| 14 | +struct PreLayer{L,G,V} |
| 15 | + layer::L |
| 16 | + grad::G # same fixed sizes as layer |
| 17 | + fwd::V # vector of dynamic length |
| 18 | + rev::V |
| 19 | +end |
| 20 | + |
| 21 | +Flux.@functor PreLayer |
| 22 | +Flux.trainable(p::PreLayer) = (; layer = p.layer) |
| 23 | + |
| 24 | +""" |
| 25 | + model |> pre |
| 26 | +
|
| 27 | +Wrap as many layers as possible with `PreLayer`, |
| 28 | +to store pre-allocated space for output & gradient. |
| 29 | +Ignores layers it doesn't understand. |
| 30 | +""" |
| 31 | +pre(model) = fmap(PreLayer, model; exclude = x -> hasmethod(PreLayer, Tuple{typeof(x)})) |
| 32 | + |
| 33 | +""" |
| 34 | + nopre(model) |
| 35 | +
|
| 36 | +Remove all `PreLayer`s & return the plain model. |
| 37 | +""" |
| 38 | +nopre(model) = fmap(x -> x.layer, model; exclude = x -> x isa PreLayer) |
| 39 | + |
| 40 | + |
| 41 | +##### |
| 42 | +##### Dense |
| 43 | +##### |
| 44 | + |
| 45 | +function PreLayer(d::Dense) |
| 46 | + grad = _struct_sim(d) |
| 47 | + fwd, rev = similar(d.weight, 0), similar(d.weight, 0) |
| 48 | + PreLayer(d, grad, fwd, rev) |
| 49 | +end |
| 50 | + |
| 51 | +function (p::PreLayer{<:Dense})(x::AbstractMatrix{<:Real}) |
| 52 | + y, dx = _pre_setup(p, x) |
| 53 | + _densecall!(y, p, x, dx) |
| 54 | +end |
| 55 | + |
| 56 | +function _pre_setup(p::PreLayer{<:Dense}, x) # this function @nograd |
| 57 | + _, b = size(x) |
| 58 | + o, i = size(p.layer.weight) |
| 59 | + if o*b != length(p.fwd) |
| 60 | + resize!(p.fwd, o*b) |
| 61 | + resize!(p.rev, i*b) |
| 62 | + end |
| 63 | + y = _pre_reshape(p.fwd, (o,b)) |
| 64 | + dx = _pre_reshape(p.rev, (i,b)) |
| 65 | + (; y, dx) |
| 66 | +end |
| 67 | + |
| 68 | +function _densecall!(y, p, x, dx) |
| 69 | + y .= p.layer.bias |
| 70 | + mul!(y, p.layer.weight, x, true, true) |
| 71 | + act!(y, p.layer.σ) |
| 72 | + y |
| 73 | +end |
| 74 | + |
| 75 | +function ChainRulesCore.rrule(::typeof(_densecall!), y, p, x, dx) |
| 76 | + y = _densecall!(y, p, x, dx) |
| 77 | + function back(dy) |
| 78 | + dy = unthunk(dy) |
| 79 | + dy = ∇act!(y, dy, p.layer.σ) |
| 80 | + # layer |
| 81 | + weight = mul!(p.grad.weight, dy, x') |
| 82 | + bias = ∇bias!(p.grad.bias, dy) |
| 83 | + tang = Tangent{Dense}(; weight, bias) |
| 84 | + # input |
| 85 | + dx = mul!(dx, p.layer.weight', dy) |
| 86 | + return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT) |
| 87 | + end |
| 88 | + y, back |
| 89 | +end |
| 90 | + |
| 91 | +##### |
| 92 | +##### Scale |
| 93 | +##### |
| 94 | + |
| 95 | +scale!(y, (scale, ds), (x, dx), (bias, db)) = y .= scale .* x .+ bias |
| 96 | +# scale!(y, (scale, ds), (x, dx), (bias, db)) = @strided y .= scale .* x .+ bias |
| 97 | + |
| 98 | +function ChainRulesCore.rrule(::typeof(scale!), y, (scale, ds), (x, dx), (bias, db)) |
| 99 | + y = scale!(y, (scale, ds), (x, dx), (bias, db)) |
| 100 | + function back(dy) |
| 101 | + dy = unthunk(dy) |
| 102 | + @strided dx .= dy .* scale |
| 103 | + @strided ds .= dy .* x |
| 104 | + dbias = ∇bias!(bias, db) |
| 105 | + return (NoT, NoT, (ds, NoT), (dx, NoT), (dbias, NoT)) |
| 106 | + end |
| 107 | + y, back |
| 108 | +end |
| 109 | + |
| 110 | +##### |
| 111 | +##### softmax |
| 112 | +##### |
| 113 | + |
| 114 | +function PreLayer(::typeof(softmax)) |
| 115 | + fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal, demands `model |> pre |> gpu` |
| 116 | + PreLayer(softmax, nothing, fwd, rev) |
| 117 | +end |
| 118 | + |
| 119 | +function (p::PreLayer{typeof(softmax)})(x::AbstractArray{<:Real}) |
| 120 | + y, dx = _pre_setup(p, x) # generic version |
| 121 | + _softmaxcall!(y, p, x, dx) |
| 122 | +end |
| 123 | + |
| 124 | +_softmaxcall!(y, p, x, dx) = softmax!(y, x) |
| 125 | + |
| 126 | +function ChainRulesCore.rrule(::typeof(_softmaxcall!), y, p, x, dx) |
| 127 | + y = _softmaxcall!(y, p, x, dx) |
| 128 | + function back(dy) |
| 129 | + # TODO: CHECK THIS! |
| 130 | + dx .= dy .* y |
| 131 | + dx .= dx .- y .* sum(dx; dims=1) # could sum! into the end of rev |
| 132 | + return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented? |
| 133 | + end |
| 134 | + y, back |
| 135 | +end |
| 136 | + |
| 137 | +##### |
| 138 | +##### BatchNorm |
| 139 | +##### |
| 140 | + |
| 141 | +function PreLayer(bn::BatchNorm) |
| 142 | + grad = (β = similar(bn.β), γ = similar(bn.γ)) # only trainable fields |
| 143 | + fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal |
| 144 | + PreLayer(bn, grad, fwd, rev) |
| 145 | +end |
| 146 | + |
| 147 | +function (p::PreLayer{<:BatchNorm})(x::AbstractArray{<:Real}) |
| 148 | + y, dx = _pre_setup(p, x) |
| 149 | + # _batchnormcall!(y, p, x, dx) |
| 150 | + |
| 151 | + # from (BN::BatchNorm)(x) |
| 152 | + N = ndims(x) |
| 153 | + reduce_dims = [1:N-2; N] |
| 154 | + affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) |
| 155 | + _norm_layer_forward!(y, p, (x, dx); reduce_dims, affine_shape) |
| 156 | +end |
| 157 | + |
| 158 | +using Flux: _isactive, _track_stats!, hasaffine |
| 159 | + |
| 160 | +function _norm_layer_forward!(y, p, (x, dx); reduce_dims, affine_shape) |
| 161 | + l = p.layer |
| 162 | + N = ndims(x) |
| 163 | + |
| 164 | + # This block verbatim from Flux. However, mean & var aren't in-place, |
| 165 | + # nor are their gradients... add more storage? |
| 166 | + |
| 167 | + if !_isactive(l) && l.track_stats # testmode with tracked stats |
| 168 | + stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) |
| 169 | + μ = reshape(l.μ, stats_shape) |
| 170 | + σ² = reshape(l.σ², stats_shape) |
| 171 | + else # trainmode or testmode without tracked stats |
| 172 | + μ = mean(x; dims=reduce_dims) |
| 173 | + σ² = var(x; mean=μ, dims=reduce_dims, corrected=false) |
| 174 | + if l.track_stats |
| 175 | + _track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std |
| 176 | + end |
| 177 | + end |
| 178 | + |
| 179 | + y = _norm_layer_forward!(y, x, dx, μ, σ², l.ϵ) |
| 180 | + hasaffine(l) || return act!(y, l.λ) |
| 181 | + |
| 182 | + γ = reshape(l.γ, affine_shape) |
| 183 | + β = reshape(l.β, affine_shape) |
| 184 | + # return l.λ.(γ .* y .+ β) |
| 185 | + y2 = scale!(y, (γ, p.grad.γ), (x, dx), (β, p.grad.β)) |
| 186 | + return act!(y2, l.λ) |
| 187 | +end |
| 188 | + |
| 189 | +_norm_layer_forward!(y, x, dx, μ, σ², ϵ) = y .= (x .- μ) ./ sqrt.(σ² .+ ϵ) |
| 190 | +# _norm_layer_forward!(y, x, dx, μ, σ², ϵ) = @strided y .= (x .- μ) ./ sqrt.(σ² .+ ϵ) |
| 191 | + |
| 192 | +function ChainRulesCore.rrule(::typeof(_norm_layer_forward!), y, x, dx, μ, σ², ϵ) |
| 193 | + y = _norm_layer_forward!(y, x, dx, μ, σ², ϵ) |
| 194 | + function back(dy) |
| 195 | + dx .= dy ./ sqrt.(σ² .+ ϵ) |
| 196 | + # TODO write gradients for mean & variance, these are WRONG! |
| 197 | + dμ = NoT |
| 198 | + dσ² = NoT |
| 199 | + return (NoT, NoT, dx, NoT, dμ, dσ², NoT) |
| 200 | + end |
| 201 | + y, back |
| 202 | +end |
| 203 | + |
| 204 | + |
| 205 | +##### |
| 206 | +##### activation functions |
| 207 | +##### |
| 208 | + |
| 209 | +act!(y, ::typeof(identity)) = y |
| 210 | +function act!(y, act::F) where F |
| 211 | + σ = Flux.NNlib.fast_act(act, y) |
| 212 | + # y .= σ.(y) |
| 213 | + # Unfortunately this hits https://github.com/JuliaLang/julia/issues/43153 |
| 214 | + # maybe you could patch Strided.jl to avoid it? Or use another package... |
| 215 | + @strided y .= σ.(y) |
| 216 | + # FastBroadcast.@.. y = σ(y) |
| 217 | +end |
| 218 | + |
| 219 | +# Piracy, disable @strided on CuArrays: |
| 220 | +Strided.maybestrided(x::Flux.CuArray) = x |
| 221 | + |
| 222 | +# For this rule, it's important to use what `act!` returns, not what it mutates |
| 223 | +ChainRulesCore.rrule(::typeof(act!), y, f) = act!(y, f), dz -> (NoT, ∇act!(y, dy, f), NoT) |
| 224 | + |
| 225 | +∇act!(y, dy, ::typeof(identity)) = dy |
| 226 | +∇act!(y, dy, ::typeof(relu)) = @. y = ifelse(y>0, dy, 0f0) |
| 227 | +∇act!(y, dy, ::typeof(tanh)) = @. y = (1 - y^2) |
| 228 | +∇act!(y, dy, ::typeof(sigmoid)) = @. y = y * (1 - y) |
| 229 | + |
| 230 | + |
| 231 | +##### |
| 232 | +##### PreLayer utils |
| 233 | +##### |
| 234 | + |
| 235 | +_struct_sim(x) = Flux.fmapstructure(x) do x |
| 236 | + x isa AbstractArray{<:Real} ? similar(x) : nothing |
| 237 | +end |
| 238 | + |
| 239 | +function _pre_setup(p::PreLayer, x) # generic version |
| 240 | + if length(x) != length(p.fwd) |
| 241 | + resize!(p.fwd, length(x)) |
| 242 | + resize!(p.rev, length(x)) |
| 243 | + end |
| 244 | + y = _pre_reshape(p.fwd, size(x)) |
| 245 | + dx = _pre_reshape(p.rev, size(x)) |
| 246 | + (; y, dx) |
| 247 | +end |
| 248 | +ChainRulesCore.@non_differentiable _pre_setup(::Any, ::Any) |
| 249 | + |
| 250 | +# Cannot use reshape(::Array), as that prevents later resize! |
| 251 | +_pre_reshape(x::Array, size::Tuple) = Base.ReshapedArray(x, size, ()) |
| 252 | +# Must use reshape(::CuArray) as mul! rejects ReshapedArray |
| 253 | +_pre_reshape(x::Flux.CuArray, size::Tuple) = reshape(x, size) |
| 254 | +_pre_reshape(x, size::Tuple) = reshape(x, size) |
| 255 | + |
| 256 | +∇bias!(::Bool, dx) = NoT |
| 257 | +∇bias!(bias, dx) = sum!(bias, dx) |
| 258 | + |
| 259 | +function Base.show(io::IO, p::PreLayer) |
| 260 | + show(io, p.layer) |
| 261 | + printstyled(io, " |> pre", color=:blue) |
| 262 | +end |
| 263 | + |
| 264 | +Flux._show_children(p::PreLayer) = Flux._show_children(p.layer) |
0 commit comments