From bbf8f832a023a6955642e6a1d487d8c57ca40883 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:09:30 +0100 Subject: [PATCH 01/36] Create configuration for GPU benchmarks on local runner --- .github/workflows/gpu_benchmark.yml | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 .github/workflows/gpu_benchmark.yml diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml new file mode 100644 index 000000000..c02b4bab9 --- /dev/null +++ b/.github/workflows/gpu_benchmark.yml @@ -0,0 +1,53 @@ +name: GPU Benchmarks + +on: + pull_request: + types: [opened, reopened, synchronize] + workflow_dispatch: + schedule: + - cron: "0 0 * * 1" # Run At 00:00 on Monday + +permissions: + pull-requests: write + +jobs: + benchmark: + runs-on: [self-hosted, gpu] + container: + image: ghcr.io/nvidia/jax:jax + options: --gpus all + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up environment + run: | + pip install .[all] + + - name: Run JaxSim benchmarks + id: benchmark + run: | + echo "### Benchmark Results" > results.txt + pytest --benchmark-only --gpu-only >> results.txt + # Output the results to GitHub Actions for use in the comment + echo "results=$(cat results.txt)" >> $GITHUB_ENV + + - name: Debugging comment for benchmarks start + uses: thollander/actions-comment-pull-request@v3 + with: + message: | + Running GPU benchmarks for this PR :rocket: + comment-tag: to_delete_on_completion + mode: delete-on-completion + + - name: Post results in PR comment + if: github.event_name == 'pull_request' + uses: thollander/actions-comment-pull-request@v3 + with: + pr-number: ${{ github.event.number }} + message: | + _(execution **${{ github.run_id }}** / attempt **${{ github.run_attempt }}**)_ + ${{ env.results }} + comment-tag: execution + mode: upsert From 23aada9e2998d0870fe887601abc526b1238b070 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 10:41:06 +0100 Subject: [PATCH 02/36] Restore `api.model.forward_kinematics` for benchmarking --- src/jaxsim/api/model.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 47f601512..65d1f0cf1 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1118,6 +1118,34 @@ def forward_dynamics_crb( return v̇_WB, s̈ +@jax.jit +@js.common.named_scope +def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Matrix: + """ + Compute the forward kinematics of the model. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + The nL x 4 x 4 array containing the stacked homogeneous transformations + of the links. The first axis is the link index. + """ + + W_H_LL, _ = jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=data.base_position, + base_quaternion=data.base_quaternion, + joint_positions=data.joint_positions, + joint_velocities=data.joint_velocities, + base_linear_velocity_inertial=data._base_linear_velocity, + base_angular_velocity_inertial=data._base_angular_velocity, + ) + + return W_H_LL + + @jax.jit @js.common.named_scope def free_floating_mass_matrix( From f791f9144e07ef6c2799cda6d7a1403a84336f90 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 10:41:29 +0100 Subject: [PATCH 03/36] Explicitly set `SoftContacts` in benchmarks --- tests/test_benchmark.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index cf6898f18..747eab2d3 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -104,6 +104,9 @@ def test_soft_contact_model( ): model = jaxsim_model_ergocub_reduced + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.SoftContacts() + benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) From 071b249024a0966fae841af0e6e3cb0cd501b5cd Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 11:25:38 +0100 Subject: [PATCH 04/36] Refactor workflow to improve performance result handling and caching --- .github/workflows/gpu_benchmark.yml | 51 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index c02b4bab9..06880b304 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -12,10 +12,10 @@ permissions: jobs: benchmark: - runs-on: [self-hosted, gpu] + runs-on: self-hosted container: image: ghcr.io/nvidia/jax:jax - options: --gpus all + options: --rm --gpus all steps: - name: Checkout repository @@ -25,29 +25,30 @@ jobs: run: | pip install .[all] - - name: Run JaxSim benchmarks - id: benchmark - run: | - echo "### Benchmark Results" > results.txt - pytest --benchmark-only --gpu-only >> results.txt - # Output the results to GitHub Actions for use in the comment - echo "results=$(cat results.txt)" >> $GITHUB_ENV + - name: Performance regression check + uses: actions/checkout@v4 - - name: Debugging comment for benchmarks start - uses: thollander/actions-comment-pull-request@v3 + - name: Run benchmark and store result + run: pytest bench.py --benchmark-json output.json + + - name: Download previous benchmark data + uses: actions/cache@v4 with: - message: | - Running GPU benchmarks for this PR :rocket: - comment-tag: to_delete_on_completion - mode: delete-on-completion - - - name: Post results in PR comment - if: github.event_name == 'pull_request' - uses: thollander/actions-comment-pull-request@v3 + path: ./cache + key: ${{ runner.os }}-benchmark + + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 with: - pr-number: ${{ github.event.number }} - message: | - _(execution **${{ github.run_id }}** / attempt **${{ github.run_attempt }}**)_ - ${{ env.results }} - comment-tag: execution - mode: upsert + tool: 'pytest' + output-file-path: output.json + external-data-json-path: ./cache/benchmark-data.json + fail-on-alert: true + summary-always: true + alert-threshold: 150% + auto-push: false + gh-pages-branch: gh-pages + + - name: Push benchmark result + if: github.event_name != 'pull_request' + run: git push 'https://ami-iit:${{ secrets.GITHUB_TOKEN }}@github.com/ami-iit/jaxsim.git' gh-pages:gh-pages From 6538907322f19995a687341b60c9ac7ff1faecfb Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 11:36:39 +0100 Subject: [PATCH 05/36] Add benchmark for full simulation step --- tests/test_benchmark.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 747eab2d3..bcac88151 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -132,3 +132,15 @@ def test_relaxed_rigid_contact_model( model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts() benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) + + +@pytest.mark.benchmark +def test_simulation_step( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts() + + benchmark_test_function(js.model.step, model, benchmark, batch_size) From 3494ff708b0f4879757e7997e523786cadbb6034 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Fri, 28 Feb 2025 13:40:23 +0100 Subject: [PATCH 06/36] Update benchmark command --- .github/workflows/gpu_benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index 06880b304..d3480c112 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -29,7 +29,7 @@ jobs: uses: actions/checkout@v4 - name: Run benchmark and store result - run: pytest bench.py --benchmark-json output.json + run: pytest tests/test_benchmark.py --benchmark-only --benchmark-json output.json - name: Download previous benchmark data uses: actions/cache@v4 From a6aac9e08634e3ec3c51161d640b118ecedc42a0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 28 Feb 2025 13:58:21 +0100 Subject: [PATCH 07/36] Increase ulimit stack size and use host IPC namespace --- .github/workflows/gpu_benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index d3480c112..7fc1dd23f 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -15,7 +15,7 @@ jobs: runs-on: self-hosted container: image: ghcr.io/nvidia/jax:jax - options: --rm --gpus all + options: --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 steps: - name: Checkout repository From 9aedc95e29d77745ab7c224689dedd7246b3fa8c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 28 Feb 2025 14:00:46 +0100 Subject: [PATCH 08/36] Install Gazebo SDF --- .github/workflows/gpu_benchmark.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index 7fc1dd23f..84675dae3 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -25,6 +25,14 @@ jobs: run: | pip install .[all] + - name: Install Gazebo SDF + run: | + apt update && apt install -y lsb-release wget + sh -c 'echo "deb http://packages.osrfoundation.org/gazebo/ubuntu-stable `lsb_release -cs` main" > /etc/apt/sources.list.d/gazebo-stable.list' + wget http://packages.osrfoundation.org/gazebo.key -O - | apt-key add - + apt-get update + apt install -y libsdformat14-dev libsdformat14 + - name: Performance regression check uses: actions/checkout@v4 From 662c08e02553d564312af7f0955bcf66ee8f75ed Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 28 Feb 2025 14:54:32 +0100 Subject: [PATCH 09/36] Skip soft and rigid contacts model tests --- .github/workflows/gpu_benchmark.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index 84675dae3..7f1ae3663 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -37,7 +37,8 @@ jobs: uses: actions/checkout@v4 - name: Run benchmark and store result - run: pytest tests/test_benchmark.py --benchmark-only --benchmark-json output.json + run: | + pytest tests/test_benchmark.py -k 'not test_rigid_contact_model and not test_soft_contact_model' --benchmark-only --benchmark-json output.json - name: Download previous benchmark data uses: actions/cache@v4 From 9e37572dce70075d8b80334a2b2c40aeb4197dfc Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Fri, 28 Feb 2025 15:00:23 +0100 Subject: [PATCH 10/36] Pass API token for commenting PRs --- .github/workflows/gpu_benchmark.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index 7f1ae3663..a71d7cab5 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -57,6 +57,7 @@ jobs: alert-threshold: 150% auto-push: false gh-pages-branch: gh-pages + github-token: ${{ secrets.GITHUB_TOKEN }} - name: Push benchmark result if: github.event_name != 'pull_request' From 66b8c7031211f563fbe3e81e388dc3b4d9196a6d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 28 Feb 2025 15:03:08 +0100 Subject: [PATCH 11/36] Ensure GPU utilization and benchmark vectorized computations --- .github/workflows/gpu_benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index a71d7cab5..f169fdaac 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -38,7 +38,7 @@ jobs: - name: Run benchmark and store result run: | - pytest tests/test_benchmark.py -k 'not test_rigid_contact_model and not test_soft_contact_model' --benchmark-only --benchmark-json output.json + pytest tests/test_benchmark.py -k 'not test_rigid_contact_model and not test_soft_contact_model' --gpu-only --batch-size 128 --benchmark-only --benchmark-json output.json - name: Download previous benchmark data uses: actions/cache@v4 From 33628bbb6c46b688d1ffc72f965bf57f43570ab7 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 28 Feb 2025 15:08:26 +0100 Subject: [PATCH 12/36] Always comment PRs and add permissions for GH pages --- .github/workflows/gpu_benchmark.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/gpu_benchmark.yml b/.github/workflows/gpu_benchmark.yml index f169fdaac..9d3d4968e 100644 --- a/.github/workflows/gpu_benchmark.yml +++ b/.github/workflows/gpu_benchmark.yml @@ -9,6 +9,8 @@ on: permissions: pull-requests: write + deployments: write + contents: write jobs: benchmark: @@ -54,6 +56,7 @@ jobs: external-data-json-path: ./cache/benchmark-data.json fail-on-alert: true summary-always: true + comment-always: true alert-threshold: 150% auto-push: false gh-pages-branch: gh-pages From 05111afa6d256dfe6d388360c5272e1ba7cfff5e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 Jan 2025 15:00:53 +0100 Subject: [PATCH 13/36] Reintroduce soft, viscoelastic and rigid contact models This reverts commit 9408f744ec6302c107c886f43b92529809acbd21. --- src/jaxsim/rbda/contacts/__init__.py | 12 +- src/jaxsim/rbda/contacts/rigid.py | 462 +++++++++ src/jaxsim/rbda/contacts/soft.py | 480 ++++++++++ src/jaxsim/rbda/contacts/visco_elastic.py | 1066 +++++++++++++++++++++ tests/test_automatic_differentiation.py | 57 +- 5 files changed, 2073 insertions(+), 4 deletions(-) create mode 100644 src/jaxsim/rbda/contacts/rigid.py create mode 100644 src/jaxsim/rbda/contacts/soft.py create mode 100644 src/jaxsim/rbda/contacts/visco_elastic.py diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 3688468cf..06646f14d 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,5 +1,13 @@ -from . import relaxed_rigid +from . import relaxed_rigid, rigid, soft, visco_elastic from .common import ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams +from .rigid import RigidContacts, RigidContactsParams +from .soft import SoftContacts, SoftContactsParams +from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams -ContactParamsTypes = RelaxedRigidContactsParams +ContactParamsTypes = ( + SoftContactsParams + | RigidContactsParams + | RelaxedRigidContactsParams + | ViscoElasticContactsParams +) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py new file mode 100644 index 000000000..d04a7b895 --- /dev/null +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -0,0 +1,462 @@ +from __future__ import annotations + +import dataclasses +from typing import Any + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr + +from . import common +from .common import ContactModel, ContactsParams + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class RigidContactsParams(ContactsParams): + """Parameters of the rigid contacts model.""" + + # Static friction coefficient + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + # Baumgarte proportional term + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + # Baumgarte derivative term + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + ) + ) + + def __eq__(self, other: RigidContactsParams) -> bool: + return hash(self) == hash(other) + + @classmethod + def build( + cls: type[Self], + *, + mu: jtp.FloatLike | None = None, + K: jtp.FloatLike | None = None, + D: jtp.FloatLike | None = None, + ) -> Self: + """Create a `RigidContactParams` instance.""" + + return cls( + mu=jnp.array( + mu + if mu is not None + else cls.__dataclass_fields__["mu"].default_factory() + ).astype(float), + K=jnp.array( + K if K is not None else cls.__dataclass_fields__["K"].default_factory() + ).astype(float), + D=jnp.array( + D if D is not None else cls.__dataclass_fields__["D"].default_factory() + ).astype(float), + ) + + def valid(self) -> jtp.BoolLike: + """Check if the parameters are valid.""" + return bool( + jnp.all(self.mu >= 0.0) + and jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + ) + + +@jax_dataclasses.pytree_dataclass +class RigidContacts(ContactModel): + """Rigid contacts model.""" + + regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field( + default=1e-6, kw_only=True + ) + + _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( + default=("solver_tol",), kw_only=True + ) + _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( + default=(1e-3,), kw_only=True + ) + + @property + def solver_options(self) -> dict[str, Any]: + """Get the solver options as a dictionary.""" + + return dict( + zip( + self._solver_options_keys, + self._solver_options_values, + strict=True, + ) + ) + + @classmethod + def build( + cls: type[Self], + regularization_delassus: jtp.FloatLike | None = None, + solver_options: dict[str, Any] | None = None, + **kwargs, + ) -> Self: + """ + Create a `RigidContacts` instance with specified parameters. + + Args: + regularization_delassus: + The regularization term to add to the diagonal of the Delassus matrix. + solver_options: The options to pass to the QP solver. + **kwargs: Extra arguments which are ignored. + + Returns: + The `RigidContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + # Get the default solver options. + default_solver_options = dict( + zip(cls._solver_options_keys, cls._solver_options_values, strict=True) + ) + + # Create the solver options to set by combining the default solver options + # with the user-provided solver options. + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) + + # Make sure that the solver options are hashable. + # We need to check this because the solver options are static. + try: + hash(tuple(solver_options.values())) + except TypeError as exc: + raise ValueError( + "The values of the solver options must be hashable." + ) from exc + + return cls( + regularization_delassus=float( + regularization_delassus + if regularization_delassus is not None + else cls.__dataclass_fields__["regularization_delassus"].default + ), + _solver_options_keys=tuple(solver_options.keys()), + _solver_options_values=tuple(solver_options.values()), + **kwargs, + ) + + @staticmethod + def compute_impact_velocity( + inactive_collidable_points: jtp.ArrayLike, + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + generalized_velocity: jtp.VectorLike, + ) -> jtp.Vector: + """ + Return the new velocity of the system after a potential impact. + + Args: + inactive_collidable_points: The activation state of the collidable points. + M: The mass matrix of the system (in mixed representation). + J_WC: The Jacobian matrix of the collidable points (in mixed representation). + generalized_velocity: The generalized velocity of the system. + + Note: + The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity` + must be expressed in the same velocity representation. + """ + + # Compute system velocity after impact maintaining zero linear velocity of active points. + sl = jnp.s_[:, 0:3, :] + Jl_WC = J_WC[sl] + + # Zero out the jacobian rows of inactive points. + Jl_WC = jnp.vstack( + jnp.where( + inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], + jnp.zeros_like(Jl_WC), + Jl_WC, + ) + ) + + A = jnp.vstack( + [ + jnp.hstack([M, -Jl_WC.T]), + jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]), + ] + ) + b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])]) + + BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0] + + return BW_ν_post_impact[0 : M.shape[0]] + + @jax.jit + def compute_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + link_forces: + Optional `(n_links, 6)` matrix of external forces acting on the links, + expressed in the same representation of data. + joint_force_references: + Optional `(n_joints,)` vector of joint forces. + + Returns: + A tuple containing as first element the computed contact forces. + """ + + # Import qpax privately just in this method. + import qpax + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + n_collidable_points = len(indices_of_enabled_collidable_points) + + link_forces = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ) + + joint_force_references = jnp.atleast_1d( + jnp.array(joint_force_references, dtype=float).squeeze() + if joint_force_references is not None + else jnp.zeros((model.number_of_joints(),)) + ) + + # Compute kin-dyn quantities used in the contact model. + with data.switch_velocity_representation(VelRepr.Mixed): + BW_ν = data.generalized_velocity() + + M = js.model.free_floating_mass_matrix(model=model, data=data) + + J_WC = js.contact.jacobian(model=model, data=data) + J̇_WC = js.contact.jacobian_derivative(model=model, data=data) + + W_H_C = js.contact.transforms(model=model, data=data) + + # Compute the position and linear velocities (mixed representation) of + # all enabled collidable points belonging to the robot. + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) + + # Compute the penetration depth and velocity of the collidable points. + # Note that this function considers the penetration in the normal direction. + δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( + position, velocity, model.terrain + ) + + # Build a references object to simplify converting link forces. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # Compute the generalized free acceleration. + with ( + references.switch_velocity_representation(VelRepr.Mixed), + data.switch_velocity_representation(VelRepr.Mixed), + ): + + BW_ν̇_free = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references( + model=model + ), + ) + ) + + # Compute the free linear acceleration of the collidable points. + # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. + free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( + BW_nu=BW_ν, + BW_nu_dot=BW_ν̇_free, + CW_J_WC_BW=J_WC, + CW_J_dot_WC_BW=J̇_WC, + ).flatten() + + # Compute stabilization term. + baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term( + inactive_collidable_points=(δ <= 0), + δ=δ, + δ_dot=δ_dot, + n=n̂, + K=data.contacts_params.K, + D=data.contacts_params.D, + ).flatten() + + # Compute the Delassus matrix. + delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) + + # Initialize regularization term of the Delassus matrix for + # better numerical conditioning. + Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0]) + + # Construct the quadratic cost function. + Q = delassus_matrix + Iε + q = free_contact_acc - baumgarte_term + + # Construct the inequality constraints. + G = RigidContacts._compute_ineq_constraint_matrix( + inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu + ) + h_bounds = RigidContacts._compute_ineq_bounds( + n_collidable_points=n_collidable_points + ) + + # Construct the equality constraints. + A = jnp.zeros((0, 3 * n_collidable_points)) + b = jnp.zeros((0,)) + + # Solve the following optimization problem with qpax: + # + # min_{x} 0.5 x⊤ Q x + q⊤ x + # + # s.t. A x = b + # G x ≤ h + # + # TODO: add possibility to notify if the QP problem did not converge. + solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841 + Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options + ) + + # Reshape the optimized solution to be a matrix of 3D contact forces. + CW_fl_C = solution.reshape(-1, 3) + + # Convert the contact forces from mixed to inertial-fixed representation. + W_f_C = jax.vmap( + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + transform=W_H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) + ), + )(CW_fl_C, W_H_C) + + return W_f_C, {} + + @staticmethod + def _delassus_matrix( + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + ) -> jtp.Matrix: + + sl = jnp.s_[:, 0:3, :] + J_WC_lin = jnp.vstack(J_WC[sl]) + + delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + return delassus_matrix + + @staticmethod + def _compute_ineq_constraint_matrix( + inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike + ) -> jtp.Matrix: + """ + Compute the inequality constraint matrix for a single collidable point. + + Rows 0-3: enforce the friction pyramid constraint, + Row 4: last one is for the non negativity of the vertical force + Row 5: contact complementarity condition + """ + G_single_point = jnp.array( + [ + [1, 0, -mu], + [0, 1, -mu], + [-1, 0, -mu], + [0, -1, -mu], + [0, 0, -1], + [0, 0, 0], + ] + ) + G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1)) + G = G.at[:, 5, 2].set(inactive_collidable_points) + + G = jax.scipy.linalg.block_diag(*G) + return G + + @staticmethod + def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: + + n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,)) + + @staticmethod + def _linear_acceleration_of_collidable_points( + BW_nu: jtp.ArrayLike, + BW_nu_dot: jtp.ArrayLike, + CW_J_WC_BW: jtp.MatrixLike, + CW_J_dot_WC_BW: jtp.MatrixLike, + ) -> jtp.Matrix: + + BW_ν = BW_nu + BW_ν̇ = BW_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW + + # Compute the linear acceleration of the collidable points. + # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. + CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ + + CW_a_WC = CW_a_WC.reshape(-1, 6) + return CW_a_WC[:, 0:3].squeeze() + + @staticmethod + def _compute_baumgarte_stabilization_term( + inactive_collidable_points: jtp.ArrayLike, + δ: jtp.ArrayLike, + δ_dot: jtp.ArrayLike, + n: jtp.ArrayLike, + K: jtp.FloatLike, + D: jtp.FloatLike, + ) -> jtp.Array: + + return jnp.where( + inactive_collidable_points[:, jnp.newaxis], + jnp.zeros_like(n), + (K * δ + D * δ_dot)[:, jnp.newaxis] * n, + ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py new file mode 100644 index 000000000..dde16cfb2 --- /dev/null +++ b/src/jaxsim/rbda/contacts/soft.py @@ -0,0 +1,480 @@ +from __future__ import annotations + +import dataclasses +import functools + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.math +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.math import StandardGravity +from jaxsim.terrain import Terrain + +from . import common + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class SoftContactsParams(common.ContactsParams): + """Parameters of the soft contacts model.""" + + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(1e6, dtype=float) + ) + + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(2000, dtype=float) + ) + + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + p: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + q: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + def __hash__(self) -> int: + + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.p), + HashedNumpyArray.hash_of_array(self.q), + ) + ) + + def __eq__(self, other: SoftContactsParams) -> bool: + + if not isinstance(other, SoftContactsParams): + return NotImplemented + + return hash(self) == hash(other) + + @classmethod + def build( + cls: type[Self], + *, + K: jtp.FloatLike = 1e6, + D: jtp.FloatLike = 2_000, + mu: jtp.FloatLike = 0.5, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a SoftContactsParams instance with specified parameters. + + Args: + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A SoftContactsParams instance with the specified parameters. + """ + + return SoftContactsParams( + K=jnp.array(K, dtype=float), + D=jnp.array(D, dtype=float), + mu=jnp.array(mu, dtype=float), + p=jnp.array(p, dtype=float), + q=jnp.array(q, dtype=float), + ) + + @classmethod + def build_default_from_jaxsim_model( + cls: type[Self], + model: js.model.JaxSimModel, + *, + standard_gravity: jtp.FloatLike = StandardGravity, + static_friction_coefficient: jtp.FloatLike = 0.5, + max_penetration: jtp.FloatLike = 0.001, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> SoftContactsParams: + """ + Create a SoftContactsParams instance with good default parameters. + + Args: + model: The target model. + standard_gravity: The standard gravity constant. + static_friction_coefficient: + The static friction coefficient between the model and the terrain. + max_penetration: The maximum penetration depth. + number_of_active_collidable_points_steady_state: + The number of contacts supporting the weight of the model + in steady state. + damping_ratio: The ratio controlling the damping behavior. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A `SoftContactsParams` instance with the specified parameters. + + Note: + The `damping_ratio` parameter allows to operate on the following conditions: + - ξ > 1.0: over-damped + - ξ = 1.0: critically damped + - ξ < 1.0: under-damped + """ + + # Use symbols for input parameters. + ξ = damping_ratio + δ_max = max_penetration + μc = static_friction_coefficient + + # Compute the total mass of the model. + m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() + + # Rename the standard gravity. + g = standard_gravity + + # Compute the average support force on each collidable point. + f_average = m * g / number_of_active_collidable_points_steady_state + + # Compute the stiffness to get the desired steady-state penetration. + # Note that this is dependent on the non-linear exponent used in + # the damping term of the Hunt/Crossley model. + K = f_average / jnp.power(δ_max, 1 + p) + + # Compute the damping using the damping ratio. + critical_damping = 2 * jnp.sqrt(K * m) + D = ξ * critical_damping + + return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q) + + def valid(self) -> jtp.BoolLike: + """ + Check if the parameters are valid. + + Returns: + `True` if the parameters are valid, `False` otherwise. + """ + + return jnp.hstack( + [ + self.K >= 0.0, + self.D >= 0.0, + self.mu >= 0.0, + self.p >= 0.0, + self.q >= 0.0, + ] + ).all() + + +@jax_dataclasses.pytree_dataclass +class SoftContacts(common.ContactModel): + """Soft contacts model.""" + + @classmethod + def build( + cls: type[Self], + model: js.model.JaxSimModel | None = None, + **kwargs, + ) -> Self: + """ + Create a `SoftContacts` instance with specified parameters. + + Args: + model: + The robot model considered by the contact model. + If passed, it is used to estimate good default parameters. + **kwargs: Additional parameters to pass to the contact model. + + Returns: + The `SoftContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls(**kwargs) + + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + """ + + # Initialize the material deformation to zero. + tangential_deformation = jnp.zeros( + shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), + dtype=float, + ) + + return {"tangential_deformation": tangential_deformation} + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def hunt_crossley_contact_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + terrain: Terrain, + K: jtp.FloatLike, + D: jtp.FloatLike, + mu: jtp.FloatLike, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force using the Hunt/Crossley model. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + terrain: The terrain model. + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ + + # Convert the input vectors to arrays. + W_p_C = jnp.array(position, dtype=float).squeeze() + W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() + m = jnp.array(tangential_deformation, dtype=float).squeeze() + + # Use symbol for the static friction. + μ = mu + + # Compute the penetration depth, its rate, and the considered terrain normal. + δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) + + # There are few operations like computing the norm of a vector with zero length + # or computing the square root of zero that are problematic in an AD context. + # To avoid these issues, we introduce a small tolerance ε to their arguments + # and make sure that we do not check them against zero directly. + ε = jnp.finfo(float).eps + + # Compute the powers of the penetration depth. + # Inject ε to address AD issues in differentiating the square root when + # p and q are fractional. + δp = jnp.power(δ + ε, p) + δq = jnp.power(δ + ε, q) + + # ======================== + # Compute the normal force + # ======================== + + # Non-linear spring-damper model (Hunt/Crossley model). + # This is the force magnitude along the direction normal to the terrain. + force_normal_mag = (K * δp) * δ + (D * δq) * δ̇ + + # Depending on the magnitude of δ̇, the normal force could be negative. + force_normal_mag = jnp.maximum(0.0, force_normal_mag) + + # Compute the 3D linear force in C[W] frame. + f_normal = force_normal_mag * n̂ + + # ============================ + # Compute the tangential force + # ============================ + + # Extract the tangential component of the velocity. + v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂ + + # Extract the normal and tangential components of the material deformation. + m_normal = jnp.dot(m, n̂) * n̂ + m_tangential = m - jnp.dot(m, n̂) * n̂ + + # Compute the tangential force in the sticking case. + # Using the tangential component of the material deformation should not be + # necessary if the sticking-slipping transition occurs in a terrain area + # with a locally constant normal. However, this assumption is not true in + # general, especially for highly uneven terrains. + f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential) + + # Detect the contact type (sticking or slipping). + # Note that if there is no contact, sticking is set to True, and this detail + # is exploited in the computation of the `contact_status` variable. + sticking = jnp.logical_or( + δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2 + ) + + # Compute the direction of the tangential force. + # To prevent dividing by zero, we use a switch statement. + norm = jaxsim.math.safe_norm(f_tangential) + f_tangential_direction = f_tangential / ( + norm + jnp.finfo(float).eps * (norm == 0) + ) + + # Project the tangential force to the friction cone if slipping. + f_tangential = jnp.where( + sticking, + f_tangential, + jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, + ) + + # Set the tangential force to zero if there is no contact. + f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential) + + # ===================================== + # Compute the material deformation rate + # ===================================== + + # Compute the derivative of the material deformation. + # Note that we included an additional relaxation of `m_normal` in the + # sticking case, so that the normal deformation that could have accumulated + # from a previous slipping phase can relax to zero. + ṁ_no_contact = -(K / D) * m + ṁ_sticking = v_tangential - (K / D) * m_normal + ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq) + + # Compute the contact status: + # 0: slipping + # 1: sticking + # 2: no contact + contact_status = sticking.astype(int) + contact_status += (δ <= 0).astype(int) + + # Select the right material deformation rate depending on the contact status. + ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact) + + # ========================================== + # Compute and return the final contact force + # ========================================== + + # Sum the normal and tangential forces. + CW_fl = f_normal + f_tangential + + return CW_fl, ṁ + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def compute_contact_force( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: SoftContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + parameters: The parameters of the soft contacts model. + terrain: The terrain model. + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ + + CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( + position=position, + velocity=velocity, + tangential_deformation=tangential_deformation, + terrain=terrain, + K=parameters.K, + D=parameters.D, + mu=parameters.mu, + p=parameters.p, + q=parameters.q, + ) + + # Pack a mixed 6D force. + CW_f = jnp.hstack([CW_fl, jnp.zeros(3)]) + + # Compute the 6D force transform from the mixed to the inertial-fixed frame. + W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation( + translation=jnp.array(position), inverse=True + ).T + + # Compute the 6D force in the inertial-fixed frame. + W_f = W_Xf_CW @ CW_f + + return W_f, ṁ + + @staticmethod + @jax.jit + def compute_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + A tuple containing as first element the computed contact forces, and as + second element a dictionary with derivative of the material deformation. + """ + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + # Compute the position and linear velocities (mixed representation) of + # all the collidable points belonging to the robot and extract the ones + # for the enabled collidable points. + W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) + + # Extract the material deformation corresponding to the collidable points. + m = data.state.extended["tangential_deformation"] + + m_enabled = m[indices_of_enabled_collidable_points] + + # Initialize the tangential deformation rate array for every collidable point. + ṁ = jnp.zeros_like(m) + + # Compute the contact forces only for the enabled collidable points. + # Since we treat them as independent, we can vmap the computation. + W_f, ṁ_enabled = jax.vmap( + lambda p, v, m: SoftContacts.compute_contact_force( + position=p, + velocity=v, + tangential_deformation=m, + parameters=data.contacts_params, + terrain=model.terrain, + ) + )(W_p_C, W_ṗ_C, m_enabled) + + ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) + + return W_f, dict(m_dot=ṁ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py new file mode 100644 index 000000000..40ad4ab61 --- /dev/null +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -0,0 +1,1066 @@ +from __future__ import annotations + +import dataclasses +import functools +from typing import Any + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim +import jaxsim.api as js +import jaxsim.exceptions +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation +from jaxsim.math import StandardGravity +from jaxsim.terrain import Terrain + +from . import common +from .soft import SoftContacts, SoftContactsParams + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class ViscoElasticContactsParams(common.ContactsParams): + """Parameters of the visco-elastic contacts model.""" + + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(1e6, dtype=float) + ) + + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(2000, dtype=float) + ) + + static_friction: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + p: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + q: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + @classmethod + def build( + cls: type[Self], + K: jtp.FloatLike = 1e6, + D: jtp.FloatLike = 2_000, + static_friction: jtp.FloatLike = 0.5, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a SoftContactsParams instance with specified parameters. + + Args: + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + static_friction: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model. + + Returns: + A ViscoElasticParams instance with the specified parameters. + """ + + return ViscoElasticContactsParams( + K=jnp.array(K, dtype=float), + D=jnp.array(D, dtype=float), + static_friction=jnp.array(static_friction, dtype=float), + p=jnp.array(p, dtype=float), + q=jnp.array(q, dtype=float), + ) + + @classmethod + def build_default_from_jaxsim_model( + cls: type[Self], + model: js.model.JaxSimModel, + *, + standard_gravity: jtp.FloatLike = StandardGravity, + static_friction_coefficient: jtp.FloatLike = 0.5, + max_penetration: jtp.FloatLike = 0.001, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a ViscoElasticContactsParams instance with good default parameters. + + Args: + model: The target model. + standard_gravity: The standard gravity constant. + static_friction_coefficient: + The static friction coefficient between the model and the terrain. + max_penetration: The maximum penetration depth. + number_of_active_collidable_points_steady_state: + The number of contacts supporting the weight of the model + in steady state. + damping_ratio: The ratio controlling the damping behavior. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model. + + Returns: + A `ViscoElasticContactsParams` instance with the specified parameters. + + Note: + The `damping_ratio` parameter allows to operate on the following conditions: + - ξ > 1.0: over-damped + - ξ = 1.0: critically damped + - ξ < 1.0: under-damped + """ + + # Call the SoftContact builder instead of duplicating the logic. + soft_contacts_params = SoftContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_penetration, + number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state, + damping_ratio=damping_ratio, + ) + + return ViscoElasticContactsParams.build( + K=soft_contacts_params.K, + D=soft_contacts_params.D, + static_friction=soft_contacts_params.mu, + p=p, + q=q, + ) + + def valid(self) -> jtp.BoolLike: + """ + Check if the parameters are valid. + + Returns: + `True` if the parameters are valid, `False` otherwise. + """ + + return ( + jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + and jnp.all(self.static_friction >= 0.0) + and jnp.all(self.p >= 0.0) + and jnp.all(self.q >= 0.0) + ) + + def __hash__(self) -> int: + + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + HashedNumpyArray.hash_of_array(self.static_friction), + HashedNumpyArray.hash_of_array(self.p), + HashedNumpyArray.hash_of_array(self.q), + ) + ) + + def __eq__(self, other: ViscoElasticContactsParams) -> bool: + + if not isinstance(other, ViscoElasticContactsParams): + return False + + return hash(self) == hash(other) + + +@jax_dataclasses.pytree_dataclass +class ViscoElasticContacts(common.ContactModel): + """Visco-elastic contacts model.""" + + max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25) + + @classmethod + def build( + cls: type[Self], + model: js.model.JaxSimModel | None = None, + max_squarings: jtp.IntLike | None = None, + **kwargs, + ) -> Self: + """ + Create a `ViscoElasticContacts` instance with specified parameters. + + Args: + model: + The robot model considered by the contact model. + If passed, it is used to estimate good default parameters. + max_squarings: + The maximum number of squarings performed in the matrix exponential. + **kwargs: Extra arguments to ignore. + + Returns: + The `ViscoElasticContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls( + max_squarings=int( + max_squarings + if max_squarings is not None + else cls.__dataclass_fields__["max_squarings"].default + ), + ) + + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + """ + + # Initialize the material deformation to zero. + tangential_deformation = jnp.zeros( + shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), + dtype=float, + ) + + return {"tangential_deformation": tangential_deformation} + + @jax.jit + def compute_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike | None = None, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + dt: The time step to consider. If not specified, it is read from the model. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding + to the velocity representation of `data`. + joint_force_references: The joint force references to apply. + + Note: + This contact model, contrarily to most other contact models, requires the + knowledge of the integration step. It is not straightforward to assess how + this contact model behaves when used with high-order Runge-Kutta schemes. + For the time being, it is recommended to use a simple forward Euler scheme. + The main benefit of this model is that the stiff contact dynamics is computed + separately from the rest of the system dynamics, which allows to use simple + integration schemes without altering significantly the simulation stability. + + Returns: + A tuple containing as first element the computed 6D contact force applied to + the contact point and expressed in the world frame, and as second element + a dictionary of optional additional information. + """ + + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + # Initialize the time step. + dt = dt if dt is not None else model.time_step + + # Compute the average contact linear forces in mixed representation by + # integrating the contact dynamics in the continuous time domain. + CW_f̅l, CW_fl̿, m_tf = ( + ViscoElasticContacts._compute_contact_forces_with_exponential_integration( + model=model, + data=data, + dt=jnp.array(dt).astype(float), + link_forces=link_forces, + joint_force_references=joint_force_references, + indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, + max_squarings=self.max_squarings, + ) + ) + + # ============================================ + # Compute the inertial-fixed 6D contact forces + # ============================================ + + # Compute the transforms of the mixed frames `C[W] = (W_p_C, [W])` + # associated to each collidable point. + W_H_C = js.contact.transforms(model=model, data=data)[ + indices_of_enabled_collidable_points, :, : + ] + + # Vmapped transformation from mixed to inertial-fixed representation. + compute_forces_inertial_fixed_vmap = jax.vmap( + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_C, + is_force=True, + ) + ) + ) + + # Express the linear contact forces in the inertial-fixed frame. + W_f̅_C, W_f̿_C = jax.vmap( + lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C) + )(jnp.stack([CW_f̅l, CW_fl̿])) + + return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf) + + @staticmethod + @functools.partial(jax.jit, static_argnames=("max_squarings",)) + def _compute_contact_forces_with_exponential_integration( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + indices_of_enabled_collidable_points: jtp.VectorLike | None = None, + max_squarings: int = 25, + ) -> tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]: + """ + Compute the average contact forces by integrating the contact dynamics. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + dt: The integration time step. + link_forces: The 6D forces to apply to the links. + joint_force_references: The joint force references to apply. + indices_of_enabled_collidable_points: + The indices of the enabled collidable points. + max_squarings: + The maximum number of squarings performed in the matrix exponential. + + Returns: + A tuple containing: + - The average contact forces. + - The average of the average contact forces. + - The tangential deformation at the final state. + """ + + # ========================== + # Populate missing arguments + # ========================== + + indices = ( + indices_of_enabled_collidable_points + if indices_of_enabled_collidable_points is not None + else jnp.arange( + len(model.kin_dyn_parameters.contact_parameters.body) + ).astype(int) + ) + + # ================================== + # Compute the contact point dynamics + # ================================== + + p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data) + m_t0 = data.state.extended["tangential_deformation"][indices, :] + + p_t0 = p_t0[indices, :] + v_t0 = v_t0[indices, :] + + # Compute the linearized contact dynamics. + # Note that it linearizes the (non-linear) contact model at (p, v, m)[t0]. + A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + indices_of_enabled_collidable_points=indices, + p_t0=p_t0, + v_t0=v_t0, + m_t0=m_t0, + ) + + # ============================================= + # Compute the integrals of the contact dynamics + # ============================================= + + # Pack the initial state of the contact points. + x_t0 = jnp.hstack([p_t0.flatten(), v_t0.flatten(), m_t0.flatten()]) + + # Pack the augmented matrix used to compute the single and double integral + # of the exponential integration. + A̅ = jnp.vstack( + [ + jnp.hstack( + [ + A, + jnp.vstack(b), + jnp.vstack(x_t0), + jnp.vstack(jnp.zeros_like(x_t0)), + ] + ), + jnp.hstack([jnp.zeros(A.shape[1]), 0, 1, 0]), + jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 1]), + jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 0]), + ] + ) + + # Compute the matrix exponential. + exp_tA = jax.scipy.linalg.expm( + (dt * A̅).astype(float), max_squarings=max_squarings + ) + + # Integrate the contact dynamics in the continuous time domain. + x_int, x_int2 = ( + jnp.hstack([jnp.eye(A.shape[0]), jnp.zeros(shape=(A.shape[0], 3))]) + @ exp_tA + @ jnp.vstack([jnp.zeros(shape=(A.shape[0] + 1, 2)), jnp.eye(2)]) + ).T + + jaxsim.exceptions.raise_runtime_error_if( + condition=jnp.isnan(x_int).any(), + msg="NaN integration, try to increase `max_squarings` or decreasing `dt`", + ) + + # ========================== + # Compute the contact forces + # ========================== + + # Compute the average contact forces. + CW_f̅, _ = jnp.split( + (A_sc @ x_int / dt + b_sc).reshape(-1, 3), + indices_or_sections=2, + ) + + # Compute the average of the average contact forces. + CW_f̿, _ = jnp.split( + (A_sc @ x_int2 * 2 / (dt**2) + b_sc).reshape(-1, 3), + indices_or_sections=2, + ) + + # Extract the tangential deformation at the final state. + x_tf = x_int / dt + m_tf = jnp.split(x_tf, 3)[2].reshape(-1, 3) + + return CW_f̅, CW_f̿, m_tf + + @staticmethod + @jax.jit + def _contact_points_dynamics( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + indices_of_enabled_collidable_points: jtp.VectorLike | None = None, + p_t0: jtp.MatrixLike | None = None, + v_t0: jtp.MatrixLike | None = None, + m_t0: jtp.MatrixLike | None = None, + ) -> tuple[jtp.Matrix, jtp.Vector, jtp.Matrix, jtp.Vector]: + """ + Compute the dynamics of the contact points. + + Note: + This function projects the system dynamics to the contact space and + returns the matrices of a linear system to simulate its evolution. + Since the active contact model can be non-linear, this function also + linearizes the contact model at the initial state. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + link_forces: The 6D forces to apply to the links. + joint_force_references: The joint force references to apply. + indices_of_enabled_collidable_points: + The indices of the enabled collidable points. + p_t0: The initial position of the collidable points. + v_t0: The initial velocity of the collidable points. + m_t0: The initial tangential deformation of the collidable points. + + Returns: + A tuple containing: + - The `A` matrix of the linear system that models the contact dynamics. + - The `b` vector of the linear system that models the contact dynamics. + - The `A_sc` matrix of the linear system that approximates the contact model. + - The `b_sc` vector of the linear system that approximates the contact model. + """ + + indices_of_enabled_collidable_points = ( + indices_of_enabled_collidable_points + if indices_of_enabled_collidable_points is not None + else jnp.arange( + len(model.kin_dyn_parameters.contact_parameters.body) + ).astype(int) + ) + + p_t0 = jnp.atleast_2d( + p_t0 + if p_t0 is not None + else js.contact.collidable_point_positions(model=model, data=data)[ + indices_of_enabled_collidable_points, : + ] + ) + + v_t0 = jnp.atleast_2d( + v_t0 + if v_t0 is not None + else js.contact.collidable_point_velocities(model=model, data=data)[ + indices_of_enabled_collidable_points, : + ] + ) + + m_t0 = jnp.atleast_2d( + m_t0 + if m_t0 is not None + else data.state.extended["tangential_deformation"][ + indices_of_enabled_collidable_points, : + ] + ) + + # We expect that the 6D forces of the `link_forces` argument are expressed + # in the frame corresponding to the velocity representation of `data`. + references = js.references.JaxSimModelReferences.build( + model=model, + link_forces=link_forces, + joint_force_references=joint_force_references, + data=data, + velocity_representation=data.velocity_representation, + ) + + # =========================== + # Linearize the contact model + # =========================== + + # Linearize the contact model at the initial state of all considered + # contact points. + A_sc_points, b_sc_points = jax.vmap( + lambda p, v, m: ViscoElasticContacts._linearize_contact_model( + position=p, + velocity=v, + tangential_deformation=m, + parameters=data.contacts_params, + terrain=model.terrain, + ) + )(p_t0, v_t0, m_t0) + + # Since x = [p1, p2, ..., v1, v2, ..., m1, m2, ...], we need to split the A_sc of + # individual points since otherwise we'd get x = [ p1, v1, m1, p2, v2, m2, ...]. + A_sc_p, A_sc_v, A_sc_m = jnp.split(A_sc_points, indices_or_sections=3, axis=-1) + + # We want to have in output first the forces and then the material deformation rates. + # Therefore, we need to extract the components is A_sc_* separately. + A_sc = jnp.vstack( + [ + jnp.hstack( + [ + jax.scipy.linalg.block_diag(*A_sc_p[:, 0:3, :]), + jax.scipy.linalg.block_diag(*A_sc_v[:, 0:3, :]), + jax.scipy.linalg.block_diag(*A_sc_m[:, 0:3, :]), + ], + ), + jnp.hstack( + [ + jax.scipy.linalg.block_diag(*A_sc_p[:, 3:6, :]), + jax.scipy.linalg.block_diag(*A_sc_v[:, 3:6, :]), + jax.scipy.linalg.block_diag(*A_sc_m[:, 3:6, :]), + ] + ), + ] + ) + + # We need to do the same for the b_sc. + b_sc = jnp.hstack( + [b_sc_points[:, 0:3].flatten(), b_sc_points[:, 3:6].flatten()] + ) + + # =========================================================== + # Compute the A and b matrices of the contact points dynamics + # =========================================================== + + with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): + + BW_ν = data.generalized_velocity() + + M = js.model.free_floating_mass_matrix(model=model, data=data) + + CW_Jl_WC = js.contact.jacobian( + model=model, + data=data, + output_vel_repr=jaxsim.VelRepr.Mixed, + )[indices_of_enabled_collidable_points, 0:3, :] + + CW_J̇l_WC = js.contact.jacobian_derivative( + model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed + )[indices_of_enabled_collidable_points, 0:3, :] + + # Compute the Delassus matrix. + ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0] + + I_nc = jnp.eye(v_t0.flatten().size) + O_nc = jnp.zeros(shape=(p_t0.flatten().size, p_t0.flatten().size)) + + # Pack the A matrix. + A = jnp.vstack( + [ + jnp.hstack([O_nc, I_nc, O_nc]), + ψ @ jnp.split(A_sc, 2, axis=0)[0], + jnp.split(A_sc, 2, axis=0)[1], + ] + ) + + # Short names for few variables. + ν = BW_ν + J = jnp.vstack(CW_Jl_WC) + J̇ = jnp.vstack(CW_J̇l_WC) + + # Compute the free system acceleration components. + with ( + data.switch_velocity_representation(jaxsim.VelRepr.Mixed), + references.switch_velocity_representation(jaxsim.VelRepr.Mixed), + ): + + BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), + ) + + # Pack the free system acceleration in mixed representation. + ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free]) + + # Compute the acceleration of collidable points. + # This is the true derivative of ṗ only in mixed representation. + p̈ = J @ ν̇_free + J̇ @ ν + + # Pack the b array. + b = jnp.hstack( + [ + jnp.zeros_like(p_t0.flatten()), + p̈ + ψ @ jnp.split(b_sc, indices_or_sections=2)[0], + jnp.split(b_sc, indices_or_sections=2)[1], + ] + ) + + return A, b, A_sc, b_sc + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def _linearize_contact_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: ViscoElasticContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Matrix, jtp.Vector]: + """ + Linearize the Hunt/Crossley contact model at the initial state. + + Args: + position: The position of the contact point. + velocity: The velocity of the contact point. + tangential_deformation: The tangential deformation of the contact point. + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + A tuple containing the `A` matrix and the `b` vector of the linear system + corresponding to the contact dynamics linearized at the initial state. + """ + + # Initialize the state at which the model is linearized. + p0 = jnp.array(position, dtype=float).squeeze() + v0 = jnp.array(velocity, dtype=float).squeeze() + m0 = jnp.array(tangential_deformation, dtype=float).squeeze() + + # ============ + # Compute A_sc + # ============ + + compute_contact_force_non_linear_model = functools.partial( + ViscoElasticContacts._compute_contact_force_non_linear_model, + parameters=parameters, + terrain=terrain, + ) + + # Compute with AD the functions to get the Jacobians of CW_fl. + df_dp_fun, df_dv_fun, df_dm_fun = ( + jax.jacrev( + lambda p0, v0, m0: compute_contact_force_non_linear_model( + position=p0, velocity=v0, tangential_deformation=m0 + )[0], + argnums=num, + ) + for num in (0, 1, 2) + ) + + # Compute with AD the functions to get the Jacobians of ṁ. + dṁ_dp_fun, dṁ_dv_fun, dṁ_dm_fun = ( + jax.jacrev( + lambda p0, v0, m0: compute_contact_force_non_linear_model( + position=p0, velocity=v0, tangential_deformation=m0 + )[1], + argnums=num, + ) + for num in (0, 1, 2) + ) + + # Compute the Jacobians of the contact forces w.r.t. the state. + df_dp = jnp.vstack(df_dp_fun(p0, v0, m0)) + df_dv = jnp.vstack(df_dv_fun(p0, v0, m0)) + df_dm = jnp.vstack(df_dm_fun(p0, v0, m0)) + + # Compute the Jacobians of the material deformation rate w.r.t. the state. + dṁ_dp = jnp.vstack(dṁ_dp_fun(p0, v0, m0)) + dṁ_dv = jnp.vstack(dṁ_dv_fun(p0, v0, m0)) + dṁ_dm = jnp.vstack(dṁ_dm_fun(p0, v0, m0)) + + # Pack the A matrix. + A_sc = jnp.vstack( + [ + jnp.hstack([df_dp, df_dv, df_dm]), + jnp.hstack([dṁ_dp, dṁ_dv, dṁ_dm]), + ] + ) + + # ============ + # Compute b_sc + # ============ + + # Compute the output of the non-linear model at the initial state. + x0 = jnp.hstack([p0, v0, m0]) + f0, ṁ0 = compute_contact_force_non_linear_model( + position=p0, velocity=v0, tangential_deformation=m0 + ) + + # Pack the b vector. + b_sc = jnp.hstack([f0, ṁ0]) - A_sc @ x0 + + return A_sc, b_sc + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def _compute_contact_force_non_linear_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: ViscoElasticContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact forces using the non-linear Hunt/Crossley model. + + Args: + position: The position of the contact point. + velocity: The velocity of the contact point. + tangential_deformation: The tangential deformation of the contact point. + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + A tuple containing: + - The linear contact force in the mixed contact frame. + - The rate of material deformation. + """ + + # Compute the linear contact force in mixed representation using + # the non-linear Hunt/Crossley model. + # The following function also returns the rate of material deformation. + CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( + position=position, + velocity=velocity, + tangential_deformation=tangential_deformation, + terrain=terrain, + K=parameters.K, + D=parameters.D, + mu=parameters.static_friction, + p=parameters.p, + q=parameters.q, + ) + + return CW_fl, ṁ + + @staticmethod + @jax.jit + def integrate_data_with_average_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + average_link_contact_forces_inertial: jtp.MatrixLike | None = None, + average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None, + ) -> js.data.JaxSimModelData: + """ + Advance the system state by integrating the dynamics. + + Args: + model: The model to consider. + data: The data of the considered model. + dt: The integration time step. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding + to the velocity representation of `data`. + joint_force_references: The joint force references to apply. + average_link_contact_forces_inertial: + The average contact forces computed with the exponential integrator and + expressed in the inertial-fixed frame. + average_of_average_link_contact_forces_mixed: + The average of the average contact forces computed with the exponential + integrator and expressed in the mixed frame. + + Returns: + The data object storing the system state at the final time. + """ + + s_t0 = data.joint_positions() + W_p_B_t0 = data.base_position() + W_Q_B_t0 = data.base_orientation(dcm=False) + + ṡ_t0 = data.joint_velocities() + with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): + W_ṗ_B_t0 = data.base_velocity()[0:3] + W_ω_WB_t0 = data.base_velocity()[3:6] + + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + W_ν_t0 = data.generalized_velocity() + + # We expect that the 6D forces of the `link_forces` argument are expressed + # in the frame corresponding to the velocity representation of `data`. + references = js.references.JaxSimModelReferences.build( + model=model, + link_forces=link_forces, + joint_force_references=joint_force_references, + data=data, + velocity_representation=data.velocity_representation, + ) + + W_f̅_L = ( + jnp.array(average_link_contact_forces_inertial) + if average_link_contact_forces_inertial is not None + else jnp.zeros_like(references._link_forces) + ).astype(float) + + LW_f̿_L = ( + jnp.array(average_of_average_link_contact_forces_mixed) + if average_of_average_link_contact_forces_mixed is not None + else W_f̅_L + ).astype(float) + + # Compute the system inertial acceleration, used to integrate the system velocity. + # It considers the average contact forces computed with the exponential integrator. + with ( + data.switch_velocity_representation(jaxsim.VelRepr.Inertial), + references.switch_velocity_representation(jaxsim.VelRepr.Inertial), + ): + + W_ν̇_pr = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_force_references=references.joint_force_references( + model=model + ), + link_forces=W_f̅_L + references.link_forces(model=model, data=data), + ) + ) + + # Compute the system mixed acceleration, used to integrate the system position. + # It considers the average of the average contact forces computed with the + # exponential integrator. + with ( + data.switch_velocity_representation(jaxsim.VelRepr.Mixed), + references.switch_velocity_representation(jaxsim.VelRepr.Mixed), + ): + + BW_ν̇_pr2 = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_force_references=references.joint_force_references( + model=model + ), + link_forces=LW_f̿_L + references.link_forces(model=model, data=data), + ) + ) + + # Integrate the system velocity using the inertial-fixed acceleration. + W_ν_plus = W_ν_t0 + dt * W_ν̇_pr + + # Integrate the system position using the mixed velocity. + q_plus = jnp.hstack( + [ + # Note: here both ṗ and p̈ -> need mixed representation. + W_p_B_t0 + dt * W_ṗ_B_t0 + 0.5 * dt**2 * BW_ν̇_pr2[0:3], + jaxsim.math.Quaternion.integration( + dt=dt, + quaternion=W_Q_B_t0, + omega=(W_ω_WB_t0 + 0.5 * dt * BW_ν̇_pr2[3:6]), + omega_in_body_fixed=False, + ).squeeze(), + s_t0 + dt * ṡ_t0 + 0.5 * dt**2 * BW_ν̇_pr2[6:], + ] + ) + + # Create the data at the final time. + data_tf = data.copy() + data_tf = data_tf.reset_joint_positions(q_plus[7:]) + data_tf = data_tf.reset_base_position(q_plus[0:3]) + data_tf = data_tf.reset_base_quaternion(q_plus[3:7]) + data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:]) + data_tf = data_tf.reset_base_velocity( + W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial + ) + + return data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + + +@jax.jit +def step( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike | None = None, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, +) -> tuple[js.data.JaxSimModelData, dict[str, Any]]: + """ + Step the system dynamics with the visco-elastic contact model. + + Args: + model: The model to consider. + data: The data of the considered model. + dt: The time step to consider. If not specified, it is read from the model. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding to + the velocity representation of `data`. + joint_force_references: The joint force references to consider. + + Returns: + A tuple containing the new data of the model + and an empty dictionary of auxiliary data. + """ + + assert isinstance(model.contact_model, ViscoElasticContacts) + assert isinstance(data.contacts_params, ViscoElasticContactsParams) + + # Compute the contact forces in inertial-fixed representation. + # TODO: understand what's wrong in other representations. + data_inertial_fixed = data.replace( + velocity_representation=jaxsim.VelRepr.Inertial, validate=False + ) + + # Create the references object. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + velocity_representation=data.velocity_representation, + ) + + # Initialize the time step. + dt = dt if dt is not None else model.time_step + + # Compute the contact forces with the exponential integrator. + W_f̅_C, aux_data = model.contact_model.compute_contact_forces( + model=model, + data=data_inertial_fixed, + dt=jnp.array(dt).astype(float), + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), + ) + + # Extract the final material deformation and the average of average forces + # from the dictionary containing auxiliary data. + m_tf = aux_data["m_tf"] + W_f̿_C = aux_data["W_f_avg2_C"] + + # =============================== + # Compute the link contact forces + # =============================== + + # Get the link contact forces by summing the forces of contact points belonging + # to the same link. + W_f̅_L, W_f̿_L = jax.vmap( + lambda W_f_C: model.contact_model.link_forces_from_contact_forces( + model=model, data=data_inertial_fixed, contact_forces=W_f_C + ) + )(jnp.stack([W_f̅_C, W_f̿_C])) + + # Compute the link transforms. + W_H_L = ( + js.model.forward_kinematics(model=model, data=data) + if data.velocity_representation is not jaxsim.VelRepr.Inertial + else jnp.zeros(shape=(model.number_of_links(), 4, 4)) + ) + + # For integration purpose, we need the average of average forces expressed in + # mixed representation. + LW_f̿_L = jax.vmap( + lambda W_f_L, W_H_L: ( + ModelDataWithVelocityRepresentation.inertial_to_other_representation( + array=W_f_L, + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_L, + is_force=True, + ) + ) + )(W_f̿_L, W_H_L) + + # ========================== + # Integrate the system state + # ========================== + + # Integrate the system dynamics using the average contact forces. + data_tf: js.data.JaxSimModelData = ( + model.contact_model.integrate_data_with_average_contact_forces( + model=model, + data=data_inertial_fixed, + dt=dt, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), + average_link_contact_forces_inertial=W_f̅_L, + average_of_average_link_contact_forces_mixed=LW_f̿_L, + ) + ) + + # Store the tangential deformation at the final state. + # Note that this was integrated in the continuous time domain, therefore it should + # be much more accurate than the one computed with the discrete soft contacts. + with data_tf.mutable_context(): + + # Extract the indices corresponding to the enabled collidable points. + # The visco-elastic contact model computed only their contact forces. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + data_tf.state.extended |= { + "tangential_deformation": data_tf.state.extended["tangential_deformation"] + .at[indices_of_enabled_collidable_points] + .set(m_tf) + } + + # Restore the original velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + + return data_tf, {} diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index e248aefea..48e78c3b2 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -11,6 +11,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr +from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. @@ -290,6 +291,55 @@ def test_ad_jacobian( ) +def test_ad_soft_contacts( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) + p = jax.random.uniform(subkey1, shape=(3,), minval=-1) + v = jax.random.uniform(subkey2, shape=(3,), minval=-1) + m = jax.random.uniform(subkey3, shape=(3,), minval=-1) + + # Get the soft contacts parameters. + parameters = js.contact.estimate_good_contact_parameters(model=model) + + # ==== + # Test + # ==== + + # Get a closure exposing only the parameters to be differentiated. + def close_over_inputs_and_parameters( + p: jtp.VectorLike, + v: jtp.VectorLike, + m: jtp.VectorLike, + params: SoftContactsParams, + ) -> tuple[jtp.Vector, jtp.Vector]: + + W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force( + position=p, + velocity=v, + tangential_deformation=m, + parameters=params, + terrain=model.terrain, + ) + + return W_f_Ci, CW_ṁ + + # Check derivatives against finite differences. + check_grads( + f=close_over_inputs_and_parameters, + args=(p, v, m, parameters), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + # On GPU, the tolerance needs to be increased. + rtol=0.02 if "gpu" in {d.platform for d in p.devices()} else None, + ) + + def test_ad_integration( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, @@ -329,9 +379,10 @@ def step( s: jtp.Vector, W_v_WB: jtp.Vector, ṡ: jtp.Vector, + m: jtp.Vector, τ: jtp.Vector, W_f_L: jtp.Matrix, - ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the # quaternion non-unitary, which will cause the AD check to fail. @@ -345,6 +396,7 @@ def step( base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, + extended_state={"tangential_deformation": m}, ) data_xf = js.model.step( @@ -359,8 +411,9 @@ def step( xf_s = data_xf.joint_positions xf_W_v_WB = data_xf.base_velocity xf_ṡ = data_xf.joint_velocities + xf_m = data_xf.extended_state["tangential_deformation"] - return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ + return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m # Check derivatives against finite differences. # We set forward mode only because the backward mode is not supported by the From 29f9c6173750c09d08848848e3194513fca74b5d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 Jan 2025 15:54:38 +0100 Subject: [PATCH 14/36] Streamline new API changes to alternative contact models --- src/jaxsim/api/contact.py | 95 +++++++++++- src/jaxsim/api/contact_model.py | 14 +- src/jaxsim/api/data.py | 20 +++ src/jaxsim/api/integrators.py | 11 ++ src/jaxsim/api/model.py | 78 ++++++++-- src/jaxsim/math/__init__.py | 2 +- src/jaxsim/rbda/contacts/relaxed_rigid.py | 2 +- src/jaxsim/rbda/contacts/rigid.py | 18 +-- src/jaxsim/rbda/contacts/soft.py | 12 +- src/jaxsim/rbda/contacts/visco_elastic.py | 147 +++++++------------ src/jaxsim/rbda/utils.py | 4 +- tests/test_automatic_differentiation.py | 10 +- tests/test_simulations.py | 170 +++++++++++++++++++++- 13 files changed, 438 insertions(+), 145 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index f0d65c4a0..2e3a39bb0 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -37,14 +37,11 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - # Switch to inertial-fixed since the RBDAs expect velocities in this representation. - with data.switch_velocity_representation(VelRepr.Inertial): - - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( - model=model, - link_transforms=data._link_transforms, - link_velocities=data._link_velocities, - ) + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + model=model, + link_transforms=data._link_transforms, + link_velocities=data._link_velocities, + ) return W_p_Ci, W_ṗ_Ci @@ -164,7 +161,11 @@ def estimate_good_soft_contacts_parameters( def estimate_good_contact_parameters( model: js.model.JaxSimModel, *, + standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + max_penetration: jtp.FloatLike | None = None, **kwargs, ) -> jaxsim.rbda.contacts.ContactParamsTypes: """ @@ -172,7 +173,12 @@ def estimate_good_contact_parameters( Args: model: The model to consider. + standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. + number_of_active_collidable_points_steady_state: + The number of active collidable points in steady state. + damping_ratio: The damping ratio. + max_penetration: The maximum penetration allowed. kwargs: Additional model-specific parameters passed to the builder method of the parameters class. @@ -190,8 +196,81 @@ def estimate_good_contact_parameters( specific application. """ + def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: + """ + Displacement between the CoM and the lowest collidable point using zero + joint positions. + """ + + zero_data = js.data.JaxSimModelData.build( + model=model, + ) + + W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] + + if model.floating_base(): + W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] + return 2 * (W_pz_CoM - W_pz_C.min()) + + return 2 * W_pz_CoM + + max_δ = ( + max_penetration + if max_penetration is not None + # Consider as default a 0.5% of the model height. + else 0.005 * estimate_model_height(model=model) + ) + + nc = number_of_active_collidable_points_steady_state + match model.contact_model: + case contacts.SoftContacts(): + assert isinstance(model.contact_model, contacts.SoftContacts) + + parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + **kwargs, + ) + + case contacts.ViscoElasticContacts(): + assert isinstance(model.contact_model, contacts.ViscoElasticContacts) + + parameters = ( + contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + **kwargs, + ) + ) + + case contacts.RigidContacts(): + assert isinstance(model.contact_model, contacts.RigidContacts) + + # Disable Baumgarte stabilization by default since it does not play + # well with the forward Euler integrator. + K = kwargs.get("K", 0.0) + + parameters = contacts.RigidContactsParams.build( + mu=static_friction_coefficient, + **( + dict( + K=K, + D=2 * jnp.sqrt(K), + ) + | kwargs + ), + ) + case contacts.RelaxedRigidContacts(): assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) diff --git a/src/jaxsim/api/contact_model.py b/src/jaxsim/api/contact_model.py index 1eeca4c5d..cda5222ba 100644 --- a/src/jaxsim/api/contact_model.py +++ b/src/jaxsim/api/contact_model.py @@ -5,6 +5,7 @@ import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim.rbda.contacts import SoftContacts @jax.jit @@ -15,7 +16,7 @@ def link_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, -) -> jtp.Matrix: +) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]: """ Compute the 6D contact forces of all links of the model in inertial representation. @@ -33,11 +34,14 @@ def link_contact_forces( """ # Compute the contact forces for each collidable point with the active contact model. - W_f_C, _ = model.contact_model.compute_contact_forces( + W_f_C, aux_dict = model.contact_model.compute_contact_forces( model=model, data=data, - link_forces=link_forces, - joint_force_references=joint_torques, + **( + dict(link_forces=link_forces, joint_force_references=joint_torques) + if not isinstance(model.contact_model, SoftContacts) + else {} + ), ) # Compute the 6D forces applied to the links equivalent to the forces applied @@ -46,7 +50,7 @@ def link_contact_forces( model=model, data=data, contact_forces=W_f_C ) - return W_f_L + return W_f_L, aux_dict @staticmethod diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 813cf0094..cddeaf665 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -64,6 +64,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None) + # Extended state for soft and rigid contact models. + contact_state: dict[str, jtp.Array] = dataclasses.field(default=None) + @staticmethod def build( model: js.model.JaxSimModel, @@ -73,6 +76,7 @@ def build( base_linear_velocity: jtp.VectorLike | None = None, base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, + contact_state: dict[str, jtp.Array] | None = None, velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ @@ -89,6 +93,7 @@ def build( The base angular velocity in the selected representation. joint_velocities: The joint velocities. velocity_representation: The velocity representation to use. It defaults to mixed if not provided. + contact_state: The optional contact state. Returns: A `JaxSimModelData` initialized with the given state. @@ -171,6 +176,20 @@ def build( ) ) + contact_state = ( + { + "tangential_deformation": jnp.zeros_like( + model.kin_dyn_parameters.contact_parameters.point + ) + } + if isinstance( + model.contact_model, + jaxsim.rbda.contacts.SoftContacts + | jaxsim.rbda.contacts.ViscoElasticContacts, + ) + else contact_state or {} + ) + model_data = JaxSimModelData( velocity_representation=velocity_representation, _base_quaternion=base_quaternion, @@ -183,6 +202,7 @@ def build( _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities_inertial, + contact_state=contact_state or {}, ) if not model_data.valid(model=model): diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 0f10b231a..e9ddea1c2 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -16,6 +16,8 @@ def semi_implicit_euler_integration( data: js.data.JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, + *, + extended_contact_state: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" @@ -64,6 +66,10 @@ def semi_implicit_euler_integration( s = data.joint_positions + dt * ṡ + integrated_contact_state = jax.tree.map( + lambda x, x_dot: x + dt * x_dot, data.contact_state, extended_contact_state + ) + # TODO: Avoid double replace, e.g. by computing cached value here data = dataclasses.replace( data, @@ -73,6 +79,7 @@ def semi_implicit_euler_integration( _joint_velocities=ṡ, _base_linear_velocity=W_v_B[0:3], _base_angular_velocity=W_ω_WB, + contact_state=integrated_contact_state, ) # Update the cached computations. @@ -86,6 +93,8 @@ def rk4_integration( data: JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, + *, + extended_contact_state: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the Runge-Kutta 4 method.""" @@ -116,6 +125,7 @@ def f(x) -> dict[str, jtp.Matrix]: base_linear_velocity=data._base_linear_velocity, base_angular_velocity=data._base_angular_velocity, joint_velocities=data._joint_velocities, + contact_state=data.contact_state, ) euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt @@ -143,6 +153,7 @@ def f(x) -> dict[str, jtp.Matrix]: "_base_linear_velocity": x_tf["base_linear_velocity"], "_base_angular_velocity": x_tf["base_angular_velocity"], "_joint_velocities": x_tf["joint_velocities"], + "contact_state": x_tf["contact_state"], }, ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 65d1f0cf1..4648382dc 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -47,13 +47,13 @@ class JaxSimModel(JaxsimDataclass): default_factory=jaxsim.terrain.FlatTerrain.build, repr=False ) - gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY + gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field( default=None, repr=False ) - contacts_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field( + contact_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field( default=None, repr=False ) @@ -122,6 +122,7 @@ def build_from_model_description( contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, integrator: IntegratorType | None = None, + gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -143,6 +144,7 @@ def build_from_model_description( If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. integrator: The integrator to use for the simulation. + gravity: The gravity constant. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. @@ -175,8 +177,9 @@ def build_from_model_description( time_step=time_step, terrain=terrain, contact_model=contact_model, - contacts_params=contact_params, + contact_params=contact_params, integrator=integrator, + gravity=gravity, ) # Store the origin of the model, in case downstream logic needs it. @@ -194,7 +197,7 @@ def build( time_step: jtp.FloatLike | None = None, terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, - contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, + contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, ) -> JaxSimModel: @@ -214,8 +217,8 @@ def build( The optional name of the model overriding the physics model name. contact_model: The contact model to consider. - If not specified, a soft contacts model is used. - contacts_params: The parameters of the soft contacts. + If not specified, a relaxed-constraints rigid contacts model is used. + contact_params: The parameters of the contact model. integrator: The integrator to use for the simulation. gravity: The gravity constant. @@ -249,8 +252,8 @@ def build( else jaxsim.rbda.contacts.RelaxedRigidContacts.build() ) - if contacts_params is None: - contacts_params = contact_model._parameters_class() + if contact_params is None: + contact_params = contact_model._parameters_class() # Consider the default integrator if not specified. integrator = ( @@ -268,7 +271,7 @@ def build( time_step=time_step, terrain=terrain, contact_model=contact_model, - contacts_params=contacts_params, + contact_params=contact_params, integrator=integrator, gravity=gravity, # The following is wrapped as hashless since it's a static argument, and we @@ -470,7 +473,7 @@ def reduce( time_step=model.time_step, terrain=model.terrain, contact_model=model.contact_model, - contacts_params=model.contacts_params, + contact_params=model.contact_params, gravity=model.gravity, integrator=model.integrator, ) @@ -2099,7 +2102,7 @@ def step( # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_L_terrain = js.contact_model.link_contact_forces( + W_f_L_terrain, aux_dict = js.contact_model.link_contact_forces( model=model, data=data, link_forces=W_f_L_external, @@ -2121,7 +2124,58 @@ def step( integrator_fn = _INTEGRATORS_MAP[model.integrator] data_tf = integrator_fn( - model=model, data=data, link_forces=W_f_L_total, joint_torques=τ_total + model=model, + data=data, + link_forces=W_f_L_total, + joint_torques=τ_total, + extended_contact_state=aux_dict, + ) + + if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts): + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + W_p_C = js.contact.collidable_point_positions(model, data_tf)[ + indices_of_enabled_collidable_points + ] + + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + + with data_tf.switch_velocity_representation(VelRepr.Mixed): + J_WC = js.contact.jacobian(model, data_tf)[ + indices_of_enabled_collidable_points + ] + M = js.model.free_floating_mass_matrix(model, data_tf) + BW_ν_pre_impact = data_tf.generalized_velocity + + # Compute the impact velocity. + # It may be discontinuous in case new contacts are made. + BW_ν_post_impact = ( + jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity( + generalized_velocity=BW_ν_pre_impact, + inactive_collidable_points=(δ <= 0), + M=M, + J_WC=J_WC, + ) + ) + + # Reset the generalized velocity. + data_tf = data_tf.replace( + model=model, + base_linear_velocity=BW_ν_post_impact[0:3], + base_angular_velocity=BW_ν_post_impact[3:6], + joint_velocities=BW_ν_post_impact[6:], + ) + + # Restore the input velocity representation + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False ) return data_tf diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index e7c221742..cf0bcb107 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -11,4 +11,4 @@ # Define the default standard gravity constant. -STANDARD_GRAVITY = -9.81 +STANDARD_GRAVITY = 9.81 diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 5e20ada54..19217850c 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -341,7 +341,7 @@ def compute_contact_forces( model=model, position_constraint=position_constraint, velocity_constraint=velocity, - parameters=model.contacts_params, + parameters=model.contact_params, ) # Compute the Delassus matrix and the free mixed linear acceleration of diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index d04a7b895..76bbd1076 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -263,7 +263,7 @@ def compute_contact_forces( # Compute kin-dyn quantities used in the contact model. with data.switch_velocity_representation(VelRepr.Mixed): - BW_ν = data.generalized_velocity() + BW_ν = data.generalized_velocity M = js.model.free_floating_mass_matrix(model=model, data=data) @@ -294,19 +294,13 @@ def compute_contact_forces( ) # Compute the generalized free acceleration. - with ( - references.switch_velocity_representation(VelRepr.Mixed), - data.switch_velocity_representation(VelRepr.Mixed), - ): - + with data.switch_velocity_representation(VelRepr.Mixed): BW_ν̇_free = jnp.hstack( js.ode.system_acceleration( model=model, data=data, link_forces=references.link_forces(model=model, data=data), - joint_force_references=references.joint_force_references( - model=model - ), + joint_torques=references.joint_force_references(model=model), ) ) @@ -325,8 +319,8 @@ def compute_contact_forces( δ=δ, δ_dot=δ_dot, n=n̂, - K=data.contacts_params.K, - D=data.contacts_params.D, + K=model.contact_params.K, + D=model.contact_params.D, ).flatten() # Compute the Delassus matrix. @@ -342,7 +336,7 @@ def compute_contact_forces( # Construct the inequality constraints. G = RigidContacts._compute_ineq_constraint_matrix( - inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu + inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu ) h_bounds = RigidContacts._compute_ineq_bounds( n_collidable_points=n_collidable_points diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index dde16cfb2..86889d2d4 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -11,7 +11,7 @@ import jaxsim.math import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.math import StandardGravity +from jaxsim.math import STANDARD_GRAVITY from jaxsim.terrain import Terrain from . import common @@ -108,7 +108,7 @@ def build_default_from_jaxsim_model( cls: type[Self], model: js.model.JaxSimModel, *, - standard_gravity: jtp.FloatLike = StandardGravity, + standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, @@ -456,7 +456,11 @@ def compute_contact_forces( W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) # Extract the material deformation corresponding to the collidable points. - m = data.state.extended["tangential_deformation"] + m = ( + data.contact_state["tangential_deformation"] + if "tangential_deformation" in data.contact_state + else jnp.zeros_like(W_p_C) + ) m_enabled = m[indices_of_enabled_collidable_points] @@ -470,7 +474,7 @@ def compute_contact_forces( position=p, velocity=v, tangential_deformation=m, - parameters=data.contacts_params, + parameters=model.contact_params, terrain=model.terrain, ) )(W_p_C, W_ṗ_C, m_enabled) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 40ad4ab61..b9569db7c 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -14,7 +14,7 @@ import jaxsim.typing as jtp from jaxsim import logging from jaxsim.api.common import ModelDataWithVelocityRepresentation -from jaxsim.math import StandardGravity +from jaxsim.math import STANDARD_GRAVITY from jaxsim.terrain import Terrain from . import common @@ -90,7 +90,7 @@ def build_default_from_jaxsim_model( cls: type[Self], model: js.model.JaxSimModel, *, - standard_gravity: jtp.FloatLike = StandardGravity, + standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, @@ -375,7 +375,7 @@ def _compute_contact_forces_with_exponential_integration( # ================================== p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data) - m_t0 = data.state.extended["tangential_deformation"][indices, :] + m_t0 = data.contact_state["tangential_deformation"][indices, :] p_t0 = p_t0[indices, :] v_t0 = v_t0[indices, :] @@ -525,7 +525,7 @@ def _contact_points_dynamics( m_t0 = jnp.atleast_2d( m_t0 if m_t0 is not None - else data.state.extended["tangential_deformation"][ + else data.contact_state["tangential_deformation"][ indices_of_enabled_collidable_points, : ] ) @@ -551,7 +551,7 @@ def _contact_points_dynamics( position=p, velocity=v, tangential_deformation=m, - parameters=data.contacts_params, + parameters=model.contact_params, terrain=model.terrain, ) )(p_t0, v_t0, m_t0) @@ -599,12 +599,11 @@ def _contact_points_dynamics( CW_Jl_WC = js.contact.jacobian( model=model, data=data, - output_vel_repr=jaxsim.VelRepr.Mixed, )[indices_of_enabled_collidable_points, 0:3, :] - CW_J̇l_WC = js.contact.jacobian_derivative( - model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed - )[indices_of_enabled_collidable_points, 0:3, :] + CW_J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[ + indices_of_enabled_collidable_points, 0:3, : + ] # Compute the Delassus matrix. ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0] @@ -627,17 +626,12 @@ def _contact_points_dynamics( J̇ = jnp.vstack(CW_J̇l_WC) # Compute the free system acceleration components. - with ( - data.switch_velocity_representation(jaxsim.VelRepr.Mixed), - references.switch_velocity_representation(jaxsim.VelRepr.Mixed), - ): - - BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_force_references=references.joint_force_references(model=model), - ) + BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_torques=references.joint_force_references(model=model), + ) # Pack the free system acceleration in mixed representation. ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free]) @@ -802,8 +796,8 @@ def integrate_data_with_average_contact_forces( dt: jtp.FloatLike, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - average_link_contact_forces_inertial: jtp.MatrixLike | None = None, - average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None, + average_link_contact_forces: jtp.MatrixLike | None = None, + average_of_average_link_contact_forces: jtp.MatrixLike | None = None, ) -> js.data.JaxSimModelData: """ Advance the system state by integrating the dynamics. @@ -816,22 +810,21 @@ def integrate_data_with_average_contact_forces( The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. joint_force_references: The joint force references to apply. - average_link_contact_forces_inertial: - The average contact forces computed with the exponential integrator and - expressed in the inertial-fixed frame. - average_of_average_link_contact_forces_mixed: + average_link_contact_forces: + The average contact forces computed with the exponential integrator. + average_of_average_link_contact_forces: The average of the average contact forces computed with the exponential - integrator and expressed in the mixed frame. + integrator. Returns: The data object storing the system state at the final time. """ - s_t0 = data.joint_positions() - W_p_B_t0 = data.base_position() - W_Q_B_t0 = data.base_orientation(dcm=False) + s_t0 = data.joint_positions + W_p_B_t0 = data.base_position + W_Q_B_t0 = data.base_quaternion - ṡ_t0 = data.joint_velocities() + ṡ_t0 = data.joint_velocities with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): W_ṗ_B_t0 = data.base_velocity()[0:3] W_ω_WB_t0 = data.base_velocity()[3:6] @@ -850,53 +843,39 @@ def integrate_data_with_average_contact_forces( ) W_f̅_L = ( - jnp.array(average_link_contact_forces_inertial) - if average_link_contact_forces_inertial is not None + jnp.array(average_link_contact_forces) + if average_link_contact_forces is not None else jnp.zeros_like(references._link_forces) ).astype(float) LW_f̿_L = ( - jnp.array(average_of_average_link_contact_forces_mixed) - if average_of_average_link_contact_forces_mixed is not None + jnp.array(average_of_average_link_contact_forces) + if average_of_average_link_contact_forces is not None else W_f̅_L ).astype(float) # Compute the system inertial acceleration, used to integrate the system velocity. # It considers the average contact forces computed with the exponential integrator. - with ( - data.switch_velocity_representation(jaxsim.VelRepr.Inertial), - references.switch_velocity_representation(jaxsim.VelRepr.Inertial), - ): - - W_ν̇_pr = jnp.hstack( - js.ode.system_acceleration( - model=model, - data=data, - joint_force_references=references.joint_force_references( - model=model - ), - link_forces=W_f̅_L + references.link_forces(model=model, data=data), - ) + W_ν̇_pr = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_torques=references.joint_force_references(model=model), + link_forces=W_f̅_L + references.link_forces(model=model, data=data), ) + ) # Compute the system mixed acceleration, used to integrate the system position. # It considers the average of the average contact forces computed with the # exponential integrator. - with ( - data.switch_velocity_representation(jaxsim.VelRepr.Mixed), - references.switch_velocity_representation(jaxsim.VelRepr.Mixed), - ): - - BW_ν̇_pr2 = jnp.hstack( - js.ode.system_acceleration( - model=model, - data=data, - joint_force_references=references.joint_force_references( - model=model - ), - link_forces=LW_f̿_L + references.link_forces(model=model, data=data), - ) + BW_ν̇_pr2 = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_torques=references.joint_force_references(model=model), + link_forces=LW_f̿_L + references.link_forces(model=model, data=data), ) + ) # Integrate the system velocity using the inertial-fixed acceleration. W_ν_plus = W_ν_t0 + dt * W_ν̇_pr @@ -917,8 +896,7 @@ def integrate_data_with_average_contact_forces( ) # Create the data at the final time. - data_tf = data.copy() - data_tf = data_tf.reset_joint_positions(q_plus[7:]) + data_tf = data.reset_joint_positions(q_plus[7:]) data_tf = data_tf.reset_base_position(q_plus[0:3]) data_tf = data_tf.reset_base_quaternion(q_plus[3:7]) data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:]) @@ -926,6 +904,8 @@ def integrate_data_with_average_contact_forces( W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial ) + data_tf = data_tf.update_cached(model=model) + return data_tf.replace( velocity_representation=data.velocity_representation, validate=False ) @@ -958,13 +938,7 @@ def step( """ assert isinstance(model.contact_model, ViscoElasticContacts) - assert isinstance(data.contacts_params, ViscoElasticContactsParams) - - # Compute the contact forces in inertial-fixed representation. - # TODO: understand what's wrong in other representations. - data_inertial_fixed = data.replace( - velocity_representation=jaxsim.VelRepr.Inertial, validate=False - ) + assert isinstance(model.contact_params, ViscoElasticContactsParams) # Create the references object. references = js.references.JaxSimModelReferences.build( @@ -981,7 +955,7 @@ def step( # Compute the contact forces with the exponential integrator. W_f̅_C, aux_data = model.contact_model.compute_contact_forces( model=model, - data=data_inertial_fixed, + data=data, dt=jnp.array(dt).astype(float), link_forces=references.link_forces(model=model, data=data), joint_force_references=references.joint_force_references(model=model), @@ -999,17 +973,13 @@ def step( # Get the link contact forces by summing the forces of contact points belonging # to the same link. W_f̅_L, W_f̿_L = jax.vmap( - lambda W_f_C: model.contact_model.link_forces_from_contact_forces( - model=model, data=data_inertial_fixed, contact_forces=W_f_C + lambda W_f_C: js.contact_model.link_forces_from_contact_forces( + model=model, data=data, contact_forces=W_f_C ) )(jnp.stack([W_f̅_C, W_f̿_C])) # Compute the link transforms. - W_H_L = ( - js.model.forward_kinematics(model=model, data=data) - if data.velocity_representation is not jaxsim.VelRepr.Inertial - else jnp.zeros(shape=(model.number_of_links(), 4, 4)) - ) + W_H_L = data.link_transforms # For integration purpose, we need the average of average forces expressed in # mixed representation. @@ -1032,12 +1002,12 @@ def step( data_tf: js.data.JaxSimModelData = ( model.contact_model.integrate_data_with_average_contact_forces( model=model, - data=data_inertial_fixed, + data=data, dt=dt, link_forces=references.link_forces(model=model, data=data), joint_force_references=references.joint_force_references(model=model), - average_link_contact_forces_inertial=W_f̅_L, - average_of_average_link_contact_forces_mixed=LW_f̿_L, + average_link_contact_forces=W_f̅_L, + average_of_average_link_contact_forces=LW_f̿_L, ) ) @@ -1052,15 +1022,10 @@ def step( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) - data_tf.state.extended |= { - "tangential_deformation": data_tf.state.extended["tangential_deformation"] + data_tf.contact_state |= { + "tangential_deformation": data_tf.contact_state["tangential_deformation"] .at[indices_of_enabled_collidable_points] .set(m_tf) } - # Restore the original velocity representation. - data_tf = data_tf.replace( - velocity_representation=data.velocity_representation, validate=False - ) - - return data_tf, {} + return data_tf diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 2ef119862..b2435dbf4 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -132,10 +132,10 @@ def process_inputs( if W_Q_B.shape != (4,): raise ValueError(W_Q_B.shape, (4,)) - # Check that the quaternion does not contain NaNs. + # Check that the quaternion does not contain NaN values. exceptions.raise_value_error_if( condition=jnp.isnan(W_Q_B).any(), - msg="A RBDA received a quaternion that contains NaNs.", + msg="A RBDA received a quaternion that contains NaN values.", ) # Check that the quaternion is unary since our RBDAs make this assumption in order diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 48e78c3b2..a7c7ebd37 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -298,6 +298,9 @@ def test_ad_soft_contacts( model = jaxsim_models_types + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build(model=model) + _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) p = jax.random.uniform(subkey1, shape=(3,), minval=-1) v = jax.random.uniform(subkey2, shape=(3,), minval=-1) @@ -379,10 +382,9 @@ def step( s: jtp.Vector, W_v_WB: jtp.Vector, ṡ: jtp.Vector, - m: jtp.Vector, τ: jtp.Vector, W_f_L: jtp.Matrix, - ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the # quaternion non-unitary, which will cause the AD check to fail. @@ -396,7 +398,6 @@ def step( base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, - extended_state={"tangential_deformation": m}, ) data_xf = js.model.step( @@ -411,9 +412,8 @@ def step( xf_s = data_xf.joint_positions xf_W_v_WB = data_xf.base_velocity xf_ṡ = data_xf.joint_velocities - xf_m = data_xf.extended_state["tangential_deformation"] - return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m + return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ # Check derivatives against finite differences. # We set forward mode only because the backward mode is not supported by the diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 56a65ef0a..2fde89cb4 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -61,7 +61,7 @@ def test_box_with_external_forces( additive=False, ) - # Initialize the integrator. + # Initialize the simulation horizon. tf = 0.5 T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) @@ -179,12 +179,172 @@ def run_simulation( for _ in T_ns: - data = js.model.step( + match model.contact_model: + + case jaxsim.rbda.contacts.ViscoElasticContacts(): + + data = jaxsim.rbda.contacts.visco_elastic.step( + model=model, + data=data, + ) + + case _: + + data = js.model.step( + model=model, + data=data, + ) + return data + + +def test_simulation_with_soft_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + # Define the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() + model.contact_params = js.contact.estimate_good_contact_parameters( model=model, - data=data, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=max_penetration, ) - return data + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_visco_elastic_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + # Define the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build() + model.contact_params = js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=max_penetration, + ) + + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( + solver_options={"solver_tol": 1e-3} + ) + model.contact_params = model.contact_model._parameters_class(K=1e5) + + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) def test_simulation_with_relaxed_rigid_contacts( @@ -198,6 +358,8 @@ def test_simulation_with_relaxed_rigid_contacts( model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( solver_options={"tol": 1e-3}, ) + model.contact_params = model.contact_model._parameters_class() + # Enable a subset of the collidable points. enabled_collidable_points_mask = np.zeros( len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool From 8eef5a2d5f62af15e8374aa18024e8e750008a69 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 30 Jan 2025 21:56:24 +0100 Subject: [PATCH 15/36] Remove viscoelastic contact model --- src/jaxsim/api/contact.py | 15 - src/jaxsim/api/data.py | 6 +- src/jaxsim/rbda/contacts/__init__.py | 8 +- src/jaxsim/rbda/contacts/visco_elastic.py | 1031 --------------------- tests/test_simulations.py | 67 +- 5 files changed, 7 insertions(+), 1120 deletions(-) delete mode 100644 src/jaxsim/rbda/contacts/visco_elastic.py diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 2e3a39bb0..b8ae78585 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -238,21 +238,6 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: **kwargs, ) - case contacts.ViscoElasticContacts(): - assert isinstance(model.contact_model, contacts.ViscoElasticContacts) - - parameters = ( - contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model( - model=model, - standard_gravity=standard_gravity, - static_friction_coefficient=static_friction_coefficient, - max_penetration=max_δ, - number_of_active_collidable_points_steady_state=nc, - damping_ratio=damping_ratio, - **kwargs, - ) - ) - case contacts.RigidContacts(): assert isinstance(model.contact_model, contacts.RigidContacts) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index cddeaf665..0b6f43d88 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -182,11 +182,7 @@ def build( model.kin_dyn_parameters.contact_parameters.point ) } - if isinstance( - model.contact_model, - jaxsim.rbda.contacts.SoftContacts - | jaxsim.rbda.contacts.ViscoElasticContacts, - ) + if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts) else contact_state or {} ) diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 06646f14d..32f05e229 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,13 +1,9 @@ -from . import relaxed_rigid, rigid, soft, visco_elastic +from . import relaxed_rigid, rigid, soft from .common import ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams from .rigid import RigidContacts, RigidContactsParams from .soft import SoftContacts, SoftContactsParams -from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams ContactParamsTypes = ( - SoftContactsParams - | RigidContactsParams - | RelaxedRigidContactsParams - | ViscoElasticContactsParams + SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py deleted file mode 100644 index b9569db7c..000000000 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ /dev/null @@ -1,1031 +0,0 @@ -from __future__ import annotations - -import dataclasses -import functools -from typing import Any - -import jax -import jax.numpy as jnp -import jax_dataclasses - -import jaxsim -import jaxsim.api as js -import jaxsim.exceptions -import jaxsim.typing as jtp -from jaxsim import logging -from jaxsim.api.common import ModelDataWithVelocityRepresentation -from jaxsim.math import STANDARD_GRAVITY -from jaxsim.terrain import Terrain - -from . import common -from .soft import SoftContacts, SoftContactsParams - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - - -@jax_dataclasses.pytree_dataclass -class ViscoElasticContactsParams(common.ContactsParams): - """Parameters of the visco-elastic contacts model.""" - - K: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(1e6, dtype=float) - ) - - D: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(2000, dtype=float) - ) - - static_friction: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.5, dtype=float) - ) - - p: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.5, dtype=float) - ) - - q: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.5, dtype=float) - ) - - @classmethod - def build( - cls: type[Self], - K: jtp.FloatLike = 1e6, - D: jtp.FloatLike = 2_000, - static_friction: jtp.FloatLike = 0.5, - p: jtp.FloatLike = 0.5, - q: jtp.FloatLike = 0.5, - ) -> Self: - """ - Create a SoftContactsParams instance with specified parameters. - - Args: - K: The stiffness parameter. - D: The damping parameter of the soft contacts model. - static_friction: The static friction coefficient. - p: - The exponent p corresponding to the damping-related non-linearity - of the Hunt/Crossley model. - q: - The exponent q corresponding to the spring-related non-linearity - of the Hunt/Crossley model. - - Returns: - A ViscoElasticParams instance with the specified parameters. - """ - - return ViscoElasticContactsParams( - K=jnp.array(K, dtype=float), - D=jnp.array(D, dtype=float), - static_friction=jnp.array(static_friction, dtype=float), - p=jnp.array(p, dtype=float), - q=jnp.array(q, dtype=float), - ) - - @classmethod - def build_default_from_jaxsim_model( - cls: type[Self], - model: js.model.JaxSimModel, - *, - standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, - static_friction_coefficient: jtp.FloatLike = 0.5, - max_penetration: jtp.FloatLike = 0.001, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, - damping_ratio: jtp.FloatLike = 1.0, - p: jtp.FloatLike = 0.5, - q: jtp.FloatLike = 0.5, - ) -> Self: - """ - Create a ViscoElasticContactsParams instance with good default parameters. - - Args: - model: The target model. - standard_gravity: The standard gravity constant. - static_friction_coefficient: - The static friction coefficient between the model and the terrain. - max_penetration: The maximum penetration depth. - number_of_active_collidable_points_steady_state: - The number of contacts supporting the weight of the model - in steady state. - damping_ratio: The ratio controlling the damping behavior. - p: - The exponent p corresponding to the damping-related non-linearity - of the Hunt/Crossley model. - q: - The exponent q corresponding to the spring-related non-linearity - of the Hunt/Crossley model. - - Returns: - A `ViscoElasticContactsParams` instance with the specified parameters. - - Note: - The `damping_ratio` parameter allows to operate on the following conditions: - - ξ > 1.0: over-damped - - ξ = 1.0: critically damped - - ξ < 1.0: under-damped - """ - - # Call the SoftContact builder instead of duplicating the logic. - soft_contacts_params = SoftContactsParams.build_default_from_jaxsim_model( - model=model, - standard_gravity=standard_gravity, - static_friction_coefficient=static_friction_coefficient, - max_penetration=max_penetration, - number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state, - damping_ratio=damping_ratio, - ) - - return ViscoElasticContactsParams.build( - K=soft_contacts_params.K, - D=soft_contacts_params.D, - static_friction=soft_contacts_params.mu, - p=p, - q=q, - ) - - def valid(self) -> jtp.BoolLike: - """ - Check if the parameters are valid. - - Returns: - `True` if the parameters are valid, `False` otherwise. - """ - - return ( - jnp.all(self.K >= 0.0) - and jnp.all(self.D >= 0.0) - and jnp.all(self.static_friction >= 0.0) - and jnp.all(self.p >= 0.0) - and jnp.all(self.q >= 0.0) - ) - - def __hash__(self) -> int: - - from jaxsim.utils.wrappers import HashedNumpyArray - - return hash( - ( - HashedNumpyArray.hash_of_array(self.K), - HashedNumpyArray.hash_of_array(self.D), - HashedNumpyArray.hash_of_array(self.static_friction), - HashedNumpyArray.hash_of_array(self.p), - HashedNumpyArray.hash_of_array(self.q), - ) - ) - - def __eq__(self, other: ViscoElasticContactsParams) -> bool: - - if not isinstance(other, ViscoElasticContactsParams): - return False - - return hash(self) == hash(other) - - -@jax_dataclasses.pytree_dataclass -class ViscoElasticContacts(common.ContactModel): - """Visco-elastic contacts model.""" - - max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25) - - @classmethod - def build( - cls: type[Self], - model: js.model.JaxSimModel | None = None, - max_squarings: jtp.IntLike | None = None, - **kwargs, - ) -> Self: - """ - Create a `ViscoElasticContacts` instance with specified parameters. - - Args: - model: - The robot model considered by the contact model. - If passed, it is used to estimate good default parameters. - max_squarings: - The maximum number of squarings performed in the matrix exponential. - **kwargs: Extra arguments to ignore. - - Returns: - The `ViscoElasticContacts` instance. - """ - - if len(kwargs) != 0: - logging.debug(msg=f"Ignoring extra arguments: {kwargs}") - - return cls( - max_squarings=int( - max_squarings - if max_squarings is not None - else cls.__dataclass_fields__["max_squarings"].default - ), - ) - - @classmethod - def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: - """ - Build zero state variables of the contact model. - """ - - # Initialize the material deformation to zero. - tangential_deformation = jnp.zeros( - shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), - dtype=float, - ) - - return {"tangential_deformation": tangential_deformation} - - @jax.jit - def compute_contact_forces( - self, - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - dt: jtp.FloatLike | None = None, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: - """ - Compute the contact forces. - - Args: - model: The robot model considered by the contact model. - data: The data of the considered model. - dt: The time step to consider. If not specified, it is read from the model. - link_forces: - The 6D forces to apply to the links expressed in the frame corresponding - to the velocity representation of `data`. - joint_force_references: The joint force references to apply. - - Note: - This contact model, contrarily to most other contact models, requires the - knowledge of the integration step. It is not straightforward to assess how - this contact model behaves when used with high-order Runge-Kutta schemes. - For the time being, it is recommended to use a simple forward Euler scheme. - The main benefit of this model is that the stiff contact dynamics is computed - separately from the rest of the system dynamics, which allows to use simple - integration schemes without altering significantly the simulation stability. - - Returns: - A tuple containing as first element the computed 6D contact force applied to - the contact point and expressed in the world frame, and as second element - a dictionary of optional additional information. - """ - - # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - # Initialize the time step. - dt = dt if dt is not None else model.time_step - - # Compute the average contact linear forces in mixed representation by - # integrating the contact dynamics in the continuous time domain. - CW_f̅l, CW_fl̿, m_tf = ( - ViscoElasticContacts._compute_contact_forces_with_exponential_integration( - model=model, - data=data, - dt=jnp.array(dt).astype(float), - link_forces=link_forces, - joint_force_references=joint_force_references, - indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, - max_squarings=self.max_squarings, - ) - ) - - # ============================================ - # Compute the inertial-fixed 6D contact forces - # ============================================ - - # Compute the transforms of the mixed frames `C[W] = (W_p_C, [W])` - # associated to each collidable point. - W_H_C = js.contact.transforms(model=model, data=data)[ - indices_of_enabled_collidable_points, :, : - ] - - # Vmapped transformation from mixed to inertial-fixed representation. - compute_forces_inertial_fixed_vmap = jax.vmap( - lambda CW_fl_C, W_H_C: ( - ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=jnp.zeros(6).at[0:3].set(CW_fl_C), - other_representation=jaxsim.VelRepr.Mixed, - transform=W_H_C, - is_force=True, - ) - ) - ) - - # Express the linear contact forces in the inertial-fixed frame. - W_f̅_C, W_f̿_C = jax.vmap( - lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C) - )(jnp.stack([CW_f̅l, CW_fl̿])) - - return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf) - - @staticmethod - @functools.partial(jax.jit, static_argnames=("max_squarings",)) - def _compute_contact_forces_with_exponential_integration( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - dt: jtp.FloatLike, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - indices_of_enabled_collidable_points: jtp.VectorLike | None = None, - max_squarings: int = 25, - ) -> tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]: - """ - Compute the average contact forces by integrating the contact dynamics. - - Args: - model: The robot model considered by the contact model. - data: The data of the considered model. - dt: The integration time step. - link_forces: The 6D forces to apply to the links. - joint_force_references: The joint force references to apply. - indices_of_enabled_collidable_points: - The indices of the enabled collidable points. - max_squarings: - The maximum number of squarings performed in the matrix exponential. - - Returns: - A tuple containing: - - The average contact forces. - - The average of the average contact forces. - - The tangential deformation at the final state. - """ - - # ========================== - # Populate missing arguments - # ========================== - - indices = ( - indices_of_enabled_collidable_points - if indices_of_enabled_collidable_points is not None - else jnp.arange( - len(model.kin_dyn_parameters.contact_parameters.body) - ).astype(int) - ) - - # ================================== - # Compute the contact point dynamics - # ================================== - - p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data) - m_t0 = data.contact_state["tangential_deformation"][indices, :] - - p_t0 = p_t0[indices, :] - v_t0 = v_t0[indices, :] - - # Compute the linearized contact dynamics. - # Note that it linearizes the (non-linear) contact model at (p, v, m)[t0]. - A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics( - model=model, - data=data, - link_forces=link_forces, - joint_force_references=joint_force_references, - indices_of_enabled_collidable_points=indices, - p_t0=p_t0, - v_t0=v_t0, - m_t0=m_t0, - ) - - # ============================================= - # Compute the integrals of the contact dynamics - # ============================================= - - # Pack the initial state of the contact points. - x_t0 = jnp.hstack([p_t0.flatten(), v_t0.flatten(), m_t0.flatten()]) - - # Pack the augmented matrix used to compute the single and double integral - # of the exponential integration. - A̅ = jnp.vstack( - [ - jnp.hstack( - [ - A, - jnp.vstack(b), - jnp.vstack(x_t0), - jnp.vstack(jnp.zeros_like(x_t0)), - ] - ), - jnp.hstack([jnp.zeros(A.shape[1]), 0, 1, 0]), - jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 1]), - jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 0]), - ] - ) - - # Compute the matrix exponential. - exp_tA = jax.scipy.linalg.expm( - (dt * A̅).astype(float), max_squarings=max_squarings - ) - - # Integrate the contact dynamics in the continuous time domain. - x_int, x_int2 = ( - jnp.hstack([jnp.eye(A.shape[0]), jnp.zeros(shape=(A.shape[0], 3))]) - @ exp_tA - @ jnp.vstack([jnp.zeros(shape=(A.shape[0] + 1, 2)), jnp.eye(2)]) - ).T - - jaxsim.exceptions.raise_runtime_error_if( - condition=jnp.isnan(x_int).any(), - msg="NaN integration, try to increase `max_squarings` or decreasing `dt`", - ) - - # ========================== - # Compute the contact forces - # ========================== - - # Compute the average contact forces. - CW_f̅, _ = jnp.split( - (A_sc @ x_int / dt + b_sc).reshape(-1, 3), - indices_or_sections=2, - ) - - # Compute the average of the average contact forces. - CW_f̿, _ = jnp.split( - (A_sc @ x_int2 * 2 / (dt**2) + b_sc).reshape(-1, 3), - indices_or_sections=2, - ) - - # Extract the tangential deformation at the final state. - x_tf = x_int / dt - m_tf = jnp.split(x_tf, 3)[2].reshape(-1, 3) - - return CW_f̅, CW_f̿, m_tf - - @staticmethod - @jax.jit - def _contact_points_dynamics( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - indices_of_enabled_collidable_points: jtp.VectorLike | None = None, - p_t0: jtp.MatrixLike | None = None, - v_t0: jtp.MatrixLike | None = None, - m_t0: jtp.MatrixLike | None = None, - ) -> tuple[jtp.Matrix, jtp.Vector, jtp.Matrix, jtp.Vector]: - """ - Compute the dynamics of the contact points. - - Note: - This function projects the system dynamics to the contact space and - returns the matrices of a linear system to simulate its evolution. - Since the active contact model can be non-linear, this function also - linearizes the contact model at the initial state. - - Args: - model: The robot model considered by the contact model. - data: The data of the considered model. - link_forces: The 6D forces to apply to the links. - joint_force_references: The joint force references to apply. - indices_of_enabled_collidable_points: - The indices of the enabled collidable points. - p_t0: The initial position of the collidable points. - v_t0: The initial velocity of the collidable points. - m_t0: The initial tangential deformation of the collidable points. - - Returns: - A tuple containing: - - The `A` matrix of the linear system that models the contact dynamics. - - The `b` vector of the linear system that models the contact dynamics. - - The `A_sc` matrix of the linear system that approximates the contact model. - - The `b_sc` vector of the linear system that approximates the contact model. - """ - - indices_of_enabled_collidable_points = ( - indices_of_enabled_collidable_points - if indices_of_enabled_collidable_points is not None - else jnp.arange( - len(model.kin_dyn_parameters.contact_parameters.body) - ).astype(int) - ) - - p_t0 = jnp.atleast_2d( - p_t0 - if p_t0 is not None - else js.contact.collidable_point_positions(model=model, data=data)[ - indices_of_enabled_collidable_points, : - ] - ) - - v_t0 = jnp.atleast_2d( - v_t0 - if v_t0 is not None - else js.contact.collidable_point_velocities(model=model, data=data)[ - indices_of_enabled_collidable_points, : - ] - ) - - m_t0 = jnp.atleast_2d( - m_t0 - if m_t0 is not None - else data.contact_state["tangential_deformation"][ - indices_of_enabled_collidable_points, : - ] - ) - - # We expect that the 6D forces of the `link_forces` argument are expressed - # in the frame corresponding to the velocity representation of `data`. - references = js.references.JaxSimModelReferences.build( - model=model, - link_forces=link_forces, - joint_force_references=joint_force_references, - data=data, - velocity_representation=data.velocity_representation, - ) - - # =========================== - # Linearize the contact model - # =========================== - - # Linearize the contact model at the initial state of all considered - # contact points. - A_sc_points, b_sc_points = jax.vmap( - lambda p, v, m: ViscoElasticContacts._linearize_contact_model( - position=p, - velocity=v, - tangential_deformation=m, - parameters=model.contact_params, - terrain=model.terrain, - ) - )(p_t0, v_t0, m_t0) - - # Since x = [p1, p2, ..., v1, v2, ..., m1, m2, ...], we need to split the A_sc of - # individual points since otherwise we'd get x = [ p1, v1, m1, p2, v2, m2, ...]. - A_sc_p, A_sc_v, A_sc_m = jnp.split(A_sc_points, indices_or_sections=3, axis=-1) - - # We want to have in output first the forces and then the material deformation rates. - # Therefore, we need to extract the components is A_sc_* separately. - A_sc = jnp.vstack( - [ - jnp.hstack( - [ - jax.scipy.linalg.block_diag(*A_sc_p[:, 0:3, :]), - jax.scipy.linalg.block_diag(*A_sc_v[:, 0:3, :]), - jax.scipy.linalg.block_diag(*A_sc_m[:, 0:3, :]), - ], - ), - jnp.hstack( - [ - jax.scipy.linalg.block_diag(*A_sc_p[:, 3:6, :]), - jax.scipy.linalg.block_diag(*A_sc_v[:, 3:6, :]), - jax.scipy.linalg.block_diag(*A_sc_m[:, 3:6, :]), - ] - ), - ] - ) - - # We need to do the same for the b_sc. - b_sc = jnp.hstack( - [b_sc_points[:, 0:3].flatten(), b_sc_points[:, 3:6].flatten()] - ) - - # =========================================================== - # Compute the A and b matrices of the contact points dynamics - # =========================================================== - - with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): - - BW_ν = data.generalized_velocity() - - M = js.model.free_floating_mass_matrix(model=model, data=data) - - CW_Jl_WC = js.contact.jacobian( - model=model, - data=data, - )[indices_of_enabled_collidable_points, 0:3, :] - - CW_J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[ - indices_of_enabled_collidable_points, 0:3, : - ] - - # Compute the Delassus matrix. - ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0] - - I_nc = jnp.eye(v_t0.flatten().size) - O_nc = jnp.zeros(shape=(p_t0.flatten().size, p_t0.flatten().size)) - - # Pack the A matrix. - A = jnp.vstack( - [ - jnp.hstack([O_nc, I_nc, O_nc]), - ψ @ jnp.split(A_sc, 2, axis=0)[0], - jnp.split(A_sc, 2, axis=0)[1], - ] - ) - - # Short names for few variables. - ν = BW_ν - J = jnp.vstack(CW_Jl_WC) - J̇ = jnp.vstack(CW_J̇l_WC) - - # Compute the free system acceleration components. - BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_torques=references.joint_force_references(model=model), - ) - - # Pack the free system acceleration in mixed representation. - ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free]) - - # Compute the acceleration of collidable points. - # This is the true derivative of ṗ only in mixed representation. - p̈ = J @ ν̇_free + J̇ @ ν - - # Pack the b array. - b = jnp.hstack( - [ - jnp.zeros_like(p_t0.flatten()), - p̈ + ψ @ jnp.split(b_sc, indices_or_sections=2)[0], - jnp.split(b_sc, indices_or_sections=2)[1], - ] - ) - - return A, b, A_sc, b_sc - - @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) - def _linearize_contact_model( - position: jtp.VectorLike, - velocity: jtp.VectorLike, - tangential_deformation: jtp.VectorLike, - parameters: ViscoElasticContactsParams, - terrain: Terrain, - ) -> tuple[jtp.Matrix, jtp.Vector]: - """ - Linearize the Hunt/Crossley contact model at the initial state. - - Args: - position: The position of the contact point. - velocity: The velocity of the contact point. - tangential_deformation: The tangential deformation of the contact point. - parameters: The parameters of the contact model. - terrain: The considered terrain. - - Returns: - A tuple containing the `A` matrix and the `b` vector of the linear system - corresponding to the contact dynamics linearized at the initial state. - """ - - # Initialize the state at which the model is linearized. - p0 = jnp.array(position, dtype=float).squeeze() - v0 = jnp.array(velocity, dtype=float).squeeze() - m0 = jnp.array(tangential_deformation, dtype=float).squeeze() - - # ============ - # Compute A_sc - # ============ - - compute_contact_force_non_linear_model = functools.partial( - ViscoElasticContacts._compute_contact_force_non_linear_model, - parameters=parameters, - terrain=terrain, - ) - - # Compute with AD the functions to get the Jacobians of CW_fl. - df_dp_fun, df_dv_fun, df_dm_fun = ( - jax.jacrev( - lambda p0, v0, m0: compute_contact_force_non_linear_model( - position=p0, velocity=v0, tangential_deformation=m0 - )[0], - argnums=num, - ) - for num in (0, 1, 2) - ) - - # Compute with AD the functions to get the Jacobians of ṁ. - dṁ_dp_fun, dṁ_dv_fun, dṁ_dm_fun = ( - jax.jacrev( - lambda p0, v0, m0: compute_contact_force_non_linear_model( - position=p0, velocity=v0, tangential_deformation=m0 - )[1], - argnums=num, - ) - for num in (0, 1, 2) - ) - - # Compute the Jacobians of the contact forces w.r.t. the state. - df_dp = jnp.vstack(df_dp_fun(p0, v0, m0)) - df_dv = jnp.vstack(df_dv_fun(p0, v0, m0)) - df_dm = jnp.vstack(df_dm_fun(p0, v0, m0)) - - # Compute the Jacobians of the material deformation rate w.r.t. the state. - dṁ_dp = jnp.vstack(dṁ_dp_fun(p0, v0, m0)) - dṁ_dv = jnp.vstack(dṁ_dv_fun(p0, v0, m0)) - dṁ_dm = jnp.vstack(dṁ_dm_fun(p0, v0, m0)) - - # Pack the A matrix. - A_sc = jnp.vstack( - [ - jnp.hstack([df_dp, df_dv, df_dm]), - jnp.hstack([dṁ_dp, dṁ_dv, dṁ_dm]), - ] - ) - - # ============ - # Compute b_sc - # ============ - - # Compute the output of the non-linear model at the initial state. - x0 = jnp.hstack([p0, v0, m0]) - f0, ṁ0 = compute_contact_force_non_linear_model( - position=p0, velocity=v0, tangential_deformation=m0 - ) - - # Pack the b vector. - b_sc = jnp.hstack([f0, ṁ0]) - A_sc @ x0 - - return A_sc, b_sc - - @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) - def _compute_contact_force_non_linear_model( - position: jtp.VectorLike, - velocity: jtp.VectorLike, - tangential_deformation: jtp.VectorLike, - parameters: ViscoElasticContactsParams, - terrain: Terrain, - ) -> tuple[jtp.Vector, jtp.Vector]: - """ - Compute the contact forces using the non-linear Hunt/Crossley model. - - Args: - position: The position of the contact point. - velocity: The velocity of the contact point. - tangential_deformation: The tangential deformation of the contact point. - parameters: The parameters of the contact model. - terrain: The considered terrain. - - Returns: - A tuple containing: - - The linear contact force in the mixed contact frame. - - The rate of material deformation. - """ - - # Compute the linear contact force in mixed representation using - # the non-linear Hunt/Crossley model. - # The following function also returns the rate of material deformation. - CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( - position=position, - velocity=velocity, - tangential_deformation=tangential_deformation, - terrain=terrain, - K=parameters.K, - D=parameters.D, - mu=parameters.static_friction, - p=parameters.p, - q=parameters.q, - ) - - return CW_fl, ṁ - - @staticmethod - @jax.jit - def integrate_data_with_average_contact_forces( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - dt: jtp.FloatLike, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - average_link_contact_forces: jtp.MatrixLike | None = None, - average_of_average_link_contact_forces: jtp.MatrixLike | None = None, - ) -> js.data.JaxSimModelData: - """ - Advance the system state by integrating the dynamics. - - Args: - model: The model to consider. - data: The data of the considered model. - dt: The integration time step. - link_forces: - The 6D forces to apply to the links expressed in the frame corresponding - to the velocity representation of `data`. - joint_force_references: The joint force references to apply. - average_link_contact_forces: - The average contact forces computed with the exponential integrator. - average_of_average_link_contact_forces: - The average of the average contact forces computed with the exponential - integrator. - - Returns: - The data object storing the system state at the final time. - """ - - s_t0 = data.joint_positions - W_p_B_t0 = data.base_position - W_Q_B_t0 = data.base_quaternion - - ṡ_t0 = data.joint_velocities - with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): - W_ṗ_B_t0 = data.base_velocity()[0:3] - W_ω_WB_t0 = data.base_velocity()[3:6] - - with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - W_ν_t0 = data.generalized_velocity() - - # We expect that the 6D forces of the `link_forces` argument are expressed - # in the frame corresponding to the velocity representation of `data`. - references = js.references.JaxSimModelReferences.build( - model=model, - link_forces=link_forces, - joint_force_references=joint_force_references, - data=data, - velocity_representation=data.velocity_representation, - ) - - W_f̅_L = ( - jnp.array(average_link_contact_forces) - if average_link_contact_forces is not None - else jnp.zeros_like(references._link_forces) - ).astype(float) - - LW_f̿_L = ( - jnp.array(average_of_average_link_contact_forces) - if average_of_average_link_contact_forces is not None - else W_f̅_L - ).astype(float) - - # Compute the system inertial acceleration, used to integrate the system velocity. - # It considers the average contact forces computed with the exponential integrator. - W_ν̇_pr = jnp.hstack( - js.ode.system_acceleration( - model=model, - data=data, - joint_torques=references.joint_force_references(model=model), - link_forces=W_f̅_L + references.link_forces(model=model, data=data), - ) - ) - - # Compute the system mixed acceleration, used to integrate the system position. - # It considers the average of the average contact forces computed with the - # exponential integrator. - BW_ν̇_pr2 = jnp.hstack( - js.ode.system_acceleration( - model=model, - data=data, - joint_torques=references.joint_force_references(model=model), - link_forces=LW_f̿_L + references.link_forces(model=model, data=data), - ) - ) - - # Integrate the system velocity using the inertial-fixed acceleration. - W_ν_plus = W_ν_t0 + dt * W_ν̇_pr - - # Integrate the system position using the mixed velocity. - q_plus = jnp.hstack( - [ - # Note: here both ṗ and p̈ -> need mixed representation. - W_p_B_t0 + dt * W_ṗ_B_t0 + 0.5 * dt**2 * BW_ν̇_pr2[0:3], - jaxsim.math.Quaternion.integration( - dt=dt, - quaternion=W_Q_B_t0, - omega=(W_ω_WB_t0 + 0.5 * dt * BW_ν̇_pr2[3:6]), - omega_in_body_fixed=False, - ).squeeze(), - s_t0 + dt * ṡ_t0 + 0.5 * dt**2 * BW_ν̇_pr2[6:], - ] - ) - - # Create the data at the final time. - data_tf = data.reset_joint_positions(q_plus[7:]) - data_tf = data_tf.reset_base_position(q_plus[0:3]) - data_tf = data_tf.reset_base_quaternion(q_plus[3:7]) - data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:]) - data_tf = data_tf.reset_base_velocity( - W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial - ) - - data_tf = data_tf.update_cached(model=model) - - return data_tf.replace( - velocity_representation=data.velocity_representation, validate=False - ) - - -@jax.jit -def step( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - dt: jtp.FloatLike | None = None, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, -) -> tuple[js.data.JaxSimModelData, dict[str, Any]]: - """ - Step the system dynamics with the visco-elastic contact model. - - Args: - model: The model to consider. - data: The data of the considered model. - dt: The time step to consider. If not specified, it is read from the model. - link_forces: - The 6D forces to apply to the links expressed in the frame corresponding to - the velocity representation of `data`. - joint_force_references: The joint force references to consider. - - Returns: - A tuple containing the new data of the model - and an empty dictionary of auxiliary data. - """ - - assert isinstance(model.contact_model, ViscoElasticContacts) - assert isinstance(model.contact_params, ViscoElasticContactsParams) - - # Create the references object. - references = js.references.JaxSimModelReferences.build( - model=model, - data=data, - link_forces=link_forces, - joint_force_references=joint_force_references, - velocity_representation=data.velocity_representation, - ) - - # Initialize the time step. - dt = dt if dt is not None else model.time_step - - # Compute the contact forces with the exponential integrator. - W_f̅_C, aux_data = model.contact_model.compute_contact_forces( - model=model, - data=data, - dt=jnp.array(dt).astype(float), - link_forces=references.link_forces(model=model, data=data), - joint_force_references=references.joint_force_references(model=model), - ) - - # Extract the final material deformation and the average of average forces - # from the dictionary containing auxiliary data. - m_tf = aux_data["m_tf"] - W_f̿_C = aux_data["W_f_avg2_C"] - - # =============================== - # Compute the link contact forces - # =============================== - - # Get the link contact forces by summing the forces of contact points belonging - # to the same link. - W_f̅_L, W_f̿_L = jax.vmap( - lambda W_f_C: js.contact_model.link_forces_from_contact_forces( - model=model, data=data, contact_forces=W_f_C - ) - )(jnp.stack([W_f̅_C, W_f̿_C])) - - # Compute the link transforms. - W_H_L = data.link_transforms - - # For integration purpose, we need the average of average forces expressed in - # mixed representation. - LW_f̿_L = jax.vmap( - lambda W_f_L, W_H_L: ( - ModelDataWithVelocityRepresentation.inertial_to_other_representation( - array=W_f_L, - other_representation=jaxsim.VelRepr.Mixed, - transform=W_H_L, - is_force=True, - ) - ) - )(W_f̿_L, W_H_L) - - # ========================== - # Integrate the system state - # ========================== - - # Integrate the system dynamics using the average contact forces. - data_tf: js.data.JaxSimModelData = ( - model.contact_model.integrate_data_with_average_contact_forces( - model=model, - data=data, - dt=dt, - link_forces=references.link_forces(model=model, data=data), - joint_force_references=references.joint_force_references(model=model), - average_link_contact_forces=W_f̅_L, - average_of_average_link_contact_forces=LW_f̿_L, - ) - ) - - # Store the tangential deformation at the final state. - # Note that this was integrated in the continuous time domain, therefore it should - # be much more accurate than the one computed with the discrete soft contacts. - with data_tf.mutable_context(): - - # Extract the indices corresponding to the enabled collidable points. - # The visco-elastic contact model computed only their contact forces. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - data_tf.contact_state |= { - "tangential_deformation": data_tf.contact_state["tangential_deformation"] - .at[indices_of_enabled_collidable_points] - .set(m_tf) - } - - return data_tf diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 2fde89cb4..456b4003f 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -179,21 +179,11 @@ def run_simulation( for _ in T_ns: - match model.contact_model: - - case jaxsim.rbda.contacts.ViscoElasticContacts(): - - data = jaxsim.rbda.contacts.visco_elastic.step( - model=model, - data=data, - ) - - case _: + data = js.model.step( + model=model, + data=data, + ) - data = js.model.step( - model=model, - data=data, - ) return data @@ -248,55 +238,6 @@ def test_simulation_with_soft_contacts( assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) -def test_simulation_with_visco_elastic_contacts( - jaxsim_model_box: js.model.JaxSimModel, -): - - model = jaxsim_model_box - - # Define the maximum penetration of each collidable point at steady state. - max_penetration = 0.001 - - with model.editable(validate=False) as model: - - model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build() - model.contact_params = js.contact.estimate_good_contact_parameters( - model=model, - number_of_active_collidable_points_steady_state=4, - static_friction_coefficient=1.0, - damping_ratio=1.0, - max_penetration=max_penetration, - ) - - # Enable a subset of the collidable points. - enabled_collidable_points_mask = np.zeros( - len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool - ) - enabled_collidable_points_mask[[0, 1, 2, 3]] = True - model.kin_dyn_parameters.contact_parameters.enabled = tuple( - enabled_collidable_points_mask.tolist() - ) - - # Check jaxsim_model_box@conftest.py. - box_height = 0.1 - - # Build the data of the model. - data_t0 = js.data.JaxSimModelData.build( - model=model, - base_position=jnp.array([0.0, 0.0, box_height * 2]), - velocity_representation=VelRepr.Inertial, - ) - - # =========================================== - # Run the simulation and test the final state - # =========================================== - - data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) - - assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) - assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) - - def test_simulation_with_rigid_contacts( jaxsim_model_box: js.model.JaxSimModel, ): From 65ac30c7967fd24ef32373b436d5604510731c49 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 12:56:32 +0100 Subject: [PATCH 16/36] Convert base post impact velocity to inertial before setting --- src/jaxsim/api/model.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 4648382dc..844f42cc9 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2165,17 +2165,20 @@ def step( ) ) - # Reset the generalized velocity. - data_tf = data_tf.replace( - model=model, - base_linear_velocity=BW_ν_post_impact[0:3], - base_angular_velocity=BW_ν_post_impact[3:6], - joint_velocities=BW_ν_post_impact[6:], + BW_ν_post_impact_inertial = data_tf.other_representation_to_inertial( + array=BW_ν_post_impact[0:6], + other_representation=VelRepr.Mixed, + transform=data_tf._base_transform.at[0:3, 0:3].set(jnp.eye(3)), + is_force=False, ) - # Restore the input velocity representation - data_tf = data_tf.replace( - velocity_representation=data.velocity_representation, validate=False - ) + # Reset the generalized velocity. + data_tf = dataclasses.replace( + data_tf, + velocity_representation=data.velocity_representation, + _base_linear_velocity=BW_ν_post_impact_inertial[0:3], + _base_angular_velocity=BW_ν_post_impact_inertial[3:6], + _joint_velocities=BW_ν_post_impact[6:], + ) return data_tf From 1d0b47cad79e666b1743d3515b2b1f67c2e2f384 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 14:24:59 +0100 Subject: [PATCH 17/36] Use polymorphism for estimating good contact parameters --- src/jaxsim/api/contact.py | 58 +++------------------- src/jaxsim/rbda/contacts/common.py | 78 ++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 50 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index b8ae78585..371e23dcb 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -11,7 +11,6 @@ import jaxsim.typing as jtp from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform -from jaxsim.rbda import contacts from .common import VelRepr @@ -166,7 +165,6 @@ def estimate_good_contact_parameters( number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, - **kwargs, ) -> jaxsim.rbda.contacts.ContactParamsTypes: """ Estimate good contact parameters. @@ -179,9 +177,6 @@ def estimate_good_contact_parameters( The number of active collidable points in steady state. damping_ratio: The damping ratio. max_penetration: The maximum penetration allowed. - kwargs: - Additional model-specific parameters passed to the builder method of - the parameters class. Returns: The estimated good contacts parameters. @@ -223,51 +218,14 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: nc = number_of_active_collidable_points_steady_state - match model.contact_model: - - case contacts.SoftContacts(): - assert isinstance(model.contact_model, contacts.SoftContacts) - - parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model( - model=model, - standard_gravity=standard_gravity, - static_friction_coefficient=static_friction_coefficient, - max_penetration=max_δ, - number_of_active_collidable_points_steady_state=nc, - damping_ratio=damping_ratio, - **kwargs, - ) - - case contacts.RigidContacts(): - assert isinstance(model.contact_model, contacts.RigidContacts) - - # Disable Baumgarte stabilization by default since it does not play - # well with the forward Euler integrator. - K = kwargs.get("K", 0.0) - - parameters = contacts.RigidContactsParams.build( - mu=static_friction_coefficient, - **( - dict( - K=K, - D=2 * jnp.sqrt(K), - ) - | kwargs - ), - ) - - case contacts.RelaxedRigidContacts(): - assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) - - parameters = contacts.RelaxedRigidContactsParams.build( - mu=static_friction_coefficient, - **kwargs, - ) - - case _: - raise ValueError(f"Invalid contact model: {model.contact_model}") - - return parameters + return model.contact_model._parameters_class().build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + ) @jax.jit diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index e7539fc66..159fdb254 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -9,6 +9,7 @@ import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp +from jaxsim.math import STANDARD_GRAVITY from jaxsim.utils import JaxsimDataclass try: @@ -80,6 +81,83 @@ def build(cls: type[Self], **kwargs) -> Self: """ pass + def build_default_from_jaxsim_model( + self: type[Self], + model: js.model.JaxSimModel, + *, + stiffness: jtp.FloatLike | None = None, + damping: jtp.FloatLike | None = None, + standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, + static_friction_coefficient: jtp.FloatLike = 0.5, + max_penetration: jtp.FloatLike = 0.001, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a `ContactsParams` instance with default parameters. + + Args: + model: The robot model considered by the contact model. + stiffness: The stiffness of the contact model. + damping: The damping of the contact model. + standard_gravity: The standard gravity acceleration. + static_friction_coefficient: The static friction coefficient. + max_penetration: The maximum penetration depth. + number_of_active_collidable_points_steady_state: + The number of active collidable points in steady state. + damping_ratio: The damping ratio. + p: The first parameter of the contact model. + q: The second parameter of the contact model. + + Returns: + The `ContactsParams` instance. + + Note: + The `stiffness` is intended as the terrain stiffness in the Soft Contacts model, + while it is the Baumgarte stabilization stiffness in the Rigid Contacts model. + + The `damping` is intended as the terrain damping in the Soft Contacts model, + while it is the Baumgarte stabilization damping in the Rigid Contacts model. + + The `damping_ratio` parameter allows to operate on the following conditions: + - ξ > 1.0: over-damped + - ξ = 1.0: critically damped + - ξ < 1.0: under-damped + """ + + # Use symbols for input parameters. + ξ = damping_ratio + δ_max = max_penetration + μc = static_friction_coefficient + + # Compute the total mass of the model. + m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() + + # Rename the standard gravity. + g = standard_gravity + + # Compute the average support force on each collidable point. + f_average = m * g / number_of_active_collidable_points_steady_state + + # Compute the stiffness to get the desired steady-state penetration. + # Note that this is dependent on the non-linear exponent used in + # the damping term of the Hunt/Crossley model. + K = f_average / jnp.power(δ_max, 1 + p) if stiffness is None else stiffness + + # Compute the damping using the damping ratio. + critical_damping = 2 * jnp.sqrt(K * m) + D = ξ * critical_damping if damping is None else damping + + return self.build( + K=K, + D=D, + mu=μc, + p=p, + q=q, + ) + @abc.abstractmethod def valid(self, **kwargs) -> jtp.BoolLike: """ From 882241635de1f080724429fa6801196eaf0d27ac Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 14:25:30 +0100 Subject: [PATCH 18/36] Use polymorphism to update the contact state during the step --- src/jaxsim/api/model.py | 5 ++++- src/jaxsim/rbda/contacts/common.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 844f42cc9..4ee35e6db 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2109,6 +2109,9 @@ def step( joint_torques=τ_total, ) + # Update the contact state data + contact_state = model.contact_model.update_contact_state(aux_dict) + # ============================== # Compute the total link forces # ============================== @@ -2128,7 +2131,7 @@ def step( data=data, link_forces=W_f_L_total, joint_torques=τ_total, - extended_contact_state=aux_dict, + extended_contact_state=contact_state, ) if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts): diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 159fdb254..beaf86b60 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -251,3 +251,27 @@ def _parameters_class(cls) -> type[ContactsParams]: else cls.__class__.__name__ + "Params" ), ) + + def update_contact_state( + self: type[Self], old_contact_state: dict[str, jtp.Array] + ) -> dict[str, jtp.Array]: + """ + Update the contact state. + + Args: + old_contact_state: The old contact state. + + Returns: + The updated contact state. + """ + + # Import the contact models to avoid circular imports. + from .relaxed_rigid import RelaxedRigidContacts + from .rigid import RigidContacts + from .soft import SoftContacts + + match self: + case SoftContacts(): + return {"tangential_deformation": old_contact_state["m_dot"]} + case RigidContacts() | RelaxedRigidContacts(): + return {} From e008f0d865267a27bdf01fc9de999ff9cbc83e0a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 14:34:26 +0100 Subject: [PATCH 19/36] Refactor initialization of `contact_state` --- src/jaxsim/api/data.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 0b6f43d88..0fa001b1d 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -176,15 +176,13 @@ def build( ) ) - contact_state = ( - { - "tangential_deformation": jnp.zeros_like( - model.kin_dyn_parameters.contact_parameters.point - ) - } - if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts) - else contact_state or {} - ) + contact_state = contact_state or {} + + if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): + contact_state.setdefault( + "tangential_deformation", + jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), + ) model_data = JaxSimModelData( velocity_representation=velocity_representation, @@ -198,7 +196,7 @@ def build( _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities_inertial, - contact_state=contact_state or {}, + contact_state=contact_state, ) if not model_data.valid(model=model): From 8e6c3b511878ed3790b182fadd09f285ea5ccaca Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 14:45:16 +0100 Subject: [PATCH 20/36] Use polymorphism to update the post impact velocity --- src/jaxsim/api/model.py | 52 ++------------------ src/jaxsim/rbda/contacts/common.py | 76 ++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 49 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 4ee35e6db..83fd0cdb5 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2134,54 +2134,8 @@ def step( extended_contact_state=contact_state, ) - if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts): - # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - W_p_C = js.contact.collidable_point_positions(model, data_tf)[ - indices_of_enabled_collidable_points - ] - - # Compute the penetration depth of the collidable points. - δ, *_ = jax.vmap( - jaxsim.rbda.contacts.common.compute_penetration_data, - in_axes=(0, 0, None), - )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) - - with data_tf.switch_velocity_representation(VelRepr.Mixed): - J_WC = js.contact.jacobian(model, data_tf)[ - indices_of_enabled_collidable_points - ] - M = js.model.free_floating_mass_matrix(model, data_tf) - BW_ν_pre_impact = data_tf.generalized_velocity - - # Compute the impact velocity. - # It may be discontinuous in case new contacts are made. - BW_ν_post_impact = ( - jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity( - generalized_velocity=BW_ν_pre_impact, - inactive_collidable_points=(δ <= 0), - M=M, - J_WC=J_WC, - ) - ) - - BW_ν_post_impact_inertial = data_tf.other_representation_to_inertial( - array=BW_ν_post_impact[0:6], - other_representation=VelRepr.Mixed, - transform=data_tf._base_transform.at[0:3, 0:3].set(jnp.eye(3)), - is_force=False, - ) - - # Reset the generalized velocity. - data_tf = dataclasses.replace( - data_tf, - velocity_representation=data.velocity_representation, - _base_linear_velocity=BW_ν_post_impact_inertial[0:3], - _base_angular_velocity=BW_ν_post_impact_inertial[3:6], - _joint_velocities=BW_ν_post_impact[6:], - ) + data_tf = model.contact_model.update_velocity_after_impact( + model=model, data=data_tf + ) return data_tf diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index beaf86b60..294f32ae3 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import dataclasses import functools import jax @@ -275,3 +276,78 @@ def update_contact_state( return {"tangential_deformation": old_contact_state["m_dot"]} case RigidContacts() | RelaxedRigidContacts(): return {} + + @jax.jit + @js.common.named_scope + def update_velocity_after_impact( + self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData + ) -> js.data.JaxSimModelData: + """ + Update the velocity after an impact. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + + Returns: + The updated data of the considered model. + """ + + # Import the rigid contact model to avoid circular imports. + from jaxsim.api.common import VelRepr + + from .rigid import RigidContacts + + if isinstance(self, RigidContacts): + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + W_p_C = js.contact.collidable_point_positions(model, data)[ + indices_of_enabled_collidable_points + ] + + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + + original_representation = data.velocity_representation + + with data.switch_velocity_representation(VelRepr.Mixed): + J_WC = js.contact.jacobian(model, data)[ + indices_of_enabled_collidable_points + ] + M = js.model.free_floating_mass_matrix(model, data) + BW_ν_pre_impact = data.generalized_velocity + + # Compute the impact velocity. + # It may be discontinuous in case new contacts are made. + BW_ν_post_impact = ( + jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity( + generalized_velocity=BW_ν_pre_impact, + inactive_collidable_points=(δ <= 0), + M=M, + J_WC=J_WC, + ) + ) + + BW_ν_post_impact_inertial = data.other_representation_to_inertial( + array=BW_ν_post_impact[0:6], + other_representation=VelRepr.Mixed, + transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)), + is_force=False, + ) + + # Reset the generalized velocity. + data = dataclasses.replace( + data, + velocity_representation=original_representation, + _base_linear_velocity=BW_ν_post_impact_inertial[0:3], + _base_angular_velocity=BW_ν_post_impact_inertial[3:6], + _joint_velocities=BW_ν_post_impact[6:], + ) + + return data From d3b84044cdbd6ce2053bbae0ea0713f6f35cae89 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 16:10:47 +0100 Subject: [PATCH 21/36] Make `build_default_from_jaxsim_model` compatible with all models --- src/jaxsim/rbda/contacts/common.py | 3 +++ src/jaxsim/rbda/contacts/relaxed_rigid.py | 25 +++++++++++------------ src/jaxsim/rbda/contacts/soft.py | 2 ++ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 294f32ae3..4fc8088f0 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -95,6 +95,7 @@ def build_default_from_jaxsim_model( damping_ratio: jtp.FloatLike = 1.0, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, + **kwargs, ) -> Self: """ Create a `ContactsParams` instance with default parameters. @@ -111,6 +112,7 @@ def build_default_from_jaxsim_model( damping_ratio: The damping ratio. p: The first parameter of the contact model. q: The second parameter of the contact model. + **kwargs: Optional additional arguments. Returns: The `ContactsParams` instance. @@ -157,6 +159,7 @@ def build_default_from_jaxsim_model( mu=μc, p=p, q=q, + **kwargs, ) @abc.abstractmethod diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 19217850c..8ed05bf44 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -64,12 +64,12 @@ class RelaxedRigidContactsParams(common.ContactsParams): ) # Stiffness - stiffness: jtp.Float = dataclasses.field( + K: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.0, dtype=float) ) # Damping - damping: jtp.Float = dataclasses.field( + D: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.0, dtype=float) ) @@ -90,8 +90,8 @@ def __hash__(self) -> int: HashedNumpyArray(self.width), HashedNumpyArray(self.midpoint), HashedNumpyArray(self.power), - HashedNumpyArray(self.stiffness), - HashedNumpyArray(self.damping), + HashedNumpyArray(self.K), + HashedNumpyArray(self.D), HashedNumpyArray(self.mu), ) ) @@ -110,9 +110,10 @@ def build( width: jtp.FloatLike | None = None, midpoint: jtp.FloatLike | None = None, power: jtp.FloatLike | None = None, - stiffness: jtp.FloatLike | None = None, - damping: jtp.FloatLike | None = None, + K: jtp.FloatLike | None = None, + D: jtp.FloatLike | None = None, mu: jtp.FloatLike | None = None, + **kwargs, ) -> Self: """Create a `RelaxedRigidContactsParams` instance.""" @@ -151,13 +152,11 @@ def default(name: str): power=jnp.array( power if power is not None else default("power"), dtype=float ), - stiffness=jnp.array( - stiffness if stiffness is not None else default("stiffness"), + K=jnp.array( + K if K is not None else default("K"), dtype=float, ), - damping=jnp.array( - damping if damping is not None else default("damping"), dtype=float - ), + D=jnp.array(D if D is not None else default("D"), dtype=float), mu=jnp.array(mu if mu is not None else default("mu"), dtype=float), ) @@ -505,8 +504,8 @@ def _regularizers( "width", "midpoint", "power", - "stiffness", - "damping", + "K", + "D", "mu", ) ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 86889d2d4..43b3773eb 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -76,6 +76,7 @@ def build( mu: jtp.FloatLike = 0.5, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, + **kwargs, ) -> Self: """ Create a SoftContactsParams instance with specified parameters. @@ -90,6 +91,7 @@ def build( q: The exponent q corresponding to the spring-related non-linearity of the Hunt/Crossley model + **kwargs: Additional parameters to pass to the contact model. Returns: A SoftContactsParams instance with the specified parameters. From 6853fe3fa16d43a343e093574889c2bcace0e432 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 16:12:16 +0100 Subject: [PATCH 22/36] Unify `api.contact_model` in `api.contact` --- src/jaxsim/api/__init__.py | 1 - src/jaxsim/api/contact.py | 98 +++++++++++++++++++++++++++++ src/jaxsim/api/contact_model.py | 105 -------------------------------- src/jaxsim/api/model.py | 2 +- 4 files changed, 99 insertions(+), 107 deletions(-) delete mode 100644 src/jaxsim/api/contact_model.py diff --git a/src/jaxsim/api/__init__.py b/src/jaxsim/api/__init__.py index 1d9a7cafe..ec4bd36ca 100644 --- a/src/jaxsim/api/__init__.py +++ b/src/jaxsim/api/__init__.py @@ -4,7 +4,6 @@ actuation_model, com, contact, - contact_model, frame, integrators, joint, diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 371e23dcb..9ee210412 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -11,6 +11,7 @@ import jaxsim.typing as jtp from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform +from jaxsim.rbda.contacts import SoftContacts from .common import VelRepr @@ -526,3 +527,100 @@ def compute_O_J̇_WC_I( ) return O_J̇_WC + + +@jax.jit +@js.common.named_scope +def link_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_forces: jtp.MatrixLike | None = None, + joint_torques: jtp.VectorLike | None = None, +) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]: + """ + Compute the 6D contact forces of all links of the model in inertial representation. + + Args: + model: The model to consider. + data: The data of the considered model. + link_forces: + The 6D external forces to apply to the links expressed in inertial representation + joint_torques: + The joint torques acting on the joints. + + Returns: + A `(nL, 6)` array containing the stacked 6D contact forces of the links, + expressed in inertial representation. + """ + + # Compute the contact forces for each collidable point with the active contact model. + W_f_C, aux_dict = model.contact_model.compute_contact_forces( + model=model, + data=data, + **( + dict(link_forces=link_forces, joint_force_references=joint_torques) + if not isinstance(model.contact_model, SoftContacts) + else {} + ), + ) + + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + W_f_L = link_forces_from_contact_forces( + model=model, data=data, contact_forces=W_f_C + ) + + return W_f_L, aux_dict + + +@staticmethod +def link_forces_from_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + contact_forces: jtp.MatrixLike, +) -> jtp.Matrix: + """ + Compute the link forces from the contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + contact_forces: The contact forces computed by the contact model. + + Returns: + The 6D contact forces applied to the links and expressed in the frame of + the velocity representation of data. + """ + + # Get the object storing the contact parameters of the model. + contact_parameters = model.kin_dyn_parameters.contact_parameters + + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + contact_parameters.indices_of_enabled_collidable_points + ) + + # Convert the contact forces to a JAX array. + W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze()) + + # Construct the vector defining the parent link index of each collidable point. + # We use this vector to sum the 6D forces of all collidable points rigidly + # attached to the same link. + parent_link_index_of_collidable_points = jnp.array( + contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + + # Create the mask that associate each collidable point to their parent link. + # We use this mask to sum the collidable points to the right link. + mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( + model.number_of_links() + ) + + # Sum the forces of all collidable points rigidly attached to a body. + # Since the contact forces W_f_C are expressed in the world frame, + # we don't need any coordinate transformation. + W_f_L = mask.T @ W_f_C + + return W_f_L diff --git a/src/jaxsim/api/contact_model.py b/src/jaxsim/api/contact_model.py deleted file mode 100644 index cda5222ba..000000000 --- a/src/jaxsim/api/contact_model.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -import jax -import jax.numpy as jnp - -import jaxsim.api as js -import jaxsim.typing as jtp -from jaxsim.rbda.contacts import SoftContacts - - -@jax.jit -@js.common.named_scope -def link_contact_forces( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - link_forces: jtp.MatrixLike | None = None, - joint_torques: jtp.VectorLike | None = None, -) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]: - """ - Compute the 6D contact forces of all links of the model in inertial representation. - - Args: - model: The model to consider. - data: The data of the considered model. - link_forces: - The 6D external forces to apply to the links expressed in inertial representation - joint_torques: - The joint torques acting on the joints. - - Returns: - A `(nL, 6)` array containing the stacked 6D contact forces of the links, - expressed in inertial representation. - """ - - # Compute the contact forces for each collidable point with the active contact model. - W_f_C, aux_dict = model.contact_model.compute_contact_forces( - model=model, - data=data, - **( - dict(link_forces=link_forces, joint_force_references=joint_torques) - if not isinstance(model.contact_model, SoftContacts) - else {} - ), - ) - - # Compute the 6D forces applied to the links equivalent to the forces applied - # to the frames associated to the collidable points. - W_f_L = link_forces_from_contact_forces( - model=model, data=data, contact_forces=W_f_C - ) - - return W_f_L, aux_dict - - -@staticmethod -def link_forces_from_contact_forces( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - contact_forces: jtp.MatrixLike, -) -> jtp.Matrix: - """ - Compute the link forces from the contact forces. - - Args: - model: The robot model considered by the contact model. - data: The data of the considered model. - contact_forces: The contact forces computed by the contact model. - - Returns: - The 6D contact forces applied to the links and expressed in the frame of - the velocity representation of data. - """ - - # Get the object storing the contact parameters of the model. - contact_parameters = model.kin_dyn_parameters.contact_parameters - - # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - contact_parameters.indices_of_enabled_collidable_points - ) - - # Convert the contact forces to a JAX array. - W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze()) - - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - - # Create the mask that associate each collidable point to their parent link. - # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() - ) - - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_C are expressed in the world frame, - # we don't need any coordinate transformation. - W_f_L = mask.T @ W_f_C - - return W_f_L diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 83fd0cdb5..77e8e7379 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2102,7 +2102,7 @@ def step( # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_L_terrain, aux_dict = js.contact_model.link_contact_forces( + W_f_L_terrain, aux_dict = js.contact.link_contact_forces( model=model, data=data, link_forces=W_f_L_external, From 39f99894763d89e51619732224b376dab0400260 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 10 Feb 2025 16:16:19 +0100 Subject: [PATCH 23/36] Rename `aux_dict` to `old_contact_state` --- src/jaxsim/api/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 77e8e7379..dedc06a91 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2102,15 +2102,16 @@ def step( # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_L_terrain, aux_dict = js.contact.link_contact_forces( + W_f_L_terrain, old_contact_state = js.contact.link_contact_forces( model=model, data=data, link_forces=W_f_L_external, joint_torques=τ_total, ) - # Update the contact state data - contact_state = model.contact_model.update_contact_state(aux_dict) + # Update the contact state data. This is necessary only for the contact models + # that require propagation and integration of contact state. + contact_state = model.contact_model.update_contact_state(old_contact_state) # ============================== # Compute the total link forces From 6716a41b070e077321fa92a0156202b532bfbc81 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 25 Feb 2025 17:04:08 +0100 Subject: [PATCH 24/36] Make static methods functions in `RigidContacts` --- src/jaxsim/rbda/contacts/rigid.py | 160 ++++++++++++++++-------------- 1 file changed, 83 insertions(+), 77 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 76bbd1076..afe9c3eb5 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp import jax_dataclasses +import qpax import jaxsim.api as js import jaxsim.typing as jtp @@ -239,9 +240,6 @@ def compute_contact_forces( A tuple containing as first element the computed contact forces. """ - # Import qpax privately just in this method. - import qpax - # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points @@ -306,7 +304,7 @@ def compute_contact_forces( # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. - free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( + free_contact_acc = _linear_acceleration_of_collidable_points( BW_nu=BW_ν, BW_nu_dot=BW_ν̇_free, CW_J_WC_BW=J_WC, @@ -314,7 +312,7 @@ def compute_contact_forces( ).flatten() # Compute stabilization term. - baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term( + baumgarte_term = _compute_baumgarte_stabilization_term( inactive_collidable_points=(δ <= 0), δ=δ, δ_dot=δ_dot, @@ -324,7 +322,7 @@ def compute_contact_forces( ).flatten() # Compute the Delassus matrix. - delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) + delassus_matrix = _compute_delassus_matrix(M=M, J_WC=J_WC) # Initialize regularization term of the Delassus matrix for # better numerical conditioning. @@ -335,12 +333,10 @@ def compute_contact_forces( q = free_contact_acc - baumgarte_term # Construct the inequality constraints. - G = RigidContacts._compute_ineq_constraint_matrix( + G = _compute_ineq_constraint_matrix( inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu ) - h_bounds = RigidContacts._compute_ineq_bounds( - n_collidable_points=n_collidable_points - ) + h_bounds = _compute_ineq_bounds(n_collidable_points=n_collidable_points) # Construct the equality constraints. A = jnp.zeros((0, 3 * n_collidable_points)) @@ -375,82 +371,92 @@ def compute_contact_forces( return W_f_C, {} - @staticmethod - def _delassus_matrix( - M: jtp.MatrixLike, - J_WC: jtp.MatrixLike, - ) -> jtp.Matrix: - sl = jnp.s_[:, 0:3, :] - J_WC_lin = jnp.vstack(J_WC[sl]) +@jax.jit +@js.common.named_scope +def _compute_delassus_matrix( + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, +) -> jtp.Matrix: + + sl = jnp.s_[:, 0:3, :] + J_WC_lin = jnp.vstack(J_WC[sl]) + + delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + return delassus_matrix + + +@jax.jit +@js.common.named_scope +def _compute_ineq_constraint_matrix( + inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike +) -> jtp.Matrix: + """ + Compute the inequality constraint matrix for a single collidable point. + + Rows 0-3: enforce the friction pyramid constraint, + Row 4: last one is for the non negativity of the vertical force + Row 5: contact complementarity condition + """ + G_single_point = jnp.array( + [ + [1, 0, -mu], + [0, 1, -mu], + [-1, 0, -mu], + [0, -1, -mu], + [0, 0, -1], + [0, 0, 0], + ] + ) + G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1)) + G = G.at[:, 5, 2].set(inactive_collidable_points) - delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T - return delassus_matrix + G = jax.scipy.linalg.block_diag(*G) + return G - @staticmethod - def _compute_ineq_constraint_matrix( - inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike - ) -> jtp.Matrix: - """ - Compute the inequality constraint matrix for a single collidable point. - Rows 0-3: enforce the friction pyramid constraint, - Row 4: last one is for the non negativity of the vertical force - Row 5: contact complementarity condition - """ - G_single_point = jnp.array( - [ - [1, 0, -mu], - [0, 1, -mu], - [-1, 0, -mu], - [0, -1, -mu], - [0, 0, -1], - [0, 0, 0], - ] - ) - G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1)) - G = G.at[:, 5, 2].set(inactive_collidable_points) +@jax.jit +@js.common.named_scope +def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: - G = jax.scipy.linalg.block_diag(*G) - return G + n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,)) - @staticmethod - def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: - n_constraints = 6 * n_collidable_points - return jnp.zeros(shape=(n_constraints,)) +@jax.jit +@js.common.named_scope +def _linear_acceleration_of_collidable_points( + BW_nu: jtp.ArrayLike, + BW_nu_dot: jtp.ArrayLike, + CW_J_WC_BW: jtp.MatrixLike, + CW_J_dot_WC_BW: jtp.MatrixLike, +) -> jtp.Matrix: - @staticmethod - def _linear_acceleration_of_collidable_points( - BW_nu: jtp.ArrayLike, - BW_nu_dot: jtp.ArrayLike, - CW_J_WC_BW: jtp.MatrixLike, - CW_J_dot_WC_BW: jtp.MatrixLike, - ) -> jtp.Matrix: + BW_ν = BW_nu + BW_ν̇ = BW_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW - BW_ν = BW_nu - BW_ν̇ = BW_nu_dot - CW_J̇_WC_BW = CW_J_dot_WC_BW + # Compute the linear acceleration of the collidable points. + # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. + CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ - # Compute the linear acceleration of the collidable points. - # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. - CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ + CW_a_WC = CW_a_WC.reshape(-1, 6) + return CW_a_WC[:, 0:3].squeeze() - CW_a_WC = CW_a_WC.reshape(-1, 6) - return CW_a_WC[:, 0:3].squeeze() - @staticmethod - def _compute_baumgarte_stabilization_term( - inactive_collidable_points: jtp.ArrayLike, - δ: jtp.ArrayLike, - δ_dot: jtp.ArrayLike, - n: jtp.ArrayLike, - K: jtp.FloatLike, - D: jtp.FloatLike, - ) -> jtp.Array: - - return jnp.where( - inactive_collidable_points[:, jnp.newaxis], - jnp.zeros_like(n), - (K * δ + D * δ_dot)[:, jnp.newaxis] * n, - ) +@jax.jit +@js.common.named_scope +def _compute_baumgarte_stabilization_term( + inactive_collidable_points: jtp.ArrayLike, + δ: jtp.ArrayLike, + δ_dot: jtp.ArrayLike, + n: jtp.ArrayLike, + K: jtp.FloatLike, + D: jtp.FloatLike, +) -> jtp.Array: + + return jnp.where( + inactive_collidable_points[:, jnp.newaxis], + jnp.zeros_like(n), + (K * δ + D * δ_dot)[:, jnp.newaxis] * n, + ) From f3ee2b6f234985ae42d2890f43cf24326f1248d6 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 16:49:40 +0100 Subject: [PATCH 25/36] Make `system_acceleration` return the contact state derivative --- src/jaxsim/api/integrators.py | 26 +++++++++----------- src/jaxsim/api/model.py | 30 +---------------------- src/jaxsim/api/ode.py | 28 ++++++++++++++++++--- src/jaxsim/rbda/contacts/relaxed_rigid.py | 4 +-- src/jaxsim/rbda/contacts/rigid.py | 4 +-- 5 files changed, 41 insertions(+), 51 deletions(-) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index e9ddea1c2..383552934 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -16,15 +16,13 @@ def semi_implicit_euler_integration( data: js.data.JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, - *, - extended_contact_state: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): # Compute the system acceleration - W_v̇_WB, s̈ = js.ode.system_acceleration( + W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration( model=model, data=data, link_forces=link_forces, @@ -67,7 +65,9 @@ def semi_implicit_euler_integration( s = data.joint_positions + dt * ṡ integrated_contact_state = jax.tree.map( - lambda x, x_dot: x + dt * x_dot, data.contact_state, extended_contact_state + lambda x, x_dot: x + dt * x_dot, + data.contact_state, + contact_state_derivative, ) # TODO: Avoid double replace, e.g. by computing cached value here @@ -93,8 +93,6 @@ def rk4_integration( data: JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, - *, - extended_contact_state: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the Runge-Kutta 4 method.""" @@ -146,15 +144,13 @@ def f(x) -> dict[str, jtp.Matrix]: data_tf = dataclasses.replace( data, - **{ - "_base_position": x_tf["base_position"], - "_base_quaternion": x_tf["base_quaternion"], - "_joint_positions": x_tf["joint_positions"], - "_base_linear_velocity": x_tf["base_linear_velocity"], - "_base_angular_velocity": x_tf["base_angular_velocity"], - "_joint_velocities": x_tf["joint_velocities"], - "contact_state": x_tf["contact_state"], - }, + _base_position=x_tf["base_position"], + _base_quaternion=x_tf["base_quaternion"], + _joint_positions=x_tf["joint_positions"], + _base_linear_velocity=x_tf["base_linear_velocity"], + _base_angular_velocity=x_tf["base_angular_velocity"], + _joint_velocities=x_tf["joint_velocities"], + contact_state=x_tf["contact_state"], ) return data_tf.replace(model=model) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index dedc06a91..21ceaf1ae 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2092,33 +2092,6 @@ def step( model, data, joint_force_references=τ_references ) - # ====================== - # Compute contact forces - # ====================== - - W_f_L_terrain = jnp.zeros_like(W_f_L_external) - - if len(model.kin_dyn_parameters.contact_parameters.body) > 0: - - # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact - # with the terrain. - W_f_L_terrain, old_contact_state = js.contact.link_contact_forces( - model=model, - data=data, - link_forces=W_f_L_external, - joint_torques=τ_total, - ) - - # Update the contact state data. This is necessary only for the contact models - # that require propagation and integration of contact state. - contact_state = model.contact_model.update_contact_state(old_contact_state) - - # ============================== - # Compute the total link forces - # ============================== - - W_f_L_total = W_f_L_external + W_f_L_terrain - # ============================= # Advance the simulation state # ============================= @@ -2130,9 +2103,8 @@ def step( data_tf = integrator_fn( model=model, data=data, - link_forces=W_f_L_total, + link_forces=W_f_L_external, joint_torques=τ_total, - extended_contact_state=contact_state, ) data_tf = model.contact_model.update_velocity_after_impact( diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 2f61e1d43..dc903b8c1 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -46,12 +46,33 @@ def system_acceleration( else jnp.zeros((model.number_of_links(), 6)) ).astype(float) + # ====================== + # Compute contact forces + # ====================== + + if len(model.kin_dyn_parameters.contact_parameters.body) > 0: + + # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact + # with the terrain. + W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces( + model=model, + data=data, + link_forces=f_L, + joint_torques=joint_torques, + ) + + W_f_L_total = f_L + W_f_L_terrain + + # Update the contact state data. This is necessary only for the contact models + # that require propagation and integration of contact state. + contact_state = model.contact_model.update_contact_state(contact_state_derivative) + # Store the link forces in a references object. references = js.references.JaxSimModelReferences.build( model=model, data=data, velocity_representation=data.velocity_representation, - link_forces=f_L, + link_forces=W_f_L_total, ) # Compute forward dynamics. @@ -68,7 +89,7 @@ def system_acceleration( link_forces=references.link_forces(model=model, data=data), ) - return v̇_WB, s̈ + return v̇_WB, s̈, contact_state @jax.jit @@ -144,7 +165,7 @@ def system_dynamics( """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - W_v̇_WB, s̈ = system_acceleration( + W_v̇_WB, s̈, contact_state_derivative = system_acceleration( model=model, data=data, joint_torques=joint_torques, @@ -164,4 +185,5 @@ def system_dynamics( base_linear_velocity=W_v̇_WB[0:3], base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, + contact_state=contact_state_derivative, ) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 8ed05bf44..6f7aad5e4 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -313,11 +313,11 @@ def compute_contact_forces( BW_ν = data.generalized_velocity BW_ν̇_free = jnp.hstack( - js.ode.system_acceleration( + js.model.forward_dynamics_aba( model=model, data=data, link_forces=references.link_forces(model=model, data=data), - joint_torques=references.joint_force_references(model=model), + joint_forces=references.joint_force_references(model=model), ) ) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index afe9c3eb5..86b64e022 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -294,11 +294,11 @@ def compute_contact_forces( # Compute the generalized free acceleration. with data.switch_velocity_representation(VelRepr.Mixed): BW_ν̇_free = jnp.hstack( - js.ode.system_acceleration( + js.model.forward_dynamics_aba( model=model, data=data, link_forces=references.link_forces(model=model, data=data), - joint_torques=references.joint_force_references(model=model), + joint_forces=references.joint_force_references(model=model), ) ) From 4683e81ae803f7d446c80ede0c7085ce25be7b50 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 16:50:21 +0100 Subject: [PATCH 26/36] Parametrize contact tests on integrators --- tests/test_simulations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 456b4003f..6b40a65cd 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -188,7 +188,7 @@ def run_simulation( def test_simulation_with_soft_contacts( - jaxsim_model_box: js.model.JaxSimModel, + jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box @@ -239,7 +239,7 @@ def test_simulation_with_soft_contacts( def test_simulation_with_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, + jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box From 71235ccb826d06cea3cdb7dd93a1306acfff2387 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 16:50:45 +0100 Subject: [PATCH 27/36] Allow replacing the contact state in `JaxSimModelData` --- src/jaxsim/api/data.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 0fa001b1d..d477a8705 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -413,19 +413,31 @@ def replace( base_linear_velocity: jtp.Vector | None = None, base_angular_velocity: jtp.Vector | None = None, base_position: jtp.Vector | None = None, + *, + contact_state: dict[str, jtp.Array] | None = None, validate: bool = False, ) -> Self: """ Replace the attributes of the `JaxSimModelData` object. """ - if joint_positions is None: - joint_positions = self.joint_positions - if joint_velocities is None: - joint_velocities = self.joint_velocities - if base_quaternion is None: - base_quaternion = self.base_quaternion - if base_position is None: - base_position = self.base_position + + joint_positions = ( + self.joint_positions if joint_positions is None else joint_positions + ) + joint_velocities = ( + self.joint_velocities if joint_velocities is None else joint_velocities + ) + base_quaternion = ( + self.base_quaternion if base_quaternion is None else base_quaternion + ) + base_position = self.base_position if base_position is None else base_position + contact_state = self.contact_state if contact_state is None else contact_state + + if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): + contact_state.setdefault( + "tangential_deformation", + jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), + ) joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float) joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float) From 4f0b7df1c06da4b1f2c7656112c34e0ba3c40990 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 16:51:16 +0100 Subject: [PATCH 28/36] Clarify that the gravity should be passed as positive constant --- src/jaxsim/api/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 21ceaf1ae..60ad2f947 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -144,7 +144,7 @@ def build_from_model_description( If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. integrator: The integrator to use for the simulation. - gravity: The gravity constant. + gravity: The gravity constant. Normally passed as a positive value. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. @@ -179,7 +179,7 @@ def build_from_model_description( contact_model=contact_model, contact_params=contact_params, integrator=integrator, - gravity=gravity, + gravity=-gravity, ) # Store the origin of the model, in case downstream logic needs it. From 510a20de588f29eca62848909aefd208b9ce71a3 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 16:51:32 +0100 Subject: [PATCH 29/36] Remove JIT compilation from inequality bounds function --- src/jaxsim/rbda/contacts/rigid.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 86b64e022..26fefa5b7 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -415,11 +415,16 @@ def _compute_ineq_constraint_matrix( return G -@jax.jit -@js.common.named_scope def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: + """ + Compute the inequality bounds for the contact forces. + + Note: + Do not JIT this function as the output shape depends on the number of collidable points. + """ n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,)) From 2f6c231d5836a7eb5874a419cc2d84a10f66c544 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 17:15:59 +0100 Subject: [PATCH 30/36] Remove unused methods in soft contacts --- src/jaxsim/rbda/contacts/soft.py | 69 -------------------------------- 1 file changed, 69 deletions(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 43b3773eb..24f26be1c 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -11,7 +11,6 @@ import jaxsim.math import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.math import STANDARD_GRAVITY from jaxsim.terrain import Terrain from . import common @@ -105,74 +104,6 @@ def build( q=jnp.array(q, dtype=float), ) - @classmethod - def build_default_from_jaxsim_model( - cls: type[Self], - model: js.model.JaxSimModel, - *, - standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, - static_friction_coefficient: jtp.FloatLike = 0.5, - max_penetration: jtp.FloatLike = 0.001, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, - damping_ratio: jtp.FloatLike = 1.0, - p: jtp.FloatLike = 0.5, - q: jtp.FloatLike = 0.5, - ) -> SoftContactsParams: - """ - Create a SoftContactsParams instance with good default parameters. - - Args: - model: The target model. - standard_gravity: The standard gravity constant. - static_friction_coefficient: - The static friction coefficient between the model and the terrain. - max_penetration: The maximum penetration depth. - number_of_active_collidable_points_steady_state: - The number of contacts supporting the weight of the model - in steady state. - damping_ratio: The ratio controlling the damping behavior. - p: - The exponent p corresponding to the damping-related non-linearity - of the Hunt/Crossley model. - q: - The exponent q corresponding to the spring-related non-linearity - of the Hunt/Crossley model - - Returns: - A `SoftContactsParams` instance with the specified parameters. - - Note: - The `damping_ratio` parameter allows to operate on the following conditions: - - ξ > 1.0: over-damped - - ξ = 1.0: critically damped - - ξ < 1.0: under-damped - """ - - # Use symbols for input parameters. - ξ = damping_ratio - δ_max = max_penetration - μc = static_friction_coefficient - - # Compute the total mass of the model. - m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() - - # Rename the standard gravity. - g = standard_gravity - - # Compute the average support force on each collidable point. - f_average = m * g / number_of_active_collidable_points_steady_state - - # Compute the stiffness to get the desired steady-state penetration. - # Note that this is dependent on the non-linear exponent used in - # the damping term of the Hunt/Crossley model. - K = f_average / jnp.power(δ_max, 1 + p) - - # Compute the damping using the damping ratio. - critical_damping = 2 * jnp.sqrt(K * m) - D = ξ * critical_damping - - return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q) - def valid(self) -> jtp.BoolLike: """ Check if the parameters are valid. From 49f062e28d710cac7be68806735d767a13cd0f00 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 26 Feb 2025 17:45:56 +0100 Subject: [PATCH 31/36] Speed up rigid contact model --- src/jaxsim/rbda/contacts/common.py | 17 +++--- src/jaxsim/rbda/contacts/rigid.py | 92 ++++++++++-------------------- 2 files changed, 38 insertions(+), 71 deletions(-) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 4fc8088f0..e0b7c0e65 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -317,8 +317,6 @@ def update_velocity_after_impact( in_axes=(0, 0, None), )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) - original_representation = data.velocity_representation - with data.switch_velocity_representation(VelRepr.Mixed): J_WC = js.contact.jacobian(model, data)[ indices_of_enabled_collidable_points @@ -344,13 +342,12 @@ def update_velocity_after_impact( is_force=False, ) - # Reset the generalized velocity. - data = dataclasses.replace( - data, - velocity_representation=original_representation, - _base_linear_velocity=BW_ν_post_impact_inertial[0:3], - _base_angular_velocity=BW_ν_post_impact_inertial[3:6], - _joint_velocities=BW_ν_post_impact[6:], - ) + # Reset the generalized velocity. + data = dataclasses.replace( + data, + _base_linear_velocity=BW_ν_post_impact_inertial[0:3], + _base_angular_velocity=BW_ν_post_impact_inertial[3:6], + _joint_velocities=BW_ν_post_impact[6:], + ) return data diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 26fefa5b7..fde4c46f2 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -196,11 +196,7 @@ def compute_impact_velocity( # Zero out the jacobian rows of inactive points. Jl_WC = jnp.vstack( - jnp.where( - inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], - jnp.zeros_like(Jl_WC), - Jl_WC, - ) + jax.vmap(lambda J, δ: J * δ)(Jl_WC, inactive_collidable_points) ) A = jnp.vstack( @@ -259,16 +255,39 @@ def compute_contact_forces( else jnp.zeros((model.number_of_joints(),)) ) + # Build a references object to simplify converting link forces. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # Compute the transforms of the implicit frames corresponding to the + # collidable points. + W_H_C = js.contact.transforms(model=model, data=data) + # Compute kin-dyn quantities used in the contact model. with data.switch_velocity_representation(VelRepr.Mixed): BW_ν = data.generalized_velocity M = js.model.free_floating_mass_matrix(model=model, data=data) - J_WC = js.contact.jacobian(model=model, data=data) - J̇_WC = js.contact.jacobian_derivative(model=model, data=data) + Jl_WC = jnp.vstack(js.contact.jacobian(model=model, data=data)[:, :3, :]) + J̇l_WC = jnp.vstack( + js.contact.jacobian_derivative(model=model, data=data)[:, :3, :] + ) - W_H_C = js.contact.transforms(model=model, data=data) + # Compute the generalized free acceleration. + BW_ν̇_free = jnp.hstack( + js.model.forward_dynamics_aba( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_forces=references.joint_force_references(model=model), + ) + ) # Compute the position and linear velocities (mixed representation) of # all enabled collidable points belonging to the robot. @@ -282,34 +301,9 @@ def compute_contact_forces( position, velocity, model.terrain ) - # Build a references object to simplify converting link forces. - references = js.references.JaxSimModelReferences.build( - model=model, - data=data, - velocity_representation=data.velocity_representation, - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - - # Compute the generalized free acceleration. - with data.switch_velocity_representation(VelRepr.Mixed): - BW_ν̇_free = jnp.hstack( - js.model.forward_dynamics_aba( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_forces=references.joint_force_references(model=model), - ) - ) - # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. - free_contact_acc = _linear_acceleration_of_collidable_points( - BW_nu=BW_ν, - BW_nu_dot=BW_ν̇_free, - CW_J_WC_BW=J_WC, - CW_J_dot_WC_BW=J̇_WC, - ).flatten() + CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν # Compute stabilization term. baumgarte_term = _compute_baumgarte_stabilization_term( @@ -322,7 +316,7 @@ def compute_contact_forces( ).flatten() # Compute the Delassus matrix. - delassus_matrix = _compute_delassus_matrix(M=M, J_WC=J_WC) + delassus_matrix = _compute_delassus_matrix(M=M, J_WC=Jl_WC) # Initialize regularization term of the Delassus matrix for # better numerical conditioning. @@ -330,7 +324,7 @@ def compute_contact_forces( # Construct the quadratic cost function. Q = delassus_matrix + Iε - q = free_contact_acc - baumgarte_term + q = CW_al_free_WC - baumgarte_term # Construct the inequality constraints. G = _compute_ineq_constraint_matrix( @@ -379,10 +373,7 @@ def _compute_delassus_matrix( J_WC: jtp.MatrixLike, ) -> jtp.Matrix: - sl = jnp.s_[:, 0:3, :] - J_WC_lin = jnp.vstack(J_WC[sl]) - - delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + delassus_matrix = J_WC @ jnp.linalg.pinv(M) @ J_WC.T return delassus_matrix @@ -428,27 +419,6 @@ def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: return jnp.zeros(shape=(n_constraints,)) -@jax.jit -@js.common.named_scope -def _linear_acceleration_of_collidable_points( - BW_nu: jtp.ArrayLike, - BW_nu_dot: jtp.ArrayLike, - CW_J_WC_BW: jtp.MatrixLike, - CW_J_dot_WC_BW: jtp.MatrixLike, -) -> jtp.Matrix: - - BW_ν = BW_nu - BW_ν̇ = BW_nu_dot - CW_J̇_WC_BW = CW_J_dot_WC_BW - - # Compute the linear acceleration of the collidable points. - # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. - CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ - - CW_a_WC = CW_a_WC.reshape(-1, 6) - return CW_a_WC[:, 0:3].squeeze() - - @jax.jit @js.common.named_scope def _compute_baumgarte_stabilization_term( From 383a15bce2d69cae8a609925e6064fe9b5b618e4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 15:28:02 +0100 Subject: [PATCH 32/36] Prevent recompilation on enabled collidable points --- src/jaxsim/rbda/contacts/rigid.py | 77 ++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index fde4c46f2..532de623d 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -196,7 +196,11 @@ def compute_impact_velocity( # Zero out the jacobian rows of inactive points. Jl_WC = jnp.vstack( - jax.vmap(lambda J, δ: J * δ)(Jl_WC, inactive_collidable_points) + jnp.where( + inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], + jnp.zeros_like(Jl_WC), + Jl_WC, + ) ) A = jnp.vstack( @@ -264,20 +268,31 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) - # Compute the transforms of the implicit frames corresponding to the - # collidable points. + # Compute the position and linear velocities (mixed representation) of + # all enabled collidable points belonging to the robot. + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) + + # Compute the penetration depth and velocity of the collidable points. + # Note that this function considers the penetration in the normal direction. + δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( + position, velocity, model.terrain + ) + W_H_C = js.contact.transforms(model=model, data=data) - # Compute kin-dyn quantities used in the contact model. - with data.switch_velocity_representation(VelRepr.Mixed): + with ( + references.switch_velocity_representation(VelRepr.Mixed), + data.switch_velocity_representation(VelRepr.Mixed), + ): + # Compute kin-dyn quantities used in the contact model. BW_ν = data.generalized_velocity M = js.model.free_floating_mass_matrix(model=model, data=data) - Jl_WC = jnp.vstack(js.contact.jacobian(model=model, data=data)[:, :3, :]) - J̇l_WC = jnp.vstack( - js.contact.jacobian_derivative(model=model, data=data)[:, :3, :] - ) + J_WC = js.contact.jacobian(model=model, data=data) + J̇_WC = js.contact.jacobian_derivative(model=model, data=data) # Compute the generalized free acceleration. BW_ν̇_free = jnp.hstack( @@ -303,7 +318,12 @@ def compute_contact_forces( # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. - CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν + free_contact_acc = _linear_acceleration_of_collidable_points( + BW_nu=BW_ν, + BW_nu_dot=BW_ν̇_free, + CW_J_WC_BW=J_WC, + CW_J_dot_WC_BW=J̇_WC, + ).flatten() # Compute stabilization term. baumgarte_term = _compute_baumgarte_stabilization_term( @@ -316,7 +336,7 @@ def compute_contact_forces( ).flatten() # Compute the Delassus matrix. - delassus_matrix = _compute_delassus_matrix(M=M, J_WC=Jl_WC) + delassus_matrix = _delassus_matrix(M=M, J_WC=J_WC) # Initialize regularization term of the Delassus matrix for # better numerical conditioning. @@ -324,13 +344,13 @@ def compute_contact_forces( # Construct the quadratic cost function. Q = delassus_matrix + Iε - q = CW_al_free_WC - baumgarte_term + q = free_contact_acc - baumgarte_term # Construct the inequality constraints. G = _compute_ineq_constraint_matrix( inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu ) - h_bounds = _compute_ineq_bounds(n_collidable_points=n_collidable_points) + h_bounds = jnp.zeros(shape=(n_collidable_points * 6,)) # Construct the equality constraints. A = jnp.zeros((0, 3 * n_collidable_points)) @@ -366,14 +386,16 @@ def compute_contact_forces( return W_f_C, {} -@jax.jit -@js.common.named_scope -def _compute_delassus_matrix( +@staticmethod +def _delassus_matrix( M: jtp.MatrixLike, J_WC: jtp.MatrixLike, ) -> jtp.Matrix: - delassus_matrix = J_WC @ jnp.linalg.pinv(M) @ J_WC.T + sl = jnp.s_[:, 0:3, :] + J_WC_lin = jnp.vstack(J_WC[sl]) + + delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T return delassus_matrix @@ -419,6 +441,27 @@ def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: return jnp.zeros(shape=(n_constraints,)) +@jax.jit +@js.common.named_scope +def _linear_acceleration_of_collidable_points( + BW_nu: jtp.ArrayLike, + BW_nu_dot: jtp.ArrayLike, + CW_J_WC_BW: jtp.MatrixLike, + CW_J_dot_WC_BW: jtp.MatrixLike, +) -> jtp.Matrix: + + BW_ν = BW_nu + BW_ν̇ = BW_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW + + # Compute the linear acceleration of the collidable points. + # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. + CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ + + CW_a_WC = CW_a_WC.reshape(-1, 6) + return CW_a_WC[:, 0:3].squeeze() + + @jax.jit @js.common.named_scope def _compute_baumgarte_stabilization_term( From c853c4fff60b2373a980f0a20f6a885bbb4186b6 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 15:31:46 +0100 Subject: [PATCH 33/36] Remove duplicated Hunt-Crossley model --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 153 +--------------------- 1 file changed, 2 insertions(+), 151 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 6f7aad5e4..46003882a 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import functools from collections.abc import Callable from typing import Any @@ -11,12 +10,10 @@ import optax import jaxsim.api as js -import jaxsim.rbda.contacts import jaxsim.typing as jtp from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr -from jaxsim.terrain.terrain import Terrain -from . import common +from . import common, soft try: from typing import Self @@ -424,7 +421,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: # Initialize the optimized forces with a linear Hunt/Crossley model. init_params = jax.vmap( - lambda p, v: self._hunt_crossley_contact_model( + lambda p, v: soft.SoftContacts.hunt_crossley_contact_model( position=p, velocity=v, terrain=model.terrain, @@ -601,149 +598,3 @@ def compute_row( ) return a_ref, jnp.diag(R), K, D - - @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) - def _hunt_crossley_contact_model( - position: jtp.VectorLike, - velocity: jtp.VectorLike, - tangential_deformation: jtp.VectorLike, - terrain: Terrain, - K: jtp.FloatLike, - D: jtp.FloatLike, - mu: jtp.FloatLike, - p: jtp.FloatLike = 0.5, - q: jtp.FloatLike = 0.5, - ) -> tuple[jtp.Vector, jtp.Vector]: - """ - Compute the contact force using the Hunt/Crossley model. - - Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - tangential_deformation: The material deformation of the collidable point. - terrain: The terrain model. - K: The stiffness parameter. - D: The damping parameter of the soft contacts model. - mu: The static friction coefficient. - p: - The exponent p corresponding to the damping-related non-linearity - of the Hunt/Crossley model. - q: - The exponent q corresponding to the spring-related non-linearity - of the Hunt/Crossley model - - Returns: - A tuple containing the computed contact force and the derivative of the - material deformation. - """ - - # Convert the input vectors to arrays. - W_p_C = jnp.array(position, dtype=float).squeeze() - W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() - m = jnp.array(tangential_deformation, dtype=float).squeeze() - - # Use symbol for the static friction. - μ = mu - - # Compute the penetration depth, its rate, and the considered terrain normal. - δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) - - # There are few operations like computing the norm of a vector with zero length - # or computing the square root of zero that are problematic in an AD context. - # To avoid these issues, we introduce a small tolerance ε to their arguments - # and make sure that we do not check them against zero directly. - ε = jnp.finfo(float).eps - - # Compute the powers of the penetration depth. - # Inject ε to address AD issues in differentiating the square root when - # p and q are fractional. - δp = jnp.power(δ + ε, p) - δq = jnp.power(δ + ε, q) - - # ======================== - # Compute the normal force - # ======================== - - # Non-linear spring-damper model (Hunt/Crossley model). - # This is the force magnitude along the direction normal to the terrain. - force_normal_mag = (K * δp) * δ + (D * δq) * δ̇ - - # Depending on the magnitude of δ̇, the normal force could be negative. - force_normal_mag = jnp.maximum(0.0, force_normal_mag) - - # Compute the 3D linear force in C[W] frame. - f_normal = force_normal_mag * n̂ - - # ============================ - # Compute the tangential force - # ============================ - - # Extract the tangential component of the velocity. - v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂ - - # Extract the normal and tangential components of the material deformation. - m_normal = jnp.dot(m, n̂) * n̂ - m_tangential = m - jnp.dot(m, n̂) * n̂ - - # Compute the tangential force in the sticking case. - # Using the tangential component of the material deformation should not be - # necessary if the sticking-slipping transition occurs in a terrain area - # with a locally constant normal. However, this assumption is not true in - # general, especially for highly uneven terrains. - f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential) - - # Detect the contact type (sticking or slipping). - # Note that if there is no contact, sticking is set to True, and this detail - # is exploited in the computation of the `contact_status` variable. - sticking = jnp.logical_or( - δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2 - ) - - # Compute the direction of the tangential force. - # To prevent dividing by zero, we use a switch statement. - norm = jaxsim.math.safe_norm(f_tangential) - f_tangential_direction = f_tangential / ( - norm + jnp.finfo(float).eps * (norm == 0) - ) - - # Project the tangential force to the friction cone if slipping. - f_tangential = jnp.where( - sticking, - f_tangential, - jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, - ) - - # Set the tangential force to zero if there is no contact. - f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential) - - # ===================================== - # Compute the material deformation rate - # ===================================== - - # Compute the derivative of the material deformation. - # Note that we included an additional relaxation of `m_normal` in the - # sticking case, so that the normal deformation that could have accumulated - # from a previous slipping phase can relax to zero. - ṁ_no_contact = -(K / D) * m - ṁ_sticking = v_tangential - (K / D) * m_normal - ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq) - - # Compute the contact status: - # 0: slipping - # 1: sticking - # 2: no contact - contact_status = sticking.astype(int) - contact_status += (δ <= 0).astype(int) - - # Select the right material deformation rate depending on the contact status. - ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact) - - # ========================================== - # Compute and return the final contact force - # ========================================== - - # Sum the normal and tangential forces. - CW_fl = f_normal + f_tangential - - return CW_fl, ṁ From 1dd113d18bb1a2237e73e335b8b7b35edee081f2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 15:32:14 +0100 Subject: [PATCH 34/36] Fix link forces representation in rigid contacts --- src/jaxsim/rbda/contacts/rigid.py | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 532de623d..969c3fcba 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -216,6 +216,7 @@ def compute_impact_velocity( return BW_ν_post_impact[0 : M.shape[0]] @jax.jit + @js.common.named_scope def compute_contact_forces( self, model: js.model.JaxSimModel, @@ -304,18 +305,6 @@ def compute_contact_forces( ) ) - # Compute the position and linear velocities (mixed representation) of - # all enabled collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( - model=model, data=data - ) - - # Compute the penetration depth and velocity of the collidable points. - # Note that this function considers the penetration in the normal direction. - δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, model.terrain - ) - # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. free_contact_acc = _linear_acceleration_of_collidable_points( @@ -428,19 +417,6 @@ def _compute_ineq_constraint_matrix( return G -def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: - """ - Compute the inequality bounds for the contact forces. - - Note: - Do not JIT this function as the output shape depends on the number of collidable points. - """ - - n_constraints = 6 * n_collidable_points - - return jnp.zeros(shape=(n_constraints,)) - - @jax.jit @js.common.named_scope def _linear_acceleration_of_collidable_points( From ee5d8c8e1a94c119d863a49afd87b1339ad9ac20 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 1 Mar 2025 12:56:19 +0100 Subject: [PATCH 35/36] Move `test_collidable_point_jacobians` to `test_api_contact` --- tests/test_api_contact.py | 32 ++++++++++++++++++++++++++++++++ tests/test_contact.py | 37 ------------------------------------- 2 files changed, 32 insertions(+), 37 deletions(-) delete mode 100644 tests/test_contact.py diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 499b09428..d83c2f5bc 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -68,6 +68,38 @@ def test_contact_kinematics( assert W_ṗ_C == pytest.approx(CW_vl_WC) +def test_collidable_point_jacobians( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + _, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + # ===== + # Tests + # ===== + + # Compute the velocity of the collidable points with a RBDA. + # This function always returns the linear part of the mixed velocity of the + # implicit frame C corresponding to the collidable point. + W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + + # Compute the generalized velocity and the free-floating Jacobian of the frame C. + ν = data.generalized_velocity + CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) + + # Compute the velocity of the collidable points using the Jacobians. + v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) + + assert W_ṗ_C == pytest.approx(v_WC_from_jax[:, 0:3]) + + def test_contact_jacobian_derivative( jaxsim_models_floating_base: js.model.JaxSimModel, velocity_representation: VelRepr, diff --git a/tests/test_contact.py b/tests/test_contact.py deleted file mode 100644 index 6cef55481..000000000 --- a/tests/test_contact.py +++ /dev/null @@ -1,37 +0,0 @@ -import jax -import pytest - -import jaxsim.api as js -from jaxsim import VelRepr - - -def test_collidable_point_jacobians( - jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, - prng_key: jax.Array, -): - - model = jaxsim_models_types - - _, subkey = jax.random.split(prng_key, num=2) - data = js.data.random_model_data( - model=model, key=subkey, velocity_representation=velocity_representation - ) - - # ===== - # Tests - # ===== - - # Compute the velocity of the collidable points with a RBDA. - # This function always returns the linear part of the mixed velocity of the - # implicit frame C corresponding to the collidable point. - W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) - - # Compute the generalized velocity and the free-floating Jacobian of the frame C. - ν = data.generalized_velocity - CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) - - # Compute the velocity of the collidable points using the Jacobians. - v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) - - assert W_ṗ_C == pytest.approx(v_WC_from_jax[:, 0:3]) From 59a2a7b682d7b149607d2f25fbe4591df8dffb92 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 1 Mar 2025 14:53:25 +0100 Subject: [PATCH 36/36] Lint code and remove duplications --- src/jaxsim/api/com.py | 4 +--- src/jaxsim/api/contact.py | 8 ++------ src/jaxsim/api/data.py | 22 +++++++++++--------- src/jaxsim/api/ode.py | 6 +++--- src/jaxsim/parsers/kinematic_graph.py | 2 +- src/jaxsim/rbda/__init__.py | 2 +- src/jaxsim/rbda/contacts/soft.py | 2 +- src/jaxsim/rbda/forward_kinematics.py | 29 --------------------------- 8 files changed, 21 insertions(+), 54 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index c952fa0ac..ee85d078b 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -301,9 +301,7 @@ def other_representation_to_body( C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - L_H_C = L_H_W = jax.vmap( # noqa: F841 - lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L) - )(W_H_L) + L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841 L_v_LC = L_v_LW = jax.vmap( # noqa: F841 lambda i: -js.link.velocity( diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 9ee210412..5873d7e4c 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -267,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt # Build the link-to-point transform from the displacement between the link frame L # and the implicit contact frame C. - L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci) + L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci) # Compose the work-to-link and link-to-point transforms. return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) @@ -567,9 +567,7 @@ def link_contact_forces( # Compute the 6D forces applied to the links equivalent to the forces applied # to the frames associated to the collidable points. - W_f_L = link_forces_from_contact_forces( - model=model, data=data, contact_forces=W_f_C - ) + W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) return W_f_L, aux_dict @@ -577,7 +575,6 @@ def link_contact_forces( @staticmethod def link_forces_from_contact_forces( model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, *, contact_forces: jtp.MatrixLike, ) -> jtp.Matrix: @@ -586,7 +583,6 @@ def link_forces_from_contact_forces( Args: model: The robot model considered by the contact model. - data: The data of the considered model. contact_forces: The contact forces computed by the contact model. Returns: diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index d477a8705..d7acecef6 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -5,9 +5,9 @@ from collections.abc import Sequence try: - from typing import override + from typing import Self, override except ImportError: - from typing_extensions import override + from typing_extensions import override, Self import jax import jax.numpy as jnp @@ -22,11 +22,6 @@ from . import common from .common import VelRepr -try: - from typing import Self -except ImportError: - from typing_extensions import Self - @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): @@ -364,11 +359,14 @@ def base_transform(self) -> jtp.Matrix: @js.common.named_scope @jax.jit - def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: + def reset_base_quaternion( + self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike + ) -> Self: """ Reset the base quaternion. Args: + model: The JaxSim model to use. base_quaternion: The base orientation as a quaternion. Returns: @@ -380,15 +378,18 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: norm = jaxsim.math.safe_norm(W_Q_B) W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) - return self.replace(validate=True, base_quaternion=W_Q_B) + return self.replace(model=model, base_quaternion=W_Q_B) @js.common.named_scope @jax.jit - def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: + def reset_base_pose( + self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike + ) -> Self: """ Reset the base pose. Args: + model: The JaxSim model to use. base_pose: The base pose as an SE(3) matrix. Returns: @@ -399,6 +400,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: W_p_B = base_pose[0:3, 3] W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3]) return self.replace( + model=model, base_position=W_p_B, base_quaternion=W_Q_B, ) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index dc903b8c1..8c8a949b6 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -50,6 +50,9 @@ def system_acceleration( # Compute contact forces # ====================== + W_f_L_terrain = jnp.zeros_like(f_L) + contact_state_derivative = {} + if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact @@ -95,7 +98,6 @@ def system_acceleration( @jax.jit @js.common.named_scope def system_position_dynamics( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: @@ -103,7 +105,6 @@ def system_position_dynamics( Compute the dynamics of the system position. Args: - model: The model to consider. data: The data of the considered model. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient for adjusting the quaternion norm. @@ -173,7 +174,6 @@ def system_dynamics( ) W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( - model=model, data=data, baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, ) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 9c16136d3..b8b935a36 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -973,7 +973,7 @@ def find_parent_link_of_frame(self, name: str) -> str: if frame.parent_name in self.graph.links_dict: return frame.parent_name - elif frame.parent_name in self.graph.frames_dict: + if frame.parent_name in self.graph.frames_dict: return self.find_parent_link_of_frame(name=frame.parent_name) msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'" diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 25bafc1ae..eb89bd208 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -2,7 +2,7 @@ from .aba import aba from .collidable_points import collidable_points_pos_vel from .crba import crba -from .forward_kinematics import forward_kinematics, forward_kinematics_model +from .forward_kinematics import forward_kinematics_model from .jacobian import ( jacobian, jacobian_derivative_full_doubly_left, diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 24f26be1c..53654cad4 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -414,4 +414,4 @@ def compute_contact_forces( ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) - return W_f, dict(m_dot=ṁ) + return W_f, {"m_dot": ṁ} diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 355fe347c..58c230c7d 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -111,32 +111,3 @@ def propagate_kinematics( ) return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi - - -def forward_kinematics( - model: js.model.JaxSimModel, - link_index: jtp.Int, - base_position: jtp.VectorLike, - base_quaternion: jtp.VectorLike, - joint_positions: jtp.VectorLike, -) -> jtp.Matrix: - """ - Compute the forward kinematics of a specific link. - - Args: - model: The model to consider. - link_index: The index of the link to consider. - base_position: The position of the base link. - base_quaternion: The quaternion of the base link. - joint_positions: The positions of the joints. - - Returns: - The SE(3) transform of the link. - """ - - return forward_kinematics_model( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - )[link_index]