@@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
33 cudnnBatchNormalizationForwardTraining
44import NNlib: batchnorm, ∇batchnorm
55
6+ using EnzymeCore
7+
68# TODO : replace with new cudnn normalization interface
79# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl
810
@@ -153,3 +155,148 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
153155 scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
154156 xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
155157end
158+
159+
160+
161+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof(cudnnBNForward!)} , :: Type{RT} ,
162+ y:: OutType ,
163+ g,
164+ b,
165+ x,
166+ running_mean, running_var, momentum:: EnzymeCore.Const{<:Real} ; kws... ) where {OutType, RT}
167+
168+ if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
169+ func. val (y. val, b. val, x. val, running_mean. val, running_var. val, momentum. val; kws... )
170+ end
171+
172+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
173+ y. val
174+ else
175+ nothing
176+ end
177+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
178+ y. dval
179+ else
180+ nothing
181+ end
182+
183+ cache_g = nothing
184+ cache_x = nothing
185+ cache_running_mean = nothing
186+ cache_running_var = nothing
187+
188+ if ! (typeof (y) <: EnzymeCore.Const )
189+ if ! (typeof (x) <: EnzymeCore.Const )
190+ || ! (typeof (g) <: EnzymeCore.Const )
191+ || ! (typeof (b) <: EnzymeCore.Const )
192+
193+ if EnzymeCore. EnzymeRules. overwritten (config)[3 ]
194+ cache_g = copy (g. val)
195+ end
196+ if EnzymeCore. EnzymeRules. overwritten (config)[5 ]
197+ cache_x = copy (x. val)
198+ end
199+ if EnzymeCore. EnzymeRules. overwritten (config)[6 ]
200+ cache_running_mean = copy (running_mean. val)
201+ end
202+ if EnzymeCore. EnzymeRules. overwritten (config)[7 ]
203+ cache_running_var = copy (running_var. val)
204+ end
205+
206+ end
207+ end
208+
209+ cache = (cache_g, cache_x, cache_running_mean, cache_running_var)
210+
211+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache)
212+ end
213+
214+ function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(cudnnBNForward!)} , :: Type{RT} ,
215+ cache,
216+ y:: OutType , g, b, x, running_mean, running_var, momentum:: EnzymeCore.Const{<:Real} ; kws... ) where {OutType, RT}
217+
218+ cache_g, cache_x, cache_running_mean, cache_running_var = cache
219+
220+ if ! (typeof (y) <: EnzymeCore.Const )
221+ if ! (typeof (x) <: EnzymeCore.Const )
222+ || ! (typeof (g) <: EnzymeCore.Const )
223+ || ! (typeof (b) <: EnzymeCore.Const )
224+
225+ if EnzymeCore. EnzymeRules. overwritten (config)[3 ]
226+ cache_g = g. val
227+ end
228+ if EnzymeCore. EnzymeRules. overwritten (config)[5 ]
229+ cache_x = x. val
230+ end
231+ if EnzymeCore. EnzymeRules. overwritten (config)[6 ]
232+ cache_running_mean = running_mean. val
233+ end
234+ if EnzymeCore. EnzymeRules. overwritten (config)[7 ]
235+ cache_running_var = running_var. val
236+ end
237+
238+ end
239+ end
240+
241+ dys = y. dval
242+ dgs = (typeof (g) <: EnzymeCore.Const ) ? dys : g. dval
243+ dbs = (typeof (b) <: EnzymeCore.Const ) ? dbs : b. dval
244+ dxs = (typeof (x) <: EnzymeCore.Const ) ? dxs : x. dval
245+
246+ if EnzymeCore. EnzymeRules. width (config) == 1
247+ dys = (dys,)
248+ dxs = (dxs,)
249+ dgs = (dgs,)
250+ dbs = (dbs,)
251+ end
252+
253+ for (dy, dx, dg, db) in zip (dys, dxs, dgs, dbs)
254+ if ! (typeof (y) <: EnzymeCore.Const ) && dy != = y. val
255+
256+ if ! ((typeof (x) <: EnzymeCore.Const ) || dx === x. val)
257+ || ! ((typeof (g) <: EnzymeCore.Const ) || dg === g. val)
258+ || ! ((typeof (b) <: EnzymeCore.Const ) || db === b. val)
259+
260+ # dx values
261+ alpha = T (1 )
262+ beta = T (1 )
263+
264+ # dx = alpha * newVal + beta old(dx)
265+ # if x is constant, we can use zero for both
266+ # otherwise we want to do dx += newVal, aka alpha=beta=1
267+ if x <: EnzymeCore.Const
268+ alpha = T (0 )
269+ beta = T (0 )
270+ dx = similar (x. val)
271+ end
272+
273+ # dg / db values
274+ alpha = T (1 )
275+ beta = T (1 )
276+
277+ if g <: EnzymeCore.Const && b <: EnzymeCore.Const
278+ dalpha = T (0 )
279+ dbeta = T (0 )
280+ end
281+
282+ if g <: EnzymeCore.Const
283+ dg = similar (g. val)
284+ end
285+
286+ if b <: EnzymeCore.Const
287+ db = similar (b. val)
288+ end
289+
290+ cudnnBNBackward! (dg, cache_g, db, dx, cache_x, dy,
291+ cache_running_mean, cache_running_var,
292+ momentum. val; alpha, beta, dalpha, dbeta; kw... )
293+
294+ end
295+
296+ dy .= 0
297+
298+ end
299+ end
300+
301+ return (nothing , nothing , nothing , nothing , nothing , nothing , nothing )
302+ end
0 commit comments