From 4914b2db030f3ca96cc40f1344e98978d51ec328 Mon Sep 17 00:00:00 2001 From: Krzysztof Rusek Date: Fri, 3 Nov 2023 17:23:13 +0100 Subject: [PATCH 1/3] Add RATTLE solver --- diffrax/__init__.py | 1 + diffrax/solver/__init__.py | 1 + diffrax/solver/rattle.py | 107 +++++++++++++++++++++++++++++++++++++ test/test_solver.py | 30 +++++++++++ 4 files changed, 139 insertions(+) create mode 100644 diffrax/solver/rattle.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d8aca8d4..7e33c4bf 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -71,6 +71,7 @@ Midpoint, MultiButcherTableau, Ralston, + Rattle, ReversibleHeun, SemiImplicitEuler, Sil3, diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index ace213c4..77652d35 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -24,6 +24,7 @@ from .midpoint import Midpoint from .milstein import ItoMilstein, StratonovichMilstein from .ralston import Ralston +from .rattle import Rattle from .reversible_heun import ReversibleHeun from .runge_kutta import ( AbstractDIRK, diff --git a/diffrax/solver/rattle.py b/diffrax/solver/rattle.py new file mode 100644 index 00000000..c4821287 --- /dev/null +++ b/diffrax/solver/rattle.py @@ -0,0 +1,107 @@ +from typing import Tuple, NamedTuple, Callable +import jax +import jax.numpy as jnp +import jax.tree_util as jtu + +from equinox.internal import ω + +from ..custom_types import Bool, DenseInfo, PyTree, Scalar +from ..local_interpolation import LocalLinearInterpolation +from ..solution import RESULTS +from ..term import AbstractTerm +from .base import AbstractImplicitSolver + +_ErrorEstimate = None +_SolverState = None + + +class RattleVars(NamedTuple): + p_1_2: PyTree # Midpoint momentum + q_1: PyTree # Midpoint position + p_1: PyTree # final momentum + lam: PyTree # Midpoint Lagrange multiplier (state) + mu: PyTree # final Lagrange multiplier (momentum) + + +class Rattle(AbstractImplicitSolver): + """ Rattle method. + + 2nd order symplectic method with constrains. + + ??? cite "Reference" + + ```bibtex + @article{ANDERSEN198324, + title = {Rattle: A “velocity” version of the shake algorithm for molecular dynamics calculations}, + journal = {Journal of Computational Physics}, + volume = {52}, + number = {1}, + pages = {24-34}, + year = {1983}, + author = {Hans C Andersen}, + } + ``` + """ + + term_structure = (AbstractTerm, AbstractTerm) + interpolation_cls = LocalLinearInterpolation + # Fix TypeError: non-default argument 'constrain' follows default argument + constrain: Callable = None + + def order(self, terms): + return 2 + + def init(self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, t1: Scalar, y0: PyTree, + args: PyTree, ) -> _SolverState: + return None + + def step(self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, t1: Scalar, y0: Tuple[PyTree, PyTree], + args: PyTree, solver_state: _SolverState, made_jump: Bool, ) -> Tuple[ + Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + term_1, term_2 = terms + midpoint = (t1 + t0) / 2 + + control1_half_1 = term_1.contr(t0, midpoint) + control1_half_2 = term_1.contr(midpoint, t1) + + control2_half_1 = term_2.contr(t0, midpoint) + control2_half_2 = term_2.contr(midpoint, t1) + + p0, q0 = y0 + + def eq(x: RattleVars, args=None): + _, vjp_fun = jax.vjp(self.constrain, q0) + _, vjp_fun_mu = jax.vjp(self.constrain, x.q_1) + + zero = ((p0 ** ω - control1_half_1 * (vjp_fun(x.lam)[0]) ** ω + term_1.vf_prod(t0, q0, args, + control1_half_1) ** ω - x.p_1_2 ** ω).ω, + (q0 ** ω + term_2.vf_prod(t0, x.p_1_2, args, control2_half_1) ** ω + term_2.vf_prod(midpoint, + x.p_1_2, args, + control2_half_2) ** ω - x.q_1 ** ω).ω, + self.constrain(x.q_1), ( + x.p_1_2 ** ω + term_1.vf_prod(midpoint, x.q_1, args, control1_half_2) ** ω - ( + control1_half_2 * vjp_fun_mu(x.mu)[0] ** ω) - x.p_1 ** ω).ω, + jax.jvp(self.constrain, (x.q_1,), (term_2.vf(t1, x.p_1, args),))[1]) + return zero + + cs = jax.eval_shape(self.constrain, q0) + + init_vars = RattleVars(p_1_2=p0, q_1=(q0 ** ω * 2).ω, p_1=p0, + lam=jtu.tree_map(jnp.zeros_like, cs), + mu=jtu.tree_map(jnp.zeros_like, cs)) + + sol = self.nonlinear_solver(eq, init_vars, None) + + y1 = (sol.root.p_1, sol.root.q_1) + dense_info = dict(y0=y0, y1=y1) + return y1, None, dense_info, None, RESULTS.successful + + def func(self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, y0: Tuple[PyTree, PyTree], args: PyTree) -> \ + Tuple[PyTree, PyTree]: + term_1, term_2 = terms + y0_1, y0_2 = y0 + f1 = term_1.func(t0, y0_2, args) + f2 = term_2.func(t0, y0_1, args) + return (f1, f2) diff --git a/test/test_solver.py b/test/test_solver.py index 67ed0701..53255a6c 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -470,3 +470,33 @@ def vector_field(t, y, args): return out.ys[0] f(1.0) + + +def test_rattle(): + import numpy as np + def constrain(q): + return jnp.sqrt(jnp.sum(q**2, keepdims=True))-1. + + rat = diffrax.Rattle(nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-4, atol=1e-6),constrain=constrain) + + + def H(p, q): + del q + return p @ p.T / 2. + + # V = p^2/2m m=1, v=1 + + + terms = (diffrax.ODETerm(lambda t, q, args: -jax.grad(H, argnums=1)(jnp.zeros_like(q), q)), + diffrax.ODETerm(lambda t, p, args: jax.grad(H, argnums=0)(p, jnp.zeros_like(p))) + ) + #p,q + y0 = (jnp.asarray([1.,0.]), jnp.asarray([0.,1.]) ) + t1 = 2*jnp.pi/4 + n=2**12 + dt = t1/n + saveat = diffrax.SaveAt(t1=True) + solution = diffrax.diffeqsolve(terms,rat,0.0,t1,dt0=dt,y0=y0,saveat=saveat) + p1,q1 = solution.ys + assert np.allclose(p1,jnp.asarray([0.,-1.]), rtol=1e-4, atol=1e-4) + assert np.allclose(q1, jnp.asarray([1., 0.]), rtol=1e-4, atol=1e-4) From c8394607b79774c6182eed1dec5974c5a0cb87ae Mon Sep 17 00:00:00 2001 From: Krzysztof Rusek Date: Fri, 3 Nov 2023 17:32:21 +0100 Subject: [PATCH 2/3] add note to rattle test --- test/test_solver.py | 52 ++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/test/test_solver.py b/test/test_solver.py index 53255a6c..ff51c375 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -473,30 +473,38 @@ def vector_field(t, y, args): def test_rattle(): - import numpy as np - def constrain(q): - return jnp.sqrt(jnp.sum(q**2, keepdims=True))-1. + import numpy as np - rat = diffrax.Rattle(nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-4, atol=1e-6),constrain=constrain) + def constrain(q): + return jnp.sqrt(jnp.sum(q**2, keepdims=True)) - 1.0 + rat = diffrax.Rattle( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-4, atol=1e-6), + constrain=constrain, + ) - def H(p, q): - del q - return p @ p.T / 2. - - # V = p^2/2m m=1, v=1 + # Potential free movement on a circle + def H(p, q): + del q + return p @ p.T / 2.0 + # V = p^2/2m m=1, v=1 - terms = (diffrax.ODETerm(lambda t, q, args: -jax.grad(H, argnums=1)(jnp.zeros_like(q), q)), - diffrax.ODETerm(lambda t, p, args: jax.grad(H, argnums=0)(p, jnp.zeros_like(p))) - ) - #p,q - y0 = (jnp.asarray([1.,0.]), jnp.asarray([0.,1.]) ) - t1 = 2*jnp.pi/4 - n=2**12 - dt = t1/n - saveat = diffrax.SaveAt(t1=True) - solution = diffrax.diffeqsolve(terms,rat,0.0,t1,dt0=dt,y0=y0,saveat=saveat) - p1,q1 = solution.ys - assert np.allclose(p1,jnp.asarray([0.,-1.]), rtol=1e-4, atol=1e-4) - assert np.allclose(q1, jnp.asarray([1., 0.]), rtol=1e-4, atol=1e-4) + terms = ( + diffrax.ODETerm( + lambda t, q, args: -jax.grad(H, argnums=1)(jnp.zeros_like(q), q) + ), + diffrax.ODETerm( + lambda t, p, args: jax.grad(H, argnums=0)(p, jnp.zeros_like(p)) + ), + ) + # p,q + y0 = (jnp.asarray([1.0, 0.0]), jnp.asarray([0.0, 1.0])) + t1 = 2 * jnp.pi / 4 + n = 2**12 + dt = t1 / n + saveat = diffrax.SaveAt(t1=True) + solution = diffrax.diffeqsolve(terms, rat, 0.0, t1, dt0=dt, y0=y0, saveat=saveat) + p1, q1 = solution.ys + assert np.allclose(p1, jnp.asarray([0.0, -1.0]), rtol=1e-4, atol=1e-4) + assert np.allclose(q1, jnp.asarray([1.0, 0.0]), rtol=1e-4, atol=1e-4) From f2b432b03a92e63c70e5bfffac587aba559e1d8f Mon Sep 17 00:00:00 2001 From: Krzysztof Rusek Date: Fri, 3 Nov 2023 17:38:15 +0100 Subject: [PATCH 3/3] typing constrain function --- diffrax/solver/rattle.py | 91 ++++++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/diffrax/solver/rattle.py b/diffrax/solver/rattle.py index c4821287..5a1e484a 100644 --- a/diffrax/solver/rattle.py +++ b/diffrax/solver/rattle.py @@ -1,19 +1,22 @@ -from typing import Tuple, NamedTuple, Callable +from typing import Callable, NamedTuple, Tuple + import jax import jax.numpy as jnp import jax.tree_util as jtu - from equinox.internal import ω -from ..custom_types import Bool, DenseInfo, PyTree, Scalar +from ..custom_types import Array, Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation from ..solution import RESULTS from ..term import AbstractTerm from .base import AbstractImplicitSolver + _ErrorEstimate = None _SolverState = None +ConstrainFn = Callable[[PyTree], Array] + class RattleVars(NamedTuple): p_1_2: PyTree # Midpoint momentum @@ -24,15 +27,16 @@ class RattleVars(NamedTuple): class Rattle(AbstractImplicitSolver): - """ Rattle method. + """Rattle method. - 2nd order symplectic method with constrains. + 2nd order symplectic method with constrains `constrain(x)=0`. ??? cite "Reference" ```bibtex @article{ANDERSEN198324, - title = {Rattle: A “velocity” version of the shake algorithm for molecular dynamics calculations}, + title = {Rattle: A “velocity” version of the shake + algorithm for molecular dynamics calculations}, journal = {Journal of Computational Physics}, volume = {52}, number = {1}, @@ -46,18 +50,31 @@ class Rattle(AbstractImplicitSolver): term_structure = (AbstractTerm, AbstractTerm) interpolation_cls = LocalLinearInterpolation # Fix TypeError: non-default argument 'constrain' follows default argument - constrain: Callable = None + constrain: ConstrainFn = None def order(self, terms): return 2 - def init(self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, t1: Scalar, y0: PyTree, - args: PyTree, ) -> _SolverState: + def init( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + t1: Scalar, + y0: PyTree, + args: PyTree, + ) -> _SolverState: return None - def step(self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, t1: Scalar, y0: Tuple[PyTree, PyTree], - args: PyTree, solver_state: _SolverState, made_jump: Bool, ) -> Tuple[ - Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + def step( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + t1: Scalar, + y0: Tuple[PyTree, PyTree], + args: PyTree, + solver_state: _SolverState, + made_jump: Bool, + ) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del solver_state, made_jump term_1, term_2 = terms @@ -75,22 +92,39 @@ def eq(x: RattleVars, args=None): _, vjp_fun = jax.vjp(self.constrain, q0) _, vjp_fun_mu = jax.vjp(self.constrain, x.q_1) - zero = ((p0 ** ω - control1_half_1 * (vjp_fun(x.lam)[0]) ** ω + term_1.vf_prod(t0, q0, args, - control1_half_1) ** ω - x.p_1_2 ** ω).ω, - (q0 ** ω + term_2.vf_prod(t0, x.p_1_2, args, control2_half_1) ** ω + term_2.vf_prod(midpoint, - x.p_1_2, args, - control2_half_2) ** ω - x.q_1 ** ω).ω, - self.constrain(x.q_1), ( - x.p_1_2 ** ω + term_1.vf_prod(midpoint, x.q_1, args, control1_half_2) ** ω - ( - control1_half_2 * vjp_fun_mu(x.mu)[0] ** ω) - x.p_1 ** ω).ω, - jax.jvp(self.constrain, (x.q_1,), (term_2.vf(t1, x.p_1, args),))[1]) + zero = ( + ( + p0**ω + - control1_half_1 * (vjp_fun(x.lam)[0]) ** ω + + term_1.vf_prod(t0, q0, args, control1_half_1) ** ω + - x.p_1_2**ω + ).ω, + ( + q0**ω + + term_2.vf_prod(t0, x.p_1_2, args, control2_half_1) ** ω + + term_2.vf_prod(midpoint, x.p_1_2, args, control2_half_2) ** ω + - x.q_1**ω + ).ω, + self.constrain(x.q_1), + ( + x.p_1_2**ω + + term_1.vf_prod(midpoint, x.q_1, args, control1_half_2) ** ω + - (control1_half_2 * vjp_fun_mu(x.mu)[0] ** ω) + - x.p_1**ω + ).ω, + jax.jvp(self.constrain, (x.q_1,), (term_2.vf(t1, x.p_1, args),))[1], + ) return zero cs = jax.eval_shape(self.constrain, q0) - init_vars = RattleVars(p_1_2=p0, q_1=(q0 ** ω * 2).ω, p_1=p0, - lam=jtu.tree_map(jnp.zeros_like, cs), - mu=jtu.tree_map(jnp.zeros_like, cs)) + init_vars = RattleVars( + p_1_2=p0, + q_1=(q0**ω * 2).ω, + p_1=p0, + lam=jtu.tree_map(jnp.zeros_like, cs), + mu=jtu.tree_map(jnp.zeros_like, cs), + ) sol = self.nonlinear_solver(eq, init_vars, None) @@ -98,8 +132,13 @@ def eq(x: RattleVars, args=None): dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful - def func(self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, y0: Tuple[PyTree, PyTree], args: PyTree) -> \ - Tuple[PyTree, PyTree]: + def func( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + y0: Tuple[PyTree, PyTree], + args: PyTree, + ) -> Tuple[PyTree, PyTree]: term_1, term_2 = terms y0_1, y0_2 = y0 f1 = term_1.func(t0, y0_2, args)