Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rollout: fix bugs in checking list lengths, add tests #2377

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,20 @@ class Rollout {
}

// check length d and nthread are consistent
if (this->nthread_ == 0 && py::len(d) > 1) {
if (py::len(d) == 0) {
std::ostringstream msg;
msg << "The list of data instances is empty";
throw py::value_error(msg.str());
} else if (this->nthread_ == 0 && py::len(d) > 1) {
std::ostringstream msg;
msg << "More than one data instance passed but "
<< "rollout is configured to run on main thread";
py::value_error(msg.str());
} else if (this->nthread_ != py::len(d)) {
throw py::value_error(msg.str());
} else if (this->nthread_ > 0 && this->nthread_ != py::len(d)) {
std::ostringstream msg;
msg << "Length of data: " << py::len(d)
<< " not equal to nthread: " << this->nthread_;
py::value_error(msg.str());
throw py::value_error(msg.str());
}

std::vector<raw::MjData*> data_ptrs(py::len(d));
Expand Down
2 changes: 1 addition & 1 deletion python/mujoco/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def rollout(
if isinstance(model, list) and nroll == 1:
nroll = len(model)

if isinstance(model, list) and len(model) != nroll:
if isinstance(model, list) and len(model) > 1 and len(model) != nroll:
raise ValueError(
f'nroll inferred as {nroll} but model is length {len(model)}'
)
Expand Down
82 changes: 81 additions & 1 deletion python/mujoco/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""tests for rollout function."""

import concurrent.futures
import copy
import threading

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

import mujoco
from mujoco import rollout
import numpy as np

# -------------------------- models used for testing ---------------------------

Expand Down Expand Up @@ -794,6 +796,84 @@ def test_stateless(self):
np.testing.assert_array_equal(state, state2)
np.testing.assert_array_equal(sensordata, sensordata2)

def test_length_one_model_list(self):
model = mujoco.MjModel.from_xml_string(TEST_XML)
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

initial_state = np.random.randn(nstate)
control = np.random.randn(3, 3, model.nu)

state, sensordata = rollout.rollout(model, data, initial_state, control)
state2, sensordata2 = rollout.rollout([model], data, initial_state, control)

# assert that we get same outputs
np.testing.assert_array_equal(state, state2)
np.testing.assert_array_equal(sensordata, sensordata2)

def test_data_sizes(self):
model = mujoco.MjModel.from_xml_string(TEST_XML)
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

initial_state = np.random.randn(nstate)
control = np.random.randn(3, 3, model.nu)

# Test passing empty lists for data
with self.assertRaisesWithLiteralMatch(
ValueError, 'The list of data instances is empty'
):
rollout.rollout(model, [], initial_state, control)

with self.assertRaisesWithLiteralMatch(
ValueError, 'The list of data instances is empty'
):
with rollout.Rollout(nthread=0) as rollout_:
rollout_.rollout(model, [], initial_state, control)

with self.assertRaisesWithLiteralMatch(
ValueError, 'The list of data instances is empty'
):
with rollout.Rollout(nthread=1) as rollout_:
rollout_.rollout(model, [], initial_state, control)

with self.assertRaisesWithLiteralMatch(
ValueError, 'The list of data instances is empty'
):
with rollout.Rollout(nthread=2) as rollout_:
rollout_.rollout(model, [], initial_state, control)

# Test checking that len(data) equals nthread
with self.assertRaisesWithLiteralMatch(
ValueError,
'More than one data instance passed but rollout is configured to run on'
' main thread',
):
with rollout.Rollout(nthread=0) as rollout_:
rollout_.rollout(
model, [copy.copy(data) for i in range(2)], initial_state, control
)

with self.assertRaisesWithLiteralMatch(
ValueError, 'Length of data: 1 not equal to nthread: 2'
):
with rollout.Rollout(nthread=2) as rollout_:
rollout_.rollout(model, data, initial_state, control)

with self.assertRaisesWithLiteralMatch(
ValueError, 'Length of data: 1 not equal to nthread: 2'
):
with rollout.Rollout(nthread=2) as rollout_:
rollout_.rollout(model, [data], initial_state, control)

with self.assertRaisesWithLiteralMatch(
ValueError, 'Length of data: 3 not equal to nthread: 2'
):
with rollout.Rollout(nthread=2) as rollout_:
rollout_.rollout(
model, [copy.copy(data) for i in range(3)], initial_state, control
)


# -------------- Python implementation of rollout functionality ----------------

Expand Down