diff --git a/docs/src/reference.md b/docs/src/reference.md index a034c92c5..cb10efe59 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -33,6 +33,14 @@ tanhshrink trelu ``` +## Attention + +```@docs +dot_product_attention +dot_product_attention_scores +make_causal_mask +``` + ## Softmax `Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally. diff --git a/src/NNlib.jl b/src/NNlib.jl index acca75299..8be0d01bc 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -41,6 +41,9 @@ for f in ACTIVATIONS end export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases +include("attention.jl") +export dot_product_attention, dot_product_attention_scores, make_causal_mask + include("dropout.jl") export dropout, dropout! diff --git a/src/attention.jl b/src/attention.jl new file mode 100644 index 000000000..fb11e82d0 --- /dev/null +++ b/src/attention.jl @@ -0,0 +1,144 @@ +const AA3{T} = AbstractArray{T,3} +const AA4{T} = AbstractArray{T,4} +const AA{N,T} = AbstractArray{T,N} + +""" + dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) + +Multihead dot product attention used in transformer architectures. + +The input arrays must have the first two dimensions given by the number of features +and the sequence length, then an arbitrary number of batch dimensions or none. + + +Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores +of size `(kv_len, q_len, nheads, batch_size...)`. + +See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. + +# Arguments + +- `query`: Query array of size `(qk_dim, q_len, batch_size...)`. +- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. +- `value`: Value array of size `(v_dim, kv_len, batch_size...)`. +- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + It will be added to the attention scores before applying the softmax. Default `nothing`. +- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. + Default `identity` (no dropout). +- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + The mask is applied to the attention scores just before the softmax. + See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`. +- `nheads`: Number of heads to split the input arrays into. Default `1`. + +# Examples + +```julia +q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) +y, α = dot_product_attention(q, k, v) +``` +""" +function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N + batch_size = size(q)[3:end] + batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) + q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) + + x, α = dot_product_attention(q, k, v, args...; kws...) + + x = reshape(x, size(x, 1), size(x, 2), batch_size...) + α = reshape(α, size(α)[1:3]..., batch_size...) + return x, α +end + +function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; + fdrop=identity, mask=nothing, nheads=1) + + (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) + size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) + size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) + + # Multihead attention. TODO create fastpath for singlehead attention. + q, k, v = split_heads.((q, k, v), nheads) + x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) + return join_heads(x), α +end + +function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) + # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] + # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] + # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] + + α = dot_product_attention_scores(q, k, bias; fdrop, mask) + # [α] = [kv_len, q_len, nheads, batch_size] + + # The following permutedims and batched_mul are equivalent to + # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] + vt = permutedims(v, (1, 3, 2, 4)) + x = batched_mul(vt, α) + x = permutedims(x, (1, 3, 2, 4)) + # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] + return x, α +end + +""" + dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) + +Return the attention scores for the [`dot_product_attention`](@ref). +Input arrays must have dimensions +`(num_features ÷ nheads, nheads, sequence_length, batch_size)`. + +See [`dot_product_attention`](@ref) for more details. +""" +function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; + fdrop=identity, mask=nothing) where T + + # The following permutedims and batched_mul are equivalent to + # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) + kt = permutedims(k, (3, 1, 2, 4)) + qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) + logits = batched_mul(kt, qt) + # [logits] = [kv_len, q_len, nheads, batch_size] + + logits = apply_attn_bias(logits, bias) + logits = apply_attn_mask(logits, mask) + + α = softmax(logits, dims=1) + return fdrop(α) +end + +apply_attn_bias(logits, bias::Nothing) = logits + +apply_attn_bias(logits, bias) = logits .+ bias + + +apply_attn_mask(logits, mask::Nothing) = logits + +function apply_attn_mask(logits, mask) + neginf = typemin(eltype(logits)) + ifelse.(mask, logits, neginf) +end + + +""" + make_causal_mask(x, dims=2) + +Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. +Its elements are set such that `m[i, j] == i ≤ j`. + +Can be used to mask the attention scores in [`dot_product_attention`](@ref). +""" +function make_causal_mask(x::AbstractArray; dims::Int=2) + len = size(x, dims) + mask = triu(trues_like(x, (len, len))) + return mask +end + +trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) +falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) + +split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) +join_heads(x) = reshape(x, :, size(x)[3:end]...) + +@non_differentiable make_causal_mask(::Any...) +@non_differentiable trues_like(::Any...) +@non_differentiable falses_like(::Any...) + diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 7e5e7fd72..7458f6fa2 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A) batched_mul(A, B) -> C A ⊠ B # \\boxtimes -Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`. -If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. +Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent +any indices in the last dimensions. + +If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. To transpose each matrix, apply `batched_transpose` to the array, or `batched_adjoint` for conjugate-transpose: @@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`. Both this `copy` and `batched_mul_generic!` produce `@debug` messages, and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. """ +function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} + batch_size = size(x)[3:end] + @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." + x2 = reshape(x, size(x, 1), size(x, 2), :) + y2 = reshape(y, size(y, 1), size(y, 2), :) + z = batched_mul(x2, y2) + return reshape(z, size(z, 1), size(z, 2), batch_size...) + end + function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != B")) diff --git a/src/gemm.jl b/src/gemm.jl index 91f88fc82..95c39d23f 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings end - C + return C end end end diff --git a/test/attention.jl b/test/attention.jl new file mode 100644 index 000000000..b21088330 --- /dev/null +++ b/test/attention.jl @@ -0,0 +1,76 @@ +@testset "different batchsizes" begin + n = 15 + lenq = 3 + lenkv = 4 + for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5] + q = rand(Float32, n, lenq, batch_size...) + k = rand(Float32, n, lenkv, batch_size...) + v = rand(Float32, n, lenkv, batch_size...) + y, α = dot_product_attention(q, k, v; nheads) + @test y isa Array{Float32} + @test size(y) == (n, lenq, batch_size...) + @test size(α) == (lenkv, lenq, nheads, batch_size...) + @test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...) + end +end + +@testset "dot_product_attention_scores" begin + q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24 + α = dot_product_attention_scores(q, k) + q2, k2 = reshape.((q, k), 8, 3, 1) + y, α2 = dot_product_attention(q2, k2, k2; nheads=2) + @test α ≈ α2 +end + +@testset "specific results" begin + q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 + y, α = dot_product_attention(q, k, v; nheads=2) + ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788] + ytrue = reshape(ytrue, 4, 3, 1) + αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801] + αtrue = reshape(αtrue, 3, 3, 2, 1) + @test y ≈ ytrue atol=1e-5 + @test α ≈ αtrue atol=1e-5 +end + +@testset "mask" begin + q = rand(4, 2, 3, 1) + k = rand(4, 2, 5, 1) + + mask = rand(Bool, (5, 3)) + α = dot_product_attention_scores(q, k; mask) + @test all((α[:,:,1,1].> 0) .== mask) + @test all((α[:,:,2,1].> 0) .== mask) + + @testset "causal" begin + x = rand(4, 2, 3, 1) + mask = make_causal_mask(x, dims=3) + α = dot_product_attention_scores(x, x; mask) + @test all((α[:,:,1,1].> 0) .== mask) + @test all((α[:,:,2,1].> 0) .== mask) + end +end + +@testset "dropout" begin + q = k = v = rand(10, 10, 10) + fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) + y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) + @test 0.6 > mean(>(0), α) > 0.4 +end + +@testset "bias" begin + q = rand(4, 5, 1) + k = v = rand(4, 3, 1) + bias = randn(3, 5) + y, α = dot_product_attention(q, k, v, bias; nheads=2) + @test size(α) == (3, 5, 2, 1) + @test size(y) == (4, 5, 1) +end + +@testset "gradient" begin + q = rand(4, 5, 1) + k = v = rand(4, 3, 1) + bias = randn(3, 5) + y, α = dot_product_attention(q, k, v, bias; nheads=2) + gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias) +end diff --git a/test/runtests.jl b/test/runtests.jl index 16084b4d2..e7987ef62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,10 @@ include("test_utils.jl") include("activations.jl") end + @testset "Attention" begin + include("attention.jl") + end + @testset "Batched Multiplication" begin include("batchedmul.jl") end