Skip to content

Commit 37ee18e

Browse files
patrick-kidgerthibmonsel
authored andcommitted
In progress commit on branch delay.
1 parent f506ee5 commit 37ee18e

24 files changed

+12389
-105
lines changed

diffrax/__init__.py

Lines changed: 93 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,113 @@
11
import importlib.metadata
22

3-
from ._adjoint import (
4-
AbstractAdjoint as AbstractAdjoint,
5-
BacksolveAdjoint as BacksolveAdjoint,
6-
DirectAdjoint as DirectAdjoint,
7-
ImplicitAdjoint as ImplicitAdjoint,
8-
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
9-
)
10-
from ._autocitation import citation as citation, citation_rules as citation_rules
11-
from ._brownian import (
12-
AbstractBrownianPath as AbstractBrownianPath,
13-
UnsafeBrownianPath as UnsafeBrownianPath,
14-
VirtualBrownianTree as VirtualBrownianTree,
15-
)
16-
from ._event import (
17-
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
18-
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
19-
SteadyStateEvent as SteadyStateEvent,
20-
)
3+
from ._adjoint import AbstractAdjoint as AbstractAdjoint
4+
from ._adjoint import BacksolveAdjoint as BacksolveAdjoint
5+
from ._adjoint import DirectAdjoint as DirectAdjoint
6+
from ._adjoint import ImplicitAdjoint as ImplicitAdjoint
7+
from ._adjoint import RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint
8+
from ._autocitation import citation as citation
9+
from ._autocitation import citation_rules as citation_rules
10+
from ._brownian import AbstractBrownianPath
11+
from ._brownian import AbstractBrownianPath as AbstractBrownianPath
12+
from ._brownian import UnsafeBrownianPath
13+
from ._brownian import UnsafeBrownianPath as UnsafeBrownianPath
14+
from ._brownian import VirtualBrownianTree
15+
from ._brownian import VirtualBrownianTree as VirtualBrownianTree
16+
from ._delays import Delays as Delays
17+
from ._delays import bind_history as bind_history
18+
from ._delays import history_extrapolation_implicit as history_extrapolation_implicit
19+
from ._delays import maybe_find_discontinuity as maybe_find_discontinuity
20+
from ._event import AbstractDiscreteTerminatingEvent
21+
from ._event import AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent
22+
from ._event import DiscreteTerminatingEvent
23+
from ._event import DiscreteTerminatingEvent as DiscreteTerminatingEvent
24+
from ._event import SteadyStateEvent
25+
from ._event import SteadyStateEvent as SteadyStateEvent
2126
from ._global_interpolation import (
2227
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
28+
)
29+
from ._global_interpolation import CubicInterpolation as CubicInterpolation
30+
from ._global_interpolation import DenseInterpolation as DenseInterpolation
31+
from ._global_interpolation import LinearInterpolation as LinearInterpolation
32+
from ._global_interpolation import (
2333
backward_hermite_coefficients as backward_hermite_coefficients,
24-
CubicInterpolation as CubicInterpolation,
25-
DenseInterpolation as DenseInterpolation,
26-
linear_interpolation as linear_interpolation,
27-
LinearInterpolation as LinearInterpolation,
34+
)
35+
from ._global_interpolation import linear_interpolation as linear_interpolation
36+
from ._global_interpolation import (
2837
rectilinear_interpolation as rectilinear_interpolation,
2938
)
3039
from ._integrate import diffeqsolve as diffeqsolve
3140
from ._local_interpolation import (
3241
AbstractLocalInterpolation as AbstractLocalInterpolation,
42+
)
43+
from ._local_interpolation import (
3344
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
34-
LocalLinearInterpolation as LocalLinearInterpolation,
35-
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
3645
)
46+
from ._local_interpolation import LocalLinearInterpolation as LocalLinearInterpolation
47+
from ._local_interpolation import (
48+
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation,
49+
) # noqa: E501
3750
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
3851
from ._path import AbstractPath as AbstractPath
39-
from ._root_finder import (
40-
VeryChord as VeryChord,
41-
with_stepsize_controller_tols as with_stepsize_controller_tols,
42-
)
43-
from ._saveat import SaveAt as SaveAt, SubSaveAt as SubSaveAt
44-
from ._solution import (
45-
is_event as is_event,
46-
is_okay as is_okay,
47-
is_successful as is_successful,
48-
RESULTS as RESULTS,
49-
Solution as Solution,
50-
)
51-
from ._solver import (
52-
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
53-
AbstractDIRK as AbstractDIRK,
54-
AbstractERK as AbstractERK,
55-
AbstractESDIRK as AbstractESDIRK,
56-
AbstractImplicitSolver as AbstractImplicitSolver,
57-
AbstractItoSolver as AbstractItoSolver,
58-
AbstractRungeKutta as AbstractRungeKutta,
59-
AbstractSDIRK as AbstractSDIRK,
60-
AbstractSolver as AbstractSolver,
61-
AbstractStratonovichSolver as AbstractStratonovichSolver,
62-
AbstractWrappedSolver as AbstractWrappedSolver,
63-
Bosh3 as Bosh3,
64-
ButcherTableau as ButcherTableau,
65-
CalculateJacobian as CalculateJacobian,
66-
Dopri5 as Dopri5,
67-
Dopri8 as Dopri8,
68-
Euler as Euler,
69-
EulerHeun as EulerHeun,
70-
HalfSolver as HalfSolver,
71-
Heun as Heun,
72-
ImplicitEuler as ImplicitEuler,
73-
ItoMilstein as ItoMilstein,
74-
KenCarp3 as KenCarp3,
75-
KenCarp4 as KenCarp4,
76-
KenCarp5 as KenCarp5,
77-
Kvaerno3 as Kvaerno3,
78-
Kvaerno4 as Kvaerno4,
79-
Kvaerno5 as Kvaerno5,
80-
LeapfrogMidpoint as LeapfrogMidpoint,
81-
Midpoint as Midpoint,
82-
MultiButcherTableau as MultiButcherTableau,
83-
Ralston as Ralston,
84-
ReversibleHeun as ReversibleHeun,
85-
SemiImplicitEuler as SemiImplicitEuler,
86-
Sil3 as Sil3,
87-
StratonovichMilstein as StratonovichMilstein,
88-
Tsit5 as Tsit5,
89-
)
52+
from ._root_finder import VeryChord as VeryChord
53+
from ._root_finder import with_stepsize_controller_tols as with_stepsize_controller_tols
54+
from ._saveat import SaveAt as SaveAt
55+
from ._saveat import SubSaveAt as SubSaveAt
56+
from ._solution import RESULTS as RESULTS
57+
from ._solution import Solution as Solution
58+
from ._solution import is_event as is_event
59+
from ._solution import is_okay as is_okay
60+
from ._solution import is_successful as is_successful
61+
from ._solver import AbstractAdaptiveSolver as AbstractAdaptiveSolver
62+
from ._solver import AbstractDIRK as AbstractDIRK
63+
from ._solver import AbstractERK as AbstractERK
64+
from ._solver import AbstractESDIRK as AbstractESDIRK
65+
from ._solver import AbstractImplicitSolver as AbstractImplicitSolver
66+
from ._solver import AbstractItoSolver as AbstractItoSolver
67+
from ._solver import AbstractRungeKutta as AbstractRungeKutta
68+
from ._solver import AbstractSDIRK as AbstractSDIRK
69+
from ._solver import AbstractSolver as AbstractSolver
70+
from ._solver import AbstractStratonovichSolver as AbstractStratonovichSolver
71+
from ._solver import AbstractWrappedSolver as AbstractWrappedSolver
72+
from ._solver import Bosh3 as Bosh3
73+
from ._solver import ButcherTableau as ButcherTableau
74+
from ._solver import CalculateJacobian as CalculateJacobian
75+
from ._solver import Dopri5 as Dopri5
76+
from ._solver import Dopri8 as Dopri8
77+
from ._solver import Euler as Euler
78+
from ._solver import EulerHeun as EulerHeun
79+
from ._solver import HalfSolver as HalfSolver
80+
from ._solver import Heun as Heun
81+
from ._solver import ImplicitEuler as ImplicitEuler
82+
from ._solver import ItoMilstein as ItoMilstein
83+
from ._solver import KenCarp3 as KenCarp3
84+
from ._solver import KenCarp4 as KenCarp4
85+
from ._solver import KenCarp5 as KenCarp5
86+
from ._solver import Kvaerno3 as Kvaerno3
87+
from ._solver import Kvaerno4 as Kvaerno4
88+
from ._solver import Kvaerno5 as Kvaerno5
89+
from ._solver import LeapfrogMidpoint as LeapfrogMidpoint
90+
from ._solver import Midpoint as Midpoint
91+
from ._solver import MultiButcherTableau as MultiButcherTableau
92+
from ._solver import Ralston as Ralston
93+
from ._solver import ReversibleHeun as ReversibleHeun
94+
from ._solver import SemiImplicitEuler as SemiImplicitEuler
95+
from ._solver import Sil3 as Sil3
96+
from ._solver import StratonovichMilstein as StratonovichMilstein
97+
from ._solver import Tsit5 as Tsit5
9098
from ._step_size_controller import (
9199
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
92-
AbstractStepSizeController as AbstractStepSizeController,
93-
ConstantStepSize as ConstantStepSize,
94-
PIDController as PIDController,
95-
StepTo as StepTo,
96100
)
97-
from ._term import (
98-
AbstractTerm as AbstractTerm,
99-
ControlTerm as ControlTerm,
100-
MultiTerm as MultiTerm,
101-
ODETerm as ODETerm,
102-
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
101+
from ._step_size_controller import (
102+
AbstractStepSizeController as AbstractStepSizeController,
103103
)
104-
104+
from ._step_size_controller import ConstantStepSize as ConstantStepSize
105+
from ._step_size_controller import PIDController as PIDController
106+
from ._step_size_controller import StepTo as StepTo
107+
from ._term import AbstractTerm as AbstractTerm
108+
from ._term import ControlTerm as ControlTerm
109+
from ._term import MultiTerm as MultiTerm
110+
from ._term import ODETerm as ODETerm
111+
from ._term import WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm
105112

106113
__version__ = importlib.metadata.version("diffrax")

diffrax/_adjoint.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def loop(
119119
solver,
120120
stepsize_controller,
121121
discrete_terminating_event,
122+
delays,
122123
saveat,
123124
t0,
124125
t1,
@@ -550,13 +551,15 @@ def _loop_backsolve_bwd(
550551
solver,
551552
stepsize_controller,
552553
discrete_terminating_event,
554+
delays,
553555
saveat,
554556
t0,
555557
t1,
556558
dt0,
557559
max_steps,
558560
throw,
559561
init_state,
562+
y0_history,
560563
):
561564
assert discrete_terminating_event is None
562565

@@ -594,6 +597,8 @@ def _loop_backsolve_bwd(
594597
adjoint=self,
595598
solver=solver,
596599
stepsize_controller=stepsize_controller,
600+
discrete_terminating_event=discrete_terminating_event,
601+
delays=delays,
597602
terms=adjoint_terms,
598603
dt0=None if dt0 is None else -dt0,
599604
max_steps=max_steps,
@@ -773,6 +778,7 @@ def loop(
773778
passed_solver_state,
774779
passed_controller_state,
775780
discrete_terminating_event,
781+
delays,
776782
**kwargs,
777783
):
778784
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
@@ -818,6 +824,10 @@ def loop(
818824
raise NotImplementedError(
819825
"`diffrax.BacksolveAdjoint` is not compatible with events."
820826
)
827+
if delays is not None:
828+
raise NotImplementedError(
829+
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
830+
)
821831

822832
y = init_state.y
823833
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
@@ -832,6 +842,7 @@ def loop(
832842
init_state=init_state,
833843
solver=solver,
834844
discrete_terminating_event=discrete_terminating_event,
845+
delays=delays,
835846
**kwargs,
836847
)
837848
final_state = _only_transpose_ys(final_state)

0 commit comments

Comments
 (0)