Skip to content

Commit dbba9aa

Browse files
committed
PreLayer, take 1
1 parent 8650d54 commit dbba9aa

File tree

5 files changed

+352
-0
lines changed

5 files changed

+352
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
33
version = "0.1.0"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
68
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
710
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
811
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
912
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
13+
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1014
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1115

1216
[compat]

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ As will any features which migrate to Flux itself.
3535

3636
* Layers `Split` and `Join`
3737
* A more advanced `train!`
38+
* Layers with pre-allocated working space

src/Fluxperimental.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ export Split, Join
88
include("train.jl")
99
export shinkansen!
1010

11+
include("preallocated.jl")
12+
export pre, nopre
13+
1114
end # module Fluxperimental

src/preallocated.jl

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
using Flux, ChainRulesCore
2+
using LinearAlgebra: mul!
3+
# using FastBroadcast: @..
4+
using Strided
5+
6+
const NoT = NoTangent()
7+
8+
"""
9+
PreLayer(Dense(2 => 3, relu))
10+
11+
Stores, along with the layer, pre-allocated space for its output,
12+
and all gradient components. Only works on layers it understands.
13+
"""
14+
struct PreLayer{L,G,V}
15+
layer::L
16+
grad::G # same fixed sizes as layer
17+
fwd::V # vector of dynamic length
18+
rev::V
19+
end
20+
21+
Flux.@functor PreLayer
22+
Flux.trainable(p::PreLayer) = (; layer = p.layer)
23+
24+
"""
25+
model |> pre
26+
27+
Wrap as many layers as possible with `PreLayer`,
28+
to store pre-allocated space for output & gradient.
29+
Ignores layers it doesn't understand.
30+
"""
31+
pre(model) = fmap(PreLayer, model; exclude = x -> hasmethod(PreLayer, Tuple{typeof(x)}))
32+
33+
"""
34+
nopre(model)
35+
36+
Remove all `PreLayer`s & return the plain model.
37+
"""
38+
nopre(model) = fmap(x -> x.layer, model; exclude = x -> x isa PreLayer)
39+
40+
41+
#####
42+
##### Dense
43+
#####
44+
45+
function PreLayer(d::Dense)
46+
grad = _struct_sim(d)
47+
fwd, rev = similar(d.weight, 0), similar(d.weight, 0)
48+
PreLayer(d, grad, fwd, rev)
49+
end
50+
51+
function (p::PreLayer{<:Dense})(x::AbstractMatrix{<:Real})
52+
y, dx = _pre_setup(p, x)
53+
_densecall!(y, p, x, dx)
54+
end
55+
56+
function _pre_setup(p::PreLayer{<:Dense}, x) # this function @nograd
57+
_, b = size(x)
58+
o, i = size(p.layer.weight)
59+
if o*b != length(p.fwd)
60+
resize!(p.fwd, o*b)
61+
resize!(p.rev, i*b)
62+
end
63+
y = _pre_reshape(p.fwd, (o,b))
64+
dx = _pre_reshape(p.rev, (i,b))
65+
(; y, dx)
66+
end
67+
68+
function _densecall!(y, p, x, dx)
69+
y .= p.layer.bias
70+
mul!(y, p.layer.weight, x, true, true)
71+
act!(y, p.layer.σ)
72+
y
73+
end
74+
75+
function ChainRulesCore.rrule(::typeof(_densecall!), y, p, x, dx)
76+
y = _densecall!(y, p, x, dx)
77+
function back(dy)
78+
dy = unthunk(dy)
79+
dy = ∇act!(y, dy, p.layer.σ)
80+
# layer
81+
weight = mul!(p.grad.weight, dy, x')
82+
bias = ∇bias!(p.grad.bias, dy)
83+
tang = Tangent{Dense}(; weight, bias)
84+
# input
85+
dx = mul!(dx, p.layer.weight', dy)
86+
return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT)
87+
end
88+
y, back
89+
end
90+
91+
#####
92+
##### Scale
93+
#####
94+
95+
scale!(y, (scale, ds), (x, dx), (bias, db)) = y .= scale .* x .+ bias
96+
# scale!(y, (scale, ds), (x, dx), (bias, db)) = @strided y .= scale .* x .+ bias
97+
98+
function ChainRulesCore.rrule(::typeof(scale!), y, (scale, ds), (x, dx), (bias, db))
99+
y = scale!(y, (scale, ds), (x, dx), (bias, db))
100+
function back(dy)
101+
dy = unthunk(dy)
102+
@strided dx .= dy .* scale
103+
@strided ds .= dy .* x
104+
dbias = ∇bias!(bias, db)
105+
return (NoT, NoT, (ds, NoT), (dx, NoT), (dbias, NoT))
106+
end
107+
y, back
108+
end
109+
110+
#####
111+
##### softmax
112+
#####
113+
114+
function PreLayer(::typeof(softmax))
115+
fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal, demands `model |> pre |> gpu`
116+
PreLayer(softmax, nothing, fwd, rev)
117+
end
118+
119+
function (p::PreLayer{typeof(softmax)})(x::AbstractArray{<:Real})
120+
y, dx = _pre_setup(p, x) # generic version
121+
_softmaxcall!(y, p, x, dx)
122+
end
123+
124+
_softmaxcall!(y, p, x, dx) = softmax!(y, x)
125+
126+
function ChainRulesCore.rrule(::typeof(_softmaxcall!), y, p, x, dx)
127+
y = _softmaxcall!(y, p, x, dx)
128+
function back(dy)
129+
# TODO: CHECK THIS!
130+
dx .= dy .* y
131+
dx .= dx .- y .* sum(dx; dims=1) # could sum! into the end of rev
132+
return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
133+
end
134+
y, back
135+
end
136+
137+
#####
138+
##### BatchNorm
139+
#####
140+
141+
function PreLayer(bn::BatchNorm)
142+
grad == similar(bn.β), γ = similar(bn.γ)) # only trainable fields
143+
fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal
144+
PreLayer(bn, grad, fwd, rev)
145+
end
146+
147+
function (p::PreLayer{<:BatchNorm})(x::AbstractArray{<:Real})
148+
y, dx = _pre_setup(p, x)
149+
# _batchnormcall!(y, p, x, dx)
150+
151+
# from (BN::BatchNorm)(x)
152+
N = ndims(x)
153+
reduce_dims = [1:N-2; N]
154+
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
155+
_norm_layer_forward!(y, p, (x, dx); reduce_dims, affine_shape)
156+
end
157+
158+
using Flux: _isactive, _track_stats!, hasaffine
159+
160+
function _norm_layer_forward!(y, p, (x, dx); reduce_dims, affine_shape)
161+
l = p.layer
162+
N = ndims(x)
163+
164+
# This block verbatim from Flux. However, mean & var aren't in-place,
165+
# nor are their gradients... add more storage?
166+
167+
if !_isactive(l) && l.track_stats # testmode with tracked stats
168+
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
169+
μ = reshape(l.μ, stats_shape)
170+
σ² = reshape(l.σ², stats_shape)
171+
else # trainmode or testmode without tracked stats
172+
μ = mean(x; dims=reduce_dims)
173+
σ² = var(x; mean=μ, dims=reduce_dims, corrected=false)
174+
if l.track_stats
175+
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
176+
end
177+
end
178+
179+
y = _norm_layer_forward!(y, x, dx, μ, σ², l.ϵ)
180+
hasaffine(l) || return act!(y, l.λ)
181+
182+
γ = reshape(l.γ, affine_shape)
183+
β = reshape(l.β, affine_shape)
184+
# return l.λ.(γ .* y .+ β)
185+
y2 = scale!(y, (γ, p.grad.γ), (x, dx), (β, p.grad.β))
186+
return act!(y2, l.λ)
187+
end
188+
189+
_norm_layer_forward!(y, x, dx, μ, σ², ϵ) = y .= (x .- μ) ./ sqrt.(σ² .+ ϵ)
190+
# _norm_layer_forward!(y, x, dx, μ, σ², ϵ) = @strided y .= (x .- μ) ./ sqrt.(σ² .+ ϵ)
191+
192+
function ChainRulesCore.rrule(::typeof(_norm_layer_forward!), y, x, dx, μ, σ², ϵ)
193+
y = _norm_layer_forward!(y, x, dx, μ, σ², ϵ)
194+
function back(dy)
195+
dx .= dy ./ sqrt.(σ² .+ ϵ)
196+
# TODO write gradients for mean & variance, these are WRONG!
197+
= NoT
198+
dσ² = NoT
199+
return (NoT, NoT, dx, NoT, dμ, dσ², NoT)
200+
end
201+
y, back
202+
end
203+
204+
205+
#####
206+
##### activation functions
207+
#####
208+
209+
act!(y, ::typeof(identity)) = y
210+
function act!(y, act::F) where F
211+
σ = Flux.NNlib.fast_act(act, y)
212+
# y .= σ.(y)
213+
# Unfortunately this hits https://github.com/JuliaLang/julia/issues/43153
214+
# maybe you could patch Strided.jl to avoid it? Or use another package...
215+
@strided y .= σ.(y)
216+
# FastBroadcast.@.. y = σ(y)
217+
end
218+
219+
# Piracy, disable @strided on CuArrays:
220+
Strided.maybestrided(x::Flux.CuArray) = x
221+
222+
# For this rule, it's important to use what `act!` returns, not what it mutates
223+
ChainRulesCore.rrule(::typeof(act!), y, f) = act!(y, f), dz -> (NoT, ∇act!(y, dy, f), NoT)
224+
225+
∇act!(y, dy, ::typeof(identity)) = dy
226+
∇act!(y, dy, ::typeof(relu)) = @. y = ifelse(y>0, dy, 0f0)
227+
∇act!(y, dy, ::typeof(tanh)) = @. y = (1 - y^2)
228+
∇act!(y, dy, ::typeof(sigmoid)) = @. y = y * (1 - y)
229+
230+
231+
#####
232+
##### PreLayer utils
233+
#####
234+
235+
_struct_sim(x) = Flux.fmapstructure(x) do x
236+
x isa AbstractArray{<:Real} ? similar(x) : nothing
237+
end
238+
239+
function _pre_setup(p::PreLayer, x) # generic version
240+
if length(x) != length(p.fwd)
241+
resize!(p.fwd, length(x))
242+
resize!(p.rev, length(x))
243+
end
244+
y = _pre_reshape(p.fwd, size(x))
245+
dx = _pre_reshape(p.rev, size(x))
246+
(; y, dx)
247+
end
248+
ChainRulesCore.@non_differentiable _pre_setup(::Any, ::Any)
249+
250+
# Cannot use reshape(::Array), as that prevents later resize!
251+
_pre_reshape(x::Array, size::Tuple) = Base.ReshapedArray(x, size, ())
252+
# Must use reshape(::CuArray) as mul! rejects ReshapedArray
253+
_pre_reshape(x::Flux.CuArray, size::Tuple) = reshape(x, size)
254+
_pre_reshape(x, size::Tuple) = reshape(x, size)
255+
256+
∇bias!(::Bool, dx) = NoT
257+
∇bias!(bias, dx) = sum!(bias, dx)
258+
259+
function Base.show(io::IO, p::PreLayer)
260+
show(io, p.layer)
261+
printstyled(io, " |> pre", color=:blue)
262+
end
263+
264+
Flux._show_children(p::PreLayer) = Flux._show_children(p.layer)

test/preallocated.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
m1 = Chain(Dense(784 => 32, relu), Dense(32 => 10), softmax)
3+
m2 = m1 |> pre
4+
5+
x = randn(Float32, 784, 64);
6+
7+
@test m1(x) m2(x)
8+
9+
g1 = gradient((m,x) -> m(x)[1], m1, x)
10+
g2 = gradient((m,x) -> m(x)[1], m2, x)
11+
12+
@test g1[1].layers[1].bias g2[1].layers[1].layer.bias
13+
@test g1[2] g2[2]
14+
15+
16+
#=
17+
18+
julia> @btime gradient((m,x) -> m(x)[1], $m1, $x);
19+
min 52.167 μs, mean 2.519 ms (58 allocations, 355.41 KiB)
20+
21+
julia> @btime gradient((m,x) -> m(x)[1], $m2, $x);
22+
min 58.750 μs, mean 190.440 μs (109 allocations, 17.44 KiB)
23+
24+
25+
26+
let data = [(x,) for _ in 1:1000]
27+
o1 = Flux.setup(Adam(), m1)
28+
@btime Flux.train!((m,x) -> m(x)[1], $m1, $data, $o1)
29+
30+
o2 = Flux.setup(Adam(), m2)
31+
@btime Flux.train!((m,x) -> m(x)[1], $m2, $data, $o2)
32+
33+
nothing
34+
end
35+
36+
# min 1.799 s, mean 1.802 s (177001 allocations, 352.94 MiB)
37+
# min 146.713 ms, mean 251.041 ms (295001 allocations, 25.71 MiB)
38+
39+
40+
m1cu = m1 |> gpu
41+
m2cu = m2 |> gpu
42+
xcu = x |> gpu
43+
44+
45+
let data = [(xcu,) for _ in 1:1000]
46+
o1 = Flux.setup(Adam(), m1cu)
47+
CUDA.@time Flux.train!((m,x) -> sum(m(x)), m1cu, data, o1)
48+
49+
o2 = Flux.setup(Adam(), m2cu)
50+
CUDA.@time Flux.train!((m,x) -> sum(m(x)), m2cu, data, o2)
51+
52+
nothing
53+
end
54+
# 1.280640 seconds (1.86 M CPU allocations: 111.723 MiB, 10.99% gc time) (17.00 k GPU allocations: 340.008 MiB, 8.80% memmgmt time)
55+
# 1.327849 seconds (1.73 M CPU allocations: 112.376 MiB, 6.70% gc time) (3.00 k GPU allocations: 2.689 MiB, 2.29% memmgmt time)
56+
57+
58+
=#
59+
60+
61+
m3 = Chain(Dense(784 => 1024, tanh), BatchNorm(1024), Dense(1024 => 10), softmax)
62+
m4 = m3 |> pre
63+
64+
x = randn(Float32, 784, 64);
65+
66+
@test m3(x) m4(x)
67+
68+
@btime $m3($x);
69+
@btime $m4($x);
70+
71+
#=
72+
73+
julia> @btime $m3($x);
74+
min 318.000 μs, mean 7.944 ms (31 allocations, 1.01 MiB)
75+
76+
julia> @btime $m4($x);
77+
min 410.459 μs, mean 440.106 μs (57 allocations, 3.55 KiB)
78+
79+
=#
80+

0 commit comments

Comments
 (0)