Skip to content

Fix adjoint for NonlinearSolution constructor #998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

jClugstor
Copy link
Member

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Should hopefully fix SciML/NonlinearSolve.jl#581, in conjuction with #997

@ChrisRackauckas
Copy link
Member

Tests fail.

ChrisRackauckas added a commit to SciML/DiffEqBase.jl that referenced this pull request Apr 25, 2025
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like:

* SciML/SciMLBase.jl#997
* SciML/SciMLBase.jl#998

Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures.

So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
ChrisRackauckas added a commit that referenced this pull request Apr 25, 2025
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like:

* #997
* #998

Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures.

So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
T, N, uType, R, P, A, O, uType2, S, Tr}}, u,
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
function NonlinearSolutionAdjoint(ȳ)
(NoTangent(), ȳ.u, ntuple(_ -> NoTangent(), length(args))...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably track .prob as well (for gradients against parameters).

Also what is the type of ? What types could we encounter here? If it's another NonlinearSolution, we should investigate what produces it and potentially try to return a tangent type

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an mwe is

@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
       D(D(y)) ~ λ * y - g
       x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)

prob = ODEProblem(pend, [x => 1, y => 0], (0.0, 1.5), [g => 1.5], guesses ==> 1])

sol = solve(prob, Rodas5P())

get_vars = getsym(prob, [pend.x+pend.y]);

Zygote.gradient(sol) do sol
    u = get_vars(sol)
    # u = sol[pend.x+pend.y]
    sum(reduce(vcat, u))
end

which points to this branch

VA = recursivecopy(VA)
and
VA = recursivecopy(VA)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was always a NonlinearSolution when I was testing.

Copy link
Member Author

@jClugstor jClugstor May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just now getting back to this.
To track .prob, would it be like this:

function ChainRulesCore.rrule(
        ::Type{<:SciMLBase.NonlinearSolution{
            T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob,
        args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
    function NonlinearSolutionAdjoint(ȳ)
        (NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...)
    end
    SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...),
    NonlinearSolutionAdjoint
end

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this looks reasonable

T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob,
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
function NonlinearSolutionAdjoint(ȳ)
(NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow up, is the type of a solution type still with the latest SciMLSensitivity

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is a Tangent type with the latest SciMLSensitivity:

typeof(ȳ) = Tangent{Any}(u = 1.0, resid = ChainRulesCore.ZeroTangent(), prob = ChainRulesCore.ZeroTangent(), alg = ChainRulesCore.ZeroTangent(), retcode = ChainRulesCore.ZeroTangent(), original = ChainRulesCore.ZeroTangent(), left = ChainRulesCore.ZeroTangent(), right = ChainRulesCore.ZeroTangent(), stats = ChainRulesCore.ZeroTangent(), trace = ChainRulesCore.ZeroTangent())

@jClugstor jClugstor force-pushed the nonlinearsolution_adjoint branch from 031520d to d948f07 Compare May 14, 2025 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IntervalNonlinearProblem fails with Zygote
3 participants