|
1 |
| - |
| 1 | +# Internal function, used only for layers defined in this file. |
2 | 2 | _isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active
|
3 | 3 |
|
4 |
| -_dropout_shape(s, ::Colon) = size(s) |
5 |
| -_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) |
6 |
| - |
7 |
| -_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) |
8 |
| - |
9 |
| -""" |
10 |
| - dropout([rng = rng_from_array(x)], x, p; dims=:, active=true) |
11 |
| -
|
12 |
| -The dropout function. If `active` is `true`, |
13 |
| -for each input, either sets that input to `0` (with probability |
14 |
| -`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions, |
15 |
| -e.g. `dims=1` applies dropout along columns and `dims=2` along rows. |
16 |
| -If `active` is `false`, it just returns the input `x`. |
17 |
| -
|
18 |
| -Specify `rng` for custom RNGs instead of the default RNG. |
19 |
| -Note that custom RNGs are only supported on the CPU. |
20 |
| -
|
21 |
| -Warning: when using this function, you have to manually manage the activation |
22 |
| -state. Usually in fact, dropout is used while training |
23 |
| -but is deactivated in the inference phase. This can be |
24 |
| -automatically managed using the [`Dropout`](@ref) layer instead of the |
25 |
| -`dropout` function. |
26 |
| -
|
27 |
| -The [`Dropout`](@ref) layer is what you should use in most scenarios. |
28 |
| -""" |
29 |
| -function dropout(rng, x, p; dims=:, active::Bool=true) |
30 |
| - active || return x |
31 |
| - y = dropout_mask(rng, x, p, dims=dims) |
32 |
| - return x .* y |
33 |
| -end |
34 |
| -dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) |
35 |
| - |
36 |
| -dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) |
37 |
| -dropout_mask(rng, x::CuArray, p; kwargs...) = |
38 |
| - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) |
39 |
| -dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) |
40 |
| -function _dropout_mask(rng, x, p; dims=:) |
41 |
| - realfptype = float(real(eltype(x))) |
42 |
| - y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) |
43 |
| - y .= _dropout_kernel.(y, p, 1 - p) |
44 |
| - return y |
45 |
| -end |
46 |
| - |
47 |
| -# TODO move this to NNlib |
48 |
| -ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) |
49 |
| - |
50 | 4 | """
|
51 |
| - Dropout(p; dims=:, rng = default_rng_value()) |
| 5 | + Dropout(p; [dims, rng]) |
52 | 6 |
|
53 |
| -Dropout layer. |
| 7 | +Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability. |
| 8 | +This is used as a regularisation, i.e. to reduce overfitting. |
54 | 9 |
|
55 |
| -While training, for each input, this layer either sets that input to `0` (with probability |
56 |
| -`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the |
57 |
| -`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input |
58 |
| -(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during |
59 |
| -training. |
| 10 | +While training, it sets each input to `0` (with probability `p`) |
| 11 | +or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function. |
| 12 | +While testing, it has no effect. |
60 | 13 |
|
61 |
| -In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more |
62 |
| -details. |
| 14 | +By default the mode will switch automatically, but it can also |
| 15 | +be controlled manually via [`Flux.testmode!`](@ref). |
63 | 16 |
|
64 |
| -Specify `rng` to use a custom RNG instead of the default. |
65 |
| -Custom RNGs are only supported on the CPU. |
| 17 | +By default every input is treated independently. With the `dims` keyword, |
| 18 | +instead it takes a random choice only along that dimension. |
| 19 | +For example `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input |
| 20 | +(also called 2D dropout). |
66 | 21 |
|
67 |
| -Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. |
| 22 | +Keyword `rng` lets you specify a custom random number generator. |
| 23 | +(Only supported on the CPU.) |
68 | 24 |
|
69 | 25 | # Examples
|
70 |
| -```jldoctest |
71 |
| -julia> m = Chain(Dense(1 => 1), Dropout(1)); |
| 26 | +```julia |
| 27 | +julia> m = Chain(Dense(ones(3,2)), Dropout(0.4)) |
| 28 | +Chain( |
| 29 | + Dense(2 => 3), # 9 parameters |
| 30 | + Dropout(0.4), |
| 31 | +) |
72 | 32 |
|
73 |
| -julia> Flux.trainmode!(m); |
| 33 | +julia> m(ones(2, 7)) # test mode, no effect |
| 34 | +3×7 Matrix{Float64}: |
| 35 | + 2.0 2.0 2.0 2.0 2.0 2.0 2.0 |
| 36 | + 2.0 2.0 2.0 2.0 2.0 2.0 2.0 |
| 37 | + 2.0 2.0 2.0 2.0 2.0 2.0 2.0 |
74 | 38 |
|
75 |
| -julia> y = m([1]); |
| 39 | +julia> Flux.trainmode!(m); # would happen within gradient |
76 | 40 |
|
77 |
| -julia> y == [0] |
78 |
| -true |
| 41 | +julia> m(ones(2, 7)) |
| 42 | +3×7 Matrix{Float64}: |
| 43 | + 0.0 0.0 3.33333 0.0 0.0 0.0 0.0 |
| 44 | + 3.33333 0.0 3.33333 0.0 3.33333 0.0 3.33333 |
| 45 | + 3.33333 3.33333 0.0 3.33333 0.0 0.0 3.33333 |
79 | 46 |
|
80 |
| -julia> m = Chain(Dense(1000 => 1000), Dropout(0.5)); |
| 47 | +julia> y = m(ones(2, 10_000)); |
81 | 48 |
|
82 |
| -julia> Flux.trainmode!(m); |
| 49 | +julia> using Statistics |
83 | 50 |
|
84 |
| -julia> y = m(ones(1000)); |
| 51 | +julia> mean(y) # is about 2.0, as for test mode |
| 52 | +1.9892222222222182 |
85 | 53 |
|
86 |
| -julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1) |
87 |
| -true |
| 54 | +julia> mean(iszero, y) # is about 0.4 |
| 55 | +0.40323333333333333 |
88 | 56 | ```
|
89 | 57 | """
|
90 |
| -mutable struct Dropout{F,D,R<:AbstractRNG} |
| 58 | +mutable struct Dropout{F<:Real,D,R<:AbstractRNG} |
91 | 59 | p::F
|
92 | 60 | dims::D
|
93 | 61 | active::Union{Bool, Nothing}
|
94 | 62 | rng::R
|
95 | 63 | end
|
96 |
| -Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value()) |
| 64 | +Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value()) |
97 | 65 |
|
98 |
| -function Dropout(p; dims=:, rng = default_rng_value()) |
99 |
| - @assert 0 ≤ p ≤ 1 |
| 66 | +function Dropout(p::Real; dims=:, rng = default_rng_value()) |
| 67 | + 0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p")) |
100 | 68 | Dropout(p, dims, nothing, rng)
|
101 | 69 | end
|
102 | 70 |
|
103 | 71 | @functor Dropout
|
104 | 72 | trainable(a::Dropout) = (;)
|
105 | 73 |
|
106 |
| -function (a::Dropout)(x) |
107 |
| - _isactive(a, x) || return x |
108 |
| - return dropout(a.rng, x, a.p; dims=a.dims, active=true) |
109 |
| -end |
| 74 | +(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims) |
110 | 75 |
|
111 | 76 | testmode!(m::Dropout, mode=true) =
|
112 | 77 | (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
|
0 commit comments