Skip to content

Add EnzymeRules for batchnorm #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
143 changes: 143 additions & 0 deletions ext/NNlibCUDACUDNNExt/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
cudnnBatchNormalizationForwardTraining
import NNlib: batchnorm, ∇batchnorm

using EnzymeCore

# TODO: replace with new cudnn normalization interface
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl

Expand Down Expand Up @@ -153,3 +155,144 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
end



function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT},
y::OutType,
g,
b,
x,
running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}

if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(y.val, b.val, x.val, running_mean.val, running_var.val, momentum.val; kws...)
end

primal = if EnzymeCore.EnzymeRules.needs_primal(config)
y.val
else
nothing
end
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
y.dval
else
nothing
end

cache_g = nothing
cache_x = nothing
cache_running_mean = nothing
cache_running_var = nothing

if !(typeof(y) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const)

if EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_g = copy(g.val)
end
if EnzymeCore.EnzymeRules.overwritten(config)[5]
cache_x = copy(x.val)
end
if EnzymeCore.EnzymeRules.overwritten(config)[6]
cache_running_mean = copy(running_mean.val)
end
if EnzymeCore.EnzymeRules.overwritten(config)[7]
cache_running_var = copy(running_var.val)
end

end
end

cache = (cache_g, cache_x, cache_running_mean, cache_running_var)

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT},
cache,
y::OutType, g, b, x, running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}

cache_g, cache_x, cache_running_mean, cache_running_var = cache

if !(typeof(y) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const)

if EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_g = g.val
end
if EnzymeCore.EnzymeRules.overwritten(config)[5]
cache_x = x.val
end
if EnzymeCore.EnzymeRules.overwritten(config)[6]
cache_running_mean = running_mean.val
end
if EnzymeCore.EnzymeRules.overwritten(config)[7]
cache_running_var = running_var.val
end

end
end

dys = y.dval
dgs = (typeof(g) <: EnzymeCore.Const) ? dys : g.dval
dbs = (typeof(b) <: EnzymeCore.Const) ? dbs : b.dval
dxs = (typeof(x) <: EnzymeCore.Const) ? dxs : x.dval

if EnzymeCore.EnzymeRules.width(config) == 1
dys = (dys,)
dxs = (dxs,)
dgs = (dgs,)
dbs = (dbs,)
end

for (dy, dx, dg, db) in zip(dys, dxs, dgs, dbs)
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

if !((typeof(x) <: EnzymeCore.Const) || dx === x.val)
|| !((typeof(g) <: EnzymeCore.Const) || dg === g.val)
|| !((typeof(b) <: EnzymeCore.Const) || db === b.val)

# dx values
alpha = T(1)
beta = T(1)

# dx = alpha * newVal + beta old(dx)
# if x is constant, we can use zero for both
# otherwise we want to do dx += newVal, aka alpha=beta=1
if x <: EnzymeCore.Const
alpha = T(0)
beta = T(0)
dx = similar(x.val)
end

# dg / db values
alpha = T(1)
beta = T(1)

if g <: EnzymeCore.Const && b <: EnzymeCore.Const
dalpha = T(0)
dbeta = T(0)
end

if g <: EnzymeCore.Const
dg = similar(g.val)
end

if b <: EnzymeCore.Const
db = similar(b.val)
end

cudnnBNBackward!(dg, cache_g, db, dx, cache_x, dy,
cache_running_mean, cache_running_var,
momentum.val; alpha, beta, dalpha, dbeta; kw...)

end

dy .= 0

end
end

return (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
end