@@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
3
3
cudnnBatchNormalizationForwardTraining
4
4
import NNlib: batchnorm, ∇batchnorm
5
5
6
+ using EnzymeCore
7
+
6
8
# TODO : replace with new cudnn normalization interface
7
9
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl
8
10
@@ -153,3 +155,148 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
153
155
scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
154
156
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
155
157
end
158
+
159
+
160
+
161
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof(cudnnBNForward!)} , :: Type{RT} ,
162
+ y:: OutType ,
163
+ g,
164
+ b,
165
+ x,
166
+ running_mean, running_var, momentum:: EnzymeCore.Const{<:Real} ; kws... ) where {OutType, RT}
167
+
168
+ if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
169
+ func. val (y. val, b. val, x. val, running_mean. val, running_var. val, momentum. val; kws... )
170
+ end
171
+
172
+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
173
+ y. val
174
+ else
175
+ nothing
176
+ end
177
+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
178
+ y. dval
179
+ else
180
+ nothing
181
+ end
182
+
183
+ cache_g = nothing
184
+ cache_x = nothing
185
+ cache_running_mean = nothing
186
+ cache_running_var = nothing
187
+
188
+ if ! (typeof (y) <: EnzymeCore.Const )
189
+ if ! (typeof (x) <: EnzymeCore.Const )
190
+ || ! (typeof (g) <: EnzymeCore.Const )
191
+ || ! (typeof (b) <: EnzymeCore.Const )
192
+
193
+ if EnzymeCore. EnzymeRules. overwritten (config)[3 ]
194
+ cache_g = copy (g. val)
195
+ end
196
+ if EnzymeCore. EnzymeRules. overwritten (config)[5 ]
197
+ cache_x = copy (x. val)
198
+ end
199
+ if EnzymeCore. EnzymeRules. overwritten (config)[6 ]
200
+ cache_running_mean = copy (running_mean. val)
201
+ end
202
+ if EnzymeCore. EnzymeRules. overwritten (config)[7 ]
203
+ cache_running_var = copy (running_var. val)
204
+ end
205
+
206
+ end
207
+ end
208
+
209
+ cache = (cache_g, cache_x, cache_running_mean, cache_running_var)
210
+
211
+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache)
212
+ end
213
+
214
+ function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(cudnnBNForward!)} , :: Type{RT} ,
215
+ cache,
216
+ y:: OutType , g, b, x, running_mean, running_var, momentum:: EnzymeCore.Const{<:Real} ; kws... ) where {OutType, RT}
217
+
218
+ cache_g, cache_x, cache_running_mean, cache_running_var = cache
219
+
220
+ if ! (typeof (y) <: EnzymeCore.Const )
221
+ if ! (typeof (x) <: EnzymeCore.Const )
222
+ || ! (typeof (g) <: EnzymeCore.Const )
223
+ || ! (typeof (b) <: EnzymeCore.Const )
224
+
225
+ if EnzymeCore. EnzymeRules. overwritten (config)[3 ]
226
+ cache_g = g. val
227
+ end
228
+ if EnzymeCore. EnzymeRules. overwritten (config)[5 ]
229
+ cache_x = x. val
230
+ end
231
+ if EnzymeCore. EnzymeRules. overwritten (config)[6 ]
232
+ cache_running_mean = running_mean. val
233
+ end
234
+ if EnzymeCore. EnzymeRules. overwritten (config)[7 ]
235
+ cache_running_var = running_var. val
236
+ end
237
+
238
+ end
239
+ end
240
+
241
+ dys = y. dval
242
+ dgs = (typeof (g) <: EnzymeCore.Const ) ? dys : g. dval
243
+ dbs = (typeof (b) <: EnzymeCore.Const ) ? dbs : b. dval
244
+ dxs = (typeof (x) <: EnzymeCore.Const ) ? dxs : x. dval
245
+
246
+ if EnzymeCore. EnzymeRules. width (config) == 1
247
+ dys = (dys,)
248
+ dxs = (dxs,)
249
+ dgs = (dgs,)
250
+ dbs = (dbs,)
251
+ end
252
+
253
+ for (dy, dx, dg, db) in zip (dys, dxs, dgs, dbs)
254
+ if ! (typeof (y) <: EnzymeCore.Const ) && dy != = y. val
255
+
256
+ if ! ((typeof (x) <: EnzymeCore.Const ) || dx === x. val)
257
+ || ! ((typeof (g) <: EnzymeCore.Const ) || dg === g. val)
258
+ || ! ((typeof (b) <: EnzymeCore.Const ) || db === b. val)
259
+
260
+ # dx values
261
+ alpha = T (1 )
262
+ beta = T (1 )
263
+
264
+ # dx = alpha * newVal + beta old(dx)
265
+ # if x is constant, we can use zero for both
266
+ # otherwise we want to do dx += newVal, aka alpha=beta=1
267
+ if x <: EnzymeCore.Const
268
+ alpha = T (0 )
269
+ beta = T (0 )
270
+ dx = similar (x. val)
271
+ end
272
+
273
+ # dg / db values
274
+ alpha = T (1 )
275
+ beta = T (1 )
276
+
277
+ if g <: EnzymeCore.Const && b <: EnzymeCore.Const
278
+ dalpha = T (0 )
279
+ dbeta = T (0 )
280
+ end
281
+
282
+ if g <: EnzymeCore.Const
283
+ dg = similar (g. val)
284
+ end
285
+
286
+ if b <: EnzymeCore.Const
287
+ db = similar (b. val)
288
+ end
289
+
290
+ cudnnBNBackward! (dg, cache_g, db, dx, cache_x, dy,
291
+ cache_running_mean, cache_running_var,
292
+ momentum. val; alpha, beta, dalpha, dbeta; kw... )
293
+
294
+ end
295
+
296
+ dy .= 0
297
+
298
+ end
299
+ end
300
+
301
+ return (nothing , nothing , nothing , nothing , nothing , nothing , nothing )
302
+ end
0 commit comments