Skip to content

Conversation

sammccallum
Copy link

Re-opening #593.

Implements AbstractReversibleSolver base class and ReversibleAdjoint for reversible back propagation.

This updates SemiImplicitEuler, LeapfrogMidpoint and ReversibleHeun to subclass AbstractReversibleSolver.

Implementation

AbstractReversibleSolver subclasses AbstractSolver and adds a backward_step method:

@abc.abstractmethod
def backward_step(
    self,
    terms: PyTree[AbstractTerm],
    t0: RealScalarLike,
    t1: RealScalarLike,
    y1: Y,
    args: Args,
    solver_state: _SolverState,
    made_jump: BoolScalarLike,
) -> tuple[Y, DenseInfo, _SolverState]:

This method should reconstruct y0, solver_state at t0 from y1, solver_state at t1. See the aforementioned solvers for examples.

When backpropagating, ReversibleAdjoint uses this backward_step to reconstruct state. We then take a vjp through a local forward step and accumulate gradients.

ReversibleAdjoint now also pulls back gradients from any interpolated values, so we can use SaveAt(ts=...)!

We allow arbitrary solver_state (provided it can be reconstructed reversibly) and calculate gradients w.r.t. solver_state. Finally, we pull back these gradients onto y0, args, terms using the solver.init method.

Riccardo231 and others added 30 commits February 8, 2025 22:44
* _integrate.py

* Added new test checking gradient of vmapped diffeqsolve

* Import optimistix

* Fixed issue

* added .any()

* diffrax root finder
in python-poetry ~=3.9 is interpreted as >=3.9<3.10 [2], though it should be >=3.9,<4.0
[2] https://python-poetry.org/docs/dependency-specification/
merge changes from AbstractReversibleSolver
@sammccallum
Copy link
Author

I've also added the Reversible RK solvers here which just subclasses AbstractReversibleAdjoint. Let me know what you think of this and I can add some documentation when it's good to go!

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, gosh, this one took far too long for me to get around. Thank you for your patience! If I can I'd like this to be the next big thing I focus on getting in to Diffrax.

@@ -105,6 +106,7 @@
MultiButcherTableau as MultiButcherTableau,
QUICSORT as QUICSORT,
Ralston as Ralston,
Reversible as Reversible,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this something more specific to help distinguish what it is! I think here it's not clear that it's a solver, and even if it was ReversibleSolver then that still wouldn't disambiguate amongst the various kinds of reversibility it's possible to cook up.

What do you call this yourself / in your paper? Its structure is Hamiltonian/coupled/etc so maybe a word of that sort is suitable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call will look like this:

solver = diffrax.Reversible(diffrax.Tsit5())

with the idea being that you are "reversifying" Tsit5. So it isn't a solver in itself, but a wrapper. The boring name could be something like ReversibleWrapper and the fun name could be something like Reversify. Thoughts?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but in the future we may cook up some other way of reversifying a solver! We should pick a name for this one that leaves that possibility open.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

James has gone for U-Reversible (after the Uno reverse card :). The analogy is that we take a step forward from z0 to y1, then reverse and pull back from y1 onto z1.

reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor bug here: if it just so happens that we run with t0 == t1 then we'll end up with reversible_ts = [t0 inf inf inf ...], which will not produce desired results in the backward solve.

We have a special branch to handle the saving in the t0 == t1 case, we should add a line handling the state.reversible_ts is not None case there.


# Pull solver_state gradients back onto y0, args, terms.

_, init_vjp = eqx.filter_vjp(solver.init, terms, ts[0], ts[1], y0, args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super clear to me that ts[0], ts[1] are the correct values here. It looks like the saving routine is storing tprev, which is not necessarily the same as state.tnext, and the latter is what solver.init was originally called with.

In principle the step size controller could return anything at all; in practice it is possible for tprev to be 2 ULPs greater than state.tnext when passing to the other side of a discontinuity.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ts here are reversible_ts which follows the same logic as SaveAt(steps=True).

Is it not the case that the state.tnext identified in diffeqsolve (and used for solver.init) has to be the first step that the solver took? I appreciate that they can be different at later points in the solve, but my understanding was that the first step was set in diffeqsolve?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think in this case you're saving the tprev of the second step, not the tnext of the first.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right.

We now return the tprev and tnext passed to solver.init as a residual in the reversible loop and use these to get the vjp.

Comment on lines 1409 to 1415
# Reversible info
if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thought of an alternative for this extra buffer, btw: ReversibleAdjoint.loop could intercept saveat and add an SubSaveAt(steps=True, save_fn=lambda t, y, args: None) to record the extra times. Then peel it off again when returning the final state.

I think that (a) might be doable without making any changes to _integrate.py and (b) would allow for also supporting SaveAt(steps=True). (As in that case we can just skip adding the extra SubSaveAt.) And (c) would avoid a few of the subtle issues I've commented on above about exactly which tprev/tnext-like value is actually being saved, because you can trust in the rest of the existing diffeqsolve to do that for you.

It's not a strong suggestion though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was the original idea I tried but I couldn't get around a leaked tracer error! I'm willing to give it another go if you start feeling strongly about it though ;)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's nail everything else down and then consider this. Reflecting on this, I do suspect it will make the code much easier to maintain in the long run.

y1 = (self.coupling_parameter * (ω(y0) - ω(z0)) + ω(step_z0)).ω

step_y1, y_error, _, _, result2 = self.solver.step(
terms, t1, t0, y1, args, original_solver_state, True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just spotted this has t0 and t1 back-to-front, which I think may in general mess with our solving logic as described previously. Is this intended / is it possible not to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is intended and is essential to the solver!

@sammccallum sammccallum force-pushed the AbstractReversibleSolver branch from 0cfd4ec to 3a26ac3 Compare May 14, 2025 10:14
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you were ready for a review on this yet, but I took a look over anyway 😁 We're making really good progress! In particular now that we're settled on just AbstractERK then I think all our complicated state-reconstruction concerns go away, so the chance of footgunning ourselves has gone way down 😁

# (i.e. the state used on the forward). Otherwise, in `ReversibleAdjoint`,
# we would take a local forward step from an incorrect `solver_state`.
solver_state = jax.lax.cond(
tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this predicate is assuming we're solving over the time interval of the form [0, T]? But in practice we might have that the left endpoint is nonzero.

This aside it's (a) possible to use jnp.where, which is lower-overhead for small stuff like this, but also (b) I think lax.cond(pred, lambda: foo, lambda: bar) should work, without arguments.

Comment on lines 70 to 73
# We pre-compute the step size to avoid numerical instability during the
# backward_step. This is okay (albeit slightly ugly) as `LeapfrogMidpoint` can't
# be used with adaptive step sizes.
dt = t1 - t0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will mean we silently get the wrong results if used with diffrax.StepTo with non-constant step size.

I think we might be able to save this by adjusting the backward step logic, so that it computes y0 in the step that returns it, rather than in the step before:

def backward_step(...):
    t2, y2 = solver_state
    control = terms.contr(t0, t2)
    y0 = (y2**ω - terms.vf_prod(t0, y1, args, control) ** ω).ω
    solver_state = (t1, y1)
    return y0, ...

note that this is also essentially the same idea as what we currently do in the forward step. I think this might require some detail around how to handle the first and last step, still.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's much nicer! This should work, but as you say, requires special handling of the first and last step. I've had a think about how to do this (and the point above about assuming [0, T]) and I believe it's not possible without knowing where the start and end point of the solve is (or at least when we're at the first/last step)!

This feels like it's questioning the attempt to use a single-step api to represent multi-step solvers - hinting at this line:

# TODO: support arbitrary linear multistep methods

I have a few ideas for how to expand the leapfrog midpoint api to support this (e.g. have first_step, last_step flags or hold t0, t1 as attributes) but I'm not sure these are elegant enough for diffrax lol... wdyt?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think a few points.

On having a method for handling the last step

So one idea that has previously come to mind is to implement reversibility via a solver.adjoint_solver() -> AbstractSolver, which then has a .init() method that can be called at the start of its solve i.e. the end of the time interval.

This is mostly just an API convenience, it just gives us a natural place to do last-step handling without introducing a new method. It also feels kind of nice to be able to say e.g. Euler.adjoint_solver() -> ImplicitEuler.

I am here using 'adjoint' in the same sense as in Hairer&Wanner, referring to the solver that undoes the effect of the original solver. At this point I am not talking about backprop in any flavour; just the reconstruction.

This approach might actually also seem to lean into doing the backward pass via diffeqsolve(solver=adjoint_solver, stepsize_controller=StepTo(ts=reversible_ts), ...), but I don't think that's flexible enough to handle backprop through the interpolation, so that might not be desirable.

On handling the first step

...this bit I'm less sure about 😁 In principle it could be done by having a solver_state for the adjoint solver described above, and having that do step-counting. This seems finickity in several ways. I think it is also natural to wonder if this would cleanly extend to linear multistep methods with k steps of setup, without just introducing a big lax.switch over all the branches, yuck. (Since I could see the forward pass could probably be implemented without that, with a bit of trickery similar to how we handle the loop over the k stages of an RK method.)

On linear multistep methods in general

So I think every time I've seen these implemented, it has still been with a single-step-and-state interface. For example when implementing the first k steps of an Adams-Bashforth-Moulton method then you gradually build up the information you need in a circular buffer.

That said, this isn't an argument that I would take too strongly! Diffrax does a lot ot things that no-one else does, no reason why we couldn't invent something different here. I'd be open to trying out something else.


Overall: I have probably helped you very little 😁 I'm not 100% sure how to tackle this either. LMK what you find?

Copy link
Author

@sammccallum sammccallum Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sorry for taking so long to get back to this Patrick...

The good news is that we can handle the first step (last step of reverse solve) easily by storing the t0 that we were initialised with. The bad news is that I'm convinced it is impossible (for a single-step-and-state implementation of leapfrog) to handle "true" algebraic reversibility (i.e. reconstruct state going backwards) when $t_{i+1} - t_{i}\neq t_i-t_{i-1}$. That is, where the step size is not equal between consecutive steps. We would therefore have to disallow non-constant step sizes for leapfrog.

This is probably not too much of an issue as I imagine there are very few use cases for this being otherwise. But, are you okay with this? If so, I think we may not need to go down the multi-step rabbit hole.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, just realised that I completely forgot to write back.

So conceptually I think that's reasonable. I think this would only occur if using solver=LeapfrogMidpoint(), stepsize_controller=StepTo(...). We don't really have a way to block that off API-wise, however -- the only opportunity the solver is given to reject is working with adaptive stepsize controllers (by returning a None for the error estimate). Do you have a description of what goes wrong? I had believed this was possible.

(I was originally thinking that maybe we could avoid this simply by not making LeapfrogMidpoint be reversible, but I think that might end up being a bit of a footgun -- this may block us off from implementing e.g. reversible symplectic methods in the future.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand you correctly this is the central issue:

but have no way of reconstructing tm1 without knowing what the step size was, tm1 = t0 - dt.

Except, I think we do know this? All the times were recorded into reversible_ts for exactly this reason.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we know all the times but have no way of passing that information through the AbstractSolver.step API.

Potentially this could be solved by a MultiStepControl term where the times at all steps taken within an interval can be returned.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However this is a brand-new backward_step API ^^ Let's make that be what it takes to make that work!


Practically speaking maybe that just means a backward_init and a backward_step (to match the forward init and step). Perhaps make this available via backward_init(reversible_ts, ...) and then let that put this in some backward_state if the solver will need that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think I've settled on a simple and relatively non-intrusive solution. backward_step now takes a ts_state argument which I am imagining will act as a place for "any time state that is required to reverse a step". Currently this just holds tm1 to handle LeapfrogMidpoint but could be extended if required.

We can now use StepTo with LeapfrogMidpoint.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Let me know when you next want me to do a review of this branch 😁

```
"""

solver: AbstractERK
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're special-casing to just this, then I am much more comfortable with the logic below!

raise ValueError(
"`UReversible` is only compatible with `AbstractERK` base solvers."
)
original_solver_init = self.solver.init(terms, t0, t1, y0, args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC we're always going to be using the non-FSAL version of AbstractERK here?

If so then we get to sidstep all the painful solver_state difficulties that we've been debating back-and-forth because this will just always be None?

In which case can we have an assert original_solver_init is None here, and likewise elide it throughout the rest of the logic below? (If need be we can also add a flag to AbstractRungeKutta to force non-FSAL-ness, and set that here.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct. We are always using an AbstractERK with made_jump=True but we haven't explicitly turned off the fsal flag. Before we switched to made_jump we were turning off fsal by:

object.__setattr__(self.solver.tableau, "fsal", False)

But this modifies solver in place which is not ideal. This is to say that original_solver_init is not None in the current implementation. It would be nice to set fsal=False so that we can pass around the None throughout.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like this should do the trick then:

def __init__(self, solver: AbstractERK):
    self.solver = eqx.tree_at(lambda s: s.disable_fsal, solver, True)

And then add this flag to AbstractRungeKutta here:

scan_kind: None | Literal["lax", "checkpointed", "bounded"] = None

Which takes effect here:

fsal = fsal and not vf_expensive

@patrick-kidger patrick-kidger changed the base branch from main to dev June 16, 2025 22:39
@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 16, 2025

Heads-up that I've just updated the base branch to dev. It looks like there are a number of old commits sitting around on this PR, likely from where this branch forked off of main. You should be able to resolve these by first (a) squashing all the commits that actually belong on this branch together, and then (b) rebasing that new single commit on top of dev.

(Unrelatedly, lmk when this branch is ready for review.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants