From dae1e07605f8bb10102a4d1b5b18d8f8703bec20 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Wed, 30 Apr 2025 04:57:57 -0500 Subject: [PATCH 1/3] Fix & test * with matrix-of-adjoint This provides a bandaid for broken code on Julia 1.10, and adds a test for all versions. --- src/matrix_multiply_add.jl | 20 ++++++++++++++++++++ test/matrix_multiply_add.jl | 12 +++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/matrix_multiply_add.jl b/src/matrix_multiply_add.jl index b8e1a71c4..9c6497033 100644 --- a/src/matrix_multiply_add.jl +++ b/src/matrix_multiply_add.jl @@ -248,6 +248,26 @@ const StaticVecOrMatLikeForFiveArgMulDest{T} = Union{ return _mul!(TSize(dest), mul_parent(dest), Size(A), Size(B), A, B, NoMulAdd{TMul, TDest}()) end +@static if Base.VERSION < v"1.11" + function LinearAlgebra.mul!( + dest::AbstractMatrix{<:StaticMatrix}, + A::AbstractMatrix{<:StaticVector}, + B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} + ) where V<:StaticVector + axdest, axA, axB = axes(dest), axes(A), axes(B) + axA[2] == axB[1] || throw(DimensionMismatch("Tried to multiply arrays with axes $axA and $axB")) + axdest == (axA[1], axB[2]) || throw(DimensionMismatch("Tried to multiply arrays with axes $axA and $axB and assign to array with axes $axdest")) + fill!(dest, zero(eltype(dest))) + # This is not maximally efficient, but it produces the right answer on older Julia versions. + for ij in CartesianIndices(dest) + for k in axA[2] + dest[ij] += A[ij[1], k] * B[k, ij[2]] + end + end + return dest + end +end + """ multiplied_dimension(A, B) diff --git a/test/matrix_multiply_add.jl b/test/matrix_multiply_add.jl index 7d5bd8c7b..3523dfec9 100644 --- a/test/matrix_multiply_add.jl +++ b/test/matrix_multiply_add.jl @@ -1,6 +1,7 @@ using StaticArrays using LinearAlgebra using BenchmarkTools +using Random using Test macro test_noalloc(ex) @@ -226,7 +227,7 @@ function test_wrappers_for_size(N, test_block) A_block = rand(SMatrix{N,N,SMatrix{2,2,Int,4}}) B_block = rand(SMatrix{N,N,SMatrix{2,2,Int,4}}) bv_block = rand(SVector{N,SMatrix{2,2,Int,4}}) - + # matrix-vector for wrapper in mul_add_wrappers # LinearAlgebra can't handle these @@ -252,3 +253,12 @@ end test_wrappers_for_size(8, false) test_wrappers_for_size(16, false) end + +@testset "Adjoints and covariances" begin + X = [randn(SVector{3,Float64}) for _ in CartesianIndices((1:2, 1:3))] + μ = sum(X; dims=2)/size(X, 2) + ΔX = X .- μ + @test (ΔX*ΔX')[1, 1] ≈ sum(dx * dx' for dx in ΔX[1, :]) + B = [randn(SVector{3,Float64}) for _ in CartesianIndices((1:2, 1:2))] + @test_throws DimensionMismatch ΔX * B' +end From 325ef1239b0162688610be8bc524c967580c9b55 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Wed, 30 Apr 2025 05:32:21 -0500 Subject: [PATCH 2/3] fix ambiguity --- src/matrix_multiply_add.jl | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/matrix_multiply_add.jl b/src/matrix_multiply_add.jl index 9c6497033..b2aa42186 100644 --- a/src/matrix_multiply_add.jl +++ b/src/matrix_multiply_add.jl @@ -249,22 +249,34 @@ const StaticVecOrMatLikeForFiveArgMulDest{T} = Union{ end @static if Base.VERSION < v"1.11" + let + function mulfix!(dest, A, B) + axdest, axA, axB = axes(dest), axes(A), axes(B) + axA[2] == axB[1] || throw(DimensionMismatch("Tried to multiply arrays with axes $axA and $axB")) + axdest == (axA[1], axB[2]) || throw(DimensionMismatch("Tried to multiply arrays with axes $axA and $axB and assign to array with axes $axdest")) + fill!(dest, zero(eltype(dest))) + # This is not maximally efficient, but it produces the right answer on older Julia versions. + for ij in CartesianIndices(dest) + for k in axA[2] + dest[ij] += A[ij[1], k] * B[k, ij[2]] + end + end + return dest + end + end + function LinearAlgebra.mul!( + dest::AbstractMatrix{<:StaticMatrix}, + A::LinearAlgebra.AbstractTriangular{<:StaticVector}, + B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} + ) where V<:StaticVector + mulfix!(dest, A, B) + end function LinearAlgebra.mul!( dest::AbstractMatrix{<:StaticMatrix}, A::AbstractMatrix{<:StaticVector}, B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} ) where V<:StaticVector - axdest, axA, axB = axes(dest), axes(A), axes(B) - axA[2] == axB[1] || throw(DimensionMismatch("Tried to multiply arrays with axes $axA and $axB")) - axdest == (axA[1], axB[2]) || throw(DimensionMismatch("Tried to multiply arrays with axes $axA and $axB and assign to array with axes $axdest")) - fill!(dest, zero(eltype(dest))) - # This is not maximally efficient, but it produces the right answer on older Julia versions. - for ij in CartesianIndices(dest) - for k in axA[2] - dest[ij] += A[ij[1], k] * B[k, ij[2]] - end - end - return dest + mulfix!(dest, A, B) end end From 6a24f711757d250a808e8d3625a9cad0ce31dad3 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Wed, 30 Apr 2025 05:49:27 -0500 Subject: [PATCH 3/3] fix scoping --- src/matrix_multiply_add.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/matrix_multiply_add.jl b/src/matrix_multiply_add.jl index b2aa42186..a958024a4 100644 --- a/src/matrix_multiply_add.jl +++ b/src/matrix_multiply_add.jl @@ -263,20 +263,20 @@ end end return dest end - end - function LinearAlgebra.mul!( - dest::AbstractMatrix{<:StaticMatrix}, - A::LinearAlgebra.AbstractTriangular{<:StaticVector}, - B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} - ) where V<:StaticVector - mulfix!(dest, A, B) - end - function LinearAlgebra.mul!( - dest::AbstractMatrix{<:StaticMatrix}, - A::AbstractMatrix{<:StaticVector}, - B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} - ) where V<:StaticVector - mulfix!(dest, A, B) + function LinearAlgebra.mul!( + dest::AbstractMatrix{<:StaticMatrix}, + A::LinearAlgebra.AbstractTriangular{<:StaticVector}, + B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} + ) where V<:StaticVector + mulfix!(dest, A, B) + end + function LinearAlgebra.mul!( + dest::AbstractMatrix{<:StaticMatrix}, + A::AbstractMatrix{<:StaticVector}, + B::Adjoint{Adjoint{Float64, V}, <:AbstractMatrix{V}} + ) where V<:StaticVector + mulfix!(dest, A, B) + end end end