Skip to content

Commit e33de0c

Browse files
mcabbottCarloLucibellodarsnack
authored
Move dropout to NNlib (#2150)
* use NNlib.dropout, deprecate Flux.dropout * improve Dropout's docstring * make Dropout(0) === identity, cannot mutate * NNlibCUDA = 0.2.5 * NNlibCUDA = 0.2.6 * simplify default_rng etc * Revert "simplify default_rng etc" This reverts commit 0e396a6. * un-revert the removal of the active=true method * avoid a branch * Update src/layers/normalise.jl Co-authored-by: Carlo Lucibello <[email protected]> * Apply suggestions from code review Co-authored-by: Kyle Daruwalla <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
1 parent dc21964 commit e33de0c

File tree

5 files changed

+50
-84
lines changed

5 files changed

+50
-84
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ ChainRulesCore = "1.12"
3030
Functors = "0.3, 0.4"
3131
MLUtils = "0.2, 0.3.1, 0.4"
3232
MacroTools = "0.5"
33-
NNlib = "0.8.14"
34-
NNlibCUDA = "0.2.4"
33+
NNlib = "0.8.15"
34+
NNlibCUDA = "0.2.6"
3535
OneHotArrays = "0.1, 0.2"
3636
Optimisers = "0.2.12"
3737
ProgressLogging = "0.1"

docs/src/models/layers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ LayerNorm
123123
InstanceNorm
124124
GroupNorm
125125
Flux.normalise
126-
Flux.dropout
126+
NNlib.dropout
127127
```
128128

129129
### Test vs. Train

src/deprecations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
185185
""")
186186
end
187187

188+
188189
# v0.14 deprecations
189190

190191
# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:

src/layers/normalise.jl

Lines changed: 43 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,77 @@
1-
1+
# Internal function, used only for layers defined in this file.
22
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active
33

4-
_dropout_shape(s, ::Colon) = size(s)
5-
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
6-
7-
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
8-
9-
"""
10-
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true)
11-
12-
The dropout function. If `active` is `true`,
13-
for each input, either sets that input to `0` (with probability
14-
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
15-
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
16-
If `active` is `false`, it just returns the input `x`.
17-
18-
Specify `rng` for custom RNGs instead of the default RNG.
19-
Note that custom RNGs are only supported on the CPU.
20-
21-
Warning: when using this function, you have to manually manage the activation
22-
state. Usually in fact, dropout is used while training
23-
but is deactivated in the inference phase. This can be
24-
automatically managed using the [`Dropout`](@ref) layer instead of the
25-
`dropout` function.
26-
27-
The [`Dropout`](@ref) layer is what you should use in most scenarios.
28-
"""
29-
function dropout(rng, x, p; dims=:, active::Bool=true)
30-
active || return x
31-
y = dropout_mask(rng, x, p, dims=dims)
32-
return x .* y
33-
end
34-
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
35-
36-
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
37-
dropout_mask(rng, x::CuArray, p; kwargs...) =
38-
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
39-
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
40-
function _dropout_mask(rng, x, p; dims=:)
41-
realfptype = float(real(eltype(x)))
42-
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
43-
y .= _dropout_kernel.(y, p, 1 - p)
44-
return y
45-
end
46-
47-
# TODO move this to NNlib
48-
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
49-
504
"""
51-
Dropout(p; dims=:, rng = default_rng_value())
5+
Dropout(p; [dims, rng])
526
53-
Dropout layer.
7+
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
8+
This is used as a regularisation, i.e. to reduce overfitting.
549
55-
While training, for each input, this layer either sets that input to `0` (with probability
56-
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the
57-
`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
58-
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during
59-
training.
10+
While training, it sets each input to `0` (with probability `p`)
11+
or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
12+
While testing, it has no effect.
6013
61-
In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more
62-
details.
14+
By default the mode will switch automatically, but it can also
15+
be controlled manually via [`Flux.testmode!`](@ref).
6316
64-
Specify `rng` to use a custom RNG instead of the default.
65-
Custom RNGs are only supported on the CPU.
17+
By default every input is treated independently. With the `dims` keyword,
18+
instead it takes a random choice only along that dimension.
19+
For example `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
20+
(also called 2D dropout).
6621
67-
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
22+
Keyword `rng` lets you specify a custom random number generator.
23+
(Only supported on the CPU.)
6824
6925
# Examples
70-
```jldoctest
71-
julia> m = Chain(Dense(1 => 1), Dropout(1));
26+
```julia
27+
julia> m = Chain(Dense(ones(3,2)), Dropout(0.4))
28+
Chain(
29+
Dense(2 => 3), # 9 parameters
30+
Dropout(0.4),
31+
)
7232
73-
julia> Flux.trainmode!(m);
33+
julia> m(ones(2, 7)) # test mode, no effect
34+
3×7 Matrix{Float64}:
35+
2.0 2.0 2.0 2.0 2.0 2.0 2.0
36+
2.0 2.0 2.0 2.0 2.0 2.0 2.0
37+
2.0 2.0 2.0 2.0 2.0 2.0 2.0
7438
75-
julia> y = m([1]);
39+
julia> Flux.trainmode!(m); # would happen within gradient
7640
77-
julia> y == [0]
78-
true
41+
julia> m(ones(2, 7))
42+
3×7 Matrix{Float64}:
43+
0.0 0.0 3.33333 0.0 0.0 0.0 0.0
44+
3.33333 0.0 3.33333 0.0 3.33333 0.0 3.33333
45+
3.33333 3.33333 0.0 3.33333 0.0 0.0 3.33333
7946
80-
julia> m = Chain(Dense(1000 => 1000), Dropout(0.5));
47+
julia> y = m(ones(2, 10_000));
8148
82-
julia> Flux.trainmode!(m);
49+
julia> using Statistics
8350
84-
julia> y = m(ones(1000));
51+
julia> mean(y) # is about 2.0, as for test mode
52+
1.9892222222222182
8553
86-
julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1)
87-
true
54+
julia> mean(iszero, y) # is about 0.4
55+
0.40323333333333333
8856
```
8957
"""
90-
mutable struct Dropout{F,D,R<:AbstractRNG}
58+
mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
9159
p::F
9260
dims::D
9361
active::Union{Bool, Nothing}
9462
rng::R
9563
end
96-
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value())
64+
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
9765

98-
function Dropout(p; dims=:, rng = default_rng_value())
99-
@assert 0 p 1
66+
function Dropout(p::Real; dims=:, rng = default_rng_value())
67+
0 p 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p"))
10068
Dropout(p, dims, nothing, rng)
10169
end
10270

10371
@functor Dropout
10472
trainable(a::Dropout) = (;)
10573

106-
function (a::Dropout)(x)
107-
_isactive(a, x) || return x
108-
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
109-
end
74+
(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims)
11075

11176
testmode!(m::Dropout, mode=true) =
11277
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

test/layers/normalisation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Flux, Test, Statistics
1+
using Flux, Test, Statistics, Random
22
using Zygote: pullback, ForwardDiff
33

44
evalwgrad(f, x...) = pullback(f, x...)[1]
@@ -56,10 +56,10 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
5656
y = m(x)
5757
@test count(a->a == 0, y) > 50
5858

59-
y = Flux.dropout(values(rng_kwargs)..., x, 0.9, active=true)
59+
y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true)
6060
@test count(a->a == 0, y) > 50
6161

62-
y = Flux.dropout(values(rng_kwargs)..., x, 0.9, active=false)
62+
y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) # , active=false)
6363
@test count(a->a == 0, y) == 0
6464

6565
# CPU RNGs map onto CPU ok

0 commit comments

Comments
 (0)