diff --git a/src/diagonal.jl b/src/diagonal.jl index b6e26249..7a304738 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -806,6 +806,11 @@ end kron(A::Diagonal, B::Diagonal) = Diagonal(kron(A.diag, B.diag)) +function kron!(C::Diagonal, A::Diagonal, B::Diagonal) + kron!(C.diag, A.diag, B.diag) + return C +end + function kron(A::Diagonal, B::SymTridiagonal) kdv = kron(A.diag, B.dv) # We don't need to drop the last element diff --git a/test/diagonal.jl b/test/diagonal.jl index 55d493e5..712f426c 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -1396,14 +1396,26 @@ end end @testset "kron! for Diagonal" begin - a = Diagonal([2,2]) - b = Diagonal([1,1]) - c = Diagonal([0,0,0,0]) - kron!(c,b,a) - @test c == Diagonal([2,2,2,2]) - c=Diagonal(Vector{Float64}(undef, 4)) - kron!(c,a,b) - @test c == Diagonal([2,2,2,2]) + a = Diagonal([1, 2]) + b = Diagonal([3, 4]) + # Diagonal out + c = Diagonal([0, 0, 0, 0]) + kron!(c, b, a) + @test c == Diagonal([3, 6, 4, 8]) + @test c == kron!(fill(0, 4, 4), Matrix(b), Matrix(a)) # against dense kron! + c = Diagonal(Vector{Float64}(undef, 4)) + kron!(c, a, b) + @test c == Diagonal([3.0, 4.0, 6.0, 8.0]) + + # AbstractArray out + c = fill(0, 4, 4) + kron!(c, b, a) + @test c == diagm([3, 6, 4, 8]) + @test c == kron!(fill(0, 4, 4), Matrix(b), Matrix(a)) # against dense kron! + c = Matrix{Float64}(undef, 4, 4) + kron!(c, a, b) + @test c == diagm([3.0, 4.0, 6.0, 8.0]) + @test_throws DimensionMismatch kron!(Diagonal(zeros(5)), Diagonal(zeros(2)), Diagonal(zeros(2))) end @testset "uppertriangular/lowertriangular" begin