You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to take the gradient of a function that has an ensemble solve inside it. The gradient errors due to getting a + between ensemble solutions somehow.
Expected behavior
Gradients should work.
Minimal Reproducible Example 👇
Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.
using OrdinaryDiffEqTsit5
using SciMLSensitivity
using Zygote
functionmae2(sol, data)
l =zero(eltype(data))
for i inaxes(data, 2)
for j inaxes(data, 1)
l +=abs2(sol.u[i][j] - data[j, i])
endend
l /length(data)
endfunctionensemble_setup(x)
functionprob_func(prob, i, repeat)
remake(prob, u0=rand(2))
endfunctionf(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] =-3* u[2] + u[1] * u[2]
end
prob =ODEProblem(f, [0.5, 0.5], (0.0, 1.0), x)
prob, prob_func
endfunctionensemble_loss(x, data)
prob, prob_func =ensemble_setup(x)
ensembleprob =EnsembleProblem(prob; prob_func, safetycopy=false)
sim =solve(ensembleprob, Tsit5(), EnsembleSerial();
trajectories=3, saveat=[0., 0.4, 0.9],
save_end=true)
loss =zero(eltype(data))
for i in Base.OneTo(3)
sol = sim.u[i]
loss +=mae2(sol, data)
end
loss
end
_data = [1.12405.6]
ensemble_loss(rand(4), _data)
Zygote.gradient(x ->ensemble_loss(x, _data), rand(4))
Error & Stacktrace ⚠️
ERROR: MethodError: no method matching size(::Nothing)
The function`size` exists, but no method is defined for this combination of argument types.
Closest candidates are:size(::IdentityOperator)
@ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:21size(::NullOperator)
@ SciMLOperators ~/.julia/packages/SciMLOperators/KVzmP/src/basic.jl:115size(::LLVM.FunctionParameterSet)
@ LLVM ~/.julia/packages/LLVM/b3kFs/src/core/function.jl:200...
Stacktrace:
[1] size
@ ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:481 [inlined]
[2] axes(VA::EnsembleSolution{Any, 1, Vector{Union{Nothing, RecursiveArrayTools.VectorOfArray{Float64, 2, Vector{…}}}}})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:485
[3] combine_axes
@ ./broadcast.jl:497 [inlined]
[4] instantiate
@ ./broadcast.jl:307 [inlined]
[5] materialize
@ ./broadcast.jl:872 [inlined]
[6] +(A::EnsembleSolution{Any, 1, Vector{Union{…}}}, B::EnsembleSolution{Any, 1, Vector{Union{…}}})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/Y3i0V/src/vector_of_array.jl:661
[7] accum(x::EnsembleSolution{Any, 1, Vector{Union{…}}}, y::EnsembleSolution{Any, 1, Vector{Union{…}}})
@ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:17
[8] ensemble_loss
@ ~/dev/ensemble_zygote.jl:41 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(ensemble_loss), Vector{Float64}, Matrix{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[10] #22
@ ~/dev/ensemble_zygote.jl:52 [inlined]
[11] (::Zygote.Pullback{Tuple{var"#22#23", Vector{…}}, Tuple{Zygote.Pullback{…}, Zygote.var"#1986#back#198"{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[12] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:148
[14] top-level scope
@ ~/dev/ensemble_zygote.jl:52
Some type information was truncated. Use `show(err)` to see complete types.
Environment (please complete the following information):
Output of using Pkg; Pkg.status()
Status `~/dev/Project.toml`
[1ed8b502] SciMLSensitivity v7.72.0
⌅ [e88e6eb3] Zygote v0.6.75
Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
Output of versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-2119:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU:32×Intel(R) Core(TM) i9-14900K
WORD_SIZE:64
LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads:32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
JULIA_EDITOR = code
The text was updated successfully, but these errors were encountered:
Describe the bug 🐞
I'm trying to take the gradient of a function that has an ensemble solve inside it. The gradient errors due to getting a
+
between ensemble solutions somehow.Expected behavior
Gradients should work.
Minimal Reproducible Example 👇
Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.
Error & Stacktrace⚠️
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
The text was updated successfully, but these errors were encountered: