Skip to content

Commit

Permalink
Merge pull request #268 from ami-iit/fix/deprecation_in_tests
Browse files Browse the repository at this point in the history
Fix API deprecations in tests
  • Loading branch information
flferretti authored Oct 18, 2024
2 parents 2995e47 + 05134cb commit 481b6d8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
7 changes: 3 additions & 4 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_ad_soft_contacts(
m = jax.random.uniform(subkey3, shape=(3,), minval=-1)

# Get the soft contacts parameters.
parameters = js.contact.estimate_good_soft_contacts_parameters(model=model)
parameters = js.contact.estimate_good_contact_parameters(model=model)

# ====
# Test
Expand Down Expand Up @@ -374,8 +374,8 @@ def test_ad_integration(
)

# Initialize the integrator.
t0, dt = 0.0, 0.001
integrator_state = integrator.init(x0=data.state, t0=t0, dt=dt)
t0 = 0.0
integrator_state = integrator.init(x0=data.state, t0=t0, dt=model.time_step)

# Function exposing only the parameters to be differentiated.
def step(
Expand Down Expand Up @@ -408,7 +408,6 @@ def step(
)

data_xf, _ = js.model.step(
dt=dt,
model=model,
data=data_x0,
integrator=integrator,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def test_call_jit_compiled_function_passing_different_objects(

# If this function has never been compiled by any other test, JAX will
# jit-compile it here.
_ = js.contact.estimate_good_soft_contacts_parameters(model=model1)
_ = js.contact.estimate_good_contact_parameters(model=model1)

# Now JAX should not compile it again.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):
# Beyond running without any JIT recompilations, the following function
# should work on different JaxSimModel objects without raising any errors
# related to the comparison of Static fields.
_ = js.contact.estimate_good_soft_contacts_parameters(model=model2)
_ = js.contact.estimate_good_contact_parameters(model=model2)
stdout = buf.getvalue()

assert (
f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}"
f"Compiling {js.contact.estimate_good_contact_parameters.__name__}"
not in stdout
)

Expand Down
6 changes: 2 additions & 4 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def test_box_with_external_forces(

# Initialize the integrator.
tf = 0.5
dt = 0.001
T_ns = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)
T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=model.time_step)

# Copy the initial data...
data = data0.copy()
Expand All @@ -84,7 +83,6 @@ def test_box_with_external_forces(
data, integrator_state = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
link_forces=references.link_forces(model=model, data=data),
Expand Down

0 comments on commit 481b6d8

Please sign in to comment.