Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/SciMLOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ export
export update_coefficients!,
update_coefficients, isconstant,
iscached,
cache_operator, issquare,
cache_operator, cache_operator_hinted,
update_cache,
issquare,
islinear,
concretize,
isconvertible, has_adjoint,
Expand Down
37 changes: 33 additions & 4 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,25 @@ has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops)
@generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat)
ops_types = L.parameters[2].parameters
N = length(ops_types)

# If multiple sub-operators share the same outermost type constructor (wrapper), we can cache one of them and reuse the cache for the others. This is because operators with the same wrapper will have the same caching structure, so we can avoid redundant caching work. The `donor` tuple identifies which operator's cache to use for each sub-operator.

donor = ntuple(i -> findfirst(j -> ops_types[j].name.wrapper === ops_types[i].name.wrapper, 1:N), N)

# Unique variable names for each cached sub-operator
syms = ntuple(i -> Symbol(:op_, i), N)

# Emit cache_operator for donors, cache_operator_hinted for the rest
stmts = ntuple(N) do i
d = donor[i]
d == i ?
:($(syms[i]) = cache_operator(L.ops[$i], v)) :
:($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[d])), v))
end

return quote
ops = Base.@ntuple $N i -> cache_operator(L.ops[i], v)
return AddedOperator(ops)
$(stmts...)
return AddedOperator(($(syms...),))
end
end

Expand Down Expand Up @@ -874,6 +890,7 @@ function update_coefficients(L::ComposedOperator, u, p, t; kwargs...)
end

getops(L::ComposedOperator) = L.ops
getcache(op::ComposedOperator) = op.cache

# Copy method to avoid aliasing
function Base.copy(L::ComposedOperator)
Expand Down Expand Up @@ -939,6 +956,16 @@ end
end
end

function _get_cache_shapes(L::ComposedOperator, v::AbstractVecOrMat)
N = length(L.ops)
res = if v isa AbstractMatrix
ntuple(i -> (size(L.ops[i], 1), size(v, 2)), Val(N))
else
ntuple(i -> (size(L.ops[i], 1),), Val(N))
end
return res
end

@generated function cache_self(L::ComposedOperator, v::AbstractVecOrMat)
N = length(L.parameters[2].parameters) # Number of operators

Expand Down Expand Up @@ -1199,6 +1226,7 @@ function update_coefficients(L::InvertedOperator, u, p, t; kwargs...)
end

getops(L::InvertedOperator) = (L.L,)
getcache(op::InvertedOperator) = op.cache
islinear(L::InvertedOperator) = islinear(L.L)
isconvertible(::InvertedOperator) = false

Expand Down Expand Up @@ -1229,9 +1257,10 @@ function Base.copy(L::InvertedOperator)
)
end

_get_cache_shapes(::InvertedOperator, v::AbstractVecOrMat) = size(v)

function cache_self(L::InvertedOperator, u::AbstractVecOrMat)
cache = zero(u)
@reset L.cache = cache
@reset L.cache = zero(u)
return L
end

Expand Down
7 changes: 7 additions & 0 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,13 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
return L
end

function _get_cache_shapes(L::FunctionOperator, v::AbstractVecOrMat)
return (L.traits.sizes[1], L.traits.sizes[2])
end

getcache(op::FunctionOperator) = op.cache
update_cache(L::FunctionOperator, new_cache) = set_cache(L, new_cache)

# fix method amg bw AbstractArray, AbstractVecOrMat
cache_self(L::FunctionOperator, v::AbstractArray) = _cache_self(L, v)
cache_self(L::FunctionOperator, v::AbstractVecOrMat) = _cache_self(L, v)
Expand Down
56 changes: 56 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ getops(L) = ()
"""
$SIGNATURES

Return the current cache held by `op`, or `nothing` if it holds none.
New operator types get the safe `nothing` default automatically; override
for types that store a shareable `.cache` field.
"""
getcache(::AbstractSciMLOperator) = nothing

"""
$SIGNATURES

Checks whether `L` has preallocated caches for inplace evaluations.
"""
function iscached(L::AbstractSciMLOperator)
Expand Down Expand Up @@ -179,6 +188,53 @@ end
cache_self(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L
cache_internals(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L

"""
$SIGNATURES

Return the expected cache shape specification for `L` given input `v`.
Returns `nothing` if `L` requires no cache.
The return value can be a single `NTuple{N,Int}` (single-array cache),
a `Tuple` of such shapes (multi-array cache), or `nothing` for absent slots.
"""
_get_cache_shapes(::AbstractSciMLOperator, ::AbstractVecOrMat) = nothing

"""
$SIGNATURES

Check whether `hint` is shape-compatible with `shapes` (as returned by `_get_cache_shapes`).
Uses `zip` to avoid integer-indexed Tuple access. Reads only array metadata — safe on GPU.
"""
_cache_compatible(hint, ::Nothing) = false
_cache_compatible(::Nothing, shapes) = false
_cache_compatible(::Nothing, ::Nothing) = false
_cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}) = size(hint) == shape
function _cache_compatible(hint::Tuple, shapes::Tuple)
length(hint) != length(shapes) && return false
return all(((h, s),) -> _cache_compatible(h, s), zip(hint, shapes))
end

"""
$SIGNATURES

Inject `new_cache` into `op` as its cache. Default uses `@reset op.cache = new_cache`.
Override for operators that don't use the `.cache` field convention.
"""
update_cache(op::AbstractSciMLOperator, new_cache) = @reset op.cache = new_cache

"""
$SIGNATURES

Like `cache_operator`, but tries to reuse `hint` (an existing cache from a compatible operator)
instead of allocating new buffers. Falls back to `cache_operator` when `hint` is not compatible.
"""
function cache_operator_hinted(op::AbstractSciMLOperator, hint, v::AbstractVecOrMat)
if _cache_compatible(hint, _get_cache_shapes(op, v))
op = update_cache(op, hint)
return cache_internals(op, v)
end
return cache_operator(op, v)
end

###
# operator traits
###
Expand Down
48 changes: 35 additions & 13 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ function update_coefficients(L::TensorProductOperator, u, p, t; kwargs...)
end

getops(L::TensorProductOperator) = L.ops
getcache(op::TensorProductOperator) = op.cache

# Copy method to avoid aliasing
function Base.copy(L::TensorProductOperator)
Expand Down Expand Up @@ -362,30 +363,51 @@ function Base.:\(L::TensorProductOperator, v::AbstractVecOrMat)
return v isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,))
end

function cache_self(L::TensorProductOperator, v::AbstractVecOrMat)
function _get_cache_shapes(L::TensorProductOperator, v::AbstractVecOrMat)
outer, inner = L.ops
outer isa IdentityOperator && return nothing

mi, ni = size(inner)
mo, no = size(outer)
k = size(v, 2)

is_outer_identity = outer isa IdentityOperator
s1 = (mi, no * k)
s2 = (no, mi, k)
s3 = (mo, mi * k)
s4 = (mo * mi, k)

if reduce(&, issquare.(L.ops))
return (s1, s2, s3, s4, s1, s2, s3)
else
s5 = (ni, mo * k)
s6 = (mo, ni, k)
s7 = (no, ni * k)
return (s1, s2, s3, s4, s5, s6, s7)
end
end

function cache_self(L::TensorProductOperator, v::AbstractVecOrMat)
shapes = _get_cache_shapes(L, v)

# outer is IdentityOperator — no buffers needed
if isnothing(shapes)
@reset L.cache = (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
return L
end

# 3 arg mul!
c1 = is_outer_identity ? nothing : lmul!(false, similar(v, (mi, no * k))) # c1 = inner * v
c2 = is_outer_identity ? nothing : lmul!(false, similar(v, (no, mi, k))) # permute (2, 1, 3)
c3 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo, mi * k))) # c3 = outer * c2
s1, s2, s3, s4, s5, s6, s7 = shapes

# 5 arg mul!
c4 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo * mi, k))) # cache v in 5 arg mul!
c1 = lmul!(false, similar(v, s1)) # inner * v (3-arg mul!)
c2 = lmul!(false, similar(v, s2)) # permute (2,1,3)
c3 = lmul!(false, similar(v, s3)) # outer * c2
c4 = lmul!(false, similar(v, s4)) # copy of w for 5-arg mul!

# 3 arg ldiv!
if mapreduce(issquare, &, L.ops)
c5, c6, c7 = c1, c2, c3
c5, c6, c7 = c1, c2, c3 # square case: ldiv! reuses mul! buffers
else
c5 = lmul!(false, similar(v, (ni, mo * k))) # c5 = inner \ v
c6 = lmul!(false, similar(v, (mo, ni, k))) # permute (2, 1, 3)
c7 = lmul!(false, similar(v, (no, ni * k))) # c7 = outer \ c6
c5 = lmul!(false, similar(v, s5)) # inner \ v (3-arg ldiv!)
c6 = lmul!(false, similar(v, s6)) # permute (2,1,3)
c7 = lmul!(false, similar(v, s7)) # outer \ c6
end

@reset L.cache = (c1, c2, c3, c4, c5, c6, c7)
Expand Down
50 changes: 50 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,56 @@ end
end
end

@testset "AddedOperator cache sharing (Composed, Tensor, Composed, Tensor, Tensor)" begin
using SciMLOperators: cache_operator_hinted

m1, m2 = 2, 4 # m1 * m2 == N

# C1 and C2: same wrapper (ComposedOperator), different inner type params
# C1 = A1*B1 → ops::Tuple{MatrixOperator, MatrixOperator}
# C2 = A2*B2' → ops::Tuple{MatrixOperator, AdjointOperator{…}}
A1 = MatrixOperator(rand(N, N)); B1 = MatrixOperator(rand(N, N))
A2 = MatrixOperator(rand(N, N)); B2 = MatrixOperator(rand(N, N))
C1 = A1 * B1
C2 = A2 * B2'

# T1, T2, T3: same wrapper (TensorProductOperator), different inner type params
# T1 = Ao ⊗ Ai, T2 = Ao' ⊗ Ai, T3 = Ao ⊗ Ai'
Ao = MatrixOperator(rand(m1, m1)); Ai = MatrixOperator(rand(m2, m2))
T1 = TensorProductOperator(Ao, Ai)
T2 = TensorProductOperator(Ao', Ai)
T3 = TensorProductOperator(Ao, Ai')

L = C1 + T1 + C2 + T2 + T3 + A1 + A2
@test L isa AddedOperator
@test length(L.ops) == 7

# Matrix input
u = rand(N, K)
L = cache_operator(L, u)

# Correctness: the cached operator gives the right result
expected = C1 * u + T1 * u + C2 * u + T2 * u + T3 * u + A1 * u + A2 * u
@test L * u ≈ expected

# Cache sharing: same-wrapper sub-operators with compatible sizes share physical buffers
@test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache (same wrapper)
@test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache (same wrapper)
@test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache (same wrapper)

# Vector input
v = rand(N)
L = cache_operator(L, v)

expected = C1 * v + T1 * v + C2 * v + T2 * v + T3 * v + A1 * v + A2 * v
@test L * v ≈ expected

@test L.ops[3].cache === L.ops[1].cache
@test L.ops[4].cache === L.ops[2].cache
@test L.ops[5].cache === L.ops[2].cache
end


@testset "ComposedOperator" begin
A = rand(N, N)
B = rand(N, N)
Expand Down
Loading