diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index 2a756e71..50e859ab 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -234,7 +234,9 @@ export export update_coefficients!, update_coefficients, isconstant, iscached, - cache_operator, issquare, + cache_operator, cache_operator_hinted, + update_cache, + issquare, islinear, concretize, isconvertible, has_adjoint, diff --git a/src/basic.jl b/src/basic.jl index c7a34060..54b55ab6 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -648,9 +648,25 @@ has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) @generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat) ops_types = L.parameters[2].parameters N = length(ops_types) + + # If multiple sub-operators share the same outermost type constructor (wrapper), we can cache one of them and reuse the cache for the others. This is because operators with the same wrapper will have the same caching structure, so we can avoid redundant caching work. The `donor` tuple identifies which operator's cache to use for each sub-operator. + + donor = ntuple(i -> findfirst(j -> ops_types[j].name.wrapper === ops_types[i].name.wrapper, 1:N), N) + + # Unique variable names for each cached sub-operator + syms = ntuple(i -> Symbol(:op_, i), N) + + # Emit cache_operator for donors, cache_operator_hinted for the rest + stmts = ntuple(N) do i + d = donor[i] + d == i ? + :($(syms[i]) = cache_operator(L.ops[$i], v)) : + :($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[d])), v)) + end + return quote - ops = Base.@ntuple $N i -> cache_operator(L.ops[i], v) - return AddedOperator(ops) + $(stmts...) + return AddedOperator(($(syms...),)) end end @@ -874,6 +890,7 @@ function update_coefficients(L::ComposedOperator, u, p, t; kwargs...) end getops(L::ComposedOperator) = L.ops +getcache(op::ComposedOperator) = op.cache # Copy method to avoid aliasing function Base.copy(L::ComposedOperator) @@ -939,6 +956,16 @@ end end end +function _get_cache_shapes(L::ComposedOperator, v::AbstractVecOrMat) + N = length(L.ops) + res = if v isa AbstractMatrix + ntuple(i -> (size(L.ops[i], 1), size(v, 2)), Val(N)) + else + ntuple(i -> (size(L.ops[i], 1),), Val(N)) + end + return res +end + @generated function cache_self(L::ComposedOperator, v::AbstractVecOrMat) N = length(L.parameters[2].parameters) # Number of operators @@ -1199,6 +1226,7 @@ function update_coefficients(L::InvertedOperator, u, p, t; kwargs...) end getops(L::InvertedOperator) = (L.L,) +getcache(op::InvertedOperator) = op.cache islinear(L::InvertedOperator) = islinear(L.L) isconvertible(::InvertedOperator) = false @@ -1229,9 +1257,10 @@ function Base.copy(L::InvertedOperator) ) end +_get_cache_shapes(::InvertedOperator, v::AbstractVecOrMat) = size(v) + function cache_self(L::InvertedOperator, u::AbstractVecOrMat) - cache = zero(u) - @reset L.cache = cache + @reset L.cache = zero(u) return L end diff --git a/src/func.jl b/src/func.jl index 19cd14bd..eeaa28c9 100644 --- a/src/func.jl +++ b/src/func.jl @@ -590,6 +590,13 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray) return L end +function _get_cache_shapes(L::FunctionOperator, v::AbstractVecOrMat) + return (L.traits.sizes[1], L.traits.sizes[2]) +end + +getcache(op::FunctionOperator) = op.cache +update_cache(L::FunctionOperator, new_cache) = set_cache(L, new_cache) + # fix method amg bw AbstractArray, AbstractVecOrMat cache_self(L::FunctionOperator, v::AbstractArray) = _cache_self(L, v) cache_self(L::FunctionOperator, v::AbstractVecOrMat) = _cache_self(L, v) diff --git a/src/interface.jl b/src/interface.jl index 2801dc70..dc201749 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -139,6 +139,15 @@ getops(L) = () """ $SIGNATURES +Return the current cache held by `op`, or `nothing` if it holds none. +New operator types get the safe `nothing` default automatically; override +for types that store a shareable `.cache` field. +""" +getcache(::AbstractSciMLOperator) = nothing + +""" +$SIGNATURES + Checks whether `L` has preallocated caches for inplace evaluations. """ function iscached(L::AbstractSciMLOperator) @@ -179,6 +188,53 @@ end cache_self(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L cache_internals(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L +""" +$SIGNATURES + +Return the expected cache shape specification for `L` given input `v`. +Returns `nothing` if `L` requires no cache. +The return value can be a single `NTuple{N,Int}` (single-array cache), +a `Tuple` of such shapes (multi-array cache), or `nothing` for absent slots. +""" +_get_cache_shapes(::AbstractSciMLOperator, ::AbstractVecOrMat) = nothing + +""" +$SIGNATURES + +Check whether `hint` is shape-compatible with `shapes` (as returned by `_get_cache_shapes`). +Uses `zip` to avoid integer-indexed Tuple access. Reads only array metadata — safe on GPU. +""" +_cache_compatible(hint, ::Nothing) = false +_cache_compatible(::Nothing, shapes) = false +_cache_compatible(::Nothing, ::Nothing) = false +_cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}) = size(hint) == shape +function _cache_compatible(hint::Tuple, shapes::Tuple) + length(hint) != length(shapes) && return false + return all(((h, s),) -> _cache_compatible(h, s), zip(hint, shapes)) +end + +""" +$SIGNATURES + +Inject `new_cache` into `op` as its cache. Default uses `@reset op.cache = new_cache`. +Override for operators that don't use the `.cache` field convention. +""" +update_cache(op::AbstractSciMLOperator, new_cache) = @reset op.cache = new_cache + +""" +$SIGNATURES + +Like `cache_operator`, but tries to reuse `hint` (an existing cache from a compatible operator) +instead of allocating new buffers. Falls back to `cache_operator` when `hint` is not compatible. +""" +function cache_operator_hinted(op::AbstractSciMLOperator, hint, v::AbstractVecOrMat) + if _cache_compatible(hint, _get_cache_shapes(op, v)) + op = update_cache(op, hint) + return cache_internals(op, v) + end + return cache_operator(op, v) +end + ### # operator traits ### diff --git a/src/tensor.jl b/src/tensor.jl index 34d8be87..871c8972 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -175,6 +175,7 @@ function update_coefficients(L::TensorProductOperator, u, p, t; kwargs...) end getops(L::TensorProductOperator) = L.ops +getcache(op::TensorProductOperator) = op.cache # Copy method to avoid aliasing function Base.copy(L::TensorProductOperator) @@ -362,30 +363,51 @@ function Base.:\(L::TensorProductOperator, v::AbstractVecOrMat) return v isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,)) end -function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) +function _get_cache_shapes(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops + outer isa IdentityOperator && return nothing mi, ni = size(inner) mo, no = size(outer) k = size(v, 2) - is_outer_identity = outer isa IdentityOperator + s1 = (mi, no * k) + s2 = (no, mi, k) + s3 = (mo, mi * k) + s4 = (mo * mi, k) + + if reduce(&, issquare.(L.ops)) + return (s1, s2, s3, s4, s1, s2, s3) + else + s5 = (ni, mo * k) + s6 = (mo, ni, k) + s7 = (no, ni * k) + return (s1, s2, s3, s4, s5, s6, s7) + end +end + +function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) + shapes = _get_cache_shapes(L, v) + + # outer is IdentityOperator — no buffers needed + if isnothing(shapes) + @reset L.cache = (nothing, nothing, nothing, nothing, nothing, nothing, nothing) + return L + end - # 3 arg mul! - c1 = is_outer_identity ? nothing : lmul!(false, similar(v, (mi, no * k))) # c1 = inner * v - c2 = is_outer_identity ? nothing : lmul!(false, similar(v, (no, mi, k))) # permute (2, 1, 3) - c3 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo, mi * k))) # c3 = outer * c2 + s1, s2, s3, s4, s5, s6, s7 = shapes - # 5 arg mul! - c4 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo * mi, k))) # cache v in 5 arg mul! + c1 = lmul!(false, similar(v, s1)) # inner * v (3-arg mul!) + c2 = lmul!(false, similar(v, s2)) # permute (2,1,3) + c3 = lmul!(false, similar(v, s3)) # outer * c2 + c4 = lmul!(false, similar(v, s4)) # copy of w for 5-arg mul! - # 3 arg ldiv! if mapreduce(issquare, &, L.ops) - c5, c6, c7 = c1, c2, c3 + c5, c6, c7 = c1, c2, c3 # square case: ldiv! reuses mul! buffers else - c5 = lmul!(false, similar(v, (ni, mo * k))) # c5 = inner \ v - c6 = lmul!(false, similar(v, (mo, ni, k))) # permute (2, 1, 3) - c7 = lmul!(false, similar(v, (no, ni * k))) # c7 = outer \ c6 + c5 = lmul!(false, similar(v, s5)) # inner \ v (3-arg ldiv!) + c6 = lmul!(false, similar(v, s6)) # permute (2,1,3) + c7 = lmul!(false, similar(v, s7)) # outer \ c6 end @reset L.cache = (c1, c2, c3, c4, c5, c6, c7) diff --git a/test/basic.jl b/test/basic.jl index aaab3567..74f27d9d 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -408,6 +408,56 @@ end end end +@testset "AddedOperator cache sharing (Composed, Tensor, Composed, Tensor, Tensor)" begin + using SciMLOperators: cache_operator_hinted + + m1, m2 = 2, 4 # m1 * m2 == N + + # C1 and C2: same wrapper (ComposedOperator), different inner type params + # C1 = A1*B1 → ops::Tuple{MatrixOperator, MatrixOperator} + # C2 = A2*B2' → ops::Tuple{MatrixOperator, AdjointOperator{…}} + A1 = MatrixOperator(rand(N, N)); B1 = MatrixOperator(rand(N, N)) + A2 = MatrixOperator(rand(N, N)); B2 = MatrixOperator(rand(N, N)) + C1 = A1 * B1 + C2 = A2 * B2' + + # T1, T2, T3: same wrapper (TensorProductOperator), different inner type params + # T1 = Ao ⊗ Ai, T2 = Ao' ⊗ Ai, T3 = Ao ⊗ Ai' + Ao = MatrixOperator(rand(m1, m1)); Ai = MatrixOperator(rand(m2, m2)) + T1 = TensorProductOperator(Ao, Ai) + T2 = TensorProductOperator(Ao', Ai) + T3 = TensorProductOperator(Ao, Ai') + + L = C1 + T1 + C2 + T2 + T3 + A1 + A2 + @test L isa AddedOperator + @test length(L.ops) == 7 + + # Matrix input + u = rand(N, K) + L = cache_operator(L, u) + + # Correctness: the cached operator gives the right result + expected = C1 * u + T1 * u + C2 * u + T2 * u + T3 * u + A1 * u + A2 * u + @test L * u ≈ expected + + # Cache sharing: same-wrapper sub-operators with compatible sizes share physical buffers + @test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache (same wrapper) + @test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache (same wrapper) + @test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache (same wrapper) + + # Vector input + v = rand(N) + L = cache_operator(L, v) + + expected = C1 * v + T1 * v + C2 * v + T2 * v + T3 * v + A1 * v + A2 * v + @test L * v ≈ expected + + @test L.ops[3].cache === L.ops[1].cache + @test L.ops[4].cache === L.ops[2].cache + @test L.ops[5].cache === L.ops[2].cache +end + + @testset "ComposedOperator" begin A = rand(N, N) B = rand(N, N)