Skip to content

Commit

Permalink
Use JaxOPT PGD instead of L-BFGS
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 26, 2025
1 parent 07c192b commit 67f876f
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def compute_contact_forces(

# Create the objective function to minimize as a lambda computing the cost
# from the optimized variables x.
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
objective = lambda x: jnp.sum(jnp.square(A @ x + b))

# ========================================
# Helper function to run the L-BFGS solver
Expand Down Expand Up @@ -466,13 +466,27 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
maxiter = solver_options.pop("maxiter")

# Compute the 3D linear force in C[W] frame.
solution, _ = run_optimization(
init_params=init_params,
# solution, _ = run_optimization(
# init_params=init_params,
# fun=objective,
# opt=optax.lbfgs(**solver_options),
# tol=tol,
# maxiter=maxiter,
# )

import jaxopt

opt = jaxopt.ProjectedGradient(
fun=objective,
opt=optax.lbfgs(**solver_options),
projection=jaxopt.projection.projection_non_negative,
# maxiter=maxiter,
tol=tol,
maxiter=maxiter,
maxls=30,
implicit_diff=False,
# history_size=10,
# max_stepsize=100.0,
)
solution = opt.run(init_params=init_params).params

# Reshape the optimized solution to be a matrix of 3D contact forces.
CW_fl_C = solution.reshape(-1, 3)
Expand Down

0 comments on commit 67f876f

Please sign in to comment.