diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 48f0c2f0..d628bab2 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -251,6 +251,42 @@ class communicator return r; } + template + send_request send_multi(message_buffer const& msg, std::vector const& neighs, + std::vector const& tags) + { + assert(msg); + assert(neighs.size() == tags.size()); + auto& scheduled = m_schedule->scheduled_sends; + scheduled += neighs.size(); + send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); + + const auto s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto counter = new int{(int)neighs.size()}; + + const auto n = neighs.size(); + for (std::size_t i = 0; i < n; ++i) + { + const auto id = neighs[i]; + const auto tag = tags[i]; + send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); + send( + m_ptr, s * sizeof(T), id, tag, + [rd = r.m_data, rdx = rx.m_data, counter]() mutable + { + if ((--(*counter)) == 0) + { + delete counter; + rd->m_ready = true; + } + --(*(rd->m_scheduled)); + }, + rx.m_data); + } + return r; + } + // callback versions // ================= @@ -346,126 +382,139 @@ class communicator } template - send_request send_multi(message_buffer&& msg, std::vector const& neighs, - tag_type tag, CallBack&& callback) + send_request send_multi(message_buffer&& msg, std::vector neighs, tag_type tag, + CallBack&& callback) { OOMPH_CHECK_CALLBACK_MULTI(CallBack) assert(msg); - auto& scheduled = m_schedule->scheduled_sends; - scheduled += neighs.size(); struct msg_ref_count { message_buffer msg; int counter; std::vector neighs; + tag_type tags; + message_buffer&& message() noexcept { return std::move(msg); } + auto message_size() const noexcept { return msg.size() * sizeof(T); } + tag_type tag(std::size_t) const noexcept { return tags; } + auto m_ptr() const noexcept { return msg.m.m_heap_ptr.get(); } }; + const int n = neighs.size(); + return send_multi_impl(new msg_ref_count{std::move(msg), n, std::move(neighs), tag}, + std::forward(callback)); + } - send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto m = new msg_ref_count{std::move(msg), (int)neighs.size(), neighs}; - - for (auto id : neighs) + template + send_request send_multi(message_buffer&& msg, std::vector const& neighs, + std::vector const& tags, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_TAGS(CallBack) + assert(neighs.size() == tags.size()); + assert(msg); + struct msg_ref_count { - send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); - send( - m_ptr, s * sizeof(T), id, tag, - [rd = r.m_data, rdx = rx.m_data, m, tag, - cb = std::forward(callback)]() mutable - { - if ((--(m->counter)) == 0) - { - cb(std::move(m->msg), std::move(m->neighs), tag); - delete m; - rd->m_ready = true; - } - --(*(rd->m_scheduled)); - }, - rx.m_data); - } - return r; + message_buffer msg; + int counter; + std::vector neighs; + std::vector tags; + message_buffer&& message() noexcept { return std::move(msg); } + auto message_size() const noexcept { return msg.size() * sizeof(T); } + tag_type tag(std::size_t i) const noexcept { return tags[i]; } + auto m_ptr() const noexcept { return msg.m.m_heap_ptr.get(); } + }; + const int n = neighs.size(); + return send_multi_impl( + new msg_ref_count{std::move(msg), n, std::move(neighs), std::move(tags)}, + std::forward(callback)); } template - send_request send_multi(message_buffer& msg, std::vector const& neighs, - tag_type tag, CallBack&& callback) + send_request send_multi(message_buffer& msg, std::vector neighs, tag_type tag, + CallBack&& callback) { OOMPH_CHECK_CALLBACK_MULTI_REF(CallBack) assert(msg); - auto& scheduled = m_schedule->scheduled_sends; - scheduled += neighs.size(); struct msg_ref_count { message_buffer* msg; int counter; std::vector neighs; + tag_type tags; + message_buffer& message() noexcept { return *msg; } + auto message_size() const noexcept { return msg->size() * sizeof(T); } + tag_type tag(std::size_t) const noexcept { return tags; } + auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); } }; + const int n = neighs.size(); + return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), tag}, + std::forward(callback)); + } - send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto m = new msg_ref_count{&msg, {(int)neighs.size()}, neighs}; - - for (auto id : neighs) + template + send_request send_multi(message_buffer& msg, std::vector neighs, + std::vector tags, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CallBack) + assert(neighs.size() == tags.size()); + assert(msg); + struct msg_ref_count { - send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); - send( - m_ptr, s * sizeof(T), id, tag, - [rd = r.m_data, rdx = rx.m_data, m, tag, - cb = std::forward(callback)]() mutable - { - if ((--(m->counter)) == 0) - { - cb(*(m->msg), std::move(m->neighs), tag); - delete m; - rd->m_ready = true; - } - --(*(rd->m_scheduled)); - }, - rx.m_data); - } - return r; + message_buffer* msg; + int counter; + std::vector neighs; + std::vector tags; + message_buffer& message() noexcept { return *msg; } + auto message_size() const noexcept { return msg->size() * sizeof(T); } + tag_type tag(std::size_t i) const noexcept { return tags[i]; } + auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); } + }; + const int n = neighs.size(); + return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), std::move(tags)}, + std::forward(callback)); } template - send_request send_multi(message_buffer const& msg, std::vector const& neighs, + send_request send_multi(message_buffer const& msg, std::vector neighs, tag_type tag, CallBack&& callback) { OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CallBack) assert(msg); - auto& scheduled = m_schedule->scheduled_sends; - scheduled += neighs.size(); struct msg_ref_count { message_buffer const* msg; int counter; std::vector neighs; + tag_type tags; + message_buffer const& message() const noexcept { return *msg; } + auto message_size() const noexcept { return msg->size() * sizeof(T); } + tag_type tag(std::size_t) const noexcept { return tags; } + auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); } }; + const int n = neighs.size(); + return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), tag}, + std::forward(callback)); + } - send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto m = new msg_ref_count{&msg, {(int)neighs.size()}, neighs}; - - for (auto id : neighs) + template + send_request send_multi(message_buffer const& msg, std::vector neighs, + std::vector tags, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CallBack) + assert(neighs.size() == tags.size()); + assert(msg); + struct msg_ref_count { - send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); - send( - m_ptr, s * sizeof(T), id, tag, - [rd = r.m_data, rdx = rx.m_data, m, tag, - cb = std::forward(callback)]() mutable - { - if ((--(m->counter)) == 0) - { - cb(*(m->msg), std::move(m->neighs), tag); - delete m; - rd->m_ready = true; - } - --(*(rd->m_scheduled)); - }, - rx.m_data); - } - return r; + message_buffer const* msg; + int counter; + std::vector neighs; + std::vector tags; + message_buffer const& message() const noexcept { return *msg; } + auto message_size() const noexcept { return msg->size() * sizeof(T); } + tag_type tag(std::size_t i) const noexcept { return tags[i]; } + auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); } + }; + const int n = neighs.size(); + return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), std::move(tags)}, + std::forward(callback)); } void progress(); @@ -485,6 +534,37 @@ class communicator void recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function cb, shared_request_ptr req); + + template + send_request send_multi_impl(M* m, CallBack&& callback) + { + auto& scheduled = m_schedule->scheduled_sends; + const auto n = m->neighs.size(); + scheduled += n; + send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); + auto m_ptr = m->m_ptr(); + + for (std::size_t i = 0; i < n; ++i) + { + const auto id = m->neighs[i]; + const auto tag = m->tag(i); + send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled)); + send( + m_ptr, m->message_size(), id, tag, + [rd = r.m_data, rdx = rx.m_data, m, cb = std::forward(callback)]() mutable + { + if ((--(m->counter)) == 0) + { + cb(m->message(), std::move(m->neighs), std::move(m->tags)); + delete m; + rd->m_ready = true; + } + --(*(rd->m_scheduled)); + }, + rx.m_data); + } + return r; + } }; } // namespace oomph diff --git a/include/oomph/context.hpp b/include/oomph/context.hpp index f44c94f1..d9a64d09 100644 --- a/include/oomph/context.hpp +++ b/include/oomph/context.hpp @@ -45,6 +45,12 @@ class context ~context(); public: + int rank() const noexcept; + + int size() const noexcept; + + int local_size() const noexcept; + MPI_Comm mpi_comm() const noexcept { return m_mpi_comm.get(); } template diff --git a/include/oomph/detail/communicator_helper.hpp b/include/oomph/detail/communicator_helper.hpp index b5925e54..39311749 100644 --- a/include/oomph/detail/communicator_helper.hpp +++ b/include/oomph/detail/communicator_helper.hpp @@ -11,7 +11,7 @@ #include -#define OOMPH_CHECK_CALLBACK_F(CALLBACK, RANK_TYPE) \ +#define OOMPH_CHECK_CALLBACK_F(CALLBACK, RANK_TYPE, TAG_TYPE) \ using args_t = boost::callable_traits::args_t>; \ using arg0_t = std::tuple_element_t<0, args_t>; \ using arg1_t = std::tuple_element_t<1, args_t>; \ @@ -19,7 +19,7 @@ static_assert(std::tuple_size::value == 3, "callback must have 3 arguments"); \ static_assert(std::is_same::value, \ "rank_type is not convertible to second callback argument type"); \ - static_assert(std::is_same::value, \ + static_assert(std::is_same::value, \ "tag_type is not convertible to third callback argument type"); \ using TT = typename std::remove_reference_t::value_type; @@ -38,36 +38,54 @@ #define OOMPH_CHECK_CALLBACK(CALLBACK) \ { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type) \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \ OOMPH_CHECK_CALLBACK_MSG \ } #define OOMPH_CHECK_CALLBACK_MULTI(CALLBACK) \ { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector) \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ + OOMPH_CHECK_CALLBACK_MSG \ + } + +#define OOMPH_CHECK_CALLBACK_MULTI_TAGS(CALLBACK) \ + { \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ OOMPH_CHECK_CALLBACK_MSG \ } #define OOMPH_CHECK_CALLBACK_REF(CALLBACK) \ { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type) \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \ OOMPH_CHECK_CALLBACK_MSG_REF \ } #define OOMPH_CHECK_CALLBACK_MULTI_REF(CALLBACK) \ { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector) \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ + OOMPH_CHECK_CALLBACK_MSG_REF \ + } + +#define OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CALLBACK) \ + { \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ OOMPH_CHECK_CALLBACK_MSG_REF \ } #define OOMPH_CHECK_CALLBACK_CONST_REF(CALLBACK) \ { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type) \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \ OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ } #define OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CALLBACK) \ { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector) \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ + OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ + } + +#define OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CALLBACK) \ + { \ + OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ } diff --git a/src/src.cpp b/src/src.cpp index a01bcf66..9966e8ab 100644 --- a/src/src.cpp +++ b/src/src.cpp @@ -60,6 +60,24 @@ context::~context() { comm_map.erase(m.get()); } context::~context() = default; #endif +int +context::rank() const noexcept +{ + return m->rank(); +} + +int +context::size() const noexcept +{ + return m->size(); +} + +int +context::local_size() const noexcept +{ + return m->topology().local_size(); +} + communicator context::get_communicator() {