4343#include " ttg/parsec/fwd.h"
4444
4545#include " ttg/parsec/broadcast_keygen.h"
46+ #include " ttg/parsec/broadcast.h"
4647#include " ttg/parsec/buffer.h"
4748#include " ttg/parsec/devicescratch.h"
4849#include " ttg/parsec/thread_local.h"
@@ -2381,7 +2382,7 @@ namespace ttg_parsec {
23812382 auto bcast_key_tuple = broadcast_keygen_tuple_type ();
23822383 auto local_bcast_keys_tuple = broadcast_keygen_tuple_type ();
23832384 std::optional<std::set<ttg::device::Device>> deviceset = std::set<ttg::device::Device>();
2384- std::optional<std::set<int >> procset = std::nullopt ;
2385+ std::optional<std::set<int >> procset = std::set< int >() ;
23852386 broadcast_keygen_cb (key, bcast_key_tuple);
23862387
23872388 /* collect the processes that are involved */
@@ -2392,14 +2393,13 @@ namespace ttg_parsec {
23922393 std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
23932394
23942395 set_arg_fetch_value_and_release<value_type>(msg, deviceset.value (),
2395- [this , local_bcast_keys_tuple = std::move (local_bcast_keys_tuple)]
2396+ [this , key, procset = std::move (procset), local_bcast_keys_tuple = std::move (local_bcast_keys_tuple)]
23962397 (detail::ttg_data_copy_t *copy) {
23972398 value_type& value = *reinterpret_cast <value_type *>(copy->get_ptr ());
23982399
23992400 auto dtm = DummyTaskManager (this , copy); // set the parsec_ttg_caller and holds on to it until the end of the function
24002401
2401- bcast_keygen_local (local_bcast_keys_tuple, value,
2402- std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
2402+ broadcast_keygen_select (local_bcast_keys_tuple, procset.value (), key, value);
24032403 });
24042404
24052405 } else {
@@ -3434,99 +3434,136 @@ namespace ttg_parsec {
34343434 }
34353435 }
34363436
3437- virtual void broadcast_keygen (const void *key_ptr, const void *value_ptr) override final {
3438- // assuming that all output types are the same
3439- if constexpr (std::tuple_size_v<output_terminalsT> > 0 && !ttg::meta::is_void_v<key_type>) {
3440- using value_type = std::tuple_element_t <0 , ttg::edges_to_output_value_types_t <output_edges_type>>;
3441- const key_type& key = *static_cast <const key_type*>(key_ptr);
3442- const value_type& value = *static_cast <const value_type*>(value_ptr);
3443- auto world = ttg::default_execution_context ();
3444- int myrank = world.rank ();
3445- auto bcast_key_tuple = broadcast_keygen_tuple_type ();
3446- auto local_bcast_keys_tuple = broadcast_keygen_tuple_type ();
3447- broadcast_keygen_cb (key, bcast_key_tuple);
34483437
3449- /* collect the processes that are involved */
3450- std::optional<std::set<int >> procset = std::set<int >();
3451- std::optional<std::set<ttg::device::Device>> deviceset = std::nullopt ;
3452- keygen_query_successor (bcast_key_tuple, procset, deviceset, // ignore the devices here
3453- myrank, local_bcast_keys_tuple,
3454- std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3438+ template <template <class > class Bcast , typename Key, typename Value>
3439+ void broadcast_keygen_impl (const auto & local_bcast_keys_tuple,
3440+ const auto & procset,
3441+ const Key& key,
3442+ const Value& value) {
34553443
3456- int num_remote_peers = (procset->contains (myrank) ? procset->size () - 1 : procset->size ());
3444+ int myrank = world.rank ();
3445+ int root = keymap (key);
3446+ auto bcast = Bcast (root, myrank, procset.begin (), procset.end ());
34573447
3458- auto &world_impl = world.impl ();
3459- /* broadcast the key and value to all successor processes */
3460- if (num_remote_peers > 0 ) {
3461- uint64_t pos = 0 ;
3462- using msg_type = detail::msg_t ;
3463- std::unique_ptr<msg_type> msg = std::make_unique<msg_type>(get_instance_id (), world_impl.taskpool ()->taskpool_id ,
3464- msg_header_t ::MSG_BCAST_KEYGEN, 0 , world_impl.rank ());
3465- auto * copy = detail::find_copy_in_task (detail::parsec_ttg_caller, &value);
3466- assert (nullptr != copy);
3467- /* TODO: this assumes the worst case: that all keys are packed at once (i.e., go to the same remote). Can we do better?*/
3468- bool inline_data = can_inline_data (&value, copy, key, 1 );
3469- msg->tt_id .inline_data = inline_data;
3470-
3471- /* register the memory regions */
3472- std::vector<std::pair<int32_t , std::shared_ptr<void >>> memregs;
3473- memregs = register_bcast_data (value, inline_data, msg, pos);
3474- msg->tt_id .num_iovecs = memregs.size ();
3475- int num_iovs = memregs.size ();
3476-
3477- /* Repack registrations and register readers for each peer
3478- * so that each peer can release the data once their transfer is done.
3479- * The last peer to complete will release the registration. */
3480- uint64_t save_pos = pos;
3481-
3482- for (auto iter = procset->begin (); iter != procset->end (); ++iter) {
3483- int proc = *iter;
3484- if (proc == myrank) continue ; // local rank will be handled below
3485- using msg_t = detail::msg_t ;
3486- pos = save_pos;
3487-
3488- if (!inline_data) {
3489- for (int idx = 0 ; idx < num_iovs; ++idx) {
3490- int32_t lreg_size;
3491- std::shared_ptr<void > lreg_ptr;
3492- std::tie (lreg_size, lreg_ptr) = memregs[idx];
3493- std::memcpy (msg->bytes + pos, &lreg_size, sizeof (lreg_size));
3494- pos += sizeof (lreg_size);
3495- std::memcpy (msg->bytes + pos, lreg_ptr.get (), lreg_size);
3496- pos += lreg_size;
3497- /* mark another reader on the copy */
3498- copy = detail::register_data_copy<value_type>(copy, nullptr , true );
3499- /* create a function that will be invoked upon RMA completion at the target */
3500- std::function<void (void )> *fn = new std::function<void (void )>([=]() mutable {
3501- /* shared_ptr of value and registration captured by value so resetting
3502- * them here (through get_remote_complete_cb) will eventually release
3503- * the memory/registration */
3504- lreg_ptr.reset ();
3505- detail::release_data_copy (copy);
3506- });
3507- std::intptr_t fn_ptr{reinterpret_cast <std::intptr_t >(fn)};
3508- std::memcpy (msg->bytes + pos, &fn_ptr, sizeof (fn_ptr));
3509- pos += sizeof (fn_ptr);
3510- }
3511- }
3448+ if (bcast.has_peers ()) {
35123449
3513- /* pack the key and set the right offset */
3514- msg->tt_id .key_offset = pos;
3515- pos = pack (key, msg->bytes , pos);
3450+ using msg_type = detail::msg_t ;
3451+ auto & world_impl = world.impl ();
3452+ uint64_t pos = 0 ;
3453+ std::unique_ptr<msg_type> msg = std::make_unique<msg_type>(get_instance_id (), world_impl.taskpool ()->taskpool_id ,
3454+ msg_header_t ::MSG_BCAST_KEYGEN, 0 , world_impl.rank ());
3455+ auto * copy = detail::find_copy_in_task (detail::parsec_ttg_caller, &value);
3456+ assert (nullptr != copy);
3457+ /* TODO: this assumes the worst case: that all keys are packed at once (i.e., go to the same remote). Can we do better?*/
3458+ bool inline_data = can_inline_data (&value, copy, key, 1 );
3459+ msg->tt_id .inline_data = inline_data;
3460+
3461+ /* register the memory regions */
3462+ std::vector<std::pair<int32_t , std::shared_ptr<void >>> memregs;
3463+ memregs = register_bcast_data (value, inline_data, msg, pos);
3464+ msg->tt_id .num_iovecs = memregs.size ();
3465+ int num_iovs = memregs.size ();
3466+
3467+ /* Repack registrations and register readers for each peer
3468+ * so that each peer can release the data once their transfer is done.
3469+ * The last peer to complete will release the registration. */
3470+ uint64_t save_pos = pos;
3471+
3472+ bcast ([&](int proc) {
3473+ assert (proc != myrank);
3474+ if (proc == myrank) return ; // local rank will be handled below
3475+ using msg_t = detail::msg_t ;
3476+ pos = save_pos;
35163477
3517- parsec_taskpool_t *tp = world_impl.taskpool ();
3518- tp->tdm .module ->outgoing_message_start (tp, proc, NULL );
3519- tp->tdm .module ->outgoing_message_pack (tp, proc, NULL , NULL , 0 );
3520- parsec_ce.send_am (&parsec_ce, world_impl.parsec_ttg_tag (), proc, static_cast <void *>(msg.get ()),
3521- sizeof (msg_header_t ) + pos);
3478+ if (!inline_data) {
3479+ for (int idx = 0 ; idx < num_iovs; ++idx) {
3480+ int32_t lreg_size;
3481+ std::shared_ptr<void > lreg_ptr;
3482+ std::tie (lreg_size, lreg_ptr) = memregs[idx];
3483+ std::memcpy (msg->bytes + pos, &lreg_size, sizeof (lreg_size));
3484+ pos += sizeof (lreg_size);
3485+ std::memcpy (msg->bytes + pos, lreg_ptr.get (), lreg_size);
3486+ pos += lreg_size;
3487+ /* mark another reader on the copy */
3488+ copy = detail::register_data_copy<Value>(copy, nullptr , true );
3489+ /* create a function that will be invoked upon RMA completion at the target */
3490+ std::function<void (void )> *fn = new std::function<void (void )>([=]() mutable {
3491+ /* shared_ptr of value and registration captured by value so resetting
3492+ * them here (through get_remote_complete_cb) will eventually release
3493+ * the memory/registration */
3494+ lreg_ptr.reset ();
3495+ detail::release_data_copy (copy);
3496+ });
3497+ std::intptr_t fn_ptr{reinterpret_cast <std::intptr_t >(fn)};
3498+ std::memcpy (msg->bytes + pos, &fn_ptr, sizeof (fn_ptr));
3499+ pos += sizeof (fn_ptr);
3500+ }
35223501 }
3523- }
35243502
3525- if (procset->contains (myrank)) {
3526- // local broadcast
3527- bcast_keygen_local (local_bcast_keys_tuple, value,
3528- std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3529- }
3503+ /* pack the key and set the right offset */
3504+ msg->tt_id .key_offset = pos;
3505+ pos = pack (key, msg->bytes , pos);
3506+
3507+ parsec_taskpool_t *tp = world_impl.taskpool ();
3508+ tp->tdm .module ->outgoing_message_start (tp, proc, NULL );
3509+ tp->tdm .module ->outgoing_message_pack (tp, proc, NULL , NULL , 0 );
3510+ parsec_ce.send_am (&parsec_ce, world_impl.parsec_ttg_tag (), proc, static_cast <void *>(msg.get ()),
3511+ sizeof (msg_header_t ) + pos);
3512+ });
3513+ }
3514+
3515+ if (procset.contains (myrank)) {
3516+ // local broadcast
3517+ bcast_keygen_local (local_bcast_keys_tuple, value,
3518+ std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3519+ }
3520+ }
3521+
3522+ template <typename Key, typename Value>
3523+ void broadcast_keygen (const Key& key, const Value& value) {
3524+ int myrank = world.rank ();
3525+ int root = keymap (key);
3526+ auto bcast_key_tuple = broadcast_keygen_tuple_type ();
3527+ auto local_bcast_keys_tuple = broadcast_keygen_tuple_type ();
3528+ broadcast_keygen_cb (key, bcast_key_tuple);
3529+
3530+ /* collect the processes that are involved */
3531+ std::optional<std::set<int >> procset = std::set<int >();
3532+ std::optional<std::set<ttg::device::Device>> deviceset = std::nullopt ;
3533+ keygen_query_successor (bcast_key_tuple, procset, deviceset, // ignore the devices here
3534+ myrank, local_bcast_keys_tuple,
3535+ std::make_index_sequence<std::tuple_size_v<broadcast_keygen_tuple_type>>{});
3536+
3537+ broadcast_keygen_select (local_bcast_keys_tuple, procset.value (), key, value);
3538+ }
3539+
3540+ template <typename Key, typename Value>
3541+ void broadcast_keygen_select (const auto & local_bcast_keys_tuple,
3542+ const auto & procset,
3543+ const Key& key,
3544+ const Value& value) {
3545+ BroadcastType bcast_type = get_broadcast_type ();
3546+ switch (bcast_type) {
3547+ case BroadcastType::Pipe:
3548+ broadcast_keygen_impl<BroadcastPipe>(local_bcast_keys_tuple, procset, key, value);
3549+ break ;
3550+ case BroadcastType::Star:
3551+ broadcast_keygen_impl<BroadcastStar>(local_bcast_keys_tuple, procset, key, value);
3552+ break ;
3553+ default :
3554+ throw std::runtime_error (" Error: unknown broadcast type" );
3555+ }
3556+ }
3557+
3558+
3559+ virtual void broadcast_keygen (const void *key_ptr, const void *value_ptr) override final {
3560+ // assuming that all output types are the same
3561+ if constexpr (std::tuple_size_v<output_terminalsT> > 0 && !ttg::meta::is_void_v<key_type>) {
3562+ using value_type = std::tuple_element_t <0 , ttg::edges_to_output_value_types_t <output_edges_type>>;
3563+ const key_type& key = *static_cast <const key_type*>(key_ptr);
3564+ const value_type& value = *static_cast <const value_type*>(value_ptr);
3565+
3566+ broadcast_keygen (key, value);
35303567 } else {
35313568 throw std::runtime_error (" Error: broadcast_keygen invoked on a ttg::Task with no output terminals" );
35323569 }
0 commit comments