Skip to content

Commit

Permalink
add nroll inference tests to rollout
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Nov 25, 2024
1 parent 2f3810d commit d45d30b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/mujoco/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def rollout(model: mujoco.MjModel,
"""Rolls out open-loop trajectories from initial states, get subsequent states and sensor values.
Python wrapper for rollout.cc, see documentation therein.
Infers nstep.
Infers nroll and nstep.
Tiles inputs with singleton dimensions.
Allocates outputs if none are given.
Expand Down Expand Up @@ -67,7 +67,7 @@ def rollout(model: mujoco.MjModel,
ValueError: bad shapes or sizes.
"""
# skip_checks shortcut:
# don't infer nstep
# don't infer nroll/nstep
# don't support singleton expansion
# don't allocate output arrays
# just call rollout and return
Expand Down
104 changes: 104 additions & 0 deletions python/mujoco/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,110 @@ def test_multi_step(self, model_name):
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_initial_state(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nroll, nstate)
control = np.random.randn(nstep, model.nu)
state, sensordata = rollout.rollout(model, data, initial_state, control)

mujoco.mj_resetData(model, data)
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_control(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nroll, nstep, model.nu)
state, sensordata = rollout.rollout(model, data, initial_state, control)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_warmstart(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nstep, model.nu)
initial_warmstart = np.tile(data.qacc_warmstart.copy(), (nroll, 1))
state, sensordata = rollout.rollout(model, data, initial_state, control,
initial_warmstart=initial_warmstart)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_state(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nstep, model.nu)
state = np.empty((nroll, nstep, nstate))
state, sensordata = rollout.rollout(model, data, initial_state, control,
state=state)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_infer_nroll_sensordata(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
data = mujoco.MjData(model)

nroll = 5 # number of rollouts
nstep = 1 # number of steps

initial_state = np.random.randn(nstate)
control = np.random.randn(nstep, model.nu)
sensordata = np.empty((nroll, nstep, model.nsensordata))
state, sensordata = rollout.rollout(model, data, initial_state, control,
sensordata=sensordata)

mujoco.mj_resetData(model, data)
initial_state = np.tile(initial_state, (nroll, 1))
control = np.tile(control, (nroll, 1, 1))
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
np.testing.assert_array_equal(state, py_state)
np.testing.assert_array_equal(sensordata, py_sensordata)

@parameterized.parameters(ALL_MODELS.keys())
def test_one_rollout_fixed_ctrl(self, model_name):
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
Expand Down

0 comments on commit d45d30b

Please sign in to comment.