Skip to content

Commit da18428

Browse files
committed
Add batchnorm derivatives
1 parent ee36810 commit da18428

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

ext/NNlibCUDACUDNNExt/batchnorm.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
33
cudnnBatchNormalizationForwardTraining
44
import NNlib: batchnorm, ∇batchnorm
55

6+
using EnzymeCore
7+
68
# TODO: replace with new cudnn normalization interface
79
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl
810

@@ -153,3 +155,148 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
153155
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
154156
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
155157
end
158+
159+
160+
161+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT},
162+
y::OutType,
163+
g,
164+
b,
165+
x,
166+
running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}
167+
168+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
169+
func.val(y.val, b.val, x.val, running_mean.val, running_var.val, momentum.val; kws...)
170+
end
171+
172+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
173+
y.val
174+
else
175+
nothing
176+
end
177+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
178+
y.dval
179+
else
180+
nothing
181+
end
182+
183+
cache_g = nothing
184+
cache_x = nothing
185+
cache_running_mean = nothing
186+
cache_running_var = nothing
187+
188+
if !(typeof(y) <: EnzymeCore.Const)
189+
if !(typeof(x) <: EnzymeCore.Const)
190+
|| !(typeof(g) <: EnzymeCore.Const)
191+
|| !(typeof(b) <: EnzymeCore.Const)
192+
193+
if EnzymeCore.EnzymeRules.overwritten(config)[3]
194+
cache_g = copy(g.val)
195+
end
196+
if EnzymeCore.EnzymeRules.overwritten(config)[5]
197+
cache_x = copy(x.val)
198+
end
199+
if EnzymeCore.EnzymeRules.overwritten(config)[6]
200+
cache_running_mean = copy(running_mean.val)
201+
end
202+
if EnzymeCore.EnzymeRules.overwritten(config)[7]
203+
cache_running_var = copy(running_var.val)
204+
end
205+
206+
end
207+
end
208+
209+
cache = (cache_g, cache_x, cache_running_mean, cache_running_var)
210+
211+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
212+
end
213+
214+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT},
215+
cache,
216+
y::OutType, g, b, x, running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}
217+
218+
cache_g, cache_x, cache_running_mean, cache_running_var = cache
219+
220+
if !(typeof(y) <: EnzymeCore.Const)
221+
if !(typeof(x) <: EnzymeCore.Const)
222+
|| !(typeof(g) <: EnzymeCore.Const)
223+
|| !(typeof(b) <: EnzymeCore.Const)
224+
225+
if EnzymeCore.EnzymeRules.overwritten(config)[3]
226+
cache_g = g.val
227+
end
228+
if EnzymeCore.EnzymeRules.overwritten(config)[5]
229+
cache_x = x.val
230+
end
231+
if EnzymeCore.EnzymeRules.overwritten(config)[6]
232+
cache_running_mean = running_mean.val
233+
end
234+
if EnzymeCore.EnzymeRules.overwritten(config)[7]
235+
cache_running_var = running_var.val
236+
end
237+
238+
end
239+
end
240+
241+
dys = y.dval
242+
dgs = (typeof(g) <: EnzymeCore.Const) ? dys : g.dval
243+
dbs = (typeof(b) <: EnzymeCore.Const) ? dbs : b.dval
244+
dxs = (typeof(x) <: EnzymeCore.Const) ? dxs : x.dval
245+
246+
if EnzymeCore.EnzymeRules.width(config) == 1
247+
dys = (dys,)
248+
dxs = (dxs,)
249+
dgs = (dgs,)
250+
dbs = (dbs,)
251+
end
252+
253+
for (dy, dx, dg, db) in zip(dys, dxs, dgs, dbs)
254+
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
255+
256+
if !((typeof(x) <: EnzymeCore.Const) || dx === x.val)
257+
|| !((typeof(g) <: EnzymeCore.Const) || dg === g.val)
258+
|| !((typeof(b) <: EnzymeCore.Const) || db === b.val)
259+
260+
# dx values
261+
alpha = T(1)
262+
beta = T(1)
263+
264+
# dx = alpha * newVal + beta old(dx)
265+
# if x is constant, we can use zero for both
266+
# otherwise we want to do dx += newVal, aka alpha=beta=1
267+
if x <: EnzymeCore.Const
268+
alpha = T(0)
269+
beta = T(0)
270+
dx = similar(x.val)
271+
end
272+
273+
# dg / db values
274+
alpha = T(1)
275+
beta = T(1)
276+
277+
if g <: EnzymeCore.Const && b <: EnzymeCore.Const
278+
dalpha = T(0)
279+
dbeta = T(0)
280+
end
281+
282+
if g <: EnzymeCore.Const
283+
dg = similar(g.val)
284+
end
285+
286+
if b <: EnzymeCore.Const
287+
db = similar(b.val)
288+
end
289+
290+
cudnnBNBackward!(dg, cache_g, db, dx, cache_x, dy,
291+
cache_running_mean, cache_running_var,
292+
momentum.val; alpha, beta, dalpha, dbeta; kw...)
293+
294+
end
295+
296+
dy .= 0
297+
298+
end
299+
end
300+
301+
return (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
302+
end

0 commit comments

Comments
 (0)