diff --git a/Project.toml b/Project.toml index 490aaad5..4655ac0b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.10.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] Adapt = "2, 3" @@ -24,9 +25,8 @@ DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CatIndices", "DistributedArrays", "DelimitedFiles", "Documenter", "Test", "LinearAlgebra", "EllipsisNotation", "StaticArrays", "FillArrays"] +test = ["Aqua", "CatIndices", "DistributedArrays", "DelimitedFiles", "Documenter", "Test", "EllipsisNotation", "StaticArrays", "FillArrays"] diff --git a/src/OffsetArrays.jl b/src/OffsetArrays.jl index a00aedff..b46d4149 100644 --- a/src/OffsetArrays.jl +++ b/src/OffsetArrays.jl @@ -651,6 +651,7 @@ if isdefined(Base, :IdentityUnitRange) no_offset_view(a::Base.Slice) = Base.Slice(UnitRange(a)) no_offset_view(S::SubArray) = view(parent(S), map(no_offset_view, parentindices(S))...) end +no_offset_view(A::PermutedDimsArray{T,N,perm,iperm,P}) where {T,N,perm,iperm,P} = PermutedDimsArray(no_offset_view(parent(A)), perm) no_offset_view(a::Array) = a no_offset_view(i::Number) = i no_offset_view(A::AbstractArray) = _no_offset_view(axes(A), A) @@ -853,6 +854,8 @@ end import Adapt Adapt.adapt_structure(to, O::OffsetArray) = parent_call(x -> Adapt.adapt(to, x), O) +include("linearalgebra.jl") + if Base.VERSION >= v"1.4.2" include("precompile.jl") _precompile_() diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl new file mode 100644 index 00000000..2815ca41 --- /dev/null +++ b/src/linearalgebra.jl @@ -0,0 +1,56 @@ +using LinearAlgebra +using LinearAlgebra: MulAddMul, mul!, AdjOrTrans + +@inline LinearAlgebra.generic_matvecmul!(C::OffsetVector, fA::Function, A::AbstractVecOrMat, B::AbstractVector, + alpha, beta) = unwrap_matvecmul!(C, fA, A, B, alpha, beta) + +@inline function unwrap_matvecmul!(C::OffsetVector, fA, A::AbstractVecOrMat, B::AbstractVector, + alpha, beta) + + mB_axis = Base.axes1(B) + mA_axis, nA_axis = axes(fA(A)) + + if mB_axis != nA_axis + throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with axes(B) == ($(UnitRange(mB_axis)),)")) + end + if mA_axis != Base.axes1(C) + throw(DimensionMismatch("mul! got axes(C) == ($(UnitRange(Base.axes1(C))),), expected $(UnitRange(mA_axis))")) + end + + mul!(no_offset_view(C), fA(no_offset_view(A)), no_offset_view(B), alpha, beta) + C +end + +# The signatures of these differs from LinearAlgebra's *only* on C: +@inline LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, fA::Function, fB::Function, A::AbstractMatrix, B::AbstractMatrix, + alpha, beta) = unwrap_matmatmul!(C, fA, fB, A, B, alpha, beta) + +@inline LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, fA::Function, fB::Function, A::AbstractVecOrMat, B::AbstractVecOrMat, + alpha, beta) = unwrap_matmatmul!(C, fA, fB, A, B, alpha, beta) + +@inline LinearAlgebra.generic_matmatmul!(C::AdjOrTrans{<:Any, <:OffsetArray}, fA::Function, fB::Function, A::AbstractMatrix, B::AbstractMatrix, + alpha, beta) = unwrap_matmatmul!(C, fA, fB, A, B, alpha, beta) + +@inline function unwrap_matmatmul!(C::AbstractVecOrMat, fA, fB, A::AbstractVecOrMat, B::AbstractVecOrMat, + alpha, beta) + + mA_axis, nA_axis = axes(fA(A)) + mB_axis, nB_axis = axes(fB(B)) + + if nA_axis != mB_axis + throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with $(UnitRange(mB_axis)) from B")) + elseif mA_axis != axes(C,1) + throw(DimensionMismatch("mul! got axes(C,1) == $(UnitRange(axes(C,1))), expected $(UnitRange(mA_axis)) from A")) + elseif nB_axis != axes(C,2) + throw(DimensionMismatch("mul! got axes(C,2) == $(UnitRange(axes(C,2))), expected $(UnitRange(nB_axis)) from B")) + end + + # Must be sure `no_offset_view(C)` won't match signature above! + mul!(no_offset_view(C), fA(no_offset_view(A)), fB(no_offset_view(B)), alpha, beta) + C +end + +no_offset_view(A::Adjoint) = adjoint(no_offset_view(parent(A))) +no_offset_view(A::Transpose) = transpose(no_offset_view(parent(A))) +no_offset_view(D::Diagonal) = Diagonal(no_offset_view(parent(D))) +