-
-
Notifications
You must be signed in to change notification settings - Fork 105
Fix adjoint for NonlinearSolution constructor #998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Tests fail. |
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like: * SciML/SciMLBase.jl#997 * SciML/SciMLBase.jl#998 Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures. So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like: * #997 * #998 Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures. So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
ext/SciMLBaseChainRulesCoreExt.jl
Outdated
T, N, uType, R, P, A, O, uType2, S, Tr}}, u, | ||
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} | ||
function NonlinearSolutionAdjoint(ȳ) | ||
(NoTangent(), ȳ.u, ntuple(_ -> NoTangent(), length(args))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably track .prob
as well (for gradients against parameters).
Also what is the type of ȳ
? What types could we encounter here? If it's another NonlinearSolution, we should investigate what produces it and potentially try to return a tangent type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an mwe is
@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)
prob = ODEProblem(pend, [x => 1, y => 0], (0.0, 1.5), [g => 1.5], guesses = [λ => 1])
sol = solve(prob, Rodas5P())
get_vars = getsym(prob, [pend.x+pend.y]);
Zygote.gradient(sol) do sol
u = get_vars(sol)
# u = sol[pend.x+pend.y]
sum(reduce(vcat, u))
end
which points to this branch
SciMLBase.jl/ext/SciMLBaseZygoteExt.jl
Line 106 in 42cdf6a
VA = recursivecopy(VA) |
SciMLBase.jl/ext/SciMLBaseZygoteExt.jl
Line 188 in 42cdf6a
VA = recursivecopy(VA) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ȳ
was always a NonlinearSolution
when I was testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, just now getting back to this.
To track .prob
, would it be like this:
function ChainRulesCore.rrule(
::Type{<:SciMLBase.NonlinearSolution{
T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob,
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
function NonlinearSolutionAdjoint(ȳ)
(NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...)
end
SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...),
NonlinearSolutionAdjoint
end
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this looks reasonable
ext/SciMLBaseChainRulesCoreExt.jl
Outdated
T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob, | ||
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} | ||
function NonlinearSolutionAdjoint(ȳ) | ||
(NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a follow up, is the type of ȳ
a solution type still with the latest SciMLSensitivity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ȳ
is a Tangent type with the latest SciMLSensitivity:
typeof(ȳ) = Tangent{Any}(u = 1.0, resid = ChainRulesCore.ZeroTangent(), prob = ChainRulesCore.ZeroTangent(), alg = ChainRulesCore.ZeroTangent(), retcode = ChainRulesCore.ZeroTangent(), original = ChainRulesCore.ZeroTangent(), left = ChainRulesCore.ZeroTangent(), right = ChainRulesCore.ZeroTangent(), stats = ChainRulesCore.ZeroTangent(), trace = ChainRulesCore.ZeroTangent())
031520d
to
d948f07
Compare
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Should hopefully fix SciML/NonlinearSolve.jl#581, in conjuction with #997