Skip to content

Share same cache among similar operators in AddedOperator#370

Open
albertomercurio wants to merge 4 commits intoSciML:masterfrom
albertomercurio:master
Open

Share same cache among similar operators in AddedOperator#370
albertomercurio wants to merge 4 commits intoSciML:masterfrom
albertomercurio:master

Conversation

@albertomercurio
Copy link
Copy Markdown
Contributor

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

When caching an AddedOperator, we currently cache all the suboperators iteratively. However, this is suboptimal, as many of them can in principle share the same cache when possible.

In this PR I implement this cache sharing in the AddedOperator. Let's imagine I have the sum of Composed + Tensor + Tensor + Composed + Matrix + Tensor, many of them require some cache, but they could share the same cache. We can be sure they can share the same cache if

  1. They are the same constructor (e.g., ComposedOperator)
  2. The size of their cache is the same

the first one can be checked even at compile time, while the second requires to define _get_cache_shapes(op, v), which returns the shape of the cache.

Benchmarks

This method allows to save a lot of memory, as also shown in this example

using SparseArrays
using SciMLOperators
using CairoMakie

function generate_op(N)
    M = 10
    sparsity = 1 / N

    A_list = ntuple(i -> MatrixOperator(sprand(N, N, sparsity)) * MatrixOperator(sprand(N, N, sparsity)), Val(M))

    op = SciMLOperators.AddedOperator(A_list)

    u = rand(N, N)
    return cache_operator(op, u)
    # return op
end

# %%

N_list = 4 .^ (2:7)
sizes_main = [Base.summarysize(generate_op(N)) for N in N_list]
sizes_pr = [Base.summarysize(generate_op(N)) for N in N_list]

# %%

fig = Figure()
ax = Axis(fig[1, 1], xlabel="N", ylabel="Size (MB)", xscale=log10, yscale=log10)

scatterlines!(ax, N_list, sizes_main ./ 1e6, label="main")
scatterlines!(ax, N_list, sizes_pr ./ 1e6, label="PR")

axislegend(ax; position = :lt)

fig
image

albertomercurio and others added 4 commits May 6, 2026 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant