1
- using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2
- using CUDA
3
- using NeuralAttentionlib
4
- using NeuralAttentionlib: score_returning
5
- using BenchmarkTools
6
- using Flux: glorot_uniform
7
- CUDA. allowscalar (false )
8
1
9
2
const A3{T} = AbstractArray{T, 3 }
10
- const A4{T} = AbstractArray{T, 4 }
11
3
const TuplInt2 = Union{Int, Tuple{Int, Int}}
12
4
const TuplInt3 = Union{Int, Tuple{Int, Int, Int}}
13
5
14
- include (" attention_nnlib.jl" )
15
- include (" attention_tullio.jl" )
16
-
17
-
18
6
"""
19
- MultiHeadAttention(dims, nheads; [bias, init, dropout_prob])
7
+ MultiHeadAttention(dims; [nheads, bias, init, dropout_prob])
8
+
9
+ The multi-head dot-product attention layer used in Transformer architectures [1].
20
10
21
- Multi-head dot-product attention layer .
11
+ [1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017 .
22
12
23
13
# Arguments
24
14
25
- - `dims`: ...
26
- - `nheads`: number of heads.
27
- - `init`: weight initializer for the Dense layers.
28
- - `bias` : whether pointwise QKVO dense transforms use bias.
29
- - `dropout_prob`: dropout probability for the attention scores.
15
+ - `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
16
+ In the most general case, it is given as
17
+ `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
18
+ Can take also simpler forms as
19
+ `dims::Int`, `in_dim::Int => (qk_dim, v_dim) => out_dim`,
20
+ `in_dim::Int => qkv_dim => out_dim`.
21
+
22
+ - `nheads`: number of heads. Default `8`.
23
+ - `init`: weight initializer for the Dense layers. Default `glorot_uniform`.
24
+ - `bias` : whether pointwise QKVO dense transforms use bias. Default `false`.
25
+ - `dropout_prob`: dropout probability for the attention scores. Default `0.0`.
30
26
31
27
# Forward
32
28
33
- (::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
29
+ (mha ::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
34
30
35
31
- `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`.
36
32
- `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`.
@@ -39,38 +35,58 @@ Multi-head dot-product attention layer.
39
35
`(kv_len, q_len, nheads, batch_size)`. Default `nothing`.
40
36
- `withscores`: Whether to return the attention scores. Default `false`.
41
37
38
+ In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention)
39
+ and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).
40
+
41
+
42
+ See also [`NNlib.dot_product_attention`](@ref).
43
+
42
44
# Examples
43
45
44
46
```julia
45
- mha = MultiHeadAttention(64, 8)
47
+ mha = MultiHeadAttention(64, nheads = 8)
48
+ q = rand(Float32, (64, 10, 32))
49
+ k = rand(Float32, (64, 20, 32))
50
+ v = rand(Float32, (64, 20, 32))
51
+ y = mha(q, k, v) # [y] = [64, 10, 32]
52
+
53
+ mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8)
54
+ y = mha(q) # self-attention; [y] = [1024, 10, 32]
46
55
```
47
56
"""
48
57
struct MultiHeadAttention{P1, D, P2}
49
58
nheads:: Int
50
- qkv_proj:: P1
59
+ q_proj:: P1
60
+ k_proj:: P1
61
+ v_proj:: P1
51
62
attn_drop:: D
52
63
out_proj:: P2
53
64
end
54
65
55
66
@functor MultiHeadAttention
56
67
57
- function MultiHeadAttention (dims, nheads:: Int ;
68
+ function MultiHeadAttention (dims;
69
+ nheads:: Int = 8 ,
58
70
bias:: Bool = false ,
59
71
init = glorot_uniform,
60
72
dropout_prob = 0.0 )
61
73
62
- dims = mha_process_dims (dims)
74
+ dims = normalize_mha_dims (dims)
63
75
@assert dims. qk % nheads == 0 " qk_dim should be divisible by nheads"
64
- qkv_proj = QKVProj (dims; bias, init)
76
+ @assert dims. v % nheads == 0 " v_dim should be divisible by nheads"
77
+ q_proj = Dense (dims. q_in => dims. qk; bias, init)
78
+ k_proj = Dense (dims. k_in => dims. qk; bias, init)
79
+ v_proj = Dense (dims. v_in => dims. v; bias, init)
65
80
attn_drop = Dropout (dropout_prob)
66
81
out_proj = Dense (dims. v => dims. out; bias, init)
67
- return MultiHeadAttention (nheads, qkv_proj , attn_drop, out_proj)
82
+ return MultiHeadAttention (nheads, q_proj, k_proj, v_proj , attn_drop, out_proj)
68
83
end
69
84
70
- mha_process_dims (dims:: Int ) =
85
+ # turns the dims argument into a named tuple
86
+ normalize_mha_dims (dims:: Int ) =
71
87
(; q_in= dims, k_in= dims, v_in= dims, qk= dims, v= dims, out= dims)
72
88
73
- function mha_process_dims ((in, (qkv, out)):: Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}} )
89
+ function normalize_mha_dims ((in, (qkv, out)):: Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}} )
74
90
if in isa Int
75
91
q_in = k_in = v_in = in
76
92
else
@@ -85,209 +101,22 @@ function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2,
85
101
end
86
102
87
103
# self-attention
88
- (m :: MultiHeadAttention )(qkv; kws... ) = m (qkv, qkv, qkv; kws... )
104
+ (mha :: MultiHeadAttention )(qkv; kws... ) = mha (qkv, qkv, qkv; kws... )
89
105
90
106
# key and value are the same
91
- (m :: MultiHeadAttention )(q, kv; kws... ) = m (q, kv, kv; kws... )
107
+ (mha :: MultiHeadAttention )(q, kv; kws... ) = mha (q, kv, kv; kws... )
92
108
93
- function (m :: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 , bias= nothing ;
94
- withscores= false , mask= nothing , impl = :nnlib )
109
+ function (mha :: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 , bias= nothing ;
110
+ withscores= false , mask= nothing )
95
111
# # [q_in] = [q_in_dim, q_len, batch_size]
96
112
# # [k_in] = [k_in_dim, kv_len, batch_size]
97
113
# # [v_in] = [v_in_dim, kv_len, batch_size]
98
-
99
- q, k, v = m. qkv_proj (q_in, k_in, v_in)
100
- # [q] = [qk_dim, q_len, batch_size]
101
- # [k] = [qk_dim, kv_len, batch_size]
102
- # [v] = [v_dim, kv_len, batch_size]
103
-
104
- if impl == :tullio
105
- x, α = dot_product_attention_tullio (m. nheads, q, k, v; mask, dropout= m. attn_drop)
106
- elseif impl == :nalib
107
- x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. nheads, q, k, v, mask)
108
- elseif impl == :nnlib
109
- x, α = dot_product_attention (q, k, v, bias; m. nheads, mask, fdrop= m. attn_drop)
110
- else
111
- error (" Unknown attention implementation" )
112
- end
113
-
114
- x = m. out_proj (x)
115
-
114
+ q = mha. q_proj (q_in) # [q] = [qk_dim, q_len, batch_size]
115
+ k = mha. k_proj (k_in) # [k] = [qk_dim, kv_len, batch_size]
116
+ v = mha. v_proj (v_in) # [v] = [v_dim, kv_len, batch_size]
117
+ x, α = NNlib. dot_product_attention (q, k, v, bias; mha. nheads, mask, fdrop= mha. attn_drop)
118
+ x = mha. out_proj (x)
119
+ # [x] = [out_dim, q_len, batch_size]
120
+ # [α] = [kv_len, q_len, nheads, batch_size]
116
121
return withscores ? (x, α) : x
117
122
end
118
-
119
- struct QKVProj
120
- q_proj:: Dense
121
- k_proj:: Dense
122
- v_proj:: Dense
123
- end
124
-
125
- @functor QKVProj
126
-
127
- function QKVProj (dims; bias = false , init= glorot_uniform)
128
- return QKVProj (
129
- Dense (dims. q_in => dims. qk; bias, init),
130
- Dense (dims. k_in => dims. qk; bias, init),
131
- Dense (dims. v_in => dims. v; bias, init))
132
- end
133
-
134
- function (proj:: QKVProj )(q_in, k_in, v_in)
135
- return (proj. q_proj (q_in), proj. k_proj (k_in), proj. v_proj (v_in))
136
- end
137
-
138
- function perf (dim, len, batch_size, nheads)
139
- mha = MultiHeadAttention (dim, nheads)
140
- x = rand (Float32, (dim, len, batch_size))
141
-
142
- println (" tullio" )
143
- @btime $ mha ($ x, impl= :tullio );
144
- @btime gradient (m -> sum (m ($ x, impl= :tullio )), $ mha);
145
-
146
- println (" nalib" )
147
- @btime $ mha ($ x, $ x, $ x, impl= :nalib );
148
- @btime gradient (m -> sum (m ($ x, impl= :nalib )), $ mha);
149
-
150
- println (" nnlib" )
151
- @btime $ mha ($ x, $ x, $ x, impl= :nnlib );
152
- @btime gradient (m -> sum (m ($ x, impl= :nnlib )), $ mha);
153
-
154
- if CUDA. functional ()
155
- mha_gpu = mha |> gpu
156
- x_gpu = x |> gpu
157
-
158
- println (" tullio - gpu" )
159
- @btime $ mha_gpu ($ x_gpu, impl= :tullio );
160
- @btime gradient (m -> sum (m ($ x_gpu, impl= :tullio )), $ mha_gpu);
161
-
162
- println (" nalib - gpu" )
163
- @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nalib );
164
- @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nalib )), $ mha_gpu);
165
-
166
- println (" nnlib - gpu" )
167
- @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nnlib );
168
- @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nnlib )), $ mha_gpu);
169
- end
170
- return nothing
171
- end
172
-
173
- function test (dim, nheads, len, batch_size)
174
- mha = MultiHeadAttention (dim, nheads)
175
- q = rand (Float32, (dim, len, batch_size))
176
- k = rand (Float32, (dim, len, batch_size))
177
- v = rand (Float32, (dim, len, batch_size))
178
-
179
- y, α = mha (q, k, v, impl= :tullio , withscores= true )
180
- @test y isa Array{Float32, 3 }
181
- @test size (y) == (dim, len, batch_size)
182
- @test α isa Array{Float32, 4 }
183
- @test size (α) == (len, len, nheads, batch_size)
184
-
185
- y2, α2 = mha (q, k, v, impl= :nalib , withscores= true )
186
- @test size (y) == size (y2)
187
- @test y2 ≈ y
188
- @test size (α) == size (α2)
189
- @test α2 ≈ α
190
-
191
- y2b, α2b = mha (q, k, v, impl= :nnlib , withscores= true )
192
- @test size (y) == size (y2b)
193
- @test y2b ≈ y
194
- @test size (α) == size (α2b)
195
- @test α2b ≈ α
196
-
197
- mask = make_causal_mask (q)
198
- y3, α3 = mha (q, k, v; impl= :tullio , withscores= true , mask)
199
- y4, α4 = mha (q, k, v, impl= :nalib , withscores= true , mask= NeuralAttentionlib. CausalMask ())
200
- @test y3 ≈ y4
201
- @test α3 ≈ α4
202
-
203
- if CUDA. functional ()
204
- mha_gpu = mha |> gpu
205
- q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu
206
-
207
- y_gpu = mha_gpu (q_gpu, k_gpu, v_gpu, impl= :tullio )
208
- y_gpu2 = mha_gpu (q_gpu, k_gpu, v_gpu, impl= :nalib )
209
- @test Array (y_gpu) ≈ Array (y_gpu2)
210
- @test Array (y_gpu) ≈ y
211
- end
212
- return nothing
213
- end
214
-
215
- test (4 , 2 , 3 , 1 )
216
-
217
- perf (128 , 8 , 128 , 32 )
218
-
219
- # # M1 Pro, NNlib v0.8.12
220
- # tullio
221
- # 2.948 ms (77 allocations: 7.25 MiB)
222
- # 15.041 ms (1124 allocations: 16.71 MiB)
223
- # nalib
224
- # 3.503 ms (89 allocations: 7.75 MiB)
225
- # 15.828 ms (604 allocations: 14.70 MiB)
226
- # nnlib
227
- # 3.611 ms (87 allocations: 9.25 MiB)
228
- # 16.497 ms (1055 allocations: 20.71 MiB)
229
-
230
- # # M1 Pro, NNlib v0.8.13 (fast_maximum)
231
- # tullio
232
- # 2.427 ms (71 allocations: 7.13 MiB)
233
- # 14.510 ms (1118 allocations: 16.59 MiB)
234
- # nalib
235
- # 3.052 ms (84 allocations: 7.63 MiB)
236
- # 15.327 ms (599 allocations: 14.57 MiB)
237
- # nnlib
238
- # 3.166 ms (81 allocations: 9.13 MiB)
239
- # 16.082 ms (1049 allocations: 20.58 MiB)
240
-
241
- # # Threadripper, NNlib v0.8.12
242
- # tullio
243
- # 5.658 ms (77 allocations: 7.25 MiB)
244
- # 22.373 ms (1124 allocations: 16.71 MiB)
245
- # nalib
246
- # 6.187 ms (89 allocations: 7.75 MiB)
247
- # 23.723 ms (604 allocations: 14.70 MiB)
248
- # nnlib
249
- # 6.473 ms (87 allocations: 9.25 MiB)
250
- # 24.966 ms (1055 allocations: 20.71 MiB)
251
- # tullio - gpu
252
- # 145.332 μs (520 allocations: 24.52 KiB)
253
- # 902.020 μs (2221 allocations: 117.19 KiB)
254
- # nalib - gpu
255
- # 162.354 μs (410 allocations: 18.03 KiB)
256
- # 604.111 μs (1263 allocations: 71.78 KiB)
257
- # nnlib - gpu
258
- # 156.383 μs (440 allocations: 20.00 KiB)
259
- # 835.374 μs (1969 allocations: 100.58 KiB)
260
-
261
- # # Threadripper, NNlib v0.8.13 (fast_maximum)
262
- # tullio
263
- # 4.599 ms (71 allocations: 7.13 MiB)
264
- # 20.699 ms (1118 allocations: 16.59 MiB)
265
- # nalib
266
- # 5.049 ms (84 allocations: 7.63 MiB)
267
- # 22.252 ms (599 allocations: 14.57 MiB)
268
- # nnlib
269
- # 5.378 ms (81 allocations: 9.13 MiB)
270
- # 23.453 ms (1049 allocations: 20.58 MiB)
271
- # tullio - gpu
272
- # 145.824 μs (520 allocations: 24.52 KiB)
273
- # 915.305 μs (2221 allocations: 117.19 KiB)
274
- # nalib - gpu
275
- # 164.789 μs (410 allocations: 18.03 KiB)
276
- # 610.835 μs (1263 allocations: 71.78 KiB)
277
- # nnlib - gpu
278
- # 157.785 μs (440 allocations: 20.00 KiB)
279
- # 852.087 μs (1969 allocations: 100.58 KiB)
280
-
281
-
282
- # function prof()
283
- # dim, len, batch_size, nheads = 128, 8, 128, 32;
284
- # # dim = 384; len = 128; batch_size = 32; nheads = 12
285
- # mha = MultiHeadAttention(dim, nheads)
286
- # x = rand(Float32, (dim, len, batch_size))
287
- # @btime mha(x, impl=:tullio);
288
- # @btime mha(x, impl=:nnlib);
289
- # @profview mha(x, impl=:tullio);
290
- # @profview prof(mha, x);
291
- # y, α = mha(x; impl=:nnlib, withscores=true, mask)
292
- # y2, α2 = mha(x; impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask())
293
- # end
0 commit comments