Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 160 additions & 80 deletions include/oomph/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,42 @@ class communicator
return r;
}

template<typename T>
send_request send_multi(message_buffer<T> const& msg, std::vector<rank_type> const& neighs,
std::vector<tag_type> const& tags)
{
assert(msg);
assert(neighs.size() == tags.size());
auto& scheduled = m_schedule->scheduled_sends;
scheduled += neighs.size();
send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled));

const auto s = msg.size();
auto m_ptr = msg.m.m_heap_ptr.get();
auto counter = new int{(int)neighs.size()};

const auto n = neighs.size();
for (std::size_t i = 0; i < n; ++i)
{
const auto id = neighs[i];
const auto tag = tags[i];
send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
send(
m_ptr, s * sizeof(T), id, tag,
[rd = r.m_data, rdx = rx.m_data, counter]() mutable
{
if ((--(*counter)) == 0)
{
delete counter;
rd->m_ready = true;
}
--(*(rd->m_scheduled));
},
rx.m_data);
}
return r;
}

// callback versions
// =================

Expand Down Expand Up @@ -346,126 +382,139 @@ class communicator
}

template<typename T, typename CallBack>
send_request send_multi(message_buffer<T>&& msg, std::vector<rank_type> const& neighs,
tag_type tag, CallBack&& callback)
send_request send_multi(message_buffer<T>&& msg, std::vector<rank_type> neighs, tag_type tag,
CallBack&& callback)
{
OOMPH_CHECK_CALLBACK_MULTI(CallBack)
assert(msg);
auto& scheduled = m_schedule->scheduled_sends;
scheduled += neighs.size();
struct msg_ref_count
{
message_buffer<T> msg;
int counter;
std::vector<rank_type> neighs;
tag_type tags;
message_buffer<T>&& message() noexcept { return std::move(msg); }
auto message_size() const noexcept { return msg.size() * sizeof(T); }
tag_type tag(std::size_t) const noexcept { return tags; }
auto m_ptr() const noexcept { return msg.m.m_heap_ptr.get(); }
};
const int n = neighs.size();
return send_multi_impl(new msg_ref_count{std::move(msg), n, std::move(neighs), tag},
std::forward<CallBack>(callback));
}

send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
const auto s = msg.size();
auto m_ptr = msg.m.m_heap_ptr.get();
auto m = new msg_ref_count{std::move(msg), (int)neighs.size(), neighs};

for (auto id : neighs)
template<typename T, typename CallBack>
send_request send_multi(message_buffer<T>&& msg, std::vector<rank_type> const& neighs,
std::vector<tag_type> const& tags, CallBack&& callback)
{
OOMPH_CHECK_CALLBACK_MULTI_TAGS(CallBack)
assert(neighs.size() == tags.size());
assert(msg);
struct msg_ref_count
{
send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
send(
m_ptr, s * sizeof(T), id, tag,
[rd = r.m_data, rdx = rx.m_data, m, tag,
cb = std::forward<CallBack>(callback)]() mutable
{
if ((--(m->counter)) == 0)
{
cb(std::move(m->msg), std::move(m->neighs), tag);
delete m;
rd->m_ready = true;
}
--(*(rd->m_scheduled));
},
rx.m_data);
}
return r;
message_buffer<T> msg;
int counter;
std::vector<rank_type> neighs;
std::vector<tag_type> tags;
message_buffer<T>&& message() noexcept { return std::move(msg); }
auto message_size() const noexcept { return msg.size() * sizeof(T); }
tag_type tag(std::size_t i) const noexcept { return tags[i]; }
auto m_ptr() const noexcept { return msg.m.m_heap_ptr.get(); }
};
const int n = neighs.size();
return send_multi_impl(
new msg_ref_count{std::move(msg), n, std::move(neighs), std::move(tags)},
std::forward<CallBack>(callback));
}

template<typename T, typename CallBack>
send_request send_multi(message_buffer<T>& msg, std::vector<rank_type> const& neighs,
tag_type tag, CallBack&& callback)
send_request send_multi(message_buffer<T>& msg, std::vector<rank_type> neighs, tag_type tag,
CallBack&& callback)
{
OOMPH_CHECK_CALLBACK_MULTI_REF(CallBack)
assert(msg);
auto& scheduled = m_schedule->scheduled_sends;
scheduled += neighs.size();
struct msg_ref_count
{
message_buffer<T>* msg;
int counter;
std::vector<rank_type> neighs;
tag_type tags;
message_buffer<T>& message() noexcept { return *msg; }
auto message_size() const noexcept { return msg->size() * sizeof(T); }
tag_type tag(std::size_t) const noexcept { return tags; }
auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); }
};
const int n = neighs.size();
return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), tag},
std::forward<CallBack>(callback));
}

send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
const auto s = msg.size();
auto m_ptr = msg.m.m_heap_ptr.get();
auto m = new msg_ref_count{&msg, {(int)neighs.size()}, neighs};

for (auto id : neighs)
template<typename T, typename CallBack>
send_request send_multi(message_buffer<T>& msg, std::vector<rank_type> neighs,
std::vector<tag_type> tags, CallBack&& callback)
{
OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CallBack)
assert(neighs.size() == tags.size());
assert(msg);
struct msg_ref_count
{
send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
send(
m_ptr, s * sizeof(T), id, tag,
[rd = r.m_data, rdx = rx.m_data, m, tag,
cb = std::forward<CallBack>(callback)]() mutable
{
if ((--(m->counter)) == 0)
{
cb(*(m->msg), std::move(m->neighs), tag);
delete m;
rd->m_ready = true;
}
--(*(rd->m_scheduled));
},
rx.m_data);
}
return r;
message_buffer<T>* msg;
int counter;
std::vector<rank_type> neighs;
std::vector<tag_type> tags;
message_buffer<T>& message() noexcept { return *msg; }
auto message_size() const noexcept { return msg->size() * sizeof(T); }
tag_type tag(std::size_t i) const noexcept { return tags[i]; }
auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); }
};
const int n = neighs.size();
return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), std::move(tags)},
std::forward<CallBack>(callback));
}

template<typename T, typename CallBack>
send_request send_multi(message_buffer<T> const& msg, std::vector<rank_type> const& neighs,
send_request send_multi(message_buffer<T> const& msg, std::vector<rank_type> neighs,
tag_type tag, CallBack&& callback)
{
OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CallBack)
assert(msg);
auto& scheduled = m_schedule->scheduled_sends;
scheduled += neighs.size();
struct msg_ref_count
{
message_buffer<T> const* msg;
int counter;
std::vector<rank_type> neighs;
tag_type tags;
message_buffer<T> const& message() const noexcept { return *msg; }
auto message_size() const noexcept { return msg->size() * sizeof(T); }
tag_type tag(std::size_t) const noexcept { return tags; }
auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); }
};
const int n = neighs.size();
return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), tag},
std::forward<CallBack>(callback));
}

send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
const auto s = msg.size();
auto m_ptr = msg.m.m_heap_ptr.get();
auto m = new msg_ref_count{&msg, {(int)neighs.size()}, neighs};

for (auto id : neighs)
template<typename T, typename CallBack>
send_request send_multi(message_buffer<T> const& msg, std::vector<rank_type> neighs,
std::vector<tag_type> tags, CallBack&& callback)
{
OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CallBack)
assert(neighs.size() == tags.size());
assert(msg);
struct msg_ref_count
{
send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
send(
m_ptr, s * sizeof(T), id, tag,
[rd = r.m_data, rdx = rx.m_data, m, tag,
cb = std::forward<CallBack>(callback)]() mutable
{
if ((--(m->counter)) == 0)
{
cb(*(m->msg), std::move(m->neighs), tag);
delete m;
rd->m_ready = true;
}
--(*(rd->m_scheduled));
},
rx.m_data);
}
return r;
message_buffer<T> const* msg;
int counter;
std::vector<rank_type> neighs;
std::vector<tag_type> tags;
message_buffer<T> const& message() const noexcept { return *msg; }
auto message_size() const noexcept { return msg->size() * sizeof(T); }
tag_type tag(std::size_t i) const noexcept { return tags[i]; }
auto m_ptr() const noexcept { return msg->m.m_heap_ptr.get(); }
};
const int n = neighs.size();
return send_multi_impl(new msg_ref_count{&msg, n, std::move(neighs), std::move(tags)},
std::forward<CallBack>(callback));
}

void progress();
Expand All @@ -485,6 +534,37 @@ class communicator

void recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src,
tag_type tag, util::unique_function<void()> cb, shared_request_ptr req);

template<typename M, typename CallBack>
send_request send_multi_impl(M* m, CallBack&& callback)
{
auto& scheduled = m_schedule->scheduled_sends;
const auto n = m->neighs.size();
scheduled += n;
send_request r(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
auto m_ptr = m->m_ptr();

for (std::size_t i = 0; i < n; ++i)
{
const auto id = m->neighs[i];
const auto tag = m->tag(i);
send_request rx(shared_request_ptr(m_pool.get(), m_impl, &scheduled));
send(
m_ptr, m->message_size(), id, tag,
[rd = r.m_data, rdx = rx.m_data, m, cb = std::forward<CallBack>(callback)]() mutable
{
if ((--(m->counter)) == 0)
{
cb(m->message(), std::move(m->neighs), std::move(m->tags));
delete m;
rd->m_ready = true;
}
--(*(rd->m_scheduled));
},
rx.m_data);
}
return r;
}
};

} // namespace oomph
6 changes: 6 additions & 0 deletions include/oomph/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class context
~context();

public:
int rank() const noexcept;

int size() const noexcept;

int local_size() const noexcept;

MPI_Comm mpi_comm() const noexcept { return m_mpi_comm.get(); }

template<typename T>
Expand Down
34 changes: 26 additions & 8 deletions include/oomph/detail/communicator_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

#include <boost/callable_traits.hpp>

#define OOMPH_CHECK_CALLBACK_F(CALLBACK, RANK_TYPE) \
#define OOMPH_CHECK_CALLBACK_F(CALLBACK, RANK_TYPE, TAG_TYPE) \
using args_t = boost::callable_traits::args_t<std::remove_reference_t<CALLBACK>>; \
using arg0_t = std::tuple_element_t<0, args_t>; \
using arg1_t = std::tuple_element_t<1, args_t>; \
using arg2_t = std::tuple_element_t<2, args_t>; \
static_assert(std::tuple_size<args_t>::value == 3, "callback must have 3 arguments"); \
static_assert(std::is_same<arg1_t, RANK_TYPE>::value, \
"rank_type is not convertible to second callback argument type"); \
static_assert(std::is_same<arg2_t, tag_type>::value, \
static_assert(std::is_same<arg2_t, TAG_TYPE>::value, \
"tag_type is not convertible to third callback argument type"); \
using TT = typename std::remove_reference_t<arg0_t>::value_type;

Expand All @@ -38,36 +38,54 @@

#define OOMPH_CHECK_CALLBACK(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type) \
OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \
OOMPH_CHECK_CALLBACK_MSG \
}

#define OOMPH_CHECK_CALLBACK_MULTI(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>) \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>, tag_type) \
OOMPH_CHECK_CALLBACK_MSG \
}

#define OOMPH_CHECK_CALLBACK_MULTI_TAGS(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>, std::vector<tag_type>) \
OOMPH_CHECK_CALLBACK_MSG \
}

#define OOMPH_CHECK_CALLBACK_REF(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type) \
OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \
OOMPH_CHECK_CALLBACK_MSG_REF \
}

#define OOMPH_CHECK_CALLBACK_MULTI_REF(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>) \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>, tag_type) \
OOMPH_CHECK_CALLBACK_MSG_REF \
}

#define OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>, std::vector<tag_type>) \
OOMPH_CHECK_CALLBACK_MSG_REF \
}

#define OOMPH_CHECK_CALLBACK_CONST_REF(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type) \
OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \
OOMPH_CHECK_CALLBACK_MSG_CONST_REF \
}

#define OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>) \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>, tag_type) \
OOMPH_CHECK_CALLBACK_MSG_CONST_REF \
}

#define OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CALLBACK) \
{ \
OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector<rank_type>, std::vector<tag_type>) \
OOMPH_CHECK_CALLBACK_MSG_CONST_REF \
}
Loading