-
Notifications
You must be signed in to change notification settings - Fork 902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a persistent thread pool to rollout #2282
Add a persistent thread pool to rollout #2282
Conversation
python/mujoco/rollout.cc
Outdated
@@ -158,6 +161,47 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n | |||
} | |||
} | |||
|
|||
// C-style threaded version of _unsafe_rollout | |||
static ThreadPool* pool = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used a raw pointer here instead of a smart pointer because the style guide states that objects with static duration must be trivially destructible. My understanding is that smart pointers are not trivially destructible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial cosmetic review
doc/changelog.rst
Outdated
@@ -9,6 +9,11 @@ Bug fixes | |||
^^^^^^^^^ | |||
- Fixed a bug in the box-sphere collider, depth was incorrect for deep penetrations (:github:issue:`2206`). | |||
|
|||
Python bindings | |||
^^^^^^^^^^^^^^^ | |||
- :ref:`rollout<PyRollout>` can now accept sequences of MjData of length ``nthread``. If passed, :ref:`rollout<PyRollout>` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- :ref:`rollout<PyRollout>` now features native multi-threading. If a sequence of MjData instances
of length ``nthread`` is passed in, ``rollout`` will automatically create a persistent threadpool
and parallelize the computation. Contribution by :github:user:`aftersomemath`.
python/mujoco/rollout.cc
Outdated
if (!chunk_size.has_value()) { | ||
chunk_size_final = std::max(1, nroll / (10 * nthread)); | ||
} | ||
else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The else on a new line thing is something that we do in C code, and only to facilitate putting a comment above the block. Let's attach all else's } else {
in C++ code as in the style guide.
python/mujoco/rollout.cc
Outdated
int chunk_remainder = nroll % chunk_size; | ||
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs; | ||
|
||
if (pool == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added block comments, slightly refactored
// if existing thread pool of the wrong size, delete
if (pool->NumThreads() != nthread) {
delete pool;
}
// make threadpool if required
if (pool == nullptr) {
pool = new ThreadPool(nthread);
} else {
pool->ResetCount();
}
pool->ResetCount(); | ||
} | ||
|
||
for (int j = 0; j < nfulljobs; j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// schedule all jobs of full (chunk) size
}; | ||
pool->Schedule(task); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// schedule any remaining jobs of size < chunk_size
pool->Schedule(task); | ||
} | ||
|
||
pool->WaitCount(njobs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// increment job counter for this thread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with // wait for job counter to incremented up to the number of jobs submitted by this thread
python/mujoco/rollout.cc
Outdated
InterceptMjErrors(_unsafe_rollout)( | ||
model_ptrs, data, nroll, nstep, control_spec, state0_ptr, | ||
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr); | ||
if (nthread > 1 && nroll > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// call unsafe rollout function, multi or single threaded
python/mujoco/rollout.py
Outdated
@@ -59,6 +60,8 @@ def rollout( | |||
(nroll x nstep x nstate) | |||
sensordata: Sensor data output array (optional). | |||
(nroll x nstep x nsensordata) | |||
chunk_size: Determines threadpool chunk size. If unspecified, | |||
chunk_size = max(1, nroll / (nthread * 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing closing paren
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we're here, fix an unused var complaint by our internal linter: in rollout_test.py:358 replace i
with _
python/mujoco/rollout.cc
Outdated
int chunk_remainder = nroll % chunk_size; | ||
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs; | ||
|
||
if (pool == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ironically, the management of this pool object is not thread safe.
i.e. if the rollout function is called from multiple threads, and we drop the Python GIL (https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc#L230), then you may end up deleting a thread pool that is being used, or creating two threadpools.
I would lean towards creating a new threadpool on every call to rollout. Have you measured a serious performance issue with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nimrod-gileadi, this is clearly stated in the PR description and also discussed in the related design doc 🙂
The short answer is that yes, there is a measurable cost to starting up the threadpool, especially for smallish tasks. @aftersomemath can perhaps provide some numbers here to guide the discussion. I think the options are:
- Come up with a mechanism to make this threadsafe, that does not hurt performance or overcomplicate the API.
- Clearly document that multi-threading is not thread safe.
I think option 2 is fine, even though when pressed I can think of use cases for multi-multi-threaded rollouts. Perhaps @kevinzakka and/or @btaba would have an opinion on options 1 vs 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just measured threadpool creation overhead for the case that the pool is recreated every call. I tested the models, hopper, go 2, and humanoid. nroll = [50, 3000, 100] while nstep = [1000, 1, 50] respectively. For hopper the overhead is 2-30%. unitree go 2 is 0.3-4%. humanoid is 0.5-9%. This seems significant, at least in my view.
Yes, the management of the pool is not threadsafe. But also the pool cannot be used from multiple threads because of the way the condition variable and WaitCount work.
Some alternatives that allow a persistent threadpool and allow rollout to be threadsafe are:
- let the user pass a reference to the pool to rollout (requires wrapping the pool with pybind, or at least managing it)
- make rollout a function object (easy)
- let the user start/stop the threadpool and change the threadpools design to allow usage from multiple threads (this will result in a rather complex threadpool implementation)
I'm in favor of implementing both the current approach and, if threading-threaded-rollouts is, important, option 2. The current approach is not mutually exclusive with 2 and 2 is a breaking API change that requires the user to manage an object instead of just calling a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, it the threadpool is threadsafe, the issue is just that the pool is not compatible with multithreaded usage. My bad, I should have been more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aftersomemath, I like option 2 of making a class that can do rollouts and holds its own reference to a threadpool for as long as it's needed.
The issue with the current design is not only about thread safety, but also about leaking memory and threads once the threadpool is no longer needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pardon my ignorance, can someone spell out Option 2 more clearly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usage would look like this if implemented exactly as discussed above:
rollout = rollout.Rollout() # creates an object Rollout that acts like a function
rollout(models, data) # Creates or reuses member variable `rollout.pool`
rollout = None # reference lost, the object will be destroyed and the thread pool shut down cleanly
Usage could also look like this:
with Rollout() as rollout:
rollout(models, data)
In either case the thread pool will be shut down properly and automatically. Currently, this does not happen and the OS has to shut down the threads when the interpreter exits.
Also threading-threaded-rollout would be possible by using multiple Rollout instances.
At first glance, rollout.rollout()
would still exist, but it would have to start and stop its thread pool every call. The result is that otherwise functional algorithms would not be able to use persistent threadpools without an additional argument for an instance of Rollout. This seems verbose.
A solution (which I am inclined towards) is to have rollout.rollout()
create both a persistent Rollout instance and a cleanup handler for it that runs on interpreter shutdown (via atexit
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get now, thank you.
I'll let @nimrod-gileadi comment on the details but I will say this: the current single-threadedness and pain of using Python threads mean that this module is not highly used in the wild and when it is, it's by expert users. Ergo, it's okay to break existing behaviour. As long as the changelog includes a "breaking changes" section with migration instructions for existing code, we're good. No need to bend over backwards so existing code runs unmodified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only saw this discussion after I posted my comment.
I think making the rollout
an object is a good solution as this approach also gives the option to start and stop the threadpool whenever one likes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The examples you posted look good, @aftersomemath. Thanks. It would be good to have some configuration options in the Rollout constructor, to choose threadpool size, whether to use a threadpool at all, etc.
@yuvaltassa, I think it's very easy to maintain the current behaviour of the rollout function, as @aftersomemath said. I wouldn't say it's "bending over backwards" and I don't see a reason to break it unnecessarily.
fixed all cosmetics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @aftersomemath & @yuvaltassa for making this happen. 🙏
I have two questions. How much would it take to make the Threadpool
explicit / threadsafe as far as I understand the ThreadPool
is global and not accessible via python, right? For example in rl
one might want to use two threadpools with different number of threads in parallel. For example use n - 4
threads for training and use 4
threads just for evaluation. Would that be possible or require too much work?
Do I understand it correctly that the gil
is released in the python rollout function, so I can call the rollout function using a thread and do something else in python right?
Glad you like this! 🙂 It looks like multiple threadpools will be supported. In the detailed comments above, we discussed changing the current approach to allow this. Basically, rollout would become an object where each instance has its own pool. It's not hard to implement, especially given a real use case. The gil is released during rollouts, so yes, it can run in parallel with another python thread. |
@@ -461,7 +461,7 @@ def test_threading(self): | |||
model = mujoco.MjModel.from_xml_string(TEST_XML) | |||
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) | |||
num_workers = 32 | |||
nroll = 10000 | |||
nroll = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced the number of rollouts for all threading tests. They were taking a noticeable amount of time to run.
Rollout has been converted to a class as discussed (
|
doc/python.rst
Outdated
@@ -759,9 +759,11 @@ or | |||
|
|||
.. code-block:: python | |||
|
|||
# pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool | |||
# pool shutdown on object deletion or call to rollout_.shutdown_pool | |||
# to ensure clean shutdown of threads, call shutdown_pool before interpreter exit | |||
rollout_ = rollout.Rollout(nthread) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a keyword argument to specify nthread. I can imagine we'll want to add more at some point.
doc/python.rst
Outdated
@@ -759,9 +759,11 @@ or | |||
|
|||
.. code-block:: python | |||
|
|||
# pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool | |||
# pool shutdown on object deletion or call to rollout_.shutdown_pool | |||
# to ensure clean shutdown of threads, call shutdown_pool before interpreter exit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename shutdown_pool
to close
, since we may want to release other resources in a future implementation, and close()
is quite a common name for this kind of method.
python/mujoco/rollout.cc
Outdated
@@ -176,7 +176,7 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData | |||
int nroll, int nstep, unsigned int control_spec, | |||
const mjtNum* state0, const mjtNum* warmstart0, | |||
const mjtNum* control, mjtNum* state, mjtNum* sensordata, | |||
std::shared_ptr<ThreadPool> pool, int chunk_size) { | |||
std::shared_ptr<ThreadPool>& pool, int chunk_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit-picking: Since this function should modify the shared_ptr reference count, just make it take a raw ThreadPool*
and let the caller dereference the shared_ptr.
doc/python.rst
Outdated
@@ -711,18 +711,20 @@ The ``mujoco`` package contains two sub-modules: ``mujoco.rollout`` and ``mujoco | |||
rollout | |||
------- | |||
|
|||
``mujoco.rollout`` shows how to add additional C/C++ functionality, exposed as a Python module via pybind11. It is | |||
implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__ | |||
``mujoco.rollout`` and ``mujoco.Rollout`` shows how to add additional C/C++ functionality, exposed as a Python module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we exposing mujoco.Rollout, or mujoco.rollout.Rollout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mujoco.rollout.Rollout
, it is fixed
python/mujoco/rollout.cc
Outdated
public: | ||
Rollout(int nthread) : nthread_(nthread) { | ||
if (this->nthread_ > 0) { | ||
this->pool_ = std::shared_ptr<ThreadPool>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use std::make_shared<ThreadPool>(this->nthread_);
py::init([](int nthread) { | ||
return std::make_unique<Rollout>(nthread); | ||
}), | ||
py::doc(rollout_init_doc)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add:
py::kw_only(), py:arg("nthread")
, to require and allow the use of a keyword argument for nthread?
python/mujoco/rollout.py
Outdated
@@ -22,6 +23,228 @@ | |||
import numpy as np | |||
from numpy import typing as npt | |||
|
|||
class Rollout: | |||
def __init__(self, nthread: int = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make nthread a keyword-only argument?
All requested changes implemented. Also fixed a bug where |
One question is, the docstring for rollout appears twice. Once in |
Having two copies is OK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Just a couple of fixes needed.
python/mujoco/rollout.py
Outdated
sensordata=sensordata, | ||
chunk_size=chunk_size) | ||
|
||
if not persistent_pool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put this in a try ... finally
block, so close()
is called even if an exception is thrown?
global persistent_rollout | ||
# Create or restart persistent threadpool | ||
if persistent_rollout is None or persistent_rollout.nthread != nthread: | ||
persistent_rollout = Rollout(nthread=nthread) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you replace an existing persistent_rollout, you need to close()
the last one.
This is called automatically interpreter shutdown, but can also be called manually. | ||
""" # fmt: skip | ||
global persistent_rollout | ||
persistent_rollout = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call persistent_rollout.close() if it exists.
All changes implemented. I double checked and I think we caught all the missing |
python/mujoco/rollout.py
Outdated
if not persistent_pool: | ||
rollout.close() | ||
try: | ||
ret = rollout.rollout( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you can just return here instead of assigning to ret.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
-- 3a95b62 by Levi Burner <[email protected]>: rollout prototype native threadpool for comparing to python threads -- efd8be1 by Levi Burner <[email protected]>: copy mjpcs threadpool into python bindings -- 75603ee by Levi Burner <[email protected]>: rollout use threadpool as translation unit -- 06b90fe by Levi Burner <[email protected]>: rollout add chunk_divisor parameter -- 298ab2f by Levi Burner <[email protected]>: rollout add native threading test -- 169cf99 by Levi Burner <[email protected]>: rollout exchange chunk_divisor arg for chunk_size -- 265af85 by Levi Burner <[email protected]>: rollout fix cosmetics -- 1e8bffa by Levi Burner <[email protected]>: make native rollout a class instead of a function -- ba78821 by Levi Burner <[email protected]>: rollout update docs and changelog -- e4cb773 by Levi Burner <[email protected]>: rollout don't register atexit handler for Rollout objects -- 5a08d2e by Levi Burner <[email protected]>: rollout nthread kwarg, rename shutdown_pool to close, fixups -- f622378 by Levi Burner <[email protected]>: rollout add missing .close() calls -- 50f3ebc by Levi Burner <[email protected]>: rollout return immediately COPYBARA_INTEGRATE_REVIEW=#2282 from aftersomemath:rollout-threaded 50f3ebc PiperOrigin-RevId: 706744277 Change-Id: I1ab2263b7d6ce30cf1908aec8fd5f2eb976a19e6
Merged. |
rollout
now accepts a list of MjData of lengthnthread
. If passed, thenroll
rollouts are parallelized.The thread pool is persistent across calls to rollout, so long as the number of MjData instances is constant.
The thread pool implementation was copied from MJPC. It is not threadsafe. This means that if rollout is using threads, it should not be called from multiple threads.