Skip to content

Commit 0155f61

Browse files
MultiHeadAttention implementation (#2146)
* move multiheadattention from Metalhead * generic attention * [ci skip] updates * [ci skip] updates * [ci skip] fix tullio impl * causal mask * [ci skip] mask * [ci skip] updates * [ci skip] add native implementation * support mask = :causal * [ci skip] factor out impl * [ci skip] remove jax * [ci skip] more benchs * finish up * cleanup * add cuda tests * cleanup tests * IntOrDims * cleanup * remove with_scores * improve docstring
1 parent 1258ddf commit 0155f61

File tree

6 files changed

+270
-17
lines changed

6 files changed

+270
-17
lines changed

src/Flux.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
2323
RNN, LSTM, GRU, GRUv3,
2424
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
2525
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
26-
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
26+
Dropout, AlphaDropout,
27+
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
28+
MultiHeadAttention,
2729
Upsample, PixelShuffle,
2830
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
2931
testmode!, trainmode!
@@ -60,6 +62,7 @@ include("layers/conv.jl")
6062
include("layers/recurrent.jl")
6163
include("layers/normalise.jl")
6264
include("layers/upsample.jl")
65+
include("layers/attention.jl")
6366
include("layers/show.jl")
6467

6568
include("loading.jl")

src/layers/attention.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
const A3{T} = AbstractArray{T, 3}
3+
const IntOrDims{N} = Union{Int, Dims{N}}
4+
5+
"""
6+
MultiHeadAttention(dims; [nheads, bias, init, dropout_prob])
7+
8+
The multi-head dot-product attention layer used in Transformer architectures [1].
9+
10+
Returns the transformed input sequnce and the attention scores.
11+
12+
[1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.
13+
14+
# Arguments
15+
16+
- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
17+
In the most general case, it is given as
18+
a) `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
19+
Can take also simpler forms as
20+
b) `dims::Int`;
21+
c) `in_dim::Int => (qk_dim, v_dim) => out_dim`;
22+
d) `in_dim::Int => qkv_dim => out_dim`.
23+
- `nheads`: number of heads. Default `8`.
24+
- `init`: weight initializer for the Dense layers. Default `glorot_uniform`.
25+
- `bias` : whether pointwise QKVO dense transforms use bias. Default `false`.
26+
- `dropout_prob`: dropout probability for the attention scores. Default `0.0`.
27+
28+
# Forward
29+
30+
(mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask])
31+
32+
The arguments of the forward pass are:
33+
34+
- `q_in`: Input query array of size `(q_in_dim, q_len, batch_size)`.
35+
- `k_in`: Input key array of size `(k_in_dim, kv_len, batch_size)`.
36+
- `v_in`: Input value array of size `(v_in_dim, kv_len, batch_size)`.
37+
- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
38+
It will be added to the attention scores before the softmax.
39+
Default `nothing`.
40+
- `mask`: Input array broadcastable to size
41+
`(kv_len, q_len, nheads, batch_size)`.
42+
The mask is applied to the attention scores just before the softmax.
43+
See [`NNlib.make_causal_mask`](@ref) for creating causal masks.
44+
Default `nothing`.
45+
46+
Alternative calling signatures are `mha(q_in)`, equivalent to `mha(q_in, q_in, q_in)` (self-attention),
47+
and `mha(q_in, k_in)`, equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).
48+
49+
See also [`NNlib.dot_product_attention`](@ref).
50+
51+
# Examples
52+
53+
```julia
54+
mha = MultiHeadAttention(64, nheads = 8)
55+
q = rand(Float32, (64, 10, 32))
56+
k = rand(Float32, (64, 20, 32))
57+
v = rand(Float32, (64, 20, 32))
58+
y, α = mha(q, k, v)
59+
# [y] = [64, 10, 32]
60+
# [α] = [20, 10, 8, 32]
61+
62+
mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8)
63+
y, α = mha(q) # self-attention
64+
# [y] = [1024, 10, 32]
65+
# [α] = [10, 10, 8, 32]
66+
```
67+
"""
68+
struct MultiHeadAttention{P1, D, P2}
69+
nheads::Int
70+
q_proj::P1
71+
k_proj::P1
72+
v_proj::P1
73+
attn_drop::D
74+
out_proj::P2
75+
end
76+
77+
@functor MultiHeadAttention
78+
79+
function MultiHeadAttention(dims;
80+
nheads::Int = 8,
81+
bias::Bool = false,
82+
init = glorot_uniform,
83+
dropout_prob = 0.0)
84+
85+
dims = normalize_mha_dims(dims)
86+
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
87+
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
88+
q_proj = Dense(dims.q_in => dims.qk; bias, init)
89+
k_proj = Dense(dims.k_in => dims.qk; bias, init)
90+
v_proj = Dense(dims.v_in => dims.v; bias, init)
91+
attn_drop = Dropout(dropout_prob)
92+
out_proj = Dense(dims.v => dims.out; bias, init)
93+
return MultiHeadAttention(nheads, q_proj, k_proj, v_proj, attn_drop, out_proj)
94+
end
95+
96+
# turns the dims argument into a named tuple
97+
normalize_mha_dims(dims::Int) =
98+
(; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims)
99+
100+
function normalize_mha_dims((in, (qkv, out))::Pair{<:IntOrDims{3}, <:Pair{<:IntOrDims{2}, Int}})
101+
if in isa Int
102+
q_in = k_in = v_in = in
103+
else
104+
q_in, k_in, v_in = in
105+
end
106+
if qkv isa Int
107+
qk = v = qkv
108+
else
109+
qk, v = qkv
110+
end
111+
return (; q_in, k_in, v_in, qk, v, out)
112+
end
113+
114+
# self-attention
115+
(mha::MultiHeadAttention)(qkv; kws...) = mha(qkv, qkv, qkv; kws...)
116+
117+
# key and value are the same
118+
(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...)
119+
120+
function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
121+
bias=nothing; mask=nothing)
122+
## [q_in] = [q_in_dim, q_len, batch_size]
123+
## [k_in] = [k_in_dim, kv_len, batch_size]
124+
## [v_in] = [v_in_dim, kv_len, batch_size]
125+
q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size]
126+
k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size]
127+
v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size]
128+
x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop)
129+
x = mha.out_proj(x)
130+
# [x] = [out_dim, q_len, batch_size]
131+
# [α] = [kv_len, q_len, nheads, batch_size]
132+
return x, α
133+
end

test/cuda/layers.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,29 @@ end
338338
@test eltype(pool(reshape(gx,3,4,1))) == Float16
339339
end
340340
end
341+
342+
@testset "MultiHeadAttention" begin
343+
dim = 4; nheads = 2; len = 3; batch_size = 5
344+
mha_cpu = MultiHeadAttention(dim; nheads)
345+
x_cpu = rand(Float32, (dim, len, batch_size))
346+
y_cpu, α_cpu = mha_cpu(x_cpu)
347+
348+
mha_gpu = mha_cpu |> gpu
349+
x_gpu = x_cpu |> gpu
350+
y_gpu, α_gpu = mha_gpu(x_gpu)
351+
@test y_gpu isa CuArray{Float32}
352+
@test α_gpu isa CuArray{Float32}
353+
@test Array(y_gpu) y_cpu atol=1e-4
354+
@test Array(α_gpu) α_cpu atol=1e-4
355+
356+
gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x
357+
y, α = mha(x)
358+
return sum(y.^2) + sum.^2)
359+
end
360+
gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x
361+
y, α = mha(x)
362+
return sum(y.^2) + sum.^2)
363+
end
364+
check_grad(gm_gpu, gm_cpu)
365+
check_grad(gx_gpu, gx_cpu)
366+
end

test/layers/attention.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
3+
@testset "attention" begin
4+
dim = 4; nheads = 2; len = 3; batch_size = 5
5+
mha = MultiHeadAttention(dim; nheads)
6+
q = rand(Float32, (dim, len, batch_size))
7+
k = rand(Float32, (dim, len, batch_size))
8+
v = rand(Float32, (dim, len, batch_size))
9+
10+
y, α = mha(q, k, v)
11+
@test y isa Array{Float32, 3}
12+
@test size(y) == (dim, len, batch_size)
13+
@test α isa Array{Float32, 4}
14+
@test size(α) == (len, len, nheads, batch_size)
15+
16+
@testset "self-attention" begin
17+
y1, α1 = mha(q)
18+
y2, α2 = mha(q, q, q)
19+
@test y1 y2
20+
@test α1 α2
21+
end
22+
23+
@testset "key and value are the same" begin
24+
y1, α1 = mha(q, k)
25+
y2, α2 = mha(q, k, k)
26+
@test y1 y2
27+
@test α1 α2
28+
end
29+
30+
@testset "change dims" begin
31+
dims = 4 => 10 => 5
32+
nhead = 5
33+
mha2 = MultiHeadAttention(dims; nheads)
34+
y2, _ = mha2(q, k, v)
35+
@test size(y2) == (dims.second.second, len, batch_size)
36+
end
37+
38+
@testset "mask" begin
39+
mask = NNlib.make_causal_mask(q)
40+
y, α = mha(q; mask)
41+
@test all(α[2, 1, :, :] .== 0)
42+
@test α[:, :, 1, 1] triu(α[:, :, 1, 1])
43+
end
44+
45+
@testset "bias" begin
46+
# use bias to produce a causal mask
47+
b = zeros(Float32, (len, len))
48+
for i in 1:len, j in i:len
49+
b[i, j] = typemax(Float32)
50+
end
51+
y, α = mha(q, k, v, b)
52+
@test all(α[2, 1, :, :] .== 0)
53+
@test α[:, :, 1, 1] triu(α[:, :, 1, 1])
54+
end
55+
56+
@testset "gradient" begin
57+
gm, gq = gradient(mha, q) do mha, q
58+
y, α = mha(q)
59+
return sum(y.^2) + sum.^2)
60+
end
61+
check_grad_type(gm, mha)
62+
check_grad_type(gq, q)
63+
end
64+
end
65+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Random.seed!(0)
3333
end
3434

3535
@testset "Layers" begin
36+
include("layers/attention.jl")
3637
include("layers/basic.jl")
3738
include("layers/normalisation.jl")
3839
include("layers/stateless.jl")

test/test_utils.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
1-
function check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing::Bool)
1+
function check_grad(g_gpu, g_cpu;
2+
rtol=1e-4, atol=1e-4,
3+
allow_nothing::Bool=false)
24
allow_nothing && return
35
@show g_gpu g_cpu
46
@test false
57
end
6-
check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol; allow_nothing::Bool) =
7-
check_grad(g_gpu[], g_cpu[], atol, rtol; allow_nothing)
8-
check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol; allow_nothing::Bool) =
8+
9+
check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
10+
check_grad(g_gpu[], g_cpu[]; rtol, atol, allow_nothing)
11+
12+
check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
913
@test true
10-
check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol; allow_nothing::Bool) =
14+
15+
check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
1116
@test g_cpu g_gpu rtol=rtol atol=atol
12-
check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol; allow_nothing::Bool) =
17+
18+
check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
1319
@test g_cpu collect(g_gpu) rtol=rtol atol=atol
1420

15-
function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol; allow_nothing::Bool)
21+
function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false)
1622
for (v1, v2) in zip(g_gpu, g_cpu)
17-
check_grad(v1, v2, atol, rtol; allow_nothing)
23+
check_grad(v1, v2; rtol, atol, allow_nothing)
1824
end
1925
end
2026

21-
function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol; allow_nothing::Bool)
27+
function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false)
2228
for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu))
2329
@test k1 == k2
24-
check_grad(v1, v2, atol, rtol; allow_nothing)
30+
check_grad(v1, v2; rtol, atol, allow_nothing)
2531
end
2632
end
2733

@@ -31,10 +37,14 @@ check_type(x::CuArray{Float32}) = true
3137
check_type(x::Array{Float32}) = true
3238

3339
function gpu_autodiff_test(
34-
f_cpu, xs_cpu::Array{Float32}...;
35-
test_equal=true, rtol=1e-4, atol=1e-4,
36-
checkgrad::Bool = true, allow_nothing::Bool = false,
37-
)
40+
f_cpu,
41+
xs_cpu::Array{Float32}...;
42+
test_equal=true,
43+
rtol=1e-4, atol=1e-4,
44+
checkgrad::Bool = true,
45+
allow_nothing::Bool = false,
46+
)
47+
3848
# Compare CPU & GPU function outputs.
3949
f_gpu = f_cpu |> gpu
4050
xs_gpu = gpu.(xs_cpu)
@@ -60,7 +70,7 @@ function gpu_autodiff_test(
6070
if test_equal
6171
@test collect(y_cpu) collect(y_gpu) rtol=rtol atol=atol
6272
for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu)
63-
check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing)
73+
check_grad(g_gpu, g_cpu; atol, rtol, allow_nothing)
6474
end
6575
end
6676

@@ -78,7 +88,22 @@ function gpu_autodiff_test(
7888
@test collect(y_cpu) collect(y_gpu) rtol=rtol atol=atol
7989
@assert length(ps_gpu) == length(ps_cpu)
8090
for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu)
81-
check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol; allow_nothing)
91+
check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu]; atol, rtol, allow_nothing)
8292
end
8393
end
8494
end
95+
96+
# check_grad_type checks that the gradient type matches the primal type.
97+
98+
check_grad_type(g::Nothing, x) = nothing
99+
100+
function check_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2}
101+
@test T1 == T2
102+
@test size(g) == size(x)
103+
end
104+
105+
function check_grad_type(g::NamedTuple, x::T) where T
106+
for f in fieldnames(T)
107+
check_grad_type(g[f], getfield(x, f))
108+
end
109+
end

0 commit comments

Comments
 (0)