Skip to content

Commit

Permalink
Avoid evaluation of branches leading to NaNs in the optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jan 30, 2025
1 parent 82e8d9f commit 1f24739
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dataclasses
import functools
from collections.abc import Callable
from typing import Any

import jax
Expand Down Expand Up @@ -362,19 +361,22 @@ def compute_contact_forces(
# ========================================

def run_optimization(
init_params: jtp.Vector,
fun: Callable,
opt: optax.GradientTransformationExtraArgs,
maxiter: int,
tol: float,
carry: tuple[jtp.Vector, jtp.Int, jtp.Float],
) -> tuple[jtp.Vector, optax.OptState]:

init_params, maxiter, tol = carry

optimizer = optax.lbfgs(**solver_options)

# Get the function to compute the loss and the gradient w.r.t. its inputs.
value_and_grad_fn = optax.value_and_grad_from_state(fun)
value_and_grad_fn = optax.value_and_grad_from_state(objective)

# Initialize the carry of the following loop.
OptimizationCarry = tuple[jtp.Vector, optax.OptState]
init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
init_carry: OptimizationCarry = (
init_params,
optimizer.init(params=init_params),
)

def step(carry: OptimizationCarry) -> OptimizationCarry:

Expand All @@ -387,13 +389,13 @@ def step(carry: OptimizationCarry) -> OptimizationCarry:
b=b,
)

updates, state = opt.update(
updates, state = optimizer.update(
updates=grad,
state=state,
params=params,
value=value,
grad=grad,
value_fn=fun,
value_fn=objective,
A=A,
b=b,
)
Expand All @@ -413,11 +415,9 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:

return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))

final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
final_params, _ = jax.lax.while_loop(continuing_criterion, step, init_carry)

return final_params, final_state
return final_params, None

# ======================================
# Compute the contact forces with L-BFGS
Expand Down Expand Up @@ -448,12 +448,15 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
maxiter = solver_options.pop("maxiter")

# Compute the 3D linear force in C[W] frame.
solution, _ = run_optimization(
init_params=init_params,
fun=objective,
opt=optax.lbfgs(**solver_options),
tol=tol,
maxiter=maxiter,
solution, _ = jax.lax.cond(
pred=jnp.any(δ > 0),
true_fun=run_optimization,
false_fun=lambda *_: (init_params, None),
operand=(
init_params,
maxiter,
tol,
),
)

# Reshape the optimized solution to be a matrix of 3D contact forces.
Expand Down

0 comments on commit 1f24739

Please sign in to comment.