Skip to content

Commit e7b4667

Browse files
committed
Add Star and Pipe broadcasts for keygen broadcasts
Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
1 parent 462fed0 commit e7b4667

2 files changed

Lines changed: 227 additions & 90 deletions

File tree

ttg/ttg/parsec/broadcast.h

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#ifndef TTG_PARSEC_BROADCAST_H
2+
#define TTG_PARSEC_BROADCAST_H
3+
4+
#include <ttg/util/span.h>
5+
#include <cstdlib>
6+
#include <mutex>
7+
8+
9+
namespace ttg_parsec {
10+
11+
enum class BroadcastType {
12+
Star,
13+
Pipe
14+
};
15+
16+
BroadcastType get_broadcast_type() {
17+
static std::once_flag init_flag;
18+
static BroadcastType bcast_type = BroadcastType::Star;
19+
std::call_once(init_flag, [&](){
20+
const char *bcast_type_env = std::getenv("TTG_BCAST_TYPE");
21+
if (bcast_type_env) {
22+
if (std::strcmp(bcast_type_env, "star") == 0) {
23+
bcast_type = BroadcastType::Star;
24+
} else if (std::strcmp(bcast_type_env, "pipe") == 0) {
25+
bcast_type = BroadcastType::Pipe;
26+
}
27+
}
28+
});
29+
return bcast_type;
30+
}
31+
32+
template<typename Iter>
33+
class BroadcastStar {
34+
35+
int m_root;
36+
int m_me;
37+
Iter m_procs_begin;
38+
Iter m_procs_end;
39+
40+
public:
41+
42+
BroadcastStar(int root, int me, Iter procs_begin, Iter procs_end)
43+
: m_root(root), m_me(me), m_procs_begin(procs_begin), m_procs_end(procs_end)
44+
{}
45+
46+
bool has_peers() const {
47+
return m_me == m_root && m_procs_begin != m_procs_end;
48+
}
49+
50+
template<typename SendFn>
51+
void operator()(SendFn&& send_fn) {
52+
if (has_peers()) {
53+
for (auto it = m_procs_begin; it != m_procs_end; ++it) {
54+
int p = *it;
55+
if (p != m_root && p != m_me) {
56+
send_fn(p);
57+
}
58+
}
59+
}
60+
}
61+
};
62+
63+
template<typename Iter>
64+
class BroadcastPipe {
65+
66+
int m_root;
67+
int m_me;
68+
Iter m_procs_begin;
69+
Iter m_procs_end;
70+
71+
public:
72+
73+
BroadcastPipe(int root, int me, Iter procs_begin, Iter procs_end)
74+
: m_root(root), m_me(me), m_procs_begin(procs_begin), m_procs_end(procs_end)
75+
{
76+
assert(std::is_sorted(m_procs_begin, m_procs_end));
77+
}
78+
79+
bool has_peers() const {
80+
auto me_iter = std::find(m_procs_begin, m_procs_end, m_me);
81+
return me_iter != m_procs_end;
82+
}
83+
84+
template<typename SendFn>
85+
void operator()(SendFn&& send_fn) {
86+
auto iter = std::find(m_procs_begin, m_procs_end, m_me);
87+
assert(iter != m_procs_end);
88+
// wrap around if we reached the end
89+
if ((++iter) == m_procs_end) iter = m_procs_begin;
90+
// if we reached root we reached the end of the pipe
91+
if (*iter != m_root) {
92+
//while (next != m_procs_end && (*next == m_me || *next == m_root)) ++next;
93+
send_fn(*iter);
94+
}
95+
}
96+
};
97+
98+
} // namespace ttg_parsec
99+
100+
#endif // TTG_PARSEC_BROADCAST_H

ttg/ttg/parsec/ttg.h

Lines changed: 127 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "ttg/parsec/fwd.h"
4444

4545
#include "ttg/parsec/broadcast_keygen.h"
46+
#include "ttg/parsec/broadcast.h"
4647
#include "ttg/parsec/buffer.h"
4748
#include "ttg/parsec/devicescratch.h"
4849
#include "ttg/parsec/thread_local.h"
@@ -2381,7 +2382,7 @@ namespace ttg_parsec {
23812382
auto bcast_key_tuple = broadcast_keygen_tuple_type();
23822383
auto local_bcast_keys_tuple = broadcast_keygen_tuple_type();
23832384
std::optional<std::set<ttg::device::Device>> deviceset = std::set<ttg::device::Device>();
2384-
std::optional<std::set<int>> procset = std::nullopt;
2385+
std::optional<std::set<int>> procset = std::set<int>();
23852386
broadcast_keygen_cb(key, bcast_key_tuple);
23862387

23872388
/* collect the processes that are involved */
@@ -2392,14 +2393,13 @@ namespace ttg_parsec {
23922393
std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
23932394

23942395
set_arg_fetch_value_and_release<value_type>(msg, deviceset.value(),
2395-
[this, local_bcast_keys_tuple = std::move(local_bcast_keys_tuple)]
2396+
[this, key, procset = std::move(procset), local_bcast_keys_tuple = std::move(local_bcast_keys_tuple)]
23962397
(detail::ttg_data_copy_t *copy) {
23972398
value_type& value = *reinterpret_cast<value_type *>(copy->get_ptr());
23982399

23992400
auto dtm = DummyTaskManager(this, copy); // set the parsec_ttg_caller and holds on to it until the end of the function
24002401

2401-
bcast_keygen_local(local_bcast_keys_tuple, value,
2402-
std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
2402+
broadcast_keygen_select(local_bcast_keys_tuple, procset.value(), key, value);
24032403
});
24042404

24052405
} else {
@@ -3434,99 +3434,136 @@ namespace ttg_parsec {
34343434
}
34353435
}
34363436

3437-
virtual void broadcast_keygen(const void *key_ptr, const void *value_ptr) override final {
3438-
// assuming that all output types are the same
3439-
if constexpr (std::tuple_size_v<output_terminalsT> > 0 && !ttg::meta::is_void_v<key_type>) {
3440-
using value_type = std::tuple_element_t<0, ttg::edges_to_output_value_types_t<output_edges_type>>;
3441-
const key_type& key = *static_cast<const key_type*>(key_ptr);
3442-
const value_type& value = *static_cast<const value_type*>(value_ptr);
3443-
auto world = ttg::default_execution_context();
3444-
int myrank = world.rank();
3445-
auto bcast_key_tuple = broadcast_keygen_tuple_type();
3446-
auto local_bcast_keys_tuple = broadcast_keygen_tuple_type();
3447-
broadcast_keygen_cb(key, bcast_key_tuple);
34483437

3449-
/* collect the processes that are involved */
3450-
std::optional<std::set<int>> procset = std::set<int>();
3451-
std::optional<std::set<ttg::device::Device>> deviceset = std::nullopt;
3452-
keygen_query_successor(bcast_key_tuple, procset, deviceset, // ignore the devices here
3453-
myrank, local_bcast_keys_tuple,
3454-
std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3438+
template<template<class> class Bcast, typename Key, typename Value>
3439+
void broadcast_keygen_impl(const auto& local_bcast_keys_tuple,
3440+
const auto& procset,
3441+
const Key& key,
3442+
const Value& value) {
34553443

3456-
int num_remote_peers = (procset->contains(myrank) ? procset->size() - 1 : procset->size());
3444+
int myrank = world.rank();
3445+
int root = keymap(key);
3446+
auto bcast = Bcast(root, myrank, procset.begin(), procset.end());
34573447

3458-
auto &world_impl = world.impl();
3459-
/* broadcast the key and value to all successor processes */
3460-
if (num_remote_peers > 0) {
3461-
uint64_t pos = 0;
3462-
using msg_type = detail::msg_t;
3463-
std::unique_ptr<msg_type> msg = std::make_unique<msg_type>(get_instance_id(), world_impl.taskpool()->taskpool_id,
3464-
msg_header_t::MSG_BCAST_KEYGEN, 0, world_impl.rank());
3465-
auto* copy = detail::find_copy_in_task(detail::parsec_ttg_caller, &value);
3466-
assert(nullptr != copy);
3467-
/* TODO: this assumes the worst case: that all keys are packed at once (i.e., go to the same remote). Can we do better?*/
3468-
bool inline_data = can_inline_data(&value, copy, key, 1);
3469-
msg->tt_id.inline_data = inline_data;
3470-
3471-
/* register the memory regions */
3472-
std::vector<std::pair<int32_t, std::shared_ptr<void>>> memregs;
3473-
memregs = register_bcast_data(value, inline_data, msg, pos);
3474-
msg->tt_id.num_iovecs = memregs.size();
3475-
int num_iovs = memregs.size();
3476-
3477-
/* Repack registrations and register readers for each peer
3478-
* so that each peer can release the data once their transfer is done.
3479-
* The last peer to complete will release the registration. */
3480-
uint64_t save_pos = pos;
3481-
3482-
for (auto iter = procset->begin(); iter != procset->end(); ++iter) {
3483-
int proc = *iter;
3484-
if (proc == myrank) continue; // local rank will be handled below
3485-
using msg_t = detail::msg_t;
3486-
pos = save_pos;
3487-
3488-
if (!inline_data) {
3489-
for (int idx = 0; idx < num_iovs; ++idx) {
3490-
int32_t lreg_size;
3491-
std::shared_ptr<void> lreg_ptr;
3492-
std::tie(lreg_size, lreg_ptr) = memregs[idx];
3493-
std::memcpy(msg->bytes + pos, &lreg_size, sizeof(lreg_size));
3494-
pos += sizeof(lreg_size);
3495-
std::memcpy(msg->bytes + pos, lreg_ptr.get(), lreg_size);
3496-
pos += lreg_size;
3497-
/* mark another reader on the copy */
3498-
copy = detail::register_data_copy<value_type>(copy, nullptr, true);
3499-
/* create a function that will be invoked upon RMA completion at the target */
3500-
std::function<void(void)> *fn = new std::function<void(void)>([=]() mutable {
3501-
/* shared_ptr of value and registration captured by value so resetting
3502-
* them here (through get_remote_complete_cb) will eventually release
3503-
* the memory/registration */
3504-
lreg_ptr.reset();
3505-
detail::release_data_copy(copy);
3506-
});
3507-
std::intptr_t fn_ptr{reinterpret_cast<std::intptr_t>(fn)};
3508-
std::memcpy(msg->bytes + pos, &fn_ptr, sizeof(fn_ptr));
3509-
pos += sizeof(fn_ptr);
3510-
}
3511-
}
3448+
if (bcast.has_peers()) {
35123449

3513-
/* pack the key and set the right offset */
3514-
msg->tt_id.key_offset = pos;
3515-
pos = pack(key, msg->bytes, pos);
3450+
using msg_type = detail::msg_t;
3451+
auto& world_impl = world.impl();
3452+
uint64_t pos = 0;
3453+
std::unique_ptr<msg_type> msg = std::make_unique<msg_type>(get_instance_id(), world_impl.taskpool()->taskpool_id,
3454+
msg_header_t::MSG_BCAST_KEYGEN, 0, world_impl.rank());
3455+
auto* copy = detail::find_copy_in_task(detail::parsec_ttg_caller, &value);
3456+
assert(nullptr != copy);
3457+
/* TODO: this assumes the worst case: that all keys are packed at once (i.e., go to the same remote). Can we do better?*/
3458+
bool inline_data = can_inline_data(&value, copy, key, 1);
3459+
msg->tt_id.inline_data = inline_data;
3460+
3461+
/* register the memory regions */
3462+
std::vector<std::pair<int32_t, std::shared_ptr<void>>> memregs;
3463+
memregs = register_bcast_data(value, inline_data, msg, pos);
3464+
msg->tt_id.num_iovecs = memregs.size();
3465+
int num_iovs = memregs.size();
3466+
3467+
/* Repack registrations and register readers for each peer
3468+
* so that each peer can release the data once their transfer is done.
3469+
* The last peer to complete will release the registration. */
3470+
uint64_t save_pos = pos;
3471+
3472+
bcast([&](int proc) {
3473+
assert(proc != myrank);
3474+
if (proc == myrank) return; // local rank will be handled below
3475+
using msg_t = detail::msg_t;
3476+
pos = save_pos;
35163477

3517-
parsec_taskpool_t *tp = world_impl.taskpool();
3518-
tp->tdm.module->outgoing_message_start(tp, proc, NULL);
3519-
tp->tdm.module->outgoing_message_pack(tp, proc, NULL, NULL, 0);
3520-
parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), proc, static_cast<void *>(msg.get()),
3521-
sizeof(msg_header_t) + pos);
3478+
if (!inline_data) {
3479+
for (int idx = 0; idx < num_iovs; ++idx) {
3480+
int32_t lreg_size;
3481+
std::shared_ptr<void> lreg_ptr;
3482+
std::tie(lreg_size, lreg_ptr) = memregs[idx];
3483+
std::memcpy(msg->bytes + pos, &lreg_size, sizeof(lreg_size));
3484+
pos += sizeof(lreg_size);
3485+
std::memcpy(msg->bytes + pos, lreg_ptr.get(), lreg_size);
3486+
pos += lreg_size;
3487+
/* mark another reader on the copy */
3488+
copy = detail::register_data_copy<Value>(copy, nullptr, true);
3489+
/* create a function that will be invoked upon RMA completion at the target */
3490+
std::function<void(void)> *fn = new std::function<void(void)>([=]() mutable {
3491+
/* shared_ptr of value and registration captured by value so resetting
3492+
* them here (through get_remote_complete_cb) will eventually release
3493+
* the memory/registration */
3494+
lreg_ptr.reset();
3495+
detail::release_data_copy(copy);
3496+
});
3497+
std::intptr_t fn_ptr{reinterpret_cast<std::intptr_t>(fn)};
3498+
std::memcpy(msg->bytes + pos, &fn_ptr, sizeof(fn_ptr));
3499+
pos += sizeof(fn_ptr);
3500+
}
35223501
}
3523-
}
35243502

3525-
if (procset->contains(myrank)) {
3526-
// local broadcast
3527-
bcast_keygen_local(local_bcast_keys_tuple, value,
3528-
std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3529-
}
3503+
/* pack the key and set the right offset */
3504+
msg->tt_id.key_offset = pos;
3505+
pos = pack(key, msg->bytes, pos);
3506+
3507+
parsec_taskpool_t *tp = world_impl.taskpool();
3508+
tp->tdm.module->outgoing_message_start(tp, proc, NULL);
3509+
tp->tdm.module->outgoing_message_pack(tp, proc, NULL, NULL, 0);
3510+
parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), proc, static_cast<void *>(msg.get()),
3511+
sizeof(msg_header_t) + pos);
3512+
});
3513+
}
3514+
3515+
if (procset.contains(myrank)) {
3516+
// local broadcast
3517+
bcast_keygen_local(local_bcast_keys_tuple, value,
3518+
std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3519+
}
3520+
}
3521+
3522+
template<typename Key, typename Value>
3523+
void broadcast_keygen(const Key& key, const Value& value) {
3524+
int myrank = world.rank();
3525+
int root = keymap(key);
3526+
auto bcast_key_tuple = broadcast_keygen_tuple_type();
3527+
auto local_bcast_keys_tuple = broadcast_keygen_tuple_type();
3528+
broadcast_keygen_cb(key, bcast_key_tuple);
3529+
3530+
/* collect the processes that are involved */
3531+
std::optional<std::set<int>> procset = std::set<int>();
3532+
std::optional<std::set<ttg::device::Device>> deviceset = std::nullopt;
3533+
keygen_query_successor(bcast_key_tuple, procset, deviceset, // ignore the devices here
3534+
myrank, local_bcast_keys_tuple,
3535+
std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3536+
3537+
broadcast_keygen_select(local_bcast_keys_tuple, procset.value(), key, value);
3538+
}
3539+
3540+
template<typename Key, typename Value>
3541+
void broadcast_keygen_select(const auto& local_bcast_keys_tuple,
3542+
const auto& procset,
3543+
const Key& key,
3544+
const Value& value) {
3545+
BroadcastType bcast_type = get_broadcast_type();
3546+
switch (bcast_type) {
3547+
case BroadcastType::Pipe:
3548+
broadcast_keygen_impl<BroadcastPipe>(local_bcast_keys_tuple, procset, key, value);
3549+
break;
3550+
case BroadcastType::Star:
3551+
broadcast_keygen_impl<BroadcastStar>(local_bcast_keys_tuple, procset, key, value);
3552+
break;
3553+
default:
3554+
throw std::runtime_error("Error: unknown broadcast type");
3555+
}
3556+
}
3557+
3558+
3559+
virtual void broadcast_keygen(const void *key_ptr, const void *value_ptr) override final {
3560+
// assuming that all output types are the same
3561+
if constexpr (std::tuple_size_v<output_terminalsT> > 0 && !ttg::meta::is_void_v<key_type>) {
3562+
using value_type = std::tuple_element_t<0, ttg::edges_to_output_value_types_t<output_edges_type>>;
3563+
const key_type& key = *static_cast<const key_type*>(key_ptr);
3564+
const value_type& value = *static_cast<const value_type*>(value_ptr);
3565+
3566+
broadcast_keygen(key, value);
35303567
} else {
35313568
throw std::runtime_error("Error: broadcast_keygen invoked on a ttg::Task with no output terminals");
35323569
}

0 commit comments

Comments
 (0)