Skip to content

Commit 5745555

Browse files
finish up
1 parent 2b9b219 commit 5745555

File tree

8 files changed

+173
-424
lines changed

8 files changed

+173
-424
lines changed

Project.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,13 @@ version = "0.13.14"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8-
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
98
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
11-
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1210
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1411
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1512
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1613
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1714
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
18-
NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f"
1915
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
2016
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2117
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
@@ -26,7 +22,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2622
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2723
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2824
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
29-
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
3025
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3126

3227
[compat]

src/Flux.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
2323
RNN, LSTM, GRU, GRUv3,
2424
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
2525
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
26-
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
26+
Dropout, AlphaDropout,
27+
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
28+
MultiHeadAttention,
2729
Upsample, PixelShuffle,
2830
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
2931
testmode!, trainmode!
@@ -59,6 +61,7 @@ include("layers/conv.jl")
5961
include("layers/recurrent.jl")
6062
include("layers/normalise.jl")
6163
include("layers/upsample.jl")
64+
include("layers/attention.jl")
6265
include("layers/show.jl")
6366

6467
include("loading.jl")

src/layers/attention.jl

Lines changed: 55 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,32 @@
1-
using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2-
using CUDA
3-
using NeuralAttentionlib
4-
using NeuralAttentionlib: score_returning
5-
using BenchmarkTools
6-
using Flux: glorot_uniform
7-
CUDA.allowscalar(false)
81

92
const A3{T} = AbstractArray{T, 3}
10-
const A4{T} = AbstractArray{T, 4}
113
const TuplInt2 = Union{Int, Tuple{Int, Int}}
124
const TuplInt3 = Union{Int, Tuple{Int, Int, Int}}
135

14-
include("attention_nnlib.jl")
15-
include("attention_tullio.jl")
16-
17-
186
"""
19-
MultiHeadAttention(dims, nheads; [bias, init, dropout_prob])
7+
MultiHeadAttention(dims; [nheads, bias, init, dropout_prob])
8+
9+
The multi-head dot-product attention layer used in Transformer architectures [1].
2010
21-
Multi-head dot-product attention layer.
11+
[1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.
2212
2313
# Arguments
2414
25-
- `dims`: ...
26-
- `nheads`: number of heads.
27-
- `init`: weight initializer for the Dense layers.
28-
- `bias` : whether pointwise QKVO dense transforms use bias.
29-
- `dropout_prob`: dropout probability for the attention scores.
15+
- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
16+
In the most general case, it is given as
17+
`(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
18+
Can take also simpler forms as
19+
`dims::Int`, `in_dim::Int => (qk_dim, v_dim) => out_dim`,
20+
`in_dim::Int => qkv_dim => out_dim`.
21+
22+
- `nheads`: number of heads. Default `8`.
23+
- `init`: weight initializer for the Dense layers. Default `glorot_uniform`.
24+
- `bias` : whether pointwise QKVO dense transforms use bias. Default `false`.
25+
- `dropout_prob`: dropout probability for the attention scores. Default `0.0`.
3026
3127
# Forward
3228
33-
(::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
29+
(mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
3430
3531
- `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`.
3632
- `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`.
@@ -39,38 +35,58 @@ Multi-head dot-product attention layer.
3935
`(kv_len, q_len, nheads, batch_size)`. Default `nothing`.
4036
- `withscores`: Whether to return the attention scores. Default `false`.
4137
38+
In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention)
39+
and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).
40+
41+
42+
See also [`NNlib.dot_product_attention`](@ref).
43+
4244
# Examples
4345
4446
```julia
45-
mha = MultiHeadAttention(64, 8)
47+
mha = MultiHeadAttention(64, nheads = 8)
48+
q = rand(Float32, (64, 10, 32))
49+
k = rand(Float32, (64, 20, 32))
50+
v = rand(Float32, (64, 20, 32))
51+
y = mha(q, k, v) # [y] = [64, 10, 32]
52+
53+
mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8)
54+
y = mha(q) # self-attention; [y] = [1024, 10, 32]
4655
```
4756
"""
4857
struct MultiHeadAttention{P1, D, P2}
4958
nheads::Int
50-
qkv_proj::P1
59+
q_proj::P1
60+
k_proj::P1
61+
v_proj::P1
5162
attn_drop::D
5263
out_proj::P2
5364
end
5465

5566
@functor MultiHeadAttention
5667

57-
function MultiHeadAttention(dims, nheads::Int;
68+
function MultiHeadAttention(dims;
69+
nheads::Int = 8,
5870
bias::Bool = false,
5971
init = glorot_uniform,
6072
dropout_prob = 0.0)
6173

62-
dims = mha_process_dims(dims)
74+
dims = normalize_mha_dims(dims)
6375
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
64-
qkv_proj = QKVProj(dims; bias, init)
76+
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
77+
q_proj = Dense(dims.q_in => dims.qk; bias, init)
78+
k_proj = Dense(dims.k_in => dims.qk; bias, init)
79+
v_proj = Dense(dims.v_in => dims.v; bias, init)
6580
attn_drop = Dropout(dropout_prob)
6681
out_proj = Dense(dims.v => dims.out; bias, init)
67-
return MultiHeadAttention(nheads, qkv_proj, attn_drop, out_proj)
82+
return MultiHeadAttention(nheads, q_proj, k_proj, v_proj, attn_drop, out_proj)
6883
end
6984

70-
mha_process_dims(dims::Int) =
85+
# turns the dims argument into a named tuple
86+
normalize_mha_dims(dims::Int) =
7187
(; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims)
7288

73-
function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}})
89+
function normalize_mha_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}})
7490
if in isa Int
7591
q_in = k_in = v_in = in
7692
else
@@ -85,209 +101,22 @@ function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2,
85101
end
86102

87103
# self-attention
88-
(m::MultiHeadAttention)(qkv; kws...) = m(qkv, qkv, qkv; kws...)
104+
(mha::MultiHeadAttention)(qkv; kws...) = mha(qkv, qkv, qkv; kws...)
89105

90106
# key and value are the same
91-
(m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...)
107+
(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...)
92108

93-
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing;
94-
withscores=false, mask=nothing, impl=:nnlib)
109+
function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing;
110+
withscores=false, mask=nothing)
95111
## [q_in] = [q_in_dim, q_len, batch_size]
96112
## [k_in] = [k_in_dim, kv_len, batch_size]
97113
## [v_in] = [v_in_dim, kv_len, batch_size]
98-
99-
q, k, v = m.qkv_proj(q_in, k_in, v_in)
100-
# [q] = [qk_dim, q_len, batch_size]
101-
# [k] = [qk_dim, kv_len, batch_size]
102-
# [v] = [v_dim, kv_len, batch_size]
103-
104-
if impl == :tullio
105-
x, α = dot_product_attention_tullio(m.nheads, q, k, v; mask, dropout=m.attn_drop)
106-
elseif impl == :nalib
107-
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.nheads, q, k, v, mask)
108-
elseif impl == :nnlib
109-
x, α = dot_product_attention(q, k, v, bias; m.nheads, mask, fdrop=m.attn_drop)
110-
else
111-
error("Unknown attention implementation")
112-
end
113-
114-
x = m.out_proj(x)
115-
114+
q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size]
115+
k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size]
116+
v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size]
117+
x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop)
118+
x = mha.out_proj(x)
119+
# [x] = [out_dim, q_len, batch_size]
120+
# [α] = [kv_len, q_len, nheads, batch_size]
116121
return withscores ? (x, α) : x
117122
end
118-
119-
struct QKVProj
120-
q_proj::Dense
121-
k_proj::Dense
122-
v_proj::Dense
123-
end
124-
125-
@functor QKVProj
126-
127-
function QKVProj(dims; bias = false, init=glorot_uniform)
128-
return QKVProj(
129-
Dense(dims.q_in => dims.qk; bias, init),
130-
Dense(dims.k_in => dims.qk; bias, init),
131-
Dense(dims.v_in => dims.v; bias, init))
132-
end
133-
134-
function (proj::QKVProj)(q_in, k_in, v_in)
135-
return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in))
136-
end
137-
138-
function perf(dim, len, batch_size, nheads)
139-
mha = MultiHeadAttention(dim, nheads)
140-
x = rand(Float32, (dim, len, batch_size))
141-
142-
println("tullio")
143-
@btime $mha($x, impl=:tullio);
144-
@btime gradient(m -> sum(m($x, impl=:tullio)), $mha);
145-
146-
println("nalib")
147-
@btime $mha($x, $x, $x, impl=:nalib);
148-
@btime gradient(m -> sum(m($x, impl=:nalib)), $mha);
149-
150-
println("nnlib")
151-
@btime $mha($x, $x, $x, impl=:nnlib);
152-
@btime gradient(m -> sum(m($x, impl=:nnlib)), $mha);
153-
154-
if CUDA.functional()
155-
mha_gpu = mha |> gpu
156-
x_gpu = x |> gpu
157-
158-
println("tullio - gpu")
159-
@btime $mha_gpu($x_gpu, impl=:tullio);
160-
@btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu);
161-
162-
println("nalib - gpu")
163-
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib);
164-
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu);
165-
166-
println("nnlib - gpu")
167-
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnlib);
168-
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnlib)), $mha_gpu);
169-
end
170-
return nothing
171-
end
172-
173-
function test(dim, nheads, len, batch_size)
174-
mha = MultiHeadAttention(dim, nheads)
175-
q = rand(Float32, (dim, len, batch_size))
176-
k = rand(Float32, (dim, len, batch_size))
177-
v = rand(Float32, (dim, len, batch_size))
178-
179-
y, α = mha(q, k, v, impl=:tullio, withscores=true)
180-
@test y isa Array{Float32, 3}
181-
@test size(y) == (dim, len, batch_size)
182-
@test α isa Array{Float32, 4}
183-
@test size(α) == (len, len, nheads, batch_size)
184-
185-
y2, α2 = mha(q, k, v, impl=:nalib, withscores=true)
186-
@test size(y) == size(y2)
187-
@test y2 y
188-
@test size(α) == size(α2)
189-
@test α2 α
190-
191-
y2b, α2b = mha(q, k, v, impl=:nnlib, withscores=true)
192-
@test size(y) == size(y2b)
193-
@test y2b y
194-
@test size(α) == size(α2b)
195-
@test α2b α
196-
197-
mask = make_causal_mask(q)
198-
y3, α3 = mha(q, k, v; impl=:tullio, withscores=true, mask)
199-
y4, α4 = mha(q, k, v, impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask())
200-
@test y3 y4
201-
@test α3 α4
202-
203-
if CUDA.functional()
204-
mha_gpu = mha |> gpu
205-
q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu
206-
207-
y_gpu = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:tullio)
208-
y_gpu2 = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:nalib)
209-
@test Array(y_gpu) Array(y_gpu2)
210-
@test Array(y_gpu) y
211-
end
212-
return nothing
213-
end
214-
215-
test(4, 2, 3, 1)
216-
217-
perf(128, 8, 128, 32)
218-
219-
## M1 Pro, NNlib v0.8.12
220-
# tullio
221-
# 2.948 ms (77 allocations: 7.25 MiB)
222-
# 15.041 ms (1124 allocations: 16.71 MiB)
223-
# nalib
224-
# 3.503 ms (89 allocations: 7.75 MiB)
225-
# 15.828 ms (604 allocations: 14.70 MiB)
226-
# nnlib
227-
# 3.611 ms (87 allocations: 9.25 MiB)
228-
# 16.497 ms (1055 allocations: 20.71 MiB)
229-
230-
## M1 Pro, NNlib v0.8.13 (fast_maximum)
231-
# tullio
232-
# 2.427 ms (71 allocations: 7.13 MiB)
233-
# 14.510 ms (1118 allocations: 16.59 MiB)
234-
# nalib
235-
# 3.052 ms (84 allocations: 7.63 MiB)
236-
# 15.327 ms (599 allocations: 14.57 MiB)
237-
# nnlib
238-
# 3.166 ms (81 allocations: 9.13 MiB)
239-
# 16.082 ms (1049 allocations: 20.58 MiB)
240-
241-
## Threadripper, NNlib v0.8.12
242-
# tullio
243-
# 5.658 ms (77 allocations: 7.25 MiB)
244-
# 22.373 ms (1124 allocations: 16.71 MiB)
245-
# nalib
246-
# 6.187 ms (89 allocations: 7.75 MiB)
247-
# 23.723 ms (604 allocations: 14.70 MiB)
248-
# nnlib
249-
# 6.473 ms (87 allocations: 9.25 MiB)
250-
# 24.966 ms (1055 allocations: 20.71 MiB)
251-
# tullio - gpu
252-
# 145.332 μs (520 allocations: 24.52 KiB)
253-
# 902.020 μs (2221 allocations: 117.19 KiB)
254-
# nalib - gpu
255-
# 162.354 μs (410 allocations: 18.03 KiB)
256-
# 604.111 μs (1263 allocations: 71.78 KiB)
257-
# nnlib - gpu
258-
# 156.383 μs (440 allocations: 20.00 KiB)
259-
# 835.374 μs (1969 allocations: 100.58 KiB)
260-
261-
## Threadripper, NNlib v0.8.13 (fast_maximum)
262-
# tullio
263-
# 4.599 ms (71 allocations: 7.13 MiB)
264-
# 20.699 ms (1118 allocations: 16.59 MiB)
265-
# nalib
266-
# 5.049 ms (84 allocations: 7.63 MiB)
267-
# 22.252 ms (599 allocations: 14.57 MiB)
268-
# nnlib
269-
# 5.378 ms (81 allocations: 9.13 MiB)
270-
# 23.453 ms (1049 allocations: 20.58 MiB)
271-
# tullio - gpu
272-
# 145.824 μs (520 allocations: 24.52 KiB)
273-
# 915.305 μs (2221 allocations: 117.19 KiB)
274-
# nalib - gpu
275-
# 164.789 μs (410 allocations: 18.03 KiB)
276-
# 610.835 μs (1263 allocations: 71.78 KiB)
277-
# nnlib - gpu
278-
# 157.785 μs (440 allocations: 20.00 KiB)
279-
# 852.087 μs (1969 allocations: 100.58 KiB)
280-
281-
282-
# function prof()
283-
# dim, len, batch_size, nheads = 128, 8, 128, 32;
284-
# # dim = 384; len = 128; batch_size = 32; nheads = 12
285-
# mha = MultiHeadAttention(dim, nheads)
286-
# x = rand(Float32, (dim, len, batch_size))
287-
# @btime mha(x, impl=:tullio);
288-
# @btime mha(x, impl=:nnlib);
289-
# @profview mha(x, impl=:tullio);
290-
# @profview prof(mha, x);
291-
# y, α = mha(x; impl=:nnlib, withscores=true, mask)
292-
# y2, α2 = mha(x; impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask())
293-
# end

0 commit comments

Comments
 (0)