Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove nroll argument from rollout #2246

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Python bindings
- Provide prebuilt wheels for Python 3.13.
- Added ``bind`` method and removed id attribute from :ref:`mjSpec` objects. Using ids is error prone in scenarios of repeated attachment and
detachment. Python users are encouraged to use names for unique identification of model elements.
- Removed ``nroll`` argument from :ref:`rollout<PyRollout>` because its value can always be inferred.

Bug fixes
^^^^^^^^^
Expand Down
7 changes: 3 additions & 4 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ Roll out open-loop trajectories from initial states, get resulting states and se
input arguments (required):
model instance of MjModel
data associated instance of MjData
nroll integer, number of initial states from which to roll out trajectories
nstep integer, number of steps to be taken for each trajectory
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
state0 (nroll x nstate) nroll initial state vectors,
Expand Down Expand Up @@ -190,7 +189,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
pymodule.def(
"rollout",
[](const MjModelWrapper& m, MjDataWrapper& d,
int nroll, int nstep, unsigned int control_spec,
int nstep, unsigned int control_spec,
const PyCArray state0,
std::optional<const PyCArray> warmstart0,
std::optional<const PyCArray> control,
Expand All @@ -201,13 +200,14 @@ PYBIND11_MODULE(_rollout, pymodule) {
raw::MjData* data = d.get();

// check that some steps need to be taken, return if not
if (nroll < 1 || nstep < 1) {
if (nstep < 1) {
return;
}

// get sizes
int nstate = mj_stateSize(model, mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(model, control_spec);
int nroll = state0.shape(0);

// get raw pointers
mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
Expand All @@ -232,7 +232,6 @@ PYBIND11_MODULE(_rollout, pymodule) {
},
py::arg("model"),
py::arg("data"),
py::arg("nroll"),
py::arg("nstep"),
py::arg("control_spec"),
py::arg("state0"),
Expand Down
10 changes: 3 additions & 7 deletions python/mujoco/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def rollout(model: mujoco.MjModel,
*, # require subsequent arguments to be named
control_spec: int = mujoco.mjtState.mjSTATE_CTRL.value,
skip_checks: bool = False,
nroll: Optional[int] = None,
nstep: Optional[int] = None,
initial_warmstart: Optional[npt.ArrayLike] = None,
state: Optional[npt.ArrayLike] = None,
Expand All @@ -50,7 +49,6 @@ def rollout(model: mujoco.MjModel,
([nroll or 1] x [nstep or 1] x ncontrol)
control_spec: mjtState specification of control vectors.
skip_checks: Whether to skip internal shape and type checks.
nroll: Number of rollouts (inferred if unspecified).
nstep: Number of steps in rollouts (inferred if unspecified).
initial_warmstart: Initial qfrc_warmstart array (optional).
([nroll or 1] x nv)
Expand All @@ -74,7 +72,7 @@ def rollout(model: mujoco.MjModel,
# don't allocate output arrays
# just call rollout and return
if skip_checks:
_rollout.rollout(model, data, nroll, nstep, control_spec, initial_state,
_rollout.rollout(model, data, nstep, control_spec, initial_state,
initial_warmstart, control, state, sensordata)
return state, sensordata

Expand All @@ -83,8 +81,6 @@ def rollout(model: mujoco.MjModel,
raise ValueError('control_spec can only contain bits in mjSTATE_USER')

# check types
if nroll and not isinstance(nroll, int):
raise ValueError('nroll must be an integer')
if nstep and not isinstance(nstep, int):
raise ValueError('nstep must be an integer')
_check_must_be_numeric(
Expand Down Expand Up @@ -121,7 +117,7 @@ def rollout(model: mujoco.MjModel,
_check_trailing_dimension(model.nsensordata, sensordata=sensordata)

# infer nroll, check for incompatibilities
nroll = _infer_dimension(0, nroll or 1,
nroll = _infer_dimension(0, 1,
initial_state=initial_state,
initial_warmstart=initial_warmstart,
control=control,
Expand All @@ -146,7 +142,7 @@ def rollout(model: mujoco.MjModel,
sensordata = np.empty((nroll, nstep, model.nsensordata))

# call rollout
_rollout.rollout(model, data, nroll, nstep, control_spec, initial_state,
_rollout.rollout(model, data, nstep, control_spec, initial_state,
initial_warmstart, control, state, sensordata)

# return outputs
Expand Down
106 changes: 105 additions & 1 deletion python/mujoco/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,110 @@ def test_multi_step(self, model_name):
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_initial_state(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nroll, nstate)
control = np.random.randn(nstep, model.nu)
state, sensordata = rollout.rollout(model, data, initial_state, control)

mujoco.mj_resetData(model, data)
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_control(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nroll, nstep, model.nu)
state, sensordata = rollout.rollout(model, data, initial_state, control)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_warmstart(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nstep, model.nu)
initial_warmstart = np.tile(data.qacc_warmstart.copy(), (nroll, 1))
state, sensordata = rollout.rollout(model, data, initial_state, control,
initial_warmstart=initial_warmstart)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_state(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nstep, model.nu)
state = np.empty((nroll, nstep, nstate))
state, sensordata = rollout.rollout(model, data, initial_state, control,
state=state)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_sensordata(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nstep, model.nu)
sensordata = np.empty((nroll, nstep, model.nsensordata))
state, sensordata = rollout.rollout(model, data, initial_state, control,
sensordata=sensordata)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_one_rollout_fixed_ctrl(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
Expand Down Expand Up @@ -328,7 +432,7 @@ def thread_initializer():

def call_rollout(initial_state, control, state, sensordata):
rollout.rollout(model, thread_local.data, initial_state, control,
skip_checks=True, nroll=initial_state.shape[0],
skip_checks=True,
nstep=nstep, state=state, sensordata=sensordata)

n = nroll // num_workers # integer division
Expand Down
Loading