Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

[weakdeps]
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
SciMLOperatorsLoopVectorizationExt = "LoopVectorization"
SciMLOperatorsSparseArraysExt = "SparseArrays"
SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore"

Expand All @@ -22,6 +24,7 @@ Accessors = "0.1.42"
ArrayInterface = "7.19"
DocStringExtensions = "0.9.4"
LinearAlgebra = "1.10"
LoopVectorization = "0.12"
SparseArrays = "1.10"
StaticArraysCore = "1"
julia = "1.10"
46 changes: 46 additions & 0 deletions ext/SciMLOperatorsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module SciMLOperatorsLoopVectorizationExt

import LoopVectorization: @turbo
import SciMLOperators

const StridedMatrixOperator = SciMLOperators.MatrixOperator{<:Any, <:StridedMatrix}

SciMLOperators._has_tensor_outer_mul_fast(::StridedMatrixOperator) = true

function SciMLOperators._tensor_outer_mul_fast!(
w, outer::StridedMatrixOperator, C, mi::Int, mo::Int, no::Int, k::Int
)
A = outer.A
C = reshape(C, (mi, no, k))
W = reshape(w, (mi, mo, k))

@turbo for j in 1:k, m in 1:mo, i in 1:mi
acc = zero(eltype(w))
for o in 1:no
acc += A[m, o] * C[i, o, j]
end
W[i, m, j] = acc
end

return w
end

function SciMLOperators._tensor_outer_mul_fast!(
w, outer::StridedMatrixOperator, C, mi::Int, mo::Int, no::Int, k::Int, α, β
)
A = outer.A
C = reshape(C, (mi, no, k))
W = reshape(w, (mi, mo, k))

@turbo for j in 1:k, m in 1:mo, i in 1:mi
acc = zero(eltype(w))
for o in 1:no
acc += A[m, o] * C[i, o, j]
end
W[i, m, j] = α * acc + β * W[i, m, j]
end

return w
end

end
13 changes: 13 additions & 0 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ end
# helper functions
const PERM = (2, 1, 3)

_has_tensor_outer_mul_fast(outer) = false
function _tensor_outer_mul_fast! end

function outer_mul(L::TensorProductOperator, v::AbstractVecOrMat, C::AbstractVecOrMat)
outer, inner = L.ops

Expand Down Expand Up @@ -465,6 +468,11 @@ function outer_mul!(w::AbstractVecOrMat, L::TensorProductOperator, v::AbstractVe
return w
end

if _has_tensor_outer_mul_fast(outer)
_tensor_outer_mul_fast!(w, outer, C1, mi, mo, no, k)
return w
end

C2, C3 = L.cache[2:3]

C1 = reshape(C1, (mi, no, k))
Expand Down Expand Up @@ -503,6 +511,11 @@ function outer_mul!(
return w
end

if _has_tensor_outer_mul_fast(outer)
_tensor_outer_mul_fast!(w, outer, v, mi, mo, no, k, α, β)
return w
end

C2, C3, c4 = L.cache[2:4]

C = reshape(v, (mi, no, k))
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -10,5 +11,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
FFTW = "1.10.0"
LoopVectorization = "0.12"
SafeTestsets = "0.1.0"
Zygote = "0.7.10"
1 change: 1 addition & 0 deletions test/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using SciMLOperators, LinearAlgebra
using SparseArrays
using Random
using Test
using LoopVectorization

using SciMLOperators: InvertibleOperator, InvertedOperator, ⊗, AbstractSciMLOperator
using FFTW
Expand Down
Loading