Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a persistent thread pool to rollout #2282

Closed

Conversation

aftersomemath
Copy link
Contributor

rollout now accepts a list of MjData of length nthread. If passed, the nroll rollouts are parallelized.

The thread pool is persistent across calls to rollout, so long as the number of MjData instances is constant.

The thread pool implementation was copied from MJPC. It is not threadsafe. This means that if rollout is using threads, it should not be called from multiple threads.

@@ -158,6 +161,47 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
}
}

// C-style threaded version of _unsafe_rollout
static ThreadPool* pool = nullptr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used a raw pointer here instead of a smart pointer because the style guide states that objects with static duration must be trivially destructible. My understanding is that smart pointers are not trivially destructible.

Copy link
Collaborator

@yuvaltassa yuvaltassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial cosmetic review

@@ -9,6 +9,11 @@ Bug fixes
^^^^^^^^^
- Fixed a bug in the box-sphere collider, depth was incorrect for deep penetrations (:github:issue:`2206`).

Python bindings
^^^^^^^^^^^^^^^
- :ref:`rollout<PyRollout>` can now accept sequences of MjData of length ``nthread``. If passed, :ref:`rollout<PyRollout>`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- :ref:`rollout<PyRollout>` now features native multi-threading. If a sequence of MjData instances
  of length ``nthread`` is passed in, ``rollout`` will automatically create a persistent threadpool
  and parallelize the computation. Contribution by :github:user:`aftersomemath`.

if (!chunk_size.has_value()) {
chunk_size_final = std::max(1, nroll / (10 * nthread));
}
else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The else on a new line thing is something that we do in C code, and only to facilitate putting a comment above the block. Let's attach all else's } else { in C++ code as in the style guide.

int chunk_remainder = nroll % chunk_size;
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;

if (pool == nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added block comments, slightly refactored

  // if existing thread pool of the wrong size, delete
  if (pool->NumThreads() != nthread) {
    delete pool;
  }

  // make threadpool if required
  if (pool == nullptr) {
    pool = new ThreadPool(nthread);
  } else {
    pool->ResetCount();
  }

pool->ResetCount();
}

for (int j = 0; j < nfulljobs; j++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// schedule all jobs of full (chunk) size

};
pool->Schedule(task);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// schedule any remaining jobs of size < chunk_size

pool->Schedule(task);
}

pool->WaitCount(njobs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// increment job counter for this thread

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with // wait for job counter to incremented up to the number of jobs submitted by this thread

InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
if (nthread > 1 && nroll > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// call unsafe rollout function, multi or single threaded

@@ -59,6 +60,8 @@ def rollout(
(nroll x nstep x nstate)
sensordata: Sensor data output array (optional).
(nroll x nstep x nsensordata)
chunk_size: Determines threadpool chunk size. If unspecified,
chunk_size = max(1, nroll / (nthread * 10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing closing paren

Copy link
Collaborator

@yuvaltassa yuvaltassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're here, fix an unused var complaint by our internal linter: in rollout_test.py:358 replace i with _

int chunk_remainder = nroll % chunk_size;
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;

if (pool == nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ironically, the management of this pool object is not thread safe.

i.e. if the rollout function is called from multiple threads, and we drop the Python GIL (https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc#L230), then you may end up deleting a thread pool that is being used, or creating two threadpools.

I would lean towards creating a new threadpool on every call to rollout. Have you measured a serious performance issue with that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nimrod-gileadi, this is clearly stated in the PR description and also discussed in the related design doc 🙂

The short answer is that yes, there is a measurable cost to starting up the threadpool, especially for smallish tasks. @aftersomemath can perhaps provide some numbers here to guide the discussion. I think the options are:

  1. Come up with a mechanism to make this threadsafe, that does not hurt performance or overcomplicate the API.
  2. Clearly document that multi-threading is not thread safe.

I think option 2 is fine, even though when pressed I can think of use cases for multi-multi-threaded rollouts. Perhaps @kevinzakka and/or @btaba would have an opinion on options 1 vs 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just measured threadpool creation overhead for the case that the pool is recreated every call. I tested the models, hopper, go 2, and humanoid. nroll = [50, 3000, 100] while nstep = [1000, 1, 50] respectively. For hopper the overhead is 2-30%. unitree go 2 is 0.3-4%. humanoid is 0.5-9%. This seems significant, at least in my view.

Yes, the management of the pool is not threadsafe. But also the pool cannot be used from multiple threads because of the way the condition variable and WaitCount work.

Some alternatives that allow a persistent threadpool and allow rollout to be threadsafe are:

  1. let the user pass a reference to the pool to rollout (requires wrapping the pool with pybind, or at least managing it)
  2. make rollout a function object (easy)
  3. let the user start/stop the threadpool and change the threadpools design to allow usage from multiple threads (this will result in a rather complex threadpool implementation)

I'm in favor of implementing both the current approach and, if threading-threaded-rollouts is, important, option 2. The current approach is not mutually exclusive with 2 and 2 is a breaking API change that requires the user to manage an object instead of just calling a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it the threadpool is threadsafe, the issue is just that the pool is not compatible with multithreaded usage. My bad, I should have been more clear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aftersomemath, I like option 2 of making a class that can do rollouts and holds its own reference to a threadpool for as long as it's needed.

The issue with the current design is not only about thread safety, but also about leaking memory and threads once the threadpool is no longer needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pardon my ignorance, can someone spell out Option 2 more clearly?

Copy link
Contributor Author

@aftersomemath aftersomemath Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usage would look like this if implemented exactly as discussed above:

rollout = rollout.Rollout() # creates an object Rollout that acts like a function
rollout(models, data) # Creates or reuses member variable `rollout.pool`
rollout = None # reference lost, the object will be destroyed and the thread pool shut down cleanly

Usage could also look like this:

with Rollout() as rollout:
    rollout(models, data)

In either case the thread pool will be shut down properly and automatically. Currently, this does not happen and the OS has to shut down the threads when the interpreter exits.

Also threading-threaded-rollout would be possible by using multiple Rollout instances.

At first glance, rollout.rollout() would still exist, but it would have to start and stop its thread pool every call. The result is that otherwise functional algorithms would not be able to use persistent threadpools without an additional argument for an instance of Rollout. This seems verbose.

A solution (which I am inclined towards) is to have rollout.rollout() create both a persistent Rollout instance and a cleanup handler for it that runs on interpreter shutdown (via atexit).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get now, thank you.

I'll let @nimrod-gileadi comment on the details but I will say this: the current single-threadedness and pain of using Python threads mean that this module is not highly used in the wild and when it is, it's by expert users. Ergo, it's okay to break existing behaviour. As long as the changelog includes a "breaking changes" section with migration instructions for existing code, we're good. No need to bend over backwards so existing code runs unmodified.

Copy link
Contributor

@milutter milutter Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only saw this discussion after I posted my comment.

I think making the rollout an object is a good solution as this approach also gives the option to start and stop the threadpool whenever one likes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The examples you posted look good, @aftersomemath. Thanks. It would be good to have some configuration options in the Rollout constructor, to choose threadpool size, whether to use a threadpool at all, etc.

@yuvaltassa, I think it's very easy to maintain the current behaviour of the rollout function, as @aftersomemath said. I wouldn't say it's "bending over backwards" and I don't see a reason to break it unnecessarily.

@aftersomemath
Copy link
Contributor Author

fixed all cosmetics

Copy link
Contributor

@milutter milutter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @aftersomemath & @yuvaltassa for making this happen. 🙏

I have two questions. How much would it take to make the Threadpool explicit / threadsafe as far as I understand the ThreadPool is global and not accessible via python, right? For example in rl one might want to use two threadpools with different number of threads in parallel. For example use n - 4 threads for training and use 4 threads just for evaluation. Would that be possible or require too much work?

Do I understand it correctly that the gil is released in the python rollout function, so I can call the rollout function using a thread and do something else in python right?

@aftersomemath
Copy link
Contributor Author

Glad you like this! 🙂

It looks like multiple threadpools will be supported. In the detailed comments above, we discussed changing the current approach to allow this. Basically, rollout would become an object where each instance has its own pool. It's not hard to implement, especially given a real use case.

The gil is released during rollouts, so yes, it can run in parallel with another python thread.

@@ -461,7 +461,7 @@ def test_threading(self):
model = mujoco.MjModel.from_xml_string(TEST_XML)
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
num_workers = 32
nroll = 10000
nroll = 100
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced the number of rollouts for all threading tests. They were taking a noticeable amount of time to run.

@aftersomemath
Copy link
Contributor Author

aftersomemath commented Dec 14, 2024

Rollout has been converted to a class as discussed (rollout.Rollout). Tests were added and associated documentation was updated.

rollout.rollout is still available. By default the thread pool is not persistent, but an optional argument makes the pool persistent (and the function is no longer thread safe). rollout.rollout's persistent pool is shutdown cleanly on interpreter exit. It can also be shutdown manually by the user.

doc/python.rst Outdated
@@ -759,9 +759,11 @@ or

.. code-block:: python

# pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool
# pool shutdown on object deletion or call to rollout_.shutdown_pool
# to ensure clean shutdown of threads, call shutdown_pool before interpreter exit
rollout_ = rollout.Rollout(nthread)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a keyword argument to specify nthread. I can imagine we'll want to add more at some point.

doc/python.rst Outdated
@@ -759,9 +759,11 @@ or

.. code-block:: python

# pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool
# pool shutdown on object deletion or call to rollout_.shutdown_pool
# to ensure clean shutdown of threads, call shutdown_pool before interpreter exit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename shutdown_pool to close, since we may want to release other resources in a future implementation, and close() is quite a common name for this kind of method.

@@ -176,7 +176,7 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData
int nroll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0,
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
std::shared_ptr<ThreadPool> pool, int chunk_size) {
std::shared_ptr<ThreadPool>& pool, int chunk_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit-picking: Since this function should modify the shared_ptr reference count, just make it take a raw ThreadPool* and let the caller dereference the shared_ptr.

doc/python.rst Outdated
@@ -711,18 +711,20 @@ The ``mujoco`` package contains two sub-modules: ``mujoco.rollout`` and ``mujoco
rollout
-------

``mujoco.rollout`` shows how to add additional C/C++ functionality, exposed as a Python module via pybind11. It is
implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
``mujoco.rollout`` and ``mujoco.Rollout`` shows how to add additional C/C++ functionality, exposed as a Python module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we exposing mujoco.Rollout, or mujoco.rollout.Rollout?

Copy link
Contributor Author

@aftersomemath aftersomemath Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mujoco.rollout.Rollout, it is fixed

public:
Rollout(int nthread) : nthread_(nthread) {
if (this->nthread_ > 0) {
this->pool_ = std::shared_ptr<ThreadPool>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use std::make_shared<ThreadPool>(this->nthread_);

py::init([](int nthread) {
return std::make_unique<Rollout>(nthread);
}),
py::doc(rollout_init_doc))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add:
py::kw_only(), py:arg("nthread"), to require and allow the use of a keyword argument for nthread?

@@ -22,6 +23,228 @@
import numpy as np
from numpy import typing as npt

class Rollout:
def __init__(self, nthread: int = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make nthread a keyword-only argument?

@aftersomemath
Copy link
Contributor Author

All requested changes implemented. Also fixed a bug where close was not being called in rollout.rollout().

@aftersomemath
Copy link
Contributor Author

One question is, the docstring for rollout appears twice. Once in rollout.rollout and again in rollout.Rollout.rollout. Do we want to just have one copy and let one refer to the other?

@nimrod-gileadi
Copy link
Collaborator

One question is, the docstring for rollout appears twice. Once in rollout.rollout and again in rollout.Rollout.rollout. Do we want to just have one copy and let one refer to the other?

Having two copies is OK.

Copy link
Collaborator

@nimrod-gileadi nimrod-gileadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Just a couple of fixes needed.

sensordata=sensordata,
chunk_size=chunk_size)

if not persistent_pool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this in a try ... finally block, so close() is called even if an exception is thrown?

global persistent_rollout
# Create or restart persistent threadpool
if persistent_rollout is None or persistent_rollout.nthread != nthread:
persistent_rollout = Rollout(nthread=nthread)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you replace an existing persistent_rollout, you need to close() the last one.

This is called automatically interpreter shutdown, but can also be called manually.
""" # fmt: skip
global persistent_rollout
persistent_rollout = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call persistent_rollout.close() if it exists.

@aftersomemath
Copy link
Contributor Author

All changes implemented. I double checked and I think we caught all the missing .close() calls.

if not persistent_pool:
rollout.close()
try:
ret = rollout.rollout(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you can just return here instead of assigning to ret.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

copybara-service bot pushed a commit that referenced this pull request Dec 16, 2024
--
3a95b62 by Levi Burner <[email protected]>:

rollout prototype native threadpool for comparing to python threads

--
efd8be1 by Levi Burner <[email protected]>:

copy mjpcs threadpool into python bindings

--
75603ee by Levi Burner <[email protected]>:

rollout use threadpool as translation unit

--
06b90fe by Levi Burner <[email protected]>:

rollout add chunk_divisor parameter

--
298ab2f by Levi Burner <[email protected]>:

rollout add native threading test

--
169cf99 by Levi Burner <[email protected]>:

rollout exchange chunk_divisor arg for chunk_size

--
265af85 by Levi Burner <[email protected]>:

rollout fix cosmetics

--
1e8bffa by Levi Burner <[email protected]>:

make native rollout a class instead of a function

--
ba78821 by Levi Burner <[email protected]>:

rollout update docs and changelog

--
e4cb773 by Levi Burner <[email protected]>:

rollout don't register atexit handler for Rollout objects

--
5a08d2e by Levi Burner <[email protected]>:

rollout nthread kwarg, rename shutdown_pool to close, fixups

--
f622378 by Levi Burner <[email protected]>:

rollout add missing .close() calls

--
50f3ebc by Levi Burner <[email protected]>:

rollout return immediately

COPYBARA_INTEGRATE_REVIEW=#2282 from aftersomemath:rollout-threaded 50f3ebc
PiperOrigin-RevId: 706744277
Change-Id: I1ab2263b7d6ce30cf1908aec8fd5f2eb976a19e6
@yuvaltassa
Copy link
Collaborator

Merged.

@yuvaltassa yuvaltassa closed this Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants