Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
3a95b62 by Levi Burner <[email protected]>:

rollout prototype native threadpool for comparing to python threads

--
efd8be1 by Levi Burner <[email protected]>:

copy mjpcs threadpool into python bindings

--
75603ee by Levi Burner <[email protected]>:

rollout use threadpool as translation unit

--
06b90fe by Levi Burner <[email protected]>:

rollout add chunk_divisor parameter

--
298ab2f by Levi Burner <[email protected]>:

rollout add native threading test

--
169cf99 by Levi Burner <[email protected]>:

rollout exchange chunk_divisor arg for chunk_size

--
265af85 by Levi Burner <[email protected]>:

rollout fix cosmetics

--
1e8bffa by Levi Burner <[email protected]>:

make native rollout a class instead of a function

--
ba78821 by Levi Burner <[email protected]>:

rollout update docs and changelog

--
e4cb773 by Levi Burner <[email protected]>:

rollout don't register atexit handler for Rollout objects

--
5a08d2e by Levi Burner <[email protected]>:

rollout nthread kwarg, rename shutdown_pool to close, fixups

--
f622378 by Levi Burner <[email protected]>:

rollout add missing .close() calls

--
50f3ebc by Levi Burner <[email protected]>:

rollout return immediately

COPYBARA_INTEGRATE_REVIEW=#2282 from aftersomemath:rollout-threaded 50f3ebc
PiperOrigin-RevId: 706744277
Change-Id: I1ab2263b7d6ce30cf1908aec8fd5f2eb976a19e6
  • Loading branch information
aftersomemath authored and copybara-github committed Dec 16, 2024
1 parent b26d6f0 commit a7eb6ef
Show file tree
Hide file tree
Showing 8 changed files with 761 additions and 203 deletions.
8 changes: 8 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ Changelog
Upcoming version (not yet released)
-----------------------------------

Python bindings
^^^^^^^^^^^^^^^
- :ref:`rollout<PyRollout>` now features native multi-threading. If a sequence of ``MjData`` instances
of length ``nthread`` is passed in, ``rollout`` will automatically create a thread pool and parallelize
the computation. The thread pool can be resused across calls, but then the function cannot be called simultaneously
from multiple threads. To run multiple threaded rollouts simultaneously, use the new class ``Rollout`` which
encapsulates the thread pool. Contribution by :github:user:`aftersomemath`.

Bug fixes
^^^^^^^^^
- Fixed a bug in the box-sphere collider, depth was incorrect for deep penetrations (:github:issue:`2206`).
Expand Down
44 changes: 37 additions & 7 deletions doc/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -711,18 +711,20 @@ The ``mujoco`` package contains two sub-modules: ``mujoco.rollout`` and ``mujoco
rollout
-------

``mujoco.rollout`` shows how to add additional C/C++ functionality, exposed as a Python module via pybind11. It is
implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
``mujoco.rollout`` and ``mujoco.rollout.Rollout`` shows how to add additional C/C++ functionality, exposed as a Python module
via pybind11. It is implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
and wrapped in `rollout.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.py>`__. The module
performs a common functionality where tight loops implemented outside of Python are beneficial: rolling out a trajectory
(i.e., calling :ref:`mj_step` in a loop), given an intial state and sequence of controls, and returning subsequent
states and sensor values. The basic usage form is
states and sensor values. The rollouts are run in parallel with an internally managed thread pool if multiple MjData instances
(one per thread) are passed as an argument. The basic usage form is

.. code-block:: python
state, sensordata = rollout.rollout(model, data, initial_state, control)
``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll``.
``data`` is either a single instance of MjData or a sequence of compatible MjData of length ``nthread``.
``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are
Expand All @@ -732,13 +734,41 @@ specified by passing an optional ``control_spec`` bitflag.
If a rollout diverges, the current state and sensor values are used to fill the remainder of the trajectory.
Therefore, non-increasing time values can be used to detect diverged rollouts.

The ``rollout`` function is designed to be completely stateless, so all inputs of the stepping pipeline are set and any
The ``rollout`` function is designed to be computationally stateless, so all inputs of the stepping pipeline are set and any
values already present in the given ``MjData`` instance will have no effect on the output.

Since the Global Interpreter Lock can be released, this function can be efficiently threaded using Python threads. See
the ``test_threading`` function in
By default ``rollout.rollout`` creates a new thread pool every call if ``len(data) > 1``. To reuse the thread pool
over multiple calls use the ``persistent_pool`` argument. ``rollout.rollout`` is not thread safe when using
a persistent pool. The basic usage form is

.. code-block:: python
state, sensordata = rollout.rollout(model, data, initial_state, persistent_pool=True)
The pool is shutdown on interpreter shutdown or by a call to ``rollout.shutdown_persistent_pool``.

To use multiple thread pools from multiple threads, use ``Rollout`` objects. The basic usage form is

.. code-block:: python
# Pool shutdown upon exiting block.
with rollout.Rollout(nthread=nthread) as rollout_:
rollout_.rollout(model, data, initial_state)
or

.. code-block:: python
# Pool shutdown on object deletion or call to rollout_.close().
# To ensure clean shutdown of threads, call close() before interpreter exit.
rollout_ = rollout.Rollout(nthread=nthread)
rollout_.rollout(model, data, initial_state)
rollout_.close()
Since the Global Interpreter Lock is released, this function can also be threaded using Python threads. However, this
is less efficient than using native threads. See the ``test_threading`` function in
`rollout_test.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout_test.py>`__ for an example
of threaded operation (and more generally for usage examples).
of threaded operation (and for more general usage examples).

.. _PyMinimize:

Expand Down
2 changes: 1 addition & 1 deletion python/mujoco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ target_link_libraries(
structs_header
)

mujoco_pybind11_module(_rollout rollout.cc)
mujoco_pybind11_module(_rollout rollout.cc threadpool.cc)
target_link_libraries(_rollout PRIVATE functions_header mujoco raw)

mujoco_pybind11_module(
Expand Down
226 changes: 162 additions & 64 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
// limitations under the License.

#include <iostream>
#include <memory>
#include <optional>
#include <sstream>

#include <mujoco/mujoco.h>
#include "errors.h"
#include "raw.h"
#include "structs.h"
#include "threadpool.h"
#include <pybind11/buffer_info.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
Expand All @@ -31,14 +33,24 @@ namespace {

namespace py = ::pybind11;

using PyCArray = py::array_t<mjtNum, py::array::c_style>;

// NOLINTBEGIN(whitespace/line_length)

const auto rollout_init_doc = R"(
Construct a rollout object containing a thread pool for parallel rollouts.
input arguments (optional):
nthread integer, number of threads in pool
if zero, this pool is not started and rollouts run on the calling thread
)";

const auto rollout_doc = R"(
Roll out open-loop trajectories from initial states, get resulting states and sensor values.
input arguments (required):
model list of MjModel instances of length nroll
data associated instance of MjData
data list of associated MjData instances of length nthread
nstep integer, number of steps to be taken for each trajectory
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
state0 (nroll x nstate) nroll initial state vectors,
Expand All @@ -49,12 +61,14 @@ Roll out open-loop trajectories from initial states, get resulting states and se
output arguments (optional):
state (nroll x nstep x nstate) nroll nstep states
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
chunk_size integer, determines threadpool chunk size. If unspecified
chunk_size = max(1, nroll / (nthread * 10))
)";

// C-style rollout function, assumes all arguments are valid
// all input fields of d are initialised, contents at call time do not matter
// after returning, d will contain the last step of the last rollout
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int nstep, unsigned int control_spec,
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
mjtNum* state, mjtNum* sensordata) {
// sizes
Expand All @@ -75,7 +89,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}

// loop over rollouts
for (int r = 0; r < nroll; r++) {
for (int r = start_roll; r < end_roll; r++) {
// clear user inputs if unspecified
if (!(control_spec & mjSTATE_MOCAP_POS)) {
for (int i = 0; i < nbody; i++) {
Expand Down Expand Up @@ -158,6 +172,43 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}
}

// C-style threaded version of _unsafe_rollout
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
int nroll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0,
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
ThreadPool* pool, int chunk_size) {
int nfulljobs = nroll / chunk_size;
int chunk_remainder = nroll % chunk_size;
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;

// Reset the pool counter
pool->ResetCount();

// schedule all jobs of full (chunk) size
for (int j = 0; j < nfulljobs; j++) {
auto task = [=, &m, &d](void) {
int id = pool->WorkerId();
_unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool->Schedule(task);
}

// schedule any remaining jobs of size < chunk_size
if (chunk_remainder > 0) {
auto task = [=, &m, &d](void) {
_unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size,
nfulljobs*chunk_size+chunk_remainder,
nstep, control_spec, state0, warmstart0, control, state, sensordata);
};
pool->Schedule(task);
}

// wait for job counter to incremented up to the number of jobs submitted by this thread
pool->WaitCount(njobs);
}

// NOLINTEND(whitespace/line_length)

// check size of optional argument to rollout(), return raw pointer
Expand All @@ -181,71 +232,118 @@ mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg,
return static_cast<mjtNum*>(info.ptr);
}

class Rollout {
public:
Rollout(int nthread) : nthread_(nthread) {
if (this->nthread_ > 0) {
this->pool_ = std::make_shared<ThreadPool>(this->nthread_);
}
}

PYBIND11_MODULE(_rollout, pymodule) {
namespace py = ::pybind11;
using PyCArray = py::array_t<mjtNum, py::array::c_style>;

// roll out open loop trajectories from multiple initial states
// get subsequent states and corresponding sensor values
pymodule.def(
"rollout",
[](py::list m, MjDataWrapper& d,
int nstep, unsigned int control_spec,
const PyCArray state0,
std::optional<const PyCArray> warmstart0,
std::optional<const PyCArray> control,
std::optional<const PyCArray> state,
std::optional<const PyCArray> sensordata
) {
// get raw pointers
int nroll = state0.shape(0);
std::vector<const raw::MjModel*> model_ptrs(nroll);
for (int r = 0; r < nroll; r++) {
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
}
raw::MjData* data = d.get();
void rollout(py::list m, py::list d, int nstep, unsigned int control_spec,
const PyCArray state0, std::optional<const PyCArray> warmstart0,
std::optional<const PyCArray> control,
std::optional<const PyCArray> state,
std::optional<const PyCArray> sensordata,
std::optional<int> chunk_size) {
// get raw pointers
int nroll = state0.shape(0);
std::vector<const raw::MjModel*> model_ptrs(nroll);
for (int r = 0; r < nroll; r++) {
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
}

// check that some steps need to be taken, return if not
if (nstep < 1) {
return;
}
// check length d and nthread are consistent
if (this->nthread_ == 0 && py::len(d) > 1) {
std::ostringstream msg;
msg << "More than one data instance passed but "
<< "rollout is configured to run on main thread";
py::value_error(msg.str());
} else if (this->nthread_ != py::len(d)) {
std::ostringstream msg;
msg << "Length of data: " << py::len(d)
<< " not equal to nthread: " << this->nthread_;
py::value_error(msg.str());
}

std::vector<raw::MjData*> data_ptrs(py::len(d));
for (int t = 0; t < py::len(d); t++) {
data_ptrs[t] = d[t].cast<MjDataWrapper*>()->get();
}

// check that some steps need to be taken, return if not
if (nstep < 1) {
return;
}

// get sizes
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(model_ptrs[0], control_spec);

mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
mjtNum* warmstart0_ptr = get_array_ptr(warmstart0, "warmstart0", nroll,
1, model_ptrs[0]->nv);
mjtNum* control_ptr = get_array_ptr(control, "control", nroll,
nstep, ncontrol);
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll,
nstep, model_ptrs[0]->nsensordata);

// perform rollouts
{
// release the GIL
py::gil_scoped_release no_gil;

// call unsafe rollout function
InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
// get sizes
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(model_ptrs[0], control_spec);

mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
mjtNum* warmstart0_ptr =
get_array_ptr(warmstart0, "warmstart0", nroll, 1, model_ptrs[0]->nv);
mjtNum* control_ptr =
get_array_ptr(control, "control", nroll, nstep, ncontrol);
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll,
nstep, model_ptrs[0]->nsensordata);

// perform rollouts
{
// release the GIL
py::gil_scoped_release no_gil;

// call unsafe rollout function, multi or single threaded
if (this->nthread_ > 0 && nroll > 1) {
int chunk_size_final = 1;
if (!chunk_size.has_value()) {
chunk_size_final = std::max(1, nroll / (10 * this->nthread_));
} else {
chunk_size_final = *chunk_size;
}
},
py::arg("model"),
py::arg("data"),
py::arg("nstep"),
py::arg("control_spec"),
py::arg("state0"),
py::arg("warmstart0") = py::none(),
py::arg("control") = py::none(),
py::arg("state") = py::none(),
py::arg("sensordata") = py::none(),
py::doc(rollout_doc)
);
InterceptMjErrors(_unsafe_rollout_threaded)(
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
this->pool_.get(), chunk_size_final);
} else {
InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
}
}
}

private:
int nthread_;
std::shared_ptr<ThreadPool> pool_;
};

PYBIND11_MODULE(_rollout, pymodule) {
namespace py = ::pybind11;

py::class_<Rollout>(pymodule, "Rollout")
.def(
py::init([](int nthread) {
return std::make_unique<Rollout>(nthread);
}),
py::kw_only(),
py::arg("nthread"),
py::doc(rollout_init_doc))
.def(
"rollout",
&Rollout::rollout,
py::arg("model"),
py::arg("data"),
py::arg("nstep"),
py::arg("control_spec"),
py::arg("state0"),
py::arg("warmstart0") = py::none(),
py::arg("control") = py::none(),
py::arg("state") = py::none(),
py::arg("sensordata") = py::none(),
py::arg("chunk_size") = py::none(),
py::doc(rollout_doc));
}

} // namespace
Expand Down
Loading

0 comments on commit a7eb6ef

Please sign in to comment.