From 789734b486e5320ca7c18d2e4ae099be28f14a8a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 18 Oct 2024 16:37:07 +0200 Subject: [PATCH 1/2] Update `estimate_good_soft_contacts_parameters` in tests --- tests/test_automatic_differentiation.py | 2 +- tests/test_pytree.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 84b1e498e..2a53e4182 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -295,7 +295,7 @@ def test_ad_soft_contacts( m = jax.random.uniform(subkey3, shape=(3,), minval=-1) # Get the soft contacts parameters. - parameters = js.contact.estimate_good_soft_contacts_parameters(model=model) + parameters = js.contact.estimate_good_contact_parameters(model=model) # ==== # Test diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 561d5ac53..c2fcc0149 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -29,7 +29,7 @@ def test_call_jit_compiled_function_passing_different_objects( # If this function has never been compiled by any other test, JAX will # jit-compile it here. - _ = js.contact.estimate_good_soft_contacts_parameters(model=model1) + _ = js.contact.estimate_good_contact_parameters(model=model1) # Now JAX should not compile it again. with jax.log_compiles(): @@ -37,11 +37,11 @@ def test_call_jit_compiled_function_passing_different_objects( # Beyond running without any JIT recompilations, the following function # should work on different JaxSimModel objects without raising any errors # related to the comparison of Static fields. - _ = js.contact.estimate_good_soft_contacts_parameters(model=model2) + _ = js.contact.estimate_good_contact_parameters(model=model2) stdout = buf.getvalue() assert ( - f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}" + f"Compiling {js.contact.estimate_good_contact_parameters.__name__}" not in stdout ) From 05134cbbc984ccff5d6d5f994a633823970fe072 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 18 Oct 2024 16:42:36 +0200 Subject: [PATCH 2/2] Remove `dt` definition in favor of `model.time_step` --- tests/test_automatic_differentiation.py | 5 ++--- tests/test_simulations.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 2a53e4182..ca4c75ecb 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -374,8 +374,8 @@ def test_ad_integration( ) # Initialize the integrator. - t0, dt = 0.0, 0.001 - integrator_state = integrator.init(x0=data.state, t0=t0, dt=dt) + t0 = 0.0 + integrator_state = integrator.init(x0=data.state, t0=t0, dt=model.time_step) # Function exposing only the parameters to be differentiated. def step( @@ -408,7 +408,6 @@ def step( ) data_xf, _ = js.model.step( - dt=dt, model=model, data=data_x0, integrator=integrator, diff --git a/tests/test_simulations.py b/tests/test_simulations.py index f93edcf0b..7aabb90cb 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -71,9 +71,8 @@ def test_box_with_external_forces( # Initialize the integrator. tf = 0.5 - dt = 0.001 - T_ns = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int) - integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt) + T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) + integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=model.time_step) # Copy the initial data... data = data0.copy() @@ -84,7 +83,6 @@ def test_box_with_external_forces( data, integrator_state = js.model.step( model=model, data=data, - dt=dt, integrator=integrator, integrator_state=integrator_state, link_forces=references.link_forces(model=model, data=data),