Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from jaxopt to optax in relaxed rigid contact model #244

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ dependencies:
- python >= 3.12.0
- coloredlogs
- jax >= 0.4.26
- jaxopt >= 0.8.0
- jaxlib >= 0.4.26
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- optax >= 0.2.3
- pptree
- qpax
- rod >= 0.3.3
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ classifiers = [
dependencies = [
"coloredlogs",
"jax >= 0.4.26",
"jaxopt >= 0.8.0",
"jaxlib >= 0.4.26",
"jaxlie >= 1.3.0",
"jax_dataclasses >= 1.4.0",
"pptree",
"optax >= 0.2.3",
"qpax",
"rod >= 0.3.3",
"typing_extensions ; python_version < '3.12'",
Expand Down
72 changes: 60 additions & 12 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable
from typing import Any

import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxopt
import optax

import jaxsim.api as js
import jaxsim.typing as jtp
Expand Down Expand Up @@ -297,24 +298,71 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
A = G + R
b = CW_al_free_WC - a_ref

objective = lambda x: jnp.sum(jnp.square(A @ x + b))
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))

# Compute the 3D linear force in C[W] frame
opt = jaxopt.LBFGS(
fun=objective,
maxiter=self.parameters.max_iterations,
tol=self.parameters.tolerance,
maxls=30,
history_size=10,
max_stepsize=100.0,
)
def run_optimization(
init_params: jtp.Array,
fun: Callable,
opt: optax.GradientTransformation,
maxiter: jtp.Int,
tol: jtp.Float,
**kwargs,
):
value_and_grad_fn = optax.value_and_grad_from_state(fun)

def step(carry):
params, state = carry
value, grad = value_and_grad_fn(
params,
state=state,
A=A,
b=b,
)
updates, state = opt.update(
updates=grad,
state=state,
params=params,
value=value,
grad=grad,
value_fn=fun,
A=A,
b=b,
)
params = optax.apply_updates(params, updates)
return params, state

def continuing_criterion(carry):
_, state = carry
iter_num = optax.tree_utils.tree_get(state, "count")
grad = optax.tree_utils.tree_get(state, "grad")
err = optax.tree_utils.tree_l2_norm(grad)
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))

init_carry = (init_params, opt.init(init_params))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return final_params, final_state

init_params = (
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
+ D[:, jnp.newaxis] * velocity
).flatten()

CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
# Compute the 3D linear force in C[W] frame
CW_f_Ci, _ = run_optimization(
init_params=init_params,
A=A,
b=b,
maxiter=self.parameters.max_iterations,
opt=optax.lbfgs(
memory_size=10,
),
fun=objective,
tol=self.parameters.tolerance,
)

CW_f_Ci = CW_f_Ci.reshape((-1, 3))

def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
W_Xf_CW = Adjoint.from_transform(
Expand Down