| 
 | 1 | +@testset "batched_mul" begin  | 
 | 2 | +    using NNlib: batched_mul, batched_mul!, batched_vec,  | 
 | 3 | +                 batched_adjoint, batched_transpose  | 
 | 4 | + | 
 | 5 | +    A = randn(Float32, 3,3,2);  | 
 | 6 | +    B = randn(Float32, 3,3,2);  | 
 | 7 | + | 
 | 8 | +    C = batched_mul(A, B)  | 
 | 9 | +    @test MtlArray(C) ≈ batched_mul(MtlArray(A), MtlArray(B))  | 
 | 10 | + | 
 | 11 | +    Ct = batched_mul(batched_transpose(A), B)  | 
 | 12 | +    @test MtlArray(Ct) ≈ batched_mul(batched_transpose(MtlArray(A)), MtlArray(B))  | 
 | 13 | + | 
 | 14 | +    Ca = batched_mul(A, batched_adjoint(B))  | 
 | 15 | +    @test MtlArray(Ca) ≈ batched_mul(MtlArray(A), batched_adjoint(MtlArray(B)))  | 
 | 16 | + | 
 | 17 | +    # 5-arg batched_mul!  | 
 | 18 | +    C .= pi  | 
 | 19 | +    batched_mul!(C, A, B, 2f0, 3f0)  | 
 | 20 | +    gpuCpi = MtlArray(similar(C)) .= pi  | 
 | 21 | +    @test MtlArray(C) ≈ batched_mul!(gpuCpi, MtlArray(A), MtlArray(B), 2f0, 3f0)  | 
 | 22 | + | 
 | 23 | +    # PermutedDimsArray  | 
 | 24 | +    @test MtlArray(Ct) ≈ batched_mul(PermutedDimsArray(MtlArray(A), (2,1,3)), MtlArray(B))  | 
 | 25 | + | 
 | 26 | +    D = permutedims(B, (1,3,2))  | 
 | 27 | +    Cp = batched_mul(batched_adjoint(A), B)  | 
 | 28 | +    @test_broken MtlArray(Cp) ≈ batched_mul(batched_adjoint(MtlArray(A)), PermutedDimsArray(MtlArray(D), (1,3,2)))  | 
 | 29 | + | 
 | 30 | +    # Methods which reshape  | 
 | 31 | +    M = randn(Float32, 3,3)  | 
 | 32 | + | 
 | 33 | +    Cm = batched_mul(A, M)  | 
 | 34 | +    @test MtlArray(Cm) ≈ batched_mul(MtlArray(A), MtlArray(M))  | 
 | 35 | + | 
 | 36 | +    Cv = batched_vec(permutedims(A,(3,1,2)), M)  | 
 | 37 | +    @test_broken MtlArray(Cv) ≈ batched_vec(PermutedDimsArray(MtlArray(A),(3,1,2)), MtlArray(M))  | 
 | 38 | +end  | 
 | 39 | + | 
 | 40 | +function print_array_strs(x)  | 
 | 41 | +    str = sprint((io, x)->show(io, MIME"text/plain"(), x), x)  | 
 | 42 | +    return @view split(str, '\n')[2:end]  | 
 | 43 | +end  | 
 | 44 | + | 
 | 45 | +@testset "BatchedAdjOrTrans" begin  | 
 | 46 | +    x = rand(Float32, 3, 4, 2)  | 
 | 47 | +    y = MtlArray(x)  | 
 | 48 | + | 
 | 49 | +    bax = batched_adjoint(x)  | 
 | 50 | +    btx = batched_transpose(x)  | 
 | 51 | +    bay = batched_adjoint(y)  | 
 | 52 | +    bty = batched_transpose(y)  | 
 | 53 | + | 
 | 54 | +    @test sprint(show, bax) == sprint(show, bay)  | 
 | 55 | +    @test sprint(show, btx) == sprint(show, bty)  | 
 | 56 | + | 
 | 57 | +    @test print_array_strs(bax) == print_array_strs(bay)  | 
 | 58 | +    @test print_array_strs(btx) == print_array_strs(bty)  | 
 | 59 | + | 
 | 60 | +    @test Array(bax) == Array(bay)  | 
 | 61 | +    @test collect(bax) == collect(bay)  | 
 | 62 | +    @test Array(btx) == Array(bty)  | 
 | 63 | +    @test collect(btx) == collect(bty)  | 
 | 64 | + | 
 | 65 | +    for shape in (:, (12, 2))  | 
 | 66 | +        rbax = reshape(bax, shape)  | 
 | 67 | +        rbtx = reshape(btx, shape)  | 
 | 68 | +        rbay = reshape(bay, shape)  | 
 | 69 | +        rbty = reshape(bty, shape)  | 
 | 70 | + | 
 | 71 | +        @test sprint(show, rbax) == sprint(show, rbay)  | 
 | 72 | +        @test sprint(show, rbtx) == sprint(show, rbty)  | 
 | 73 | + | 
 | 74 | +        @test print_array_strs(rbax) == print_array_strs(rbay)  | 
 | 75 | +        @test print_array_strs(rbtx) == print_array_strs(rbty)  | 
 | 76 | + | 
 | 77 | +        @test Array(rbax) == Array(rbay)  | 
 | 78 | +        @test collect(rbax) == collect(rbay)  | 
 | 79 | +        @test Array(rbtx) == Array(rbty)  | 
 | 80 | +        @test collect(rbtx) == collect(rbty)  | 
 | 81 | +    end  | 
 | 82 | +end  | 
0 commit comments