diff --git a/docs/src/reference.md b/docs/src/reference.md index c01db6b24..ba9cc751c 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -95,6 +95,22 @@ NNlib.unfold NNlib.fold ``` +## Normalization + +These roughly correspond to Flux's `*Norm` layers. + + +```@docs +NNlib.layernorm +NNlib.batchnorm +NNlib.instancenorm +NNlib.groupnorm +NNlib.norm_stats +NNlib.norm_helper +NNlib.RunningStats +NNlib.update_running_stats! +``` + ## Upsampling `Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend. diff --git a/src/NNlib.jl b/src/NNlib.jl index 8450a0261..04be89387 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,7 +1,7 @@ module NNlib import Atomix -import ChainRulesCore: rrule +import ChainRulesCore: rrule, @ignore_derivatives using Base.Broadcast: broadcasted using Base.Threads @@ -16,7 +16,6 @@ using Pkg using Random using Requires using Statistics -using Statistics: mean const libblas = Base.libblas_name diff --git a/src/normalization.jl b/src/normalization.jl index c06843d38..2b6e19727 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -12,3 +12,302 @@ function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, runnin end y, batchnorm_pullback end + +""" + norm_stats(x, dims) + +Calculates sample mean and (uncorrected) variance of `x` along `dims`. + + - `dims=(1,...,N-2,N)` for batchnorm + - `dims=(1,...,N-2)` for instancenorm and groupnorm + - `dims=(1,...,S)` where S < N for layernorm + +This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately, +because it can share some computation across both. +Implementors may want to overload this function to use custom kernels and more. +""" +function norm_stats(x, dims) + μ = mean(x; dims) + σ² = var(x; dims, mean = μ, corrected = false) + return μ, σ² +end + +function rrule(::typeof(norm_stats), x, dims) + μ, mean_pullback = rrule(mean, x; dims) + σ², var_pullback = rrule(var, x; dims, mean = μ, corrected = false) + function norm_stats_pullback(dargs) + dμ, dσ² = unthunk(dargs) + dx = ChainRulesCore.add!!(var_pullback(dμ)[2], mean_pullback(dσ²)[2]) + return (NoTangent(), dx, NoTangent()) + end + return (μ, σ²), norm_stats_pullback +end + +_maybe_reshape(::Nothing, _) = nothing +_maybe_reshape(x, dims) = reshape(x, dims) +_apply_scale_bias(x, ::Nothing, ::Nothing) = x +_apply_scale_bias(x, scale, bias) = x .* scale .+ bias + +""" + norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, + bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) + +Shared code path for all built-in norm functions. + +`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref), +or extracted from an existing collection such as [`RunningStats`](@ref). +`bias` and `scale` are consistent with cuDNN and Flux.Scale. +We opt for `scale` over `weight` to avoid confusion with dense layers. +If the size of the statistics and affine parameters differ, +use `affine_size` to add padding dimensions as required to match the input. +""" +function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, + bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) + @ignore_derivatives if isnothing(scale) != isnothing(bias) + error("both scale and bias must be provided or left as nothing") + end + scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size) + denom = inv.(sqrt.(σ² .+ ϵ)) + return _apply_scale_bias((x .- μ) .* denom, scale′, bias′) +end + +""" + RunningStats(mean, variance, momentum) + +Contains running mean and variance estimates for stateful norm functions. +`momentum` controls the strength of the moving average update. + +Parameters should be mutable and will be updated in-place. + +See also [`update_running_stats!`](@ref). +""" +struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} + mean::M + variance::V + momentum::MT +end + +# Conditionally pulls running stats or calculates them on the fly. +# Part of the reason this is a dedicated function is to have a more type stable pullback. +function maybe_norm_stats(stats::Union{RunningStats, Nothing}, x, dims, + use_running_stats::Bool) + if stats !== nothing && use_running_stats + # Maintains consistency with mean/var + sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1) + return reshape(stats.mean, sz), reshape(stats.variance, sz) + end + # No running stats exist or are disabled in inference mode + return norm_stats(x, dims) +end + +# Kludge so we can close over a Union inner pullback type +struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray}} + back::B + projector::P +end +function (pb::MaybeNormStatsPullback)(dargs) + _, dx = unthunk(pb.back(dargs)) + return (NoTangent(), NoTangent(), pb.projector(dx), NoTangent(), NoTangent()) +end +function rrule(::typeof(maybe_norm_stats), stats::Union{RunningStats, Nothing}, x, dims, + use_running_stats::Bool) + project = ProjectTo(x) + noop_back(_) = (NoTangent(), NoTangent()) + if stats === nothing || !use_running_stats + (μ, σ²), back = rrule(norm_stats, x, dims) + else + # The default is to track, so this only happens when a layer is frozen + sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1) + μ, σ², back = reshape(stats.mean, sz), reshape(stats.variance, sz), noop_back + end + back_type = Union{typeof(noop_back), _rrule_pullback_rt(norm_stats, x, dims)} + return (μ, σ²), MaybeNormStatsPullback{back_type, typeof(project)}(back, project) +end + +""" + update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ², + reduce_dims) where {N} + +Performs a moving average update for layers with tracked statistics. +`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref). +`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref). + +See also [`RunningStats`](@ref). +""" +function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims) + V = eltype(σ²) + momentum = stats.momentum + res_mtm = one(V) - momentum + m = prod(size(x, i) for i in reduce_dims; init = 1) + correction = m / (m - one(V)) + + running_mean, running_var = stats.mean, stats.variance + stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ) + stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²) + return +end + +# Convenience functions +# We follow roughly the same arg order as torch.nn.functional.*_norm: +# input, unique args for this particular norm type, bias + scale, eps; kwargs... + +""" + layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing, + ϵ=ofeltype(x, 1e-5)) where {N, S} + +Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation. + +Normalizes `x` along the first `S` dimensions. + +For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`. + +See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). + +# Examples + +```jldoctest +julia> using Statistics + +julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels + +julia> y = NNlib.layernorm(xs, Val(3)); + +julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) && + std(y; dims = 1:3) != std(xs; dims = 1:3) +true +``` +""" +function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias = nothing, + ϵ = ofeltype(x, 1e-5)) where {N, S} + @ignore_derivatives if S > N + throw(DimensionMismatch("got $S reduction dims for $N-dimensional array")) + end + μ, σ² = norm_stats(x, ntuple(identity, S)) + return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]::Dims{S}) +end + +""" + batchnorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool) where {N} + +Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. + +Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice, +where `N-1` is the "channel" (or "feature", for 2D inputs) dimension. + +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. +`batchnorm` will renormalize the input using these statistics during inference, +and update them using batch-level statistics when training. +To override this behaviour, manually set a value for `training`. + +If specified, `scale` and `bias` will be applied as an additional learned affine transform. + +See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). +""" +function batchnorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool = within_gradient(x)) where {N} + reduce_dims = ((1:(N - 2))..., N) + μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) + # Because μ and σ² could be updated in-place, we compute the output first + y = norm_helper(x, μ, σ², scale, bias, ϵ) + @ignore_derivatives if running_stats !== nothing && training + update_running_stats!(running_stats, x, μ, σ², reduce_dims) + end + return y +end + +""" + instancenorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool)) where {N} + +Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation. + +Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice, + +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. +`instancenorm` will renormalize the input using these statistics during inference, +and update them using batch-level statistics when training. +To override this behaviour, manually set a value for `training`. + +If specified, `scale` and `bias` will be applied as an additional learned affine transform. + +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref). +""" +function instancenorm(x::AbstractArray{<:Any, N}, + running_stats::Union{RunningStats, Nothing} = nothing, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); + training::Bool = within_gradient(x)) where {N} + affine_size = (ntuple(_ -> 1, N - 2)..., size(x, N - 1), :) + reduce_dims = ((1:(N - 2))...,) + μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) + # Because μ and σ² could be updated in-place, we compute the output first + y = norm_helper(x, μ, σ², scale, bias, ϵ, affine_size) + ChainRulesCore.@ignore_derivatives if running_stats !== nothing && training + μ′, σ²′ = mean(μ; dims = N), mean(σ²; dims = N) # Need to sum (C, N) -> (C,) + update_running_stats!(running_stats, x, μ′, σ²′, reduce_dims) + end + return y +end + +""" + groupnorm(x::AbstractArray{<:Any, N}, groups::Integer, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, + ϵ = ofeltype(x, 1e-5)) where {N} + +Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation. + +Normalizes `x` along the first `N - 2` (spatial) dimensions, +where `N-1` is the "channel" (or "feature", for 2D inputs) dimension, +and the channel dimension is divided into `groups` groups along which statistics are computed. +The number of channels must be an integer multiple of the number of groups. + +If specified, `scale` and `bias` will be applied as an additional learned affine transform. + +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref). + +# Examples + +```jldoctest +julia> using Statistics + +julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels + +julia> y = NNlib.groupnorm(xs, 4); + +julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) && + std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1]) +true + +julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) && + std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2]) +true +``` +""" +function groupnorm(x::AbstractArray{<:Any, N}, groups::Integer, + scale::Union{AbstractVector, Nothing} = nothing, + bias::Union{AbstractVector, Nothing} = nothing, + ϵ = ofeltype(x, 1e-5)) where {N} + sz = size(x) + channels = @ignore_derivatives begin + ch = sz[max(1, N - 1)] + newch, remainder = divrem(ch, groups) + remainder == 0 ? newch : + throw(ArgumentError("channels $ch should be multiple of groups $groups")) + end + affine_size = (ntuple(_ -> 1, N - 2)..., channels, groups, :) + grouped_size = (sz[1:(N - 2)]..., channels, groups, :) + x′ = reshape(x, grouped_size) + μ, σ² = norm_stats(x′, ((1:(N - 2))...,)) + return reshape(norm_helper(x′, μ, σ², scale, bias, ϵ, affine_size), sz) +end diff --git a/src/utils.jl b/src/utils.jl index 3d23e7383..9edfb2112 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -162,3 +162,15 @@ if VERSION < v"1.7.0-DEV.793" end end + +# This is a terrible hack to prevent the spread of type instabilities +# when the pullback type changes depending on runtime information, +# e.g. when a normalization layer is "active" vs "inactive". +function _rrule_pullback_rt(@nospecialize(fn), args...) + rt = Base.promote_op(rrule, typeof(fn), map(typeof, args)...) + rt <: Tuple{<:Any,<:Any} && return rt.parameters[2] + return rt +end + +# Extracted from Flux. Should this have a docstring and/or be in the docs? +ofeltype(x, y) = convert(float(eltype(x)), y) \ No newline at end of file diff --git a/test/normalization.jl b/test/normalization.jl new file mode 100644 index 000000000..b38b1ca71 --- /dev/null +++ b/test/normalization.jl @@ -0,0 +1,275 @@ +FiniteDifferences.to_vec(stats::NNlib.RunningStats) = [], _ -> stats + +randn_sample(shape, μ, σ) = randn(rng, shape) .* σ .+ μ +f32_arange(shape...) = Float32.(reshape(1:prod(shape), shape)) + +function make_bn(ch; training = true) + stats, bias, scale = NNlib.RunningStats(zeros(ch), ones(ch), 0.1), zeros(ch), ones(ch) + return x -> NNlib.batchnorm(x, stats, scale, bias; training) +end +function make_in(ch; training = true) + stats, bias, scale = NNlib.RunningStats(zeros(ch), ones(ch), 0.1), zeros(ch), ones(ch) + return x -> NNlib.instancenorm(x, stats, scale, bias; training) +end +function make_gn(ch, groups) + bias, scale = zeros(ch), ones(ch) + return x -> NNlib.groupnorm(x, groups, scale, bias) +end + +@testset "Helpers" begin + # BatchNorm dimensions + let W = 128, C = 4, N = 64 + x = cat([randn_sample((W, W, 1, N), i, i) for i in 1:C]...; dims = 3) + μ, σ² = NNlib.norm_stats(x, (1, 2, 4)) + @test vec(μ)≈1:C rtol=0.05 + @test vec(σ²)≈abs2.(1:C) rtol=0.05 + end + + # LayerNorm dimensions + let W = 128, C = 64, N = 4 + x = cat([randn_sample((W, W, C, 1), i, i) for i in 1:N]...; dims = 4) + μ, σ² = NNlib.norm_stats(x, (1, 2, 3)) + @test vec(μ)≈1:N rtol=0.05 + @test vec(σ²)≈abs2.(1:N) rtol=0.05 + end + + # Group/InstanceNorm dimensions + let W = 128, C = 2, N = 2, shape = (W, W, 1, 1) + # Tile to W x W x 2 x 2 + x = cat(cat(randn_sample(shape, 1, 1), randn_sample(shape, 2, 2); dims = 3), + cat(randn_sample(shape, 3, 3), randn_sample(shape, 4, 4); dims = 3); + dims = 4) + μ, σ² = NNlib.norm_stats(x, (1, 2)) + @test vec(μ)≈1:(C * N) rtol=0.05 + @test vec(σ²)≈abs2.(1:(C * N)) rtol=0.05 + end + + x = rand(rng, 16, 16, 3, 4) + @testset "dims = $dims" for (dims, tsize) in [ + (1, 2, 4) => (1, 1, size(x, 3), 1), + (1, 2, 3) => (1, 1, 1, size(x, 4)), + (1, 2) => (1, 1, size(x, 3), size(x, 4)), + ] + meanvar = (ones(tsize), ones(tsize)) + test_rrule(NNlib.norm_stats, x, dims ⊢ NoTangent(); output_tangent = meanvar) + + running_stats = NNlib.RunningStats(meanvar..., 0.1) + y_ns, back_ns = rrule(NNlib.norm_stats, x, dims) + dx_ns = back_ns(meanvar)[2] + for (stats, training, y, y_ad, dx) in [ + (nothing, true, y_ns, y_ns, dx_ns), + (nothing, false, y_ns, y_ns, dx_ns), + (running_stats, true, y_ns, y_ns, dx_ns), + (running_stats, false, meanvar, meanvar, NoTangent()), + ] + ŷ = NNlib.maybe_norm_stats(stats, x, dims, !training) + @test ŷ[1]≈y[1] rtol=1e-5 + @test ŷ[2]≈y[2] rtol=1e-5 + ŷ, back = rrule(NNlib.maybe_norm_stats, stats, x, dims, !training) + @test ŷ == y_ad + @test back(meanvar) == (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent()) + + test_rrule(NNlib.maybe_norm_stats, stats ⊢ NoTangent(), x, dims ⊢ NoTangent(), + !training; output_tangent = meanvar, check_inferred = false) + end + + ps = ntuple(_ -> rand(rng, tsize...), 4) + gradtest((args...) -> NNlib.norm_helper(args..., size(ps[1])), x, ps..., 1e-5) + end + + p = ones(16, 16) + @test_throws ErrorException NNlib.norm_helper(x, p, p, nothing, p, 1e-5) + @test_throws ErrorException NNlib.norm_helper(x, p, p, p, nothing, 1e-5) +end + +@testset "Layer Norm" begin + full_size = (2, 3, 4, 5) + @testset for xdims in 2:4, kdims in 1:(xdims - 1) + x = rand(rng, full_size[1:xdims]...) + bias, scale = ntuple(_ -> rand(rng, full_size[1:kdims]...), 2) + dims = Val(kdims) + + y = @inferred NNlib.layernorm(x, dims) + @test size(y) == size(x) + y = @inferred NNlib.layernorm(x, dims, scale, bias) + @test size(y) == size(x) + + # FiniteDifferences gives incorrect results on some but not all args, why? + gradtest(x -> NNlib.layernorm(x, dims), x; broken = true) + gradtest((x, s, b) -> NNlib.layernorm(x, dims, s, b), x, scale, bias; skip = true) + end +end + +@testset "Batch Norm" begin + let x = [1.0 3.0 5.0; 2.0 4.0 6.0], bias = zeros(2), scale = ones(2) + @testset for use_stats in (true, false) + stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing + y, back = Zygote.pullback(NNlib.batchnorm, x, stats, scale, bias, 1e-5) + @test y≈[-1.22474 0 1.22474; -1.22474 0 1.22474] atol=1e-5 + + expected_mean, expected_var = [0.3, 0.4], [1.3, 1.3] + if use_stats + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test stats.mean ≈ expected_mean + # σ² of batch will be + # sum(abs2, [1., 3., 5.] .- 3) / 2 = 2.6... + # sum(abs2, [2., 4., 6.] .- 4) / 2 = 2.6... + # + # ∴ update rule with momentum: + # .1 * (3 / (3 - 1)) * 2.6 + (1 - .1) * 1 = 1.3 + @test stats.variance ≈ expected_var + end + + dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1)) + @test dx≈[3.06186 0.612371 -1.83711; 3.06186 0.612371 -1.83711] atol=1e-5 + @test dscale == zeros(2) + @test dbias == fill(3.0, 2) + @test dstats === nothing + + if use_stats + tmp_mean, tmp_var = copy(stats.mean), copy(stats.variance) + x′ = @inferred NNlib.batchnorm(x, stats, scale, bias, 1e-5) + @test x′[1]≈((1 - expected_mean[1]) / sqrt(expected_var[1])) atol=1e-5 + # Stats should be unchanged + @test stats.mean == tmp_mean + @test stats.variance == tmp_var + end + end + end + + let x = f32_arange(3, 2, 1), m = make_bn(2) + y = reshape(permutedims(x, [2, 1, 3]), 2, :) + y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) + @test m(x) == y + @inferred m(x) + end + + let x = f32_arange(2, 3, 2, 1), m = make_bn(2) + y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) + y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) + @test m(x) == y + @inferred m(x) + end + + let x = f32_arange(2, 2, 3, 2, 1), m = make_bn(2) + y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) + y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) + @test m(x) == y + @inferred m(x) + end + + let x = randn(Float32, 416, 416, 32, 1), m = make_bn(32; training = false) + @test (@allocated m(x)) < 100_000_000 + end +end + +@testset "Instance Norm" begin + let x = reshape(1.0:12.0, 3, 2, 2), bias = zeros(2), scale = ones(2) + @testset for use_stats in (true, false) + stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing + y, back = Zygote.pullback(NNlib.instancenorm, x, stats, scale, bias, 1e-5) + @test y≈repeat([-1.22474, 0.0, 1.22474], 1, 2, 2) rtol=1e-5 + + expected_mean, expected_var = [0.5, 0.8], [1.0, 1.0] + if use_stats + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # .1 * (2. + 8.) / 2 + 0 = .5 + # .1 * (5. + 11.) / 2 + 0 = .8 + @test stats.mean ≈ expected_mean + # σ² will be + # sum(abs2, [1. + 2. + 3.] .- 2) / 3 = 2.6... + # sum(abs2, [4. + 5. + 6.] .- 5) / 3 = 2.6... + # sum(abs2, [7. + 8. + 9.] .- 8) / 3 = 2.6... + # sum(abs2, [10. + 11. + 12.] .- 11) / 3 = 2.6... + # + # ∴ update rule with momentum: + # .1 * (3 / (3 - 1)) * 2.6... + (1 - .1) * 1 = 1. + @test stats.variance ≈ expected_var + end + + dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1)) + @test dx≈repeat([3.6742, 1.22474, -1.22474], 1, 2, 2) rtol=1e-5 + @test dscale == zeros(2) + @test dbias == fill(6.0, 2) + @test dstats === nothing + + if use_stats + tmp_mean, tmp_var = copy(stats.mean), copy(stats.variance) + x′ = @inferred NNlib.instancenorm(x, stats, scale, bias, 1e-5) + @test x′[1]≈((1 - expected_mean[1]) / sqrt(expected_var[1])) atol=1e-5 + # Stats should be unchanged + @test stats.mean == tmp_mean + @test stats.variance == tmp_var + end + end + end + + let m = make_in(2), shape = (2, 4, 1, 2, 3), x = f32_arange(shape...) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), shape...) + @test m(x) == y + @inferred m(x) + end + + # Instance norm == batch norm when channel and batch dims are squashed + let m_inorm = make_in(2; training = true), m_bnorm = make_bn(12; training = true), + shape = (5, 5, 3, 4, 2, 6), x = f32_arange(shape...) + + x′ = reshape(x, (shape[1:(end - 2)]..., :, 1)) + @test m_inorm(x) == reshape(m_bnorm(x′), shape) + end + + let m = make_in(32), x = randn(Float32, 416, 416, 32, 1) + @test (@allocated m(x)) < 100_000_000 + end +end + +@testset "Group Norm" begin + full_size = (2, 3, 6, 5) + @testset for xdims in 1:3, groups in (1, 2, 3, 6) + x = rand(rng, full_size[(end - xdims):end]...) + bias, scale = ntuple(_ -> rand(rng, full_size[end - 1]), 2) + + y = @inferred NNlib.groupnorm(x, groups) + @test size(y) == size(x) + y = @inferred NNlib.groupnorm(x, groups, scale, bias) + @test size(y) == size(x) + + # FiniteDifferences gives incorrect results on some but not all args, why? + gradtest(x -> NNlib.groupnorm(x, groups), x; broken = true) + gradtest((x, s, b) -> NNlib.groupnorm(x, groups, s, b), x, scale, bias; skip = true) + end + + let m = make_gn(4, 2), shape = (5, 5, 3, 4, 4, 6) + y = Zygote.pullback(m, f32_arange(shape...))[1] + @test size(y) == shape + end + + let m = make_gn(2, 2), shape = (2, 4, 1, 2, 3), x = f32_arange(shape...) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), shape...) + @test m(x) == y + end + + # Group norm == instance norm when the group size == number of channels + let m_inorm = make_in(4), m_gnorm = make_gn(4, 4), x = f32_arange(2, 2, 3, 4, 5) + @test m_inorm(x) ≈ m_gnorm(x) + end + + # Group norm == batch norm for a group of size 1 and batch of size 1 + let m_bnorm = make_bn(4), m_gnorm = make_gn(4, 4), x = f32_arange(2, 2, 3, 4, 1) + @test m_bnorm(x) ≈ m_gnorm(x) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 8b359ad87..4d0fc8bba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -159,4 +159,8 @@ end else @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end + + @testset "Normalization" begin + include("normalization.jl") + end end