-
Notifications
You must be signed in to change notification settings - Fork 902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a persistent thread pool to rollout #2282
Changes from all commits
3a95b62
efd8be1
75603ee
06b90fe
298ab2f
169cf99
265af85
1e8bffa
ba78821
e4cb773
5a08d2e
f622378
50f3ebc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
#include "errors.h" | ||
#include "raw.h" | ||
#include "structs.h" | ||
#include "threadpool.h" | ||
#include <pybind11/buffer_info.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> | ||
|
@@ -31,14 +32,24 @@ namespace { | |
|
||
namespace py = ::pybind11; | ||
|
||
using PyCArray = py::array_t<mjtNum, py::array::c_style>; | ||
|
||
// NOLINTBEGIN(whitespace/line_length) | ||
|
||
const auto rollout_init_doc = R"( | ||
Construct a rollout object containing a thread pool for parallel rollouts. | ||
|
||
input arguments (optional): | ||
nthread integer, number of threads in pool | ||
if zero, this pool is not started and rollouts run on the calling thread | ||
)"; | ||
|
||
const auto rollout_doc = R"( | ||
Roll out open-loop trajectories from initial states, get resulting states and sensor values. | ||
|
||
input arguments (required): | ||
model list of MjModel instances of length nroll | ||
data associated instance of MjData | ||
data list of associated MjData instances of length nthread | ||
nstep integer, number of steps to be taken for each trajectory | ||
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec) | ||
state0 (nroll x nstate) nroll initial state vectors, | ||
|
@@ -49,12 +60,14 @@ Roll out open-loop trajectories from initial states, get resulting states and se | |
output arguments (optional): | ||
state (nroll x nstep x nstate) nroll nstep states | ||
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors | ||
chunk_size integer, determines threadpool chunk size. If unspecified | ||
chunk_size = max(1, nroll / (nthread * 10)) | ||
)"; | ||
|
||
// C-style rollout function, assumes all arguments are valid | ||
// all input fields of d are initialised, contents at call time do not matter | ||
// after returning, d will contain the last step of the last rollout | ||
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int nstep, unsigned int control_spec, | ||
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec, | ||
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control, | ||
mjtNum* state, mjtNum* sensordata) { | ||
// sizes | ||
|
@@ -75,7 +88,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n | |
} | ||
|
||
// loop over rollouts | ||
for (int r = 0; r < nroll; r++) { | ||
for (int r = start_roll; r < end_roll; r++) { | ||
// clear user inputs if unspecified | ||
if (!(control_spec & mjSTATE_MOCAP_POS)) { | ||
for (int i = 0; i < nbody; i++) { | ||
|
@@ -158,6 +171,42 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n | |
} | ||
} | ||
|
||
// C-style threaded version of _unsafe_rollout | ||
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d, | ||
int nroll, int nstep, unsigned int control_spec, | ||
const mjtNum* state0, const mjtNum* warmstart0, | ||
const mjtNum* control, mjtNum* state, mjtNum* sensordata, | ||
ThreadPool* pool, int chunk_size) { | ||
int nfulljobs = nroll / chunk_size; | ||
int chunk_remainder = nroll % chunk_size; | ||
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs; | ||
|
||
// Reset the pool counter | ||
pool->ResetCount(); | ||
|
||
// schedule all jobs of full (chunk) size | ||
for (int j = 0; j < nfulljobs; j++) { | ||
auto task = [=, &m, &d](void) { | ||
int id = pool->WorkerId(); | ||
_unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size, | ||
nstep, control_spec, state0, warmstart0, control, state, sensordata); | ||
}; | ||
pool->Schedule(task); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // schedule any remaining jobs of size < chunk_size |
||
// schedule any remaining jobs of size < chunk_size | ||
if (chunk_remainder > 0) { | ||
auto task = [=, &m, &d](void) { | ||
_unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder, | ||
nstep, control_spec, state0, warmstart0, control, state, sensordata); | ||
}; | ||
pool->Schedule(task); | ||
} | ||
|
||
// wait for job counter to incremented up to the number of jobs submitted by this thread | ||
pool->WaitCount(njobs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // increment job counter for this thread There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with |
||
} | ||
|
||
// NOLINTEND(whitespace/line_length) | ||
|
||
// check size of optional argument to rollout(), return raw pointer | ||
|
@@ -181,71 +230,122 @@ mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg, | |
return static_cast<mjtNum*>(info.ptr); | ||
} | ||
|
||
class Rollout { | ||
public: | ||
Rollout(int nthread) : nthread_(nthread) { | ||
if (this->nthread_ > 0) { | ||
this->pool_ = std::make_shared<ThreadPool>(this->nthread_); | ||
} | ||
} | ||
|
||
PYBIND11_MODULE(_rollout, pymodule) { | ||
namespace py = ::pybind11; | ||
using PyCArray = py::array_t<mjtNum, py::array::c_style>; | ||
|
||
// roll out open loop trajectories from multiple initial states | ||
// get subsequent states and corresponding sensor values | ||
pymodule.def( | ||
"rollout", | ||
[](py::list m, MjDataWrapper& d, | ||
int nstep, unsigned int control_spec, | ||
const PyCArray state0, | ||
std::optional<const PyCArray> warmstart0, | ||
std::optional<const PyCArray> control, | ||
std::optional<const PyCArray> state, | ||
std::optional<const PyCArray> sensordata | ||
) { | ||
// get raw pointers | ||
int nroll = state0.shape(0); | ||
std::vector<const raw::MjModel*> model_ptrs(nroll); | ||
for (int r = 0; r < nroll; r++) { | ||
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get(); | ||
} | ||
raw::MjData* data = d.get(); | ||
void rollout(py::list m, py::list d, | ||
int nstep, unsigned int control_spec, | ||
const PyCArray state0, | ||
std::optional<const PyCArray> warmstart0, | ||
std::optional<const PyCArray> control, | ||
std::optional<const PyCArray> state, | ||
std::optional<const PyCArray> sensordata, | ||
std::optional<int> chunk_size | ||
) { | ||
// get raw pointers | ||
int nroll = state0.shape(0); | ||
std::vector<const raw::MjModel*> model_ptrs(nroll); | ||
for (int r = 0; r < nroll; r++) { | ||
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get(); | ||
} | ||
|
||
// check that some steps need to be taken, return if not | ||
if (nstep < 1) { | ||
return; | ||
} | ||
// check length d and nthread are consistent | ||
if (this->nthread_ == 0 && py::len(d) > 1) { | ||
std::ostringstream msg; | ||
msg << "More than one data instance passed but " | ||
<< "rollout is configured to run on main thread"; | ||
py::value_error(msg.str()); | ||
} else if (this->nthread_ != py::len(d)) { | ||
std::ostringstream msg; | ||
msg << "Length of data: " << py::len(d) | ||
<< " not equal to nthread: " << this->nthread_; | ||
py::value_error(msg.str()); | ||
} | ||
|
||
std::vector<raw::MjData*> data_ptrs(py::len(d)); | ||
for (int t = 0; t < py::len(d); t++) { | ||
data_ptrs[t] = d[t].cast<MjDataWrapper*>()->get(); | ||
} | ||
|
||
// check that some steps need to be taken, return if not | ||
if (nstep < 1) { | ||
return; | ||
} | ||
|
||
// get sizes | ||
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS); | ||
int ncontrol = mj_stateSize(model_ptrs[0], control_spec); | ||
|
||
mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate); | ||
mjtNum* warmstart0_ptr = get_array_ptr(warmstart0, "warmstart0", nroll, | ||
1, model_ptrs[0]->nv); | ||
mjtNum* control_ptr = get_array_ptr(control, "control", nroll, | ||
nstep, ncontrol); | ||
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate); | ||
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll, | ||
nstep, model_ptrs[0]->nsensordata); | ||
|
||
// perform rollouts | ||
{ | ||
// release the GIL | ||
py::gil_scoped_release no_gil; | ||
|
||
// call unsafe rollout function | ||
// get sizes | ||
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS); | ||
int ncontrol = mj_stateSize(model_ptrs[0], control_spec); | ||
|
||
mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate); | ||
mjtNum* warmstart0_ptr = get_array_ptr(warmstart0, "warmstart0", nroll, | ||
1, model_ptrs[0]->nv); | ||
mjtNum* control_ptr = get_array_ptr(control, "control", nroll, | ||
nstep, ncontrol); | ||
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate); | ||
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll, | ||
nstep, model_ptrs[0]->nsensordata); | ||
|
||
// perform rollouts | ||
{ | ||
// release the GIL | ||
py::gil_scoped_release no_gil; | ||
|
||
// call unsafe rollout function, multi or single threaded | ||
if (this->nthread_ > 0 && nroll > 1) { | ||
int chunk_size_final = 1; | ||
if (!chunk_size.has_value()) { | ||
chunk_size_final = std::max(1, nroll / (10 * this->nthread_)); | ||
} else { | ||
chunk_size_final = *chunk_size; | ||
} | ||
InterceptMjErrors(_unsafe_rollout_threaded)( | ||
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr, | ||
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr, | ||
this->pool_.get(), chunk_size_final); | ||
} else { | ||
InterceptMjErrors(_unsafe_rollout)( | ||
model_ptrs, data, nroll, nstep, control_spec, state0_ptr, | ||
model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr, | ||
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr); | ||
} | ||
}, | ||
py::arg("model"), | ||
py::arg("data"), | ||
py::arg("nstep"), | ||
py::arg("control_spec"), | ||
py::arg("state0"), | ||
py::arg("warmstart0") = py::none(), | ||
py::arg("control") = py::none(), | ||
py::arg("state") = py::none(), | ||
py::arg("sensordata") = py::none(), | ||
py::doc(rollout_doc) | ||
); | ||
} | ||
} | ||
|
||
private: | ||
int nthread_; | ||
std::shared_ptr<ThreadPool> pool_; | ||
}; | ||
|
||
|
||
PYBIND11_MODULE(_rollout, pymodule) { | ||
namespace py = ::pybind11; | ||
|
||
py::class_<Rollout>(pymodule, "Rollout") | ||
.def( | ||
py::init([](int nthread) { | ||
return std::make_unique<Rollout>(nthread); | ||
}), | ||
py::kw_only(), | ||
py::arg("nthread"), | ||
py::doc(rollout_init_doc)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add: |
||
.def( | ||
"rollout", | ||
&Rollout::rollout, | ||
py::arg("model"), | ||
py::arg("data"), | ||
py::arg("nstep"), | ||
py::arg("control_spec"), | ||
py::arg("state0"), | ||
py::arg("warmstart0") = py::none(), | ||
py::arg("control") = py::none(), | ||
py::arg("state") = py::none(), | ||
py::arg("sensordata") = py::none(), | ||
py::arg("chunk_size") = py::none(), | ||
py::doc(rollout_doc)); | ||
} | ||
|
||
} // namespace | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// schedule all jobs of full (chunk) size