Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a persistent thread pool to rollout #2282

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ Bug fixes
^^^^^^^^^
- Fixed a bug in the box-sphere collider, depth was incorrect for deep penetrations (:github:issue:`2206`).

Python bindings
^^^^^^^^^^^^^^^
- :ref:`rollout<PyRollout>` now features native multi-threading. If a sequence of MjData instances
of length ``nthread`` is passed in, ``rollout`` will automatically create a thread pool and parallelize
the computation. The thread pool can resused across calls, but then the function cannot be called simultaneously
from multiple threads. To run multiple threaded rollouts simultaneously, use the new class ``Rollout`` which
encapsulates the thread pool. Contribution by :github:user:`aftersomemath`.

Version 3.2.6 (Dec 2, 2024)
---------------------------

Expand Down
44 changes: 37 additions & 7 deletions doc/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -711,18 +711,20 @@ The ``mujoco`` package contains two sub-modules: ``mujoco.rollout`` and ``mujoco
rollout
-------

``mujoco.rollout`` shows how to add additional C/C++ functionality, exposed as a Python module via pybind11. It is
implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
``mujoco.rollout`` and ``mujoco.rollout.Rollout`` shows how to add additional C/C++ functionality, exposed as a Python module
via pybind11. It is implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
and wrapped in `rollout.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.py>`__. The module
performs a common functionality where tight loops implemented outside of Python are beneficial: rolling out a trajectory
(i.e., calling :ref:`mj_step` in a loop), given an intial state and sequence of controls, and returning subsequent
states and sensor values. The basic usage form is
states and sensor values. The rollouts are run in parallel with an internally managed thread pool if multiple MjData instances
(one per thread) are passed as an argument. The basic usage form is

.. code-block:: python

state, sensordata = rollout.rollout(model, data, initial_state, control)

``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll``.
``data`` is either a single instance of MjData or a sequence of compatible MjData of length ``nthread``.
``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are
Expand All @@ -732,13 +734,41 @@ specified by passing an optional ``control_spec`` bitflag.
If a rollout diverges, the current state and sensor values are used to fill the remainder of the trajectory.
Therefore, non-increasing time values can be used to detect diverged rollouts.

The ``rollout`` function is designed to be completely stateless, so all inputs of the stepping pipeline are set and any
The ``rollout`` function is designed to be computationally stateless, so all inputs of the stepping pipeline are set and any
values already present in the given ``MjData`` instance will have no effect on the output.

Since the Global Interpreter Lock can be released, this function can be efficiently threaded using Python threads. See
the ``test_threading`` function in
By default ``rollout.rollout`` creates a new thread pool every call if ``len(data) > 1``. To reuse the thread pool
over multiple calls use the ``persistent_pool`` argument. ``rollout.rollout`` is not thread safe when using
a persistent pool. The basic usage form is

.. code-block:: python

state, sensordata = rollout.rollout(model, data, initial_state, persistent_pool=True)

The pool is shutdown on interpreter shutdown or by a call to ``rollout.shutdown_persistent_pool``.

To use multiple thread pools from multiple threads, use ``Rollout`` objects. The basic usage form is

.. code-block:: python

# Pool shutdown upon exiting block
with rollout.Rollout(nthread=nthread) as rollout_:
rollout_.rollout(model, data, initial_state)

or

.. code-block:: python

# pool shutdown on object deletion or call to rollout_.close
# to ensure clean shutdown of threads, call close before interpreter exit
rollout_ = rollout.Rollout(nthread=nthread)
rollout_.rollout(model, data, initial_state)
rollout_.close()

Since the Global Interpreter Lock is released, this function can also be threaded using Python threads. However, this
is less efficient than using native threads. See the ``test_threading`` function in
`rollout_test.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout_test.py>`__ for an example
of threaded operation (and more generally for usage examples).
of threaded operation (and for more general usage examples).

.. _PyMinimize:

Expand Down
2 changes: 1 addition & 1 deletion python/mujoco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ target_link_libraries(
structs_header
)

mujoco_pybind11_module(_rollout rollout.cc)
mujoco_pybind11_module(_rollout rollout.cc threadpool.cc)
target_link_libraries(_rollout PRIVATE functions_header mujoco raw)

mujoco_pybind11_module(
Expand Down
224 changes: 162 additions & 62 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "errors.h"
#include "raw.h"
#include "structs.h"
#include "threadpool.h"
#include <pybind11/buffer_info.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
Expand All @@ -31,14 +32,24 @@ namespace {

namespace py = ::pybind11;

using PyCArray = py::array_t<mjtNum, py::array::c_style>;

// NOLINTBEGIN(whitespace/line_length)

const auto rollout_init_doc = R"(
Construct a rollout object containing a thread pool for parallel rollouts.

input arguments (optional):
nthread integer, number of threads in pool
if zero, this pool is not started and rollouts run on the calling thread
)";

const auto rollout_doc = R"(
Roll out open-loop trajectories from initial states, get resulting states and sensor values.

input arguments (required):
model list of MjModel instances of length nroll
data associated instance of MjData
data list of associated MjData instances of length nthread
nstep integer, number of steps to be taken for each trajectory
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
state0 (nroll x nstate) nroll initial state vectors,
Expand All @@ -49,12 +60,14 @@ Roll out open-loop trajectories from initial states, get resulting states and se
output arguments (optional):
state (nroll x nstep x nstate) nroll nstep states
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
chunk_size integer, determines threadpool chunk size. If unspecified
chunk_size = max(1, nroll / (nthread * 10))
)";

// C-style rollout function, assumes all arguments are valid
// all input fields of d are initialised, contents at call time do not matter
// after returning, d will contain the last step of the last rollout
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int nstep, unsigned int control_spec,
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
mjtNum* state, mjtNum* sensordata) {
// sizes
Expand All @@ -75,7 +88,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}

// loop over rollouts
for (int r = 0; r < nroll; r++) {
for (int r = start_roll; r < end_roll; r++) {
// clear user inputs if unspecified
if (!(control_spec & mjSTATE_MOCAP_POS)) {
for (int i = 0; i < nbody; i++) {
Expand Down Expand Up @@ -158,6 +171,42 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}
}

// C-style threaded version of _unsafe_rollout
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
int nroll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0,
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
ThreadPool* pool, int chunk_size) {
int nfulljobs = nroll / chunk_size;
int chunk_remainder = nroll % chunk_size;
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;

// Reset the pool counter
pool->ResetCount();

// schedule all jobs of full (chunk) size
for (int j = 0; j < nfulljobs; j++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// schedule all jobs of full (chunk) size

auto task = [=, &m, &d](void) {
int id = pool->WorkerId();
_unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool->Schedule(task);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// schedule any remaining jobs of size < chunk_size

// schedule any remaining jobs of size < chunk_size
if (chunk_remainder > 0) {
auto task = [=, &m, &d](void) {
_unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool->Schedule(task);
}

// wait for job counter to incremented up to the number of jobs submitted by this thread
pool->WaitCount(njobs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// increment job counter for this thread

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with // wait for job counter to incremented up to the number of jobs submitted by this thread

}

// NOLINTEND(whitespace/line_length)

// check size of optional argument to rollout(), return raw pointer
Expand All @@ -181,71 +230,122 @@ mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg,
return static_cast<mjtNum*>(info.ptr);
}

class Rollout {
public:
Rollout(int nthread) : nthread_(nthread) {
if (this->nthread_ > 0) {
this->pool_ = std::make_shared<ThreadPool>(this->nthread_);
}
}

PYBIND11_MODULE(_rollout, pymodule) {
namespace py = ::pybind11;
using PyCArray = py::array_t<mjtNum, py::array::c_style>;

// roll out open loop trajectories from multiple initial states
// get subsequent states and corresponding sensor values
pymodule.def(
"rollout",
[](py::list m, MjDataWrapper& d,
int nstep, unsigned int control_spec,
const PyCArray state0,
std::optional<const PyCArray> warmstart0,
std::optional<const PyCArray> control,
std::optional<const PyCArray> state,
std::optional<const PyCArray> sensordata
) {
// get raw pointers
int nroll = state0.shape(0);
std::vector<const raw::MjModel*> model_ptrs(nroll);
for (int r = 0; r < nroll; r++) {
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
}
raw::MjData* data = d.get();
void rollout(py::list m, py::list d,
int nstep, unsigned int control_spec,
const PyCArray state0,
std::optional<const PyCArray> warmstart0,
std::optional<const PyCArray> control,
std::optional<const PyCArray> state,
std::optional<const PyCArray> sensordata,
std::optional<int> chunk_size
) {
// get raw pointers
int nroll = state0.shape(0);
std::vector<const raw::MjModel*> model_ptrs(nroll);
for (int r = 0; r < nroll; r++) {
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
}

// check that some steps need to be taken, return if not
if (nstep < 1) {
return;
}
// check length d and nthread are consistent
if (this->nthread_ == 0 && py::len(d) > 1) {
std::ostringstream msg;
msg << "More than one data instance passed but "
<< "rollout is configured to run on main thread";
py::value_error(msg.str());
} else if (this->nthread_ != py::len(d)) {
std::ostringstream msg;
msg << "Length of data: " << py::len(d)
<< " not equal to nthread: " << this->nthread_;
py::value_error(msg.str());
}

std::vector<raw::MjData*> data_ptrs(py::len(d));
for (int t = 0; t < py::len(d); t++) {
data_ptrs[t] = d[t].cast<MjDataWrapper*>()->get();
}

// check that some steps need to be taken, return if not
if (nstep < 1) {
return;
}

// get sizes
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(model_ptrs[0], control_spec);

mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
mjtNum* warmstart0_ptr = get_array_ptr(warmstart0, "warmstart0", nroll,
1, model_ptrs[0]->nv);
mjtNum* control_ptr = get_array_ptr(control, "control", nroll,
nstep, ncontrol);
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll,
nstep, model_ptrs[0]->nsensordata);

// perform rollouts
{
// release the GIL
py::gil_scoped_release no_gil;

// call unsafe rollout function
// get sizes
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(model_ptrs[0], control_spec);

mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
mjtNum* warmstart0_ptr = get_array_ptr(warmstart0, "warmstart0", nroll,
1, model_ptrs[0]->nv);
mjtNum* control_ptr = get_array_ptr(control, "control", nroll,
nstep, ncontrol);
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll,
nstep, model_ptrs[0]->nsensordata);

// perform rollouts
{
// release the GIL
py::gil_scoped_release no_gil;

// call unsafe rollout function, multi or single threaded
if (this->nthread_ > 0 && nroll > 1) {
int chunk_size_final = 1;
if (!chunk_size.has_value()) {
chunk_size_final = std::max(1, nroll / (10 * this->nthread_));
} else {
chunk_size_final = *chunk_size;
}
InterceptMjErrors(_unsafe_rollout_threaded)(
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
this->pool_.get(), chunk_size_final);
} else {
InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
}
},
py::arg("model"),
py::arg("data"),
py::arg("nstep"),
py::arg("control_spec"),
py::arg("state0"),
py::arg("warmstart0") = py::none(),
py::arg("control") = py::none(),
py::arg("state") = py::none(),
py::arg("sensordata") = py::none(),
py::doc(rollout_doc)
);
}
}

private:
int nthread_;
std::shared_ptr<ThreadPool> pool_;
};


PYBIND11_MODULE(_rollout, pymodule) {
namespace py = ::pybind11;

py::class_<Rollout>(pymodule, "Rollout")
.def(
py::init([](int nthread) {
return std::make_unique<Rollout>(nthread);
}),
py::kw_only(),
py::arg("nthread"),
py::doc(rollout_init_doc))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add:
py::kw_only(), py:arg("nthread"), to require and allow the use of a keyword argument for nthread?

.def(
"rollout",
&Rollout::rollout,
py::arg("model"),
py::arg("data"),
py::arg("nstep"),
py::arg("control_spec"),
py::arg("state0"),
py::arg("warmstart0") = py::none(),
py::arg("control") = py::none(),
py::arg("state") = py::none(),
py::arg("sensordata") = py::none(),
py::arg("chunk_size") = py::none(),
py::doc(rollout_doc));
}

} // namespace
Expand Down
Loading