Skip to content

Commit

Permalink
Merge pull request #236 from ami-iit/update_crba
Browse files Browse the repository at this point in the history
Simplify CRBA by merging two `jax.lax.cond` branches
  • Loading branch information
diegoferigo authored Sep 19, 2024
2 parents 3587b33 + 9736acb commit df28248
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions src/jaxsim/rbda/crba.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,48 +94,48 @@ def backward_pass(

j = i

CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix]
carry_inner_fn = (j, Fi, M)
FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix]
fake_while_carry = (j, Fi, M)

def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
j, Fi, M = carry
# This internal for loop implements the while loop of the CRBA algorithm
# to compute off-diagonal blocks of the mass matrix M.
# In pseudocode it is implemented as a while loop. However, in order to enable
# applying reverse-mode AD, we implement it as a nested for loop with a fixed
# number of iterations and a branching model to skip for loop iterations.
def fake_while_loop(
carry: FakeWhileCarry, i: jtp.Int
) -> tuple[FakeWhileCarry, None]:

Fi = i_X_λi[j].T @ Fi
j = λ[j]
jj = j - 1
def compute(carry: FakeWhileCarry) -> FakeWhileCarry:

M_ij = Fi.T @ S[j]
j, Fi, M = carry

M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
Fi = i_X_λi[j].T @ Fi
j = λ[j]

return j, Fi, M
M_ij = Fi.T @ S[j]

# The following functions are part of a (rather messy) workaround for computing
# a while loop using a for loop with fixed number of iterations.
def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]:
def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
j, _, _ = carry
out = jax.lax.cond(
pred=(λ[j] > 0),
true_fun=while_loop_body,
false_fun=lambda carry: carry,
operand=carry,
)
return out, None
jj = j - 1
M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())

return j, Fi, M

j, _, _ = carry
return jax.lax.cond(
pred=(k == j),
true_fun=compute_inner,
false_fun=lambda carry: (carry, None),

j, Fi, M = jax.lax.cond(
pred=jnp.logical_and(i == λ[j], λ[j] > 0),
true_fun=compute,
false_fun=lambda carry: carry,
operand=carry,
)

return (j, Fi, M), None

(j, Fi, M), _ = (
jax.lax.scan(
f=inner_fn,
init=carry_inner_fn,
f=fake_while_loop,
init=fake_while_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
Expand Down

0 comments on commit df28248

Please sign in to comment.