From 10cc1d6d19d1e78253a3dce40234f89c31ffdead Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 12 Feb 2024 12:20:15 +0100 Subject: [PATCH] sq --- .../simulation/integrators_variable_step.py | 100 ++++++++++++------ 1 file changed, 66 insertions(+), 34 deletions(-) diff --git a/src/jaxsim/simulation/integrators_variable_step.py b/src/jaxsim/simulation/integrators_variable_step.py index f6476d806..dacb16723 100644 --- a/src/jaxsim/simulation/integrators_variable_step.py +++ b/src/jaxsim/simulation/integrators_variable_step.py @@ -8,14 +8,7 @@ from jaxsim import typing as jtp from jaxsim.physics.model.physics_model import PhysicsModel -from jaxsim.simulation.integrators import ( - State, - StateDerivative, - StateDerivativeCallable, - Time, - TimeHorizon, - TimeStep, -) +from jaxsim.simulation.integrators import Time, TimeHorizon, TimeStep from jaxsim.simulation.ode_data import ODEState from jaxsim.sixd import so3 @@ -27,6 +20,15 @@ BETA_MAX_DEFAULT = 2.5 MAX_STEP_REJECTIONS_DEFAULT = 5 +# Contrarily to the fixed-step integrators that operate on generic PyTrees, +# these variable-step integrators operate only on arrays (that could be the +# flatted PyTree). +State = jtp.Vector +StateDerivative = jtp.Vector +StateDerivativeCallable = Callable[ + [State, Time], tuple[StateDerivative, dict[str, Any]] +] + class AdaptiveIntegratorType(enum.IntEnum): HeunEuler = enum.auto() @@ -63,7 +65,7 @@ def initial_step_size( x0: State, t0: Time, f: StateDerivativeCallable, - order: jtp.IntLike = 2, + order: jtp.IntLike, rtol: jtp.FloatLike = RTOL_DEFAULT, atol: jtp.FloatLike = ATOL_DEFAULT, ) -> tuple[jtp.Float, StateDerivative]: @@ -89,6 +91,7 @@ def initial_step_size( E. Hairer, S. P. Norsett G. Wanner. """ + # Compute the state derivative at the initial state. ẋ0 = f(x0, t0)[0] # Scale the initial state and its derivative. @@ -140,9 +143,10 @@ def scale_array( the local integration error. """ - # Use a zeroed second state if not provided + # Use a zeroed second state if not provided. x2 = x2 if x2 is not None else jnp.zeros_like(x1) + # Return: atol + max(|x1|, |x2|) * rtol. return ( atol + jnp.vstack( @@ -169,8 +173,8 @@ def error_local( Args: x0: The initial state $x(t_0)$. xf: The final state $x(t_f)$. - error_estimate: The optional error estimate. In not given, it is computed - as the difference between the final and initial states. + error_estimate: The optional error estimate. In not given, it is computed as the + absolute value of the difference between the final and initial states. rtol: The relative tolerance to scale the state. atol: The absolute tolerance to scale the state. norm_ord: The norm to use to compute the error. Default is the infinity norm. @@ -183,18 +187,13 @@ def error_local( sc = scale_array(x1=x0, x2=xf, rtol=rtol, atol=atol) # Compute the error estimate if not given. - error_estimate = error_estimate if error_estimate is not None else xf - x0 + error_estimate = error_estimate if error_estimate is not None else jnp.abs(xf - x0) # Then, compute the local error by properly scaling the given error estimate and apply # the desired norm (default is infinity norm, that is the maximum absolute value). return jnp.linalg.norm(error_estimate / sc, ord=norm_ord) -# ======================= -# Runge-Kutta Integrators -# ======================= - - @functools.partial(jax.jit, static_argnames=["f"]) def runge_kutta_from_butcher_tableau( x0: State, @@ -205,35 +204,59 @@ def runge_kutta_from_butcher_tableau( c: jax.Array, b: jax.Array, A: jax.Array, - f0: StateDerivative | None = None, + dxdt0: StateDerivative | None = None, ) -> tuple[jax.Array, jax.Array, jax.Array | float, dict[str, Any]]: - """""" + """ + Advance a state vector by integrating a system dynamics with a Runge-Kutta integrator. + + Args: + x0: The initial state. + t0: The initial time. + dt: The integration time step. + f: The state derivative function :math:`f(x, t)`. + c: The :math:`\mathbf{c}` parameter of the Butcher tableau. + b: The :math:`\mathbf{b}` parameter of the Butcher tableau. + A: The :math:`\mathbf{A}` parameter of the Butcher tableau. + dxdt0: The optional pre-computed state derivative at the + initial :math:`(x_0, t_0)`, useful for FSAL schemes. + + Returns: + A tuple containing the next state, the intermediate states :math:`\mathbf{k}_i`, + the error estimate, and the auxiliary dictionary returned by `f`. + + Note: + If `b.T` has multiple rows (used e.g. in embedded Runge-Kutta methods), the first + returned argument is a 2D array having as many rows as `b.T`. Each i-th row + corresponds to the solution computed with coefficients of the i-th row of `b.T`. + """ # Adjust sizes of Butcher tableau arrays. c = jnp.atleast_1d(c.squeeze()) b = jnp.atleast_2d(b.squeeze()) A = jnp.atleast_2d(A.squeeze()) - h = dt + # Use a symbol for the time step. + Δt = dt # Initialize the carry of the for loop with the stacked kᵢ vectors. carry0 = jnp.zeros(shape=(c.size, x0.size), dtype=float) - # Allow FSAL (first-same-as-last) property by passing f0 = f(x0, t0) from + # Allow FSAL (first-same-as-last) property by passing ẋ0 = f(x0, t0) from # the previous iteration. - get_ẋ0 = lambda: f0 if f0 is not None else f(x0, t0)[0] + get_ẋ0 = lambda: dxdt0 if dxdt0 is not None else f(x0, t0)[0] - # We use a `jax.lax.scan` to have only a single instance of the compiled f function. + # We use a `jax.lax.scan` to have only a single instance of the compiled `f` function. # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code - # would include 4 repetitions of the f logic, making everything extremely slow. + # would include 4 repetitions of the `f` logic, making everything extremely slow. def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[Any, None]: """""" + # Unpack the carry k = carry def compute_ki(): - xi = x0 + h * jnp.dot(A[i, :], k) - ti = t0 + c[i] * h + xi = x0 + Δt * jnp.dot(A[i, :], k) + ti = t0 + c[i] * Δt return f(xi, ti)[0] # This selector enables FSAL property in the first iteration (i=0). @@ -253,11 +276,18 @@ def compute_ki(): xs=jnp.arange(c.size), ) - # Compute the output state and the error estimate. + # Compute the output state. # Note that z contains as many new states as the rows of `b.T`. - z = x0 + h * jnp.dot(b.T, k) - error_estimate = dt * jnp.dot(b.T[-1] - b.T[0], k) + z = x0 + Δt * jnp.dot(b.T, k) + # Compute the error estimate if `b.T` has multiple rows, otherwise return 0. + error_estimate = jax.lax.select( + pred=b.T.shape[0] == 1, + on_true=jnp.array(0.0, dtype=float), + on_false=dt * jnp.dot(b.T[-1] - b.T[0], k), + ) + + # TODO: populate the auxiliary dictionary return z, k, error_estimate, dict() @@ -405,9 +435,9 @@ def bogacki_shampine( return x_next, z_order, error, aux_dict -# ======================================= -# Variable-step integrators (single step) -# ======================================= +# ========================================== +# Variable-step RK integrators (single step) +# ========================================== @functools.partial( @@ -460,7 +490,7 @@ def odeint_embedded_rk_one_step( Δt0, ẋ0 = jax.lax.cond( pred=jnp.where(dt0 is None, 0.0, dt0) == 0.0, true_fun=lambda _: initial_step_size( - x0=x0, t0=t0, f=f, order=q, atol=atol, rtol=rtol + x0=x0, t0=t0, f=f, order=p, atol=atol, rtol=rtol ), false_fun=lambda _: (dt0, f(x0, t0)[0]), operand=None, @@ -744,7 +774,9 @@ def tf_next_state(x0: State, xf: State, t0: Time, dt: TimeStep) -> State: static_argnames=[ "f", "odeint_adaptive_one_step", + "integrator_type", "debug_buffers_size_per_step", + "tf_next_state", ], ) def _ode_integration_adaptive_template(