Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error with ensemble solves when using Zygote #1160

Open
SebastianM-C opened this issue Feb 5, 2025 · 0 comments
Open

error with ensemble solves when using Zygote #1160

SebastianM-C opened this issue Feb 5, 2025 · 0 comments
Labels

Comments

@SebastianM-C
Copy link

Describe the bug 🐞

I'm trying to take the gradient of a function that has an ensemble solve inside it. The gradient errors due to getting a + between ensemble solutions somehow.

Expected behavior

Gradients should work.

Minimal Reproducible Example 👇

Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.

using OrdinaryDiffEqTsit5
using SciMLSensitivity
using Zygote

function mae2(sol, data)
    l = zero(eltype(data))
    for i in axes(data, 2)
        for j in axes(data, 1)
            l += abs2(sol.u[i][j] - data[j, i])
        end
    end

    l / length(data)
end

function ensemble_setup(x)
    function prob_func(prob, i, repeat)
        remake(prob, u0=rand(2))
    end

    function f(du, u, p, t)
        du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
        du[2] = -3 * u[2] + u[1] * u[2]
    end

    prob = ODEProblem(f, [0.5, 0.5], (0.0, 1.0), x)

    prob, prob_func
end

function ensemble_loss(x, data)
    prob, prob_func = ensemble_setup(x)

    ensembleprob = EnsembleProblem(prob; prob_func, safetycopy=false)

    sim = solve(ensembleprob, Tsit5(), EnsembleSerial();
        trajectories=3, saveat=[0., 0.4, 0.9],
        save_end=true)

    loss = zero(eltype(data))
    for i in Base.OneTo(3)
        sol = sim.u[i]
        loss += mae2(sol, data)
    end

    loss
end
_data = [1.1 2 4
    0 5. 6]
ensemble_loss(rand(4), _data)

Zygote.gradient(x -> ensemble_loss(x, _data), rand(4))

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching size(::Nothing)
The function `size` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  size(::IdentityOperator)
   @ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:21
  size(::NullOperator)
   @ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:115
  size(::LLVM.FunctionParameterSet)
   @ LLVM ~/.julia/packages/LLVM/b3kFs/src/core/function.jl:200
  ...

Stacktrace:
  [1] size
    @ ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:481 [inlined]
  [2] axes(VA::EnsembleSolution{Any, 1, Vector{Union{Nothing, RecursiveArrayTools.VectorOfArray{Float64, 2, Vector{}}}}})
    @ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:485
  [3] combine_axes
    @ ./broadcast.jl:497 [inlined]
  [4] instantiate
    @ ./broadcast.jl:307 [inlined]
  [5] materialize
    @ ./broadcast.jl:872 [inlined]
  [6] +(A::EnsembleSolution{Any, 1, Vector{Union{…}}}, B::EnsembleSolution{Any, 1, Vector{Union{…}}})
    @ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:661
  [7] accum(x::EnsembleSolution{Any, 1, Vector{Union{…}}}, y::EnsembleSolution{Any, 1, Vector{Union{…}}})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:17
  [8] ensemble_loss
    @ ~/dev/ensemble_zygote.jl:41 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(ensemble_loss), Vector{Float64}, Matrix{Float64}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [10] #22
    @ ~/dev/ensemble_zygote.jl:52 [inlined]
 [11] (::Zygote.Pullback{Tuple{var"#22#23", Vector{…}}, Tuple{Zygote.Pullback{…}, Zygote.var"#1986#back#198"{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:148
 [14] top-level scope
    @ ~/dev/ensemble_zygote.jl:52
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
Status `~/dev/Project.toml`
  [1ed8b502] SciMLSensitivity v7.72.0
⌅ [e88e6eb3] Zygote v0.6.75
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
  • Output of versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × Intel(R) Core(TM) i9-14900K
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant