diff --git a/examples/distributed_array.cpp b/examples/distributed_array.cpp index 39d8b5b9..f7b74d49 100644 --- a/examples/distributed_array.cpp +++ b/examples/distributed_array.cpp @@ -52,9 +52,9 @@ class distributed_array_t lci::status_t status; T value; do { - status = lci::post_get(target_rank, &value, sizeof(T), - lci::COMP_NULL_EXPECT_DONE_OR_RETRY, - local_index * sizeof(T), m_rmrs[target_rank]); + status = + lci::post_get(target_rank, &value, sizeof(T), lci::COMP_NULL_RETRY, + local_index * sizeof(T), m_rmrs[target_rank]); lci::progress(); } while (status.is_retry()); assert(status.is_done()); @@ -70,7 +70,7 @@ class distributed_array_t do { status = lci::post_put_x(target_rank, static_cast(const_cast(&value)), - sizeof(T), lci::COMP_NULL_EXPECT_DONE_OR_RETRY, + sizeof(T), lci::COMP_NULL_RETRY, local_index * sizeof(T), m_rmrs[target_rank]) .comp_semantic(lci::comp_semantic_t::network)(); lci::progress(); diff --git a/examples/pingpong_am_mt.cpp b/examples/pingpong_am_mt.cpp index 591c336f..4e922737 100644 --- a/examples/pingpong_am_mt.cpp +++ b/examples/pingpong_am_mt.cpp @@ -46,8 +46,7 @@ void worker(int thread_id) // sender for (int i = 0; i < nmsgs; i++) { // send a message - lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL_EXPECT_DONE, - rcomp) + lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL, rcomp) .device(device) .tag(thread_id)(); // wait for an incoming message @@ -85,8 +84,7 @@ void worker(int thread_id) } free(recv_buf.base); // send a message - lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL_EXPECT_DONE, - rcomp) + lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL, rcomp) .device(device) .tag(thread_id)(); } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d52846ac..a4fafca1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,7 +23,14 @@ target_sources_relative( runtime/runtime.cpp core/communicate.cpp core/progress.cpp - collective/collective.cpp) + collective/collective.cpp + collective/alltoall.cpp + collective/barrier.cpp + collective/broadcast.cpp + collective/gather.cpp + collective/reduce_scatter.cpp + collective/allreduce.cpp + collective/reduce.cpp) if(LCI_BACKEND_ENABLE_OFI) target_sources_relative(LCI PRIVATE network/ofi/backend_ofi.cpp) diff --git a/src/api/lci.hpp b/src/api/lci.hpp index 7b990dea..58dbe470 100644 --- a/src/api/lci.hpp +++ b/src/api/lci.hpp @@ -222,6 +222,60 @@ enum class net_opcode_t { */ const char* get_net_opcode_str(net_opcode_t opcode); +/** + * @ingroup LCI_BASIC + * @brief The type of broadcast algorithm. + */ +enum class broadcast_algorithm_t { + none, /**< automatically select the best algorithm */ + direct, /**< direct algorithm */ + tree, /**< binomial tree algorithm */ + ring, /**< ring algorithm */ +}; + +/** + * @brief Get the string representation of a collective algorithm. + * @param opcode The collective algorithm. + * @return The string representation of the collective algorithm. + */ +const char* get_broadcast_algorithm_str(broadcast_algorithm_t algorithm); + +/** + * @ingroup LCI_BASIC + * @brief The type of reduce scatter algorithm. + */ +enum class reduce_scatter_algorithm_t { + none, /**< automatically select the best algorithm */ + direct, /**< direct algorithm */ + tree, /**< reduce followed by broadcast */ + ring, /**< ring algorithm */ +}; + +/** + * @brief Get the string representation of a collective algorithm. + * @param opcode The collective algorithm. + * @return The string representation of the collective algorithm. + */ +const char* get_reduce_scatter_algorithm_str(broadcast_algorithm_t algorithm); + +/** + * @ingroup LCI_BASIC + * @brief The type of allreduce algorithm. + */ +enum class allreduce_algorithm_t { + none, /**< automatically select the best algorithm */ + direct, /**< direct algorithm */ + tree, /**< reduce followed by broadcast */ + ring, /**< ring algorithm */ +}; + +/** + * @brief Get the string representation of a collective algorithm. + * @param opcode The collective algorithm. + * @return The string representation of the collective algorithm. + */ +const char* get_allreduce_algorithm_str(broadcast_algorithm_t algorithm); + /** * @ingroup LCI_BASIC * @brief The type of network-layer immediate data field. @@ -581,16 +635,33 @@ struct status_t { * @ingroup LCI_BASIC * @brief Special completion object setting `allow_posted` to false. */ +const comp_t COMP_NULL = comp_t(reinterpret_cast(0x0)); + +/** + * @ingroup LCI_BASIC + * @brief Deprecated. Same as COMP_NULL. + */ const comp_t COMP_NULL_EXPECT_DONE = - comp_t(reinterpret_cast(0x1)); + comp_t(reinterpret_cast(0x0)); /** * @ingroup LCI_BASIC * @brief Special completion object setting `allow_posted` and `allow_retry` to * false. */ +const comp_t COMP_NULL_RETRY = comp_t(reinterpret_cast(0x1)); + +/** + * @ingroup LCI_BASIC + * @brief Deprecated. Same as COMP_NULL_RETRY. + */ const comp_t COMP_NULL_EXPECT_DONE_OR_RETRY = - comp_t(reinterpret_cast(0x2)); + comp_t(reinterpret_cast(0x1)); + +inline bool comp_t::is_empty() const +{ + return reinterpret_cast(p_impl) <= 1; +} /** * @ingroup LCI_BASIC @@ -641,6 +712,14 @@ const graph_node_t GRAPH_END = reinterpret_cast(0x2); */ using graph_node_run_cb_t = status_t (*)(void* value); +/** + * @ingroup LCI_BASIC + * @brief A dummy callback function for a graph node. + * @details This function can be used as a placeholder for a graph node that + * does not perform any operation. + */ +const graph_node_run_cb_t GRAPH_NODE_DUMMY_CB = nullptr; + /** * @ingroup LCI_BASIC * @brief The function signature for a callback that will be triggered when the diff --git a/src/binding/input/collective.py b/src/binding/input/collective.py index ffef7592..a897cf1b 100644 --- a/src/binding/input/collective.py +++ b/src/binding/input/collective.py @@ -14,7 +14,7 @@ optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."), optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."), optional_arg("comp_semantic_t", "comp_semantic", "comp_semantic_t::buffer", comment="The completion semantic."), - optional_arg("comp_t", "comp", "comp_t()", comment="The completion to signal when the operation completes."), + optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."), ], doc = { "in_group": "LCI_COLL", @@ -31,10 +31,13 @@ optional_arg("device_t", "device", "runtime.get_impl()->default_device", comment="The device to use."), optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."), optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."), + optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."), + optional_arg("broadcast_algorithm_t", "algorithm", "broadcast_algorithm_t::none", comment="The collective algorithm to use."), + optional_arg("int", "ring_nsteps", "get_rank_n() - 1", comment="The number of steps in the ring algorithm."), ], doc = { "in_group": "LCI_COLL", - "brief": "A blocking broadcast operation.", + "brief": "A broadcast operation.", } ), operation( @@ -56,6 +59,50 @@ "brief": "A blocking reduce operation.", } ), +operation( + "reduce_scatter", + [ + optional_runtime_args, + positional_arg("const void*", "sendbuf", comment="The local buffer base address to send."), + positional_arg("void*", "recvbuf", comment="The local buffer base address to recv."), + positional_arg("size_t", "recvcount", comment="The number of data items to receive one each rank."), + positional_arg("size_t", "item_size", comment="The size of each data item."), + positional_arg("reduce_op_t", "op", comment="The reduction operation."), + optional_arg("device_t", "device", "runtime.get_impl()->default_device", comment="The device to use."), + optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."), + optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."), + optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."), + optional_arg("reduce_scatter_algorithm_t", "algorithm", "reduce_scatter_algorithm_t::none", comment="The collective algorithm to use."), + optional_arg("int", "ring_nsteps", "get_rank_n() - 1", comment="The number of steps in the ring algorithm."), + ], + doc = { + "in_group": "LCI_COLL", + "brief": "A reduce scatter operation.", + "details": "This operation assumes the send count is equal to `recvcount * item_size` and " + "`sendbuf` is of size at least `recvcount * item_size * get_rank_n()`.", + } +), +operation( + "allreduce", + [ + optional_runtime_args, + positional_arg("const void*", "sendbuf", comment="The local buffer base address to send."), + positional_arg("void*", "recvbuf", comment="The local buffer base address to recv."), + positional_arg("size_t", "count", comment="The number of data items in the buffer."), + positional_arg("size_t", "item_size", comment="The size of each data item."), + positional_arg("reduce_op_t", "op", comment="The reduction operation."), + optional_arg("device_t", "device", "runtime.get_impl()->default_device", comment="The device to use."), + optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."), + optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."), + optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."), + optional_arg("allreduce_algorithm_t", "algorithm", "allreduce_algorithm_t::none", comment="The collective algorithm to use."), + optional_arg("int", "ring_nsteps", "get_rank_n() - 1", comment="The number of steps in the ring algorithm."), + ], + doc = { + "in_group": "LCI_COLL", + "brief": "An allreduce operation.", + } +), operation( "allgather", [ diff --git a/src/binding/input/comp.py b/src/binding/input/comp.py index bdbc1977..7e960dc4 100644 --- a/src/binding/input/comp.py +++ b/src/binding/input/comp.py @@ -16,6 +16,7 @@ attr_enum("cq_type", enum_options=["array_atomic", "lcrq"], default_value="lcrq", comment="The completion object type."), attr("int", "cq_default_length", default_value=65536, comment="The default length of the completion queue."), ], + custom_is_empty_method=True, doc = { "in_group": "LCI_COMPLETION", "brief": "The completion object resource.", @@ -160,7 +161,7 @@ operation( "alloc_graph", [ - optional_arg("comp_t", "comp", "comp_t()", comment="Another completion object to signal when the graph is completed. The graph will be automatically destroyed afterwards."), + optional_arg("comp_t", "comp", "COMP_NULL", comment="Another completion object to signal when the graph is completed. The graph will be automatically destroyed afterwards."), optional_arg("void*", "user_context", "nullptr", comment="The arbitrary user-defined context associated with this completion object."), optional_runtime_args, return_val("comp_t", "comp", comment="The allocated completion handler."), @@ -234,7 +235,7 @@ "brief": "Test a graph.", "details": "Successful test will reset the graph to the state that is ready to be started again.", } -) +), ] def get_input(): diff --git a/src/collective/allreduce.cpp b/src/collective/allreduce.cpp new file mode 100644 index 00000000..0d7c373b --- /dev/null +++ b/src/collective/allreduce.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +namespace +{ +struct reduce_wrapper_args_t { + void* tmp_buffer; + void* recvbuf; + size_t count; + reduce_op_t op; +}; + +status_t reduce_wrapper_fn(void* args) +{ + auto* wrapper_args = static_cast(args); + wrapper_args->op(wrapper_args->tmp_buffer, wrapper_args->recvbuf, + wrapper_args->recvbuf, wrapper_args->count); + delete wrapper_args; + return errorcode_t::done; +} + +void build_graph_direct(comp_t graph, post_send_x send_op, post_recv_x recv_op, + const void* sendbuf, void* recvbuf, size_t count, + size_t item_size, reduce_op_t op) +{ + int rank = get_rank_me(); + int nranks = get_rank_n(); + + size_t tmp_buffer_size = (recvbuf != sendbuf) + ? count * item_size * (nranks - 1) + : count * item_size * nranks; + void* tmp_buffer = malloc(tmp_buffer_size); + void* p = tmp_buffer; + if (recvbuf != sendbuf) { + // copy the send buffer to the receive buffer + memcpy(recvbuf, sendbuf, count * item_size); + } else { + // we need a temporary buffer to hold the send data + memcpy(p, sendbuf, count * item_size); + send_op.local_buffer(p); + p = static_cast(p) + count * item_size; + } + + graph_node_t prev_reduce_node = nullptr; + graph_node_t free_node = + graph_add_node_x(graph, [](void* tmp_buffer) -> status_t { + free(tmp_buffer); + return errorcode_t::done; + }).value(tmp_buffer)(); + for (int i = 1; i < nranks; ++i) { + int target = (i + rank) % nranks; + graph_node_t send_node = graph_add_node_op(graph, send_op.rank(target)); + graph_add_edge(graph, GRAPH_START, send_node); + graph_add_edge(graph, send_node, free_node); + graph_node_t recv_node = + graph_add_node_op(graph, recv_op.rank(target).local_buffer(p)); + graph_node_t reduce_node = + graph_add_node_x(graph, reduce_wrapper_fn) + .value(new reduce_wrapper_args_t{p, recvbuf, count, op})(); + graph_add_edge(graph, GRAPH_START, recv_node); + graph_add_edge(graph, recv_node, reduce_node); + if (prev_reduce_node) { + graph_add_edge(graph, prev_reduce_node, reduce_node); + } + prev_reduce_node = reduce_node; + p = static_cast(p) + count * item_size; + } + graph_add_edge(graph, prev_reduce_node, free_node); + graph_add_edge(graph, free_node, GRAPH_END); +} +} // namespace + +void allreduce_x::call_impl(const void* sendbuf, void* recvbuf, size_t count, + size_t item_size, reduce_op_t op, runtime_t runtime, + device_t device, endpoint_t endpoint, + matching_engine_t matching_engine, comp_t comp, + allreduce_algorithm_t algorithm, + [[maybe_unused]] int ring_nsteps) const +{ + int seqnum = get_sequence_number(); + int nranks = get_rank_n(); + + LCI_DBG_Log( + LOG_TRACE, "collective", + "enter allreduce %d (sendbuf %p recvbuf %p item_size %lu count %lu)\n", + seqnum, sendbuf, recvbuf, item_size, count); + if (nranks == 1) { + if (recvbuf != sendbuf) { + memcpy(recvbuf, sendbuf, item_size * count); + } + if (!comp.is_empty()) { + lci::comp_signal(comp, status_t(errorcode_t::done)); + } + return; + } + + comp_t graph = alloc_graph_x().runtime(runtime).comp(comp)(); + auto send_op = post_send_x(-1, const_cast(sendbuf), count * item_size, + seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + auto recv_op = post_recv_x(-1, recvbuf, count * item_size, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + + if (algorithm == allreduce_algorithm_t::none) { + // auto select the best algorithm + algorithm = allreduce_algorithm_t::direct; + } + + if (algorithm == allreduce_algorithm_t::direct) { + // direct algorithm + build_graph_direct(graph, send_op, recv_op, sendbuf, recvbuf, count, + item_size, op); + } else { + LCI_Assert(false, "Unsupported broadcast algorithm %d", + static_cast(algorithm)); + } + + graph_start(graph); + if (comp.is_empty()) { + // blocking wait + status_t status; + do { + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + status = graph_test(graph); + } while (status.is_retry()); + free_comp(&graph); + } + + LCI_DBG_Log( + LOG_TRACE, "collective", + "leave allreduce %d (sendbuf %p recvbuf %p item_size %lu count %lu)\n", + seqnum, sendbuf, recvbuf, item_size, count); +} + +} // namespace lci \ No newline at end of file diff --git a/src/collective/alltoall.cpp b/src/collective/alltoall.cpp new file mode 100644 index 00000000..7e0ea8ea --- /dev/null +++ b/src/collective/alltoall.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +void alltoall_x::call_impl(const void* sendbuf, void* recvbuf, size_t size, + runtime_t runtime, device_t device, + endpoint_t endpoint, + matching_engine_t matching_engine) const +{ + int seqnum = get_sequence_number(); + + int rank = get_rank_me(); + int nranks = get_rank_n(); + LCI_DBG_Log(LOG_TRACE, "collective", + "enter alltoall %d (sendbuf %p recvbuf %p size %lu)\n", seqnum, + sendbuf, recvbuf, size); + + comp_t comp = alloc_sync_x().threshold(2 * nranks - 2).runtime(runtime)(); + + for (int i = 0; i < nranks; ++i) { + void* current_recvbuf = + static_cast(static_cast(recvbuf) + i * size); + void* current_sendbuf = static_cast( + static_cast(const_cast(sendbuf)) + i * size); + if (i == rank) { + memcpy(current_recvbuf, current_sendbuf, size); + continue; + } + status_t status; + do { + status = post_recv_x(i, current_recvbuf, size, seqnum, comp) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_done(false)(); + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + } while (status.is_retry()); + do { + status = post_send_x(i, current_sendbuf, size, seqnum, comp) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_done(false)(); + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + } while (status.is_retry()); + } + + // sync_wait_x(comp, nullptr).runtime(runtime)(); + while (!sync_test_x(comp, nullptr).runtime(runtime)()) { + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + } + free_comp(&comp); +} + +} // namespace lci \ No newline at end of file diff --git a/src/collective/barrier.cpp b/src/collective/barrier.cpp new file mode 100644 index 00000000..ebdf2be9 --- /dev/null +++ b/src/collective/barrier.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +void barrier_x::call_impl(runtime_t runtime, device_t device, + endpoint_t endpoint, + matching_engine_t matching_engine, + comp_semantic_t comp_semantic, comp_t comp) const +{ + int seqnum = get_sequence_number(); + [[maybe_unused]] int round = 0; + int rank = get_rank_me(); + int nranks = get_rank_n(); + + // dissemination algorithm + LCI_DBG_Log(LOG_TRACE, "collective", "enter barrier %d\n", seqnum); + if (nranks == 1) { + if (!comp.is_empty()) { + lci::comp_signal(comp, status_t(errorcode_t::done)); + } + return; + } + comp_t graph = alloc_graph_x().runtime(runtime).comp(comp)(); + graph_node_t old_node = GRAPH_START; + graph_node_t dummy_node; + for (int jump = 1; jump < nranks; jump *= 2) { + int rank_to_recv = (rank - jump + nranks) % nranks; + int rank_to_send = (rank + jump) % nranks; + LCI_DBG_Log(LOG_TRACE, "collective", + "barrier %d round %d recv from %d send to %d\n", seqnum, + round++, rank_to_recv, rank_to_send); + + auto recv = post_recv_x(rank_to_recv, nullptr, 0, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + auto send = post_send_x(rank_to_send, nullptr, 0, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + if (jump * 2 >= nranks) { + // this is the last round, need to take care of the completion semantic + send = send.comp_semantic(comp_semantic); + dummy_node = GRAPH_END; + } else { + dummy_node = graph_add_node( + graph, [](void*) -> status_t { return errorcode_t::done; }); + } + auto recv_node = graph_add_node_op(graph, recv); + auto send_node = graph_add_node_op(graph, send); + graph_add_edge(graph, old_node, recv_node); + graph_add_edge(graph, old_node, send_node); + graph_add_edge(graph, recv_node, dummy_node); + graph_add_edge(graph, send_node, dummy_node); + old_node = dummy_node; + } + graph_start(graph); + if (comp.is_empty()) { + // blocking wait + status_t status; + do { + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + status = graph_test(graph); + } while (status.is_retry()); + free_comp(&graph); + } + LCI_DBG_Log(LOG_TRACE, "collective", "leave barrier %d\n", seqnum); +} + +} // namespace lci \ No newline at end of file diff --git a/src/collective/broadcast.cpp b/src/collective/broadcast.cpp new file mode 100644 index 00000000..93f93101 --- /dev/null +++ b/src/collective/broadcast.cpp @@ -0,0 +1,234 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +namespace +{ +void build_graph_direct(comp_t graph, post_send_x send_op, post_recv_x recv_op, + int root) +{ + int rank = get_rank_me(); + int nranks = get_rank_n(); + + if (rank == root) { + for (int i = 0; i < nranks; ++i) { + int target = (i + rank) % nranks; + if (target == rank) { + // skip self + continue; + } + graph_node_t send_node = graph_add_node_op(graph, send_op.rank(target)); + graph_add_edge(graph, GRAPH_START, send_node); + graph_add_edge(graph, send_node, GRAPH_END); + } + } else { + // non-root ranks only receive + graph_node_t recv_node = graph_add_node_op(graph, recv_op.rank(root)); + graph_add_edge(graph, GRAPH_START, recv_node); + graph_add_edge(graph, recv_node, GRAPH_END); + } +} + +void build_graph_tree(comp_t graph, post_send_x send_op, post_recv_x recv_op, + int root) +{ + [[maybe_unused]] int round = 0; + + int rank = get_rank_me(); + int nranks = get_rank_n(); + + // binomial tree algorithm + graph_node_t old_node = GRAPH_START; + + bool has_data = (rank == root); + int distance_left = + (rank + nranks - root) % nranks; // distance to the first rank on the + // left that has the data (can be 0) + int distance_right = (root - 1 - rank + nranks) % + nranks; // number of empty ranks on the right + int jump = std::ceil(nranks / 2.0); + while (true) { + if (has_data && jump <= distance_right) { + // send to the right + int rank_to_send = (rank + jump) % nranks; + LCI_DBG_Log(LOG_TRACE, "collective", + "broadcast (tree) round %d send to %d\n", round, + rank_to_send); + auto send_node = graph_add_node_op(graph, send_op.rank(rank_to_send)); + graph_add_edge(graph, old_node, send_node); + old_node = send_node; + } else if (distance_left == jump) { + // receive from the right + int rank_to_recv = (rank - jump + nranks) % nranks; + LCI_DBG_Log(LOG_TRACE, "collective", + "broadcast (tree) round %d recv from %d\n", round, + rank_to_recv); + auto recv_node = graph_add_node_op(graph, recv_op.rank(rank_to_recv)); + graph_add_edge(graph, old_node, recv_node); + old_node = recv_node; + has_data = true; + } + // The rank on your left (or yourself) sends the data to a rank right of it + // by `jump` distance. update the distances accordingly + if (distance_left >= jump) { + distance_left -= jump; + } else { + // distance_left < jump + distance_right = std::min(jump - distance_left - 1, distance_right); + } + // LCI_DBG_Log( + // LOG_TRACE, "collective", + // "broadcast %d round %d jump %d distance_left %d distance_right %d\n", + // seqnum, round, jump, distance_left, distance_right); + ++round; + if (jump == 1) { + break; + } else { + jump = std::ceil(jump / 2.0); + } + } + graph_add_edge(graph, old_node, GRAPH_END); +} + +void build_graph_ring(comp_t graph, post_send_x send_op, post_recv_x recv_op, + int root, void* buffer, size_t size, int nsteps) +{ + int rank = get_rank_me(); + int nranks = get_rank_n(); + + size_t step_size = (size + nsteps - 1) / nsteps; + int left = (rank - 1 + nranks) % nranks; + int right = (rank + 1) % nranks; + send_op = send_op.rank(right); + recv_op = recv_op.rank(left); + + graph_node_t old_node = GRAPH_START; + for (int step = 0; step <= nsteps; ++step) { + graph_node_t next_node; + + if ((step == 0 && rank == root) || (step == nsteps && right == root)) { + // they won't do anything in this step + if (step == nsteps) { + graph_add_edge(graph, old_node, GRAPH_END); + } + continue; + } + + if (step == nsteps) + next_node = GRAPH_END; + else + next_node = graph_add_node(graph, GRAPH_NODE_DUMMY_CB); + + if (step != nsteps && rank != root) { + // receive from the previous rank + int idx = step; + void* step_buffer = static_cast(buffer) + idx * step_size; + size_t actual_size = std::min(step_size, size - idx * step_size); + graph_node_t recv_node = graph_add_node_op( + graph, recv_op.local_buffer(step_buffer).size(actual_size)); + graph_add_edge(graph, old_node, recv_node); + graph_add_edge(graph, recv_node, next_node); + LCI_Log(LOG_TRACE, "collective", + "broadcast (ring) step %d recv from %d size %lu\n", step, left, + actual_size); + } + + if (step != 0 && right != root) { + // send to the next rank + int idx = step - 1; + void* step_buffer = static_cast(buffer) + idx * step_size; + size_t actual_size = std::min(step_size, size - idx * step_size); + graph_node_t send_node = graph_add_node_op( + graph, send_op.local_buffer(step_buffer).size(actual_size)); + graph_add_edge(graph, old_node, send_node); + graph_add_edge(graph, send_node, next_node); + LCI_Log(LOG_TRACE, "collective", + "broadcast (ring) step %d send to %d size %lu\n", step, left, + actual_size); + } + + old_node = next_node; + } +} +} // namespace + +void broadcast_x::call_impl(void* buffer, size_t size, int root, + runtime_t runtime, device_t device, + endpoint_t endpoint, + matching_engine_t matching_engine, comp_t comp, + broadcast_algorithm_t algorithm, + int ring_nsteps) const +{ + int seqnum = get_sequence_number(); + int nranks = get_rank_n(); + + LCI_DBG_Log(LOG_TRACE, "collective", + "enter broadcast %d (root %d buffer %p size %lu)\n", seqnum, root, + buffer, size); + if (nranks == 1) { + if (!comp.is_empty()) { + lci::comp_signal(comp, status_t(errorcode_t::done)); + } + return; + } + + comp_t graph = alloc_graph_x().runtime(runtime).comp(comp)(); + auto send_op = post_send_x(-1, buffer, size, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + auto recv_op = post_recv_x(-1, buffer, size, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + + if (algorithm == broadcast_algorithm_t::none) { + // auto select the best algorithm + if (size <= 65536 /* FIXME: magic number */) { + if (nranks <= 8) { + algorithm = broadcast_algorithm_t::direct; + } else { + algorithm = broadcast_algorithm_t::tree; + } + } else { + algorithm = broadcast_algorithm_t::ring; + } + } + + if (algorithm == broadcast_algorithm_t::direct) { + // direct algorithm + build_graph_direct(graph, send_op, recv_op, root); + } else if (algorithm == broadcast_algorithm_t::tree) { + // binomial tree algorithm + build_graph_tree(graph, send_op, recv_op, root); + } else if (algorithm == broadcast_algorithm_t::ring) { + build_graph_ring(graph, send_op, recv_op, root, buffer, size, ring_nsteps); + } else { + LCI_Assert(false, "Unsupported broadcast algorithm %d", + static_cast(algorithm)); + } + + graph_start(graph); + if (comp.is_empty()) { + // blocking wait + status_t status; + do { + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + status = graph_test(graph); + } while (status.is_retry()); + free_comp(&graph); + } + + LCI_DBG_Log(LOG_TRACE, "collective", + "leave broadcast %d (root %d buffer %p size %lu)\n", seqnum, root, + buffer, size); +} + +} // namespace lci \ No newline at end of file diff --git a/src/collective/collective.cpp b/src/collective/collective.cpp index 66ffb00c..9d0f64dd 100644 --- a/src/collective/collective.cpp +++ b/src/collective/collective.cpp @@ -5,441 +5,5 @@ namespace lci { -const int MAX_SEQUENCE_NUMBER = 65536; std::atomic g_sequence_number(0); - -void barrier_x::call_impl(runtime_t runtime, device_t device, - endpoint_t endpoint, - matching_engine_t matching_engine, - comp_semantic_t comp_semantic, comp_t comp) const -{ - int seqnum = g_sequence_number.fetch_add(1, std::memory_order_relaxed) % - MAX_SEQUENCE_NUMBER; - [[maybe_unused]] int round = 0; - int rank = get_rank_me(); - int nranks = get_rank_n(); - - // dissemination algorithm - LCI_DBG_Log(LOG_TRACE, "collective", "enter barrier %d\n", seqnum); - - if (comp.is_empty() || comp == COMP_NULL_EXPECT_DONE || - comp == COMP_NULL_EXPECT_DONE_OR_RETRY) { - if (nranks == 1) { - return; - } - // blocking barrier - for (int jump = 1; jump < nranks; jump *= 2) { - int rank_to_recv = (rank - jump + nranks) % nranks; - int rank_to_send = (rank + jump) % nranks; - LCI_DBG_Log(LOG_TRACE, "collective", - "barrier %d round %d recv from %d send to %d\n", seqnum, - round++, rank_to_recv, rank_to_send); - comp_t comp = alloc_sync_x().threshold(2).runtime(runtime)(); - post_recv_x(rank_to_recv, nullptr, 0, seqnum, comp) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_retry(false) - .allow_done(false)(); - auto post_send_op = post_send_x(rank_to_send, nullptr, 0, seqnum, comp) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_retry(false) - .allow_done(false); - if (jump * 2 >= nranks) { - // this is the last round, need to take care of the completion semantic - post_send_op = post_send_op.comp_semantic(comp_semantic); - } - post_send_op(); - while (!sync_test(comp, nullptr)) { - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } - free_comp(&comp); - } - LCI_DBG_Log(LOG_TRACE, "collective", "leave barrier %d\n", seqnum); - } else { - // nonblocking barrier - if (nranks == 1) { - lci::comp_signal(comp, status_t(errorcode_t::done)); - return; - } - comp_t graph = alloc_graph_x().runtime(runtime).comp(comp)(); - graph_node_t old_node = GRAPH_START; - graph_node_t dummy_node; - for (int jump = 1; jump < nranks; jump *= 2) { - int rank_to_recv = (rank - jump + nranks) % nranks; - int rank_to_send = (rank + jump) % nranks; - LCI_DBG_Log(LOG_TRACE, "collective", - "barrier %d round %d recv from %d send to %d\n", seqnum, - round++, rank_to_recv, rank_to_send); - - auto recv = post_recv_x(rank_to_recv, nullptr, 0, seqnum, graph) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_retry(false); - auto send = post_send_x(rank_to_send, nullptr, 0, seqnum, graph) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_retry(false); - if (jump * 2 >= nranks) { - // this is the last round, need to take care of the completion semantic - send = send.comp_semantic(comp_semantic); - dummy_node = GRAPH_END; - } else { - dummy_node = graph_add_node( - graph, [](void*) -> status_t { return errorcode_t::done; }); - } - auto recv_node = graph_add_node_op(graph, recv); - auto send_node = graph_add_node_op(graph, send); - graph_add_edge(graph, old_node, recv_node); - graph_add_edge(graph, old_node, send_node); - graph_add_edge(graph, recv_node, dummy_node); - graph_add_edge(graph, send_node, dummy_node); - old_node = dummy_node; - } - graph_start(graph); - } -} - -void broadcast_x::call_impl(void* buffer, size_t size, int root, - runtime_t runtime, device_t device, - endpoint_t endpoint, - matching_engine_t matching_engine) const -{ - int seqnum = g_sequence_number.fetch_add(1, std::memory_order_relaxed) % - MAX_SEQUENCE_NUMBER; - - [[maybe_unused]] int round = 0; - int rank = get_rank_me(); - int nranks = get_rank_n(); - - if (nranks == 1) { - return; - } - // binomial tree algorithm - LCI_DBG_Log(LOG_TRACE, "collective", - "enter broadcast %d (root %d buffer %p size %lu)\n", seqnum, root, - buffer, size); - bool has_data = (rank == root); - int distance_left = - (rank + nranks - root) % nranks; // distance to the first rank on the - // left that has the data (can be 0) - int distance_right = (root - 1 - rank + nranks) % - nranks; // number of empty ranks on the right - int jump = std::ceil(nranks / 2.0); - while (true) { - if (has_data && jump <= distance_right) { - // send to the right - int rank_to_send = (rank + jump) % nranks; - LCI_DBG_Log(LOG_TRACE, "collective", "broadcast %d round %d send to %d\n", - seqnum, round, rank_to_send); - post_send_x(rank_to_send, buffer, size, seqnum, COMP_NULL_EXPECT_DONE) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine)(); - } else if (distance_left == jump) { - // receive from the right - int rank_to_recv = (rank - jump + nranks) % nranks; - LCI_DBG_Log(LOG_TRACE, "collective", - "broadcast %d round %d recv from %d\n", seqnum, round, - rank_to_recv); - post_recv_x(rank_to_recv, buffer, size, seqnum, COMP_NULL_EXPECT_DONE) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine)(); - has_data = true; - } - // The rank on your left (or yourself) sends the data to a rank right of it - // by `jump` distance. update the distances accordingly - if (distance_left >= jump) { - distance_left -= jump; - } else { - // distance_left < jump - distance_right = std::min(jump - distance_left - 1, distance_right); - } - LCI_DBG_Log( - LOG_TRACE, "collective", - "broadcast %d round %d jump %d distance_left %d distance_right %d\n", - seqnum, round, jump, distance_left, distance_right); - ++round; - if (jump == 1) { - break; - } else { - jump = std::ceil(jump / 2.0); - } - } - LCI_DBG_Log(LOG_TRACE, "collective", - "leave broadcast %d (root %d buffer %p size %lu)\n", seqnum, root, - buffer, size); -} - -void allgather_x::call_impl(const void* sendbuf, void* recvbuf, size_t size, - runtime_t runtime, device_t device, - endpoint_t endpoint, - matching_engine_t matching_engine) const -{ - int seqnum = g_sequence_number.fetch_add(1, std::memory_order_relaxed) % - MAX_SEQUENCE_NUMBER; - - int rank = get_rank_me(); - int nranks = get_rank_n(); - - if (nranks == 1) { - memcpy(recvbuf, sendbuf, size); - return; - } - // alltoall algorithm - comp_t sync = alloc_sync_x().threshold(2 * nranks - 2).runtime(runtime)(); - LCI_DBG_Log(LOG_TRACE, "collective", - "enter allgather %d (sendbuf %p recvbuf %p size %lu)\n", seqnum, - sendbuf, recvbuf, size); - status_t status; - for (int i = 1; i < nranks; ++i) { - int peer_rank = (rank + i) % nranks; - do { - status = - post_recv_x(peer_rank, static_cast(recvbuf) + peer_rank * size, - size, seqnum, sync) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_done(false)(); - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } while (status.is_retry()); - } - for (int i = 1; i < nranks; ++i) { - int peer_rank = (rank + i) % nranks; - do { - status = - post_send_x(peer_rank, const_cast(sendbuf), size, seqnum, sync) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_done(false)(); - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } while (status.is_retry()); - } - memcpy(static_cast(recvbuf) + rank * size, sendbuf, size); - while (!sync_test(sync, nullptr)) { - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } - free_comp(&sync); - LCI_DBG_Log(LOG_TRACE, "collective", - "leave allgather %d (sendbuf %p recvbuf %p size %lu)\n", seqnum, - sendbuf, recvbuf, size); -} - -void reduce_x::call_impl(const void* sendbuf, void* recvbuf, size_t count, - size_t item_size, reduce_op_t op, int root, - runtime_t runtime, device_t device, - endpoint_t endpoint, - matching_engine_t matching_engine) const -{ - int seqnum = g_sequence_number.fetch_add(1, std::memory_order_relaxed) % - MAX_SEQUENCE_NUMBER; - - int round = 0; - int rank = get_rank_me(); - int nranks = get_rank_n(); - - if (nranks == 1) { - if (recvbuf != sendbuf) { - memcpy(recvbuf, sendbuf, item_size * count); - } - return; - } - // binomial tree algorithm - LCI_DBG_Log(LOG_TRACE, "collective", - "enter reduce %d (sendbuf %p recvbuf %p item_size %lu count %lu " - "root %d)\n", - seqnum, sendbuf, recvbuf, item_size, count, root); - std::vector> actions_per_round( - std::ceil(std::log2(nranks)), {-1, false}); - // First compute the binary tree from the root to the leaves - int nchildren = 0; - bool has_data = (rank == root); - int distance_left = - (rank + nranks - root) % nranks; // distance to the first rank on the - // left that has the data (can be 0) - int distance_right = (root - 1 - rank + nranks) % - nranks; // number of empty ranks on the right - int jump = std::ceil(nranks / 2.0); - while (true) { - if (has_data && jump <= distance_right) { - // send to the right - int rank_to_send = (rank + jump) % nranks; - actions_per_round[round] = {rank_to_send, true}; - ++nchildren; // if there is a send, then there is one more child - } else if (distance_left == jump) { - // receive from the right - int rank_to_recv = (rank - jump + nranks) % nranks; - actions_per_round[round] = {rank_to_recv, false}; - has_data = true; - } - // The rank on your left (or yourself) sends the data to a rank right of it - // by `jump` distance. update the distances accordingly - if (distance_left >= jump) { - distance_left -= jump; - } else { - // distance_left < jump - distance_right = std::min(jump - distance_left - 1, distance_right); - } - ++round; - if (jump == 1) { - break; - } else { - jump = std::ceil(jump / 2.0); - } - } - // Then replay the binary tree from the leaves to the root - // also reverse the message direction - bool to_free_tmp_buffer = false; - bool to_free_data_buffer = false; - void* tmp_buffer; // to receive data - void* data_buffer; // to hold intermediate result - if (rank == root) { - // for the root, we can always use the recvbuf to hold the intermediate - // result - data_buffer = recvbuf; - if (nchildren == 1 && sendbuf != recvbuf) { - // if there is only one child, we can (almost always) use the data buffer - // to receive, except for the case sendbuf = recvbuf = data_buffer - tmp_buffer = data_buffer; - } else { - tmp_buffer = malloc(item_size * count); - to_free_tmp_buffer = true; - } - } else { - // for non-root - if (nchildren == 1) { - // if there is only one child, tmp_buffer is the data_buffer - tmp_buffer = malloc(item_size * count); - data_buffer = tmp_buffer; - to_free_tmp_buffer = true; - } else { - tmp_buffer = malloc(item_size * count); - to_free_tmp_buffer = true; - data_buffer = malloc(item_size * count); - to_free_data_buffer = true; - } - } - bool has_received = false; - for (int i = round - 1; i >= 0; --i) { - int target_rank = actions_per_round[i].first; - if (target_rank < 0) { - continue; - } - bool is_send = !actions_per_round[i].second; // reverse the direction - if (is_send) { - LCI_DBG_Log(LOG_TRACE, "collective", "reduce %d round %d send to %d\n", - seqnum, i, target_rank); - void* buffer_to_send = const_cast(sendbuf); - if (has_received) { - buffer_to_send = data_buffer; - } - post_send_x(target_rank, buffer_to_send, item_size * count, seqnum, - COMP_NULL_EXPECT_DONE) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine)(); - break; - } else { - LCI_DBG_Log(LOG_TRACE, "collective", "reduce %d round %d recv from %d\n", - seqnum, i, target_rank); - post_recv_x(target_rank, tmp_buffer, item_size * count, seqnum, - COMP_NULL_EXPECT_DONE) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine)(); - // fprintf(stderr, "rank %d reduce %d round %d recv %lu from %d\n", - // rank, seqnum, i, *(uint64_t*)tmp_buffer, target_rank); - const void* right_buffer = data_buffer; - if (!has_received) { - has_received = true; - right_buffer = sendbuf; - // fprintf(stderr, "rank %d reduce %d round %d right=sendbuf %lu\n", - // rank, seqnum, i, *(uint64_t*)right_buffer); - // } else { - // fprintf(stderr, "rank %d reduce %d round %d right=data_buffer %lu\n", - // rank, seqnum, i, *(uint64_t*)right_buffer); - } - op(tmp_buffer, right_buffer, data_buffer, count); - // fprintf(stderr, "rank %d reduce %d round %d current data %lu\n", - // rank, seqnum, i, *(uint64_t*)data_buffer); - } - } - if (to_free_tmp_buffer) free(tmp_buffer); - if (to_free_data_buffer) free(data_buffer); - LCI_DBG_Log(LOG_TRACE, "collective", - "leave reduce %d (root %d buffer %p item_size %lu n %lu)\n", - seqnum, root, tmp_buffer, item_size, count); -} - -void alltoall_x::call_impl(const void* sendbuf, void* recvbuf, size_t size, - runtime_t runtime, device_t device, - endpoint_t endpoint, - matching_engine_t matching_engine) const -{ - int seqnum = g_sequence_number.fetch_add(1, std::memory_order_relaxed) % - MAX_SEQUENCE_NUMBER; - - int rank = get_rank_me(); - int nranks = get_rank_n(); - - LCI_DBG_Log(LOG_TRACE, "collective", - "enter alltoall %d (sendbuf %p recvbuf %p size %lu)\n", seqnum, - sendbuf, recvbuf, size); - - comp_t comp = alloc_sync_x().threshold(2 * nranks - 2).runtime(runtime)(); - - for (int i = 0; i < nranks; ++i) { - void* current_recvbuf = - static_cast(static_cast(recvbuf) + i * size); - void* current_sendbuf = static_cast( - static_cast(const_cast(sendbuf)) + i * size); - if (i == rank) { - memcpy(current_recvbuf, current_sendbuf, size); - continue; - } - status_t status; - do { - status = post_recv_x(i, current_recvbuf, size, seqnum, comp) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_done(false)(); - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } while (status.is_retry()); - do { - status = post_send_x(i, current_sendbuf, size, seqnum, comp) - .runtime(runtime) - .device(device) - .endpoint(endpoint) - .matching_engine(matching_engine) - .allow_done(false)(); - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } while (status.is_retry()); - } - - // sync_wait_x(comp, nullptr).runtime(runtime)(); - while (!sync_test_x(comp, nullptr).runtime(runtime)()) { - progress_x().runtime(runtime).device(device).endpoint(endpoint)(); - } - free_comp(&comp); -} - } // namespace lci \ No newline at end of file diff --git a/src/collective/collective.hpp b/src/collective/collective.hpp new file mode 100644 index 00000000..5ce89dd5 --- /dev/null +++ b/src/collective/collective.hpp @@ -0,0 +1,20 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#ifndef LCI_COLLECTIVE_HPP +#define LCI_COLLECTIVE_HPP + +namespace lci +{ +extern std::atomic g_sequence_number; + +static inline uint64_t get_sequence_number() +{ + const int MAX_SEQUENCE_NUMBER = 65536; + return g_sequence_number.fetch_add(1, std::memory_order_relaxed) % + MAX_SEQUENCE_NUMBER; +} + +} // namespace lci + +#endif // LCI_COLLECTIVE_HPP \ No newline at end of file diff --git a/src/collective/gather.cpp b/src/collective/gather.cpp new file mode 100644 index 00000000..0b4e56d4 --- /dev/null +++ b/src/collective/gather.cpp @@ -0,0 +1,65 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +void allgather_x::call_impl(const void* sendbuf, void* recvbuf, size_t size, + runtime_t runtime, device_t device, + endpoint_t endpoint, + matching_engine_t matching_engine) const +{ + int seqnum = get_sequence_number(); + + int rank = get_rank_me(); + int nranks = get_rank_n(); + + if (nranks == 1) { + memcpy(recvbuf, sendbuf, size); + return; + } + // alltoall algorithm + comp_t sync = alloc_sync_x().threshold(2 * nranks - 2).runtime(runtime)(); + LCI_DBG_Log(LOG_TRACE, "collective", + "enter allgather %d (sendbuf %p recvbuf %p size %lu)\n", seqnum, + sendbuf, recvbuf, size); + status_t status; + for (int i = 1; i < nranks; ++i) { + int peer_rank = (rank + i) % nranks; + do { + status = + post_recv_x(peer_rank, static_cast(recvbuf) + peer_rank * size, + size, seqnum, sync) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_done(false)(); + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + } while (status.is_retry()); + } + for (int i = 1; i < nranks; ++i) { + int peer_rank = (rank + i) % nranks; + do { + status = + post_send_x(peer_rank, const_cast(sendbuf), size, seqnum, sync) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_done(false)(); + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + } while (status.is_retry()); + } + memcpy(static_cast(recvbuf) + rank * size, sendbuf, size); + while (!sync_test(sync, nullptr)) { + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + } + free_comp(&sync); + LCI_DBG_Log(LOG_TRACE, "collective", + "leave allgather %d (sendbuf %p recvbuf %p size %lu)\n", seqnum, + sendbuf, recvbuf, size); +} + +} // namespace lci \ No newline at end of file diff --git a/src/collective/reduce.cpp b/src/collective/reduce.cpp new file mode 100644 index 00000000..1b8473e1 --- /dev/null +++ b/src/collective/reduce.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +void reduce_x::call_impl(const void* sendbuf, void* recvbuf, size_t count, + size_t item_size, reduce_op_t op, int root, + runtime_t runtime, device_t device, + endpoint_t endpoint, + matching_engine_t matching_engine) const +{ + int seqnum = get_sequence_number(); + + int round = 0; + int rank = get_rank_me(); + int nranks = get_rank_n(); + + if (nranks == 1) { + if (recvbuf != sendbuf) { + memcpy(recvbuf, sendbuf, item_size * count); + } + return; + } + // binomial tree algorithm + LCI_DBG_Log(LOG_TRACE, "collective", + "enter reduce %d (sendbuf %p recvbuf %p item_size %lu count %lu " + "root %d)\n", + seqnum, sendbuf, recvbuf, item_size, count, root); + std::vector> actions_per_round( + std::ceil(std::log2(nranks)), {-1, false}); + // First compute the binary tree from the root to the leaves + int nchildren = 0; + bool has_data = (rank == root); + int distance_left = + (rank + nranks - root) % nranks; // distance to the first rank on the + // left that has the data (can be 0) + int distance_right = (root - 1 - rank + nranks) % + nranks; // number of empty ranks on the right + int jump = std::ceil(nranks / 2.0); + while (true) { + if (has_data && jump <= distance_right) { + // send to the right + int rank_to_send = (rank + jump) % nranks; + actions_per_round[round] = {rank_to_send, true}; + ++nchildren; // if there is a send, then there is one more child + } else if (distance_left == jump) { + // receive from the right + int rank_to_recv = (rank - jump + nranks) % nranks; + actions_per_round[round] = {rank_to_recv, false}; + has_data = true; + } + // The rank on your left (or yourself) sends the data to a rank right of it + // by `jump` distance. update the distances accordingly + if (distance_left >= jump) { + distance_left -= jump; + } else { + // distance_left < jump + distance_right = std::min(jump - distance_left - 1, distance_right); + } + ++round; + if (jump == 1) { + break; + } else { + jump = std::ceil(jump / 2.0); + } + } + // Then replay the binary tree from the leaves to the root + // also reverse the message direction + bool to_free_tmp_buffer = false; + bool to_free_data_buffer = false; + void* tmp_buffer; // to receive data + void* data_buffer; // to hold intermediate result + if (rank == root) { + // for the root, we can always use the recvbuf to hold the intermediate + // result + data_buffer = recvbuf; + if (nchildren == 1 && sendbuf != recvbuf) { + // if there is only one child, we can (almost always) use the data buffer + // to receive, except for the case sendbuf = recvbuf = data_buffer + tmp_buffer = data_buffer; + } else { + tmp_buffer = malloc(item_size * count); + to_free_tmp_buffer = true; + } + } else { + // for non-root + if (nchildren == 1) { + // if there is only one child, tmp_buffer is the data_buffer + tmp_buffer = malloc(item_size * count); + data_buffer = tmp_buffer; + to_free_tmp_buffer = true; + } else { + tmp_buffer = malloc(item_size * count); + to_free_tmp_buffer = true; + data_buffer = malloc(item_size * count); + to_free_data_buffer = true; + } + } + bool has_received = false; + for (int i = round - 1; i >= 0; --i) { + int target_rank = actions_per_round[i].first; + if (target_rank < 0) { + continue; + } + bool is_send = !actions_per_round[i].second; // reverse the direction + if (is_send) { + LCI_DBG_Log(LOG_TRACE, "collective", "reduce %d round %d send to %d\n", + seqnum, i, target_rank); + void* buffer_to_send = const_cast(sendbuf); + if (has_received) { + buffer_to_send = data_buffer; + } + post_send_x(target_rank, buffer_to_send, item_size * count, seqnum, + COMP_NULL) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine)(); + break; + } else { + LCI_DBG_Log(LOG_TRACE, "collective", "reduce %d round %d recv from %d\n", + seqnum, i, target_rank); + post_recv_x(target_rank, tmp_buffer, item_size * count, seqnum, COMP_NULL) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine)(); + // fprintf(stderr, "rank %d reduce %d round %d recv %lu from %d\n", + // rank, seqnum, i, *(uint64_t*)tmp_buffer, target_rank); + const void* right_buffer = data_buffer; + if (!has_received) { + has_received = true; + right_buffer = sendbuf; + // fprintf(stderr, "rank %d reduce %d round %d right=sendbuf %lu\n", + // rank, seqnum, i, *(uint64_t*)right_buffer); + // } else { + // fprintf(stderr, "rank %d reduce %d round %d right=data_buffer %lu\n", + // rank, seqnum, i, *(uint64_t*)right_buffer); + } + op(tmp_buffer, right_buffer, data_buffer, count); + // fprintf(stderr, "rank %d reduce %d round %d current data %lu\n", + // rank, seqnum, i, *(uint64_t*)data_buffer); + } + } + if (to_free_tmp_buffer) free(tmp_buffer); + if (to_free_data_buffer) free(data_buffer); + LCI_DBG_Log(LOG_TRACE, "collective", + "leave reduce %d (root %d buffer %p item_size %lu n %lu)\n", + seqnum, root, tmp_buffer, item_size, count); +} + +} // namespace lci \ No newline at end of file diff --git a/src/collective/reduce_scatter.cpp b/src/collective/reduce_scatter.cpp new file mode 100644 index 00000000..f506d3fd --- /dev/null +++ b/src/collective/reduce_scatter.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2025 The LCI Project Authors +// SPDX-License-Identifier: MIT + +#include "lci_internal.hpp" + +namespace lci +{ +namespace +{ +struct reduce_wrapper_args_t { + void* tmp_buffer; + void* recvbuf; + size_t count; + reduce_op_t op; +}; + +status_t reduce_wrapper_fn(void* args) +{ + auto* wrapper_args = static_cast(args); + wrapper_args->op(wrapper_args->tmp_buffer, wrapper_args->recvbuf, + wrapper_args->recvbuf, wrapper_args->count); + delete wrapper_args; + return errorcode_t::done; +} + +void build_graph_direct(comp_t graph, post_send_x send_op, post_recv_x recv_op, + const void* sendbuf, void* recvbuf, size_t recvcount, + size_t item_size, reduce_op_t op) +{ + int rank = get_rank_me(); + int nranks = get_rank_n(); + + char* sendbuf_c = const_cast(static_cast(sendbuf)); + char* recvbuf_c = static_cast(recvbuf); + + bool in_place = (recvbuf_c == sendbuf_c + rank * recvcount * item_size); + void* tmp_buffer = malloc(recvcount * item_size * (nranks - 1)); + void* p = tmp_buffer; + if (!in_place) { + // copy the send buffer to the receive buffer + memcpy(recvbuf, sendbuf_c + rank * recvcount * item_size, + recvcount * item_size); + } + + graph_node_t prev_reduce_node = nullptr; + graph_node_t free_node = + graph_add_node_x(graph, [](void* tmp_buffer) -> status_t { + free(tmp_buffer); + return errorcode_t::done; + }).value(tmp_buffer)(); + for (int i = 1; i < nranks; ++i) { + int target = (i + rank) % nranks; + graph_node_t send_node = graph_add_node_op( + graph, send_op.rank(target).local_buffer( + sendbuf_c + target * recvcount * item_size)); + graph_add_edge(graph, GRAPH_START, send_node); + graph_add_edge(graph, send_node, GRAPH_END); + graph_node_t recv_node = + graph_add_node_op(graph, recv_op.rank(target).local_buffer(p)); + graph_node_t reduce_node = + graph_add_node_x(graph, reduce_wrapper_fn) + .value(new reduce_wrapper_args_t{p, recvbuf, recvcount, op})(); + graph_add_edge(graph, GRAPH_START, recv_node); + graph_add_edge(graph, recv_node, reduce_node); + if (prev_reduce_node) { + graph_add_edge(graph, prev_reduce_node, reduce_node); + } + prev_reduce_node = reduce_node; + p = static_cast(p) + recvcount * item_size; + } + graph_add_edge(graph, prev_reduce_node, free_node); + graph_add_edge(graph, free_node, GRAPH_END); +} +} // namespace + +void reduce_scatter_x::call_impl(const void* sendbuf, void* recvbuf, + size_t recvcount, size_t item_size, + reduce_op_t op, runtime_t runtime, + device_t device, endpoint_t endpoint, + matching_engine_t matching_engine, comp_t comp, + reduce_scatter_algorithm_t algorithm, + [[maybe_unused]] int ring_nsteps) const +{ + int seqnum = get_sequence_number(); + int nranks = get_rank_n(); + + LCI_DBG_Log(LOG_TRACE, "collective", + "enter reduce_scatter %d (sendbuf %p recvbuf %p item_size %lu " + "recvcount %lu)\n", + seqnum, sendbuf, recvbuf, item_size, recvcount); + if (nranks == 1) { + if (recvbuf != sendbuf) { + memcpy(recvbuf, sendbuf, item_size * recvcount); + } + if (!comp.is_empty()) { + lci::comp_signal(comp, status_t(errorcode_t::done)); + } + return; + } + + comp_t graph = alloc_graph_x().runtime(runtime).comp(comp)(); + auto send_op = post_send_x(-1, const_cast(sendbuf), + recvcount * item_size, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + auto recv_op = post_recv_x(-1, recvbuf, recvcount * item_size, seqnum, graph) + .runtime(runtime) + .device(device) + .endpoint(endpoint) + .matching_engine(matching_engine) + .allow_retry(false); + + if (algorithm == reduce_scatter_algorithm_t::none) { + // auto select the best algorithm + algorithm = reduce_scatter_algorithm_t::direct; + } + + if (algorithm == reduce_scatter_algorithm_t::direct) { + // direct algorithm + build_graph_direct(graph, send_op, recv_op, sendbuf, recvbuf, recvcount, + item_size, op); + } else { + LCI_Assert(false, "Unsupported broadcast algorithm %d", + static_cast(algorithm)); + } + + graph_start(graph); + if (comp.is_empty()) { + // blocking wait + status_t status; + do { + progress_x().runtime(runtime).device(device).endpoint(endpoint)(); + status = graph_test(graph); + } while (status.is_retry()); + free_comp(&graph); + } + + LCI_DBG_Log(LOG_TRACE, "collective", + "leave reduce_scatter %d (sendbuf %p recvbuf %p item_size %lu " + "count %lu)\n", + seqnum, sendbuf, recvbuf, item_size, recvcount); +} + +} // namespace lci \ No newline at end of file diff --git a/src/comp/graph.hpp b/src/comp/graph.hpp index 19f36963..c3ed2781 100644 --- a/src/comp/graph.hpp +++ b/src/comp/graph.hpp @@ -47,6 +47,15 @@ class graph_t : public comp_impl_t void* value; graph_node_free_cb_t free_cb; std::vector> out_edges; + node_impl_t() + : signals_received(0), + signals_expected(0), + graph(nullptr), + fn(nullptr), + value(nullptr), + free_cb(nullptr) + { + } node_impl_t(graph_t* graph_, graph_node_run_cb_t fn_, void* value_, graph_node_free_cb_t free_cb_) : signals_received(0), @@ -67,7 +76,8 @@ class graph_t : public comp_impl_t comp_t m_comp; // the graph structure - std::vector m_start_nodes; + node_impl_t m_start_node; + // std::vector m_start_nodes; std::vector nodes; int m_end_signals_expected; void* m_end_value; @@ -82,6 +92,8 @@ inline void graph_t::trigger_node(graph_node_t node_) status_t status; if (node->fn) { status = node->fn(node->value); + } else { + status.set_done(); } LCI_DBG_Log(LOG_TRACE, "graph", "graph %p trigger node %p, status %s\n", node->graph, node_, status.error.get_str()); @@ -174,11 +186,12 @@ inline void graph_t::add_edge(graph_node_t src_, graph_node_t dst_, LCI_Assert(dst_ != GRAPH_START, "The destination node should not be the start node"); LCI_Assert(src_ != GRAPH_END, "The source node should not be the end node"); + node_impl_t* src; if (src_ == GRAPH_START) { - m_start_nodes.push_back(dst_); - return; + src = &m_start_node; + } else { + src = reinterpret_cast(src_); } - auto src = reinterpret_cast(src_); auto dst = reinterpret_cast(dst_); src->out_edges.push_back({dst_, fn}); if (dst_ == GRAPH_END) { @@ -192,13 +205,8 @@ inline void graph_t::start() { LCI_DBG_Log(LOG_TRACE, "graph", "graph %p start\n", this); m_end_signals_received = 0; - for (auto node_ : m_start_nodes) { - auto node = reinterpret_cast(node_); - LCI_Assert(node->signals_expected == 0, - "Start node should not have any incoming edges (signals: %d)", - node->signals_expected); - trigger_node(node_); - } + mark_complete(reinterpret_cast(&m_start_node), + status_t(errorcode_t::done)); } inline void graph_t::signal(status_t status) diff --git a/src/core/communicate.cpp b/src/core/communicate.cpp index 89787ac0..f44ce4fd 100644 --- a/src/core/communicate.cpp +++ b/src/core/communicate.cpp @@ -54,10 +54,10 @@ status_t post_comm_x::call_impl( !(direction == direction_t::IN && !local_buffer_only && !local_comp_only), "get with signal has not been implemented yet\n"); - if (local_comp == COMP_NULL_EXPECT_DONE) { + if (local_comp == COMP_NULL) { allow_retry = false; allow_posted = false; - } else if (local_comp == COMP_NULL_EXPECT_DONE_OR_RETRY) { + } else if (local_comp == COMP_NULL_RETRY) { allow_posted = false; } @@ -240,10 +240,10 @@ status_t post_comm_x::call_impl( comp_semantic == comp_semantic_t::network || direction == direction_t::IN) { // process COMP_BLOCK - if (local_comp == COMP_NULL_EXPECT_DONE) { + if (local_comp == COMP_NULL) { local_comp = alloc_sync(); free_local_comp = true; - } else if (local_comp == COMP_NULL_EXPECT_DONE_OR_RETRY) { + } else if (local_comp == COMP_NULL_RETRY) { local_comp = alloc_sync(); free_local_comp = true; } diff --git a/src/core/lci.cpp b/src/core/lci.cpp index 3e7cedbc..dade1832 100644 --- a/src/core/lci.cpp +++ b/src/core/lci.cpp @@ -25,4 +25,23 @@ const char* get_net_opcode_str(net_opcode_t opcode) return opcode_str[static_cast(opcode)]; } +const char* get_broadcast_algorithm_str(broadcast_algorithm_t algorithm) +{ + static const char algorithm_str[][8] = {"none", "direct", "tree", "ring"}; + return algorithm_str[static_cast(algorithm)]; +} + +const char* get_reduce_scatter_algorithm_str( + reduce_scatter_algorithm_t algorithm) +{ + static const char algorithm_str[][8] = {"none", "direct", "tree", "ring"}; + return algorithm_str[static_cast(algorithm)]; +} + +const char* get_allreduce_algorithm_str(allreduce_algorithm_t algorithm) +{ + static const char algorithm_str[][8] = {"none", "direct", "tree", "ring"}; + return algorithm_str[static_cast(algorithm)]; +} + } // namespace lci \ No newline at end of file diff --git a/src/core/protocol.hpp b/src/core/protocol.hpp index ce97292f..96ec8203 100644 --- a/src/core/protocol.hpp +++ b/src/core/protocol.hpp @@ -50,7 +50,7 @@ struct alignas(LCI_CACHE_LINE) internal_context_t { rdv_type(rdv_type_t::single), rank(-1), tag(0), - comp(comp_t()), + comp(COMP_NULL), user_context(nullptr) { } diff --git a/src/lci_internal.hpp b/src/lci_internal.hpp index c1f4df3c..6a0e3c19 100644 --- a/src/lci_internal.hpp +++ b/src/lci_internal.hpp @@ -48,6 +48,7 @@ #include "packet_pool/packet_pool.hpp" #include "runtime/runtime.hpp" #include "core/rendezvous.hpp" +#include "collective/collective.hpp" #ifdef LCI_USE_CUDA #include "accelerator/accelerator.hpp" #endif diff --git a/tests/unit/collective/all.cpp b/tests/unit/collective/all.cpp index 538f9692..913f90e0 100644 --- a/tests/unit/collective/all.cpp +++ b/tests/unit/collective/all.cpp @@ -16,13 +16,22 @@ TEST(COMM_COLL, broadcast) int rank = lci::get_rank_me(); int nranks = lci::get_rank_n(); - for (int root = 0; root < nranks; ++root) { - uint64_t data = 0; - if (rank == root) { - data = 0xdeadbeef; + lci::broadcast_algorithm_t algorithms[] = { + lci::broadcast_algorithm_t::direct, + lci::broadcast_algorithm_t::tree, + lci::broadcast_algorithm_t::ring, + }; + for (auto algorithm : algorithms) { + fprintf(stderr, "Testing broadcast with algorithm %s\n", + lci::get_broadcast_algorithm_str(algorithm)); + for (int root = 0; root < nranks; ++root) { + uint64_t data = 0; + if (rank == root) { + data = 0xdeadbeef; + } + lci::broadcast_x(&data, sizeof(data), root).algorithm(algorithm)(); + ASSERT_EQ(data, 0xdeadbeef); } - lci::broadcast(&data, sizeof(data), root); - ASSERT_EQ(data, 0xdeadbeef); } lci::g_runtime_fina(); } @@ -56,7 +65,7 @@ void reduce_op(const void* left, const void* right, void* dst, size_t n) } } -TEST(COMM_COLL, reduce_in_place) +TEST(COMM_COLL, reduce) { lci::g_runtime_init(); @@ -65,6 +74,15 @@ TEST(COMM_COLL, reduce_in_place) for (int root = 0; root < nranks; ++root) { uint64_t data = rank; + uint64_t result = -1; + // Check non-in-place reduction + lci::reduce(&data, &result, 1, sizeof(data), reduce_op, root); + if (rank == root) { + ASSERT_EQ(result, (nranks - 1) * nranks / 2); + } else { + ASSERT_EQ(data, rank); + } + // Check in-place reduction lci::reduce(&data, &data, 1, sizeof(data), reduce_op, root); if (rank == root) { ASSERT_EQ(data, (nranks - 1) * nranks / 2); @@ -75,19 +93,55 @@ TEST(COMM_COLL, reduce_in_place) lci::g_runtime_fina(); } -TEST(COMM_COLL, reduce) +TEST(COMM_COLL, reduce_scatter) { lci::g_runtime_init(); int rank = lci::get_rank_me(); int nranks = lci::get_rank_n(); - for (int root = 0; root < nranks; ++root) { + lci::reduce_scatter_algorithm_t algorithms[] = { + lci::reduce_scatter_algorithm_t::direct, + }; + for (auto algorithm : algorithms) { + std::vector data(nranks, rank); + for (int i = 0; i < nranks; ++i) { + data[i] += i; + } + uint64_t result = -1; + // Check non-in-place reduction + lci::reduce_scatter_x(&data[0], &result, 1, sizeof(uint64_t), reduce_op) + .algorithm(algorithm)(); + ASSERT_EQ(result, (nranks - 1) * nranks / 2 + rank * nranks); + // Check in-place reduction + lci::reduce_scatter_x(&data[0], &data[rank], 1, sizeof(uint64_t), reduce_op) + .algorithm(algorithm)(); + ASSERT_EQ(data[rank], (nranks - 1) * nranks / 2 + rank * nranks); + } + lci::g_runtime_fina(); +} + +TEST(COMM_COLL, allreduce) +{ + lci::g_runtime_init(); + + int rank = lci::get_rank_me(); + int nranks = lci::get_rank_n(); + + lci::allreduce_algorithm_t algorithms[] = { + lci::allreduce_algorithm_t::direct, + }; + for (auto algorithm : algorithms) { uint64_t data = rank; uint64_t result = -1; - lci::reduce(&data, &result, 1, sizeof(data), reduce_op, root); - if (rank == root) ASSERT_EQ(result, (nranks - 1) * nranks / 2); - ASSERT_EQ(data, rank); + // Check non-in-place reduction + lci::allreduce_x(&data, &result, 1, sizeof(data), reduce_op) + .algorithm(algorithm)(); + ASSERT_EQ(result, (nranks - 1) * nranks / 2); + // Check in-place reduction + lci::allreduce_x(&data, &data, 1, sizeof(data), reduce_op) + .algorithm(algorithm)(); + ASSERT_EQ(data, (nranks - 1) * nranks / 2); } lci::g_runtime_fina(); } diff --git a/tests/unit/loopback/test_matching_policy.hpp b/tests/unit/loopback/test_matching_policy.hpp index fb075c04..61134565 100644 --- a/tests/unit/loopback/test_matching_policy.hpp +++ b/tests/unit/loopback/test_matching_policy.hpp @@ -20,13 +20,13 @@ TEST(MATCHING_POLICY, test_rank_tag) uint64_t data = 0xdeadbeef; for (int i = 0; i < n; ++i) { - lci::status_t status = lci::post_send(0, &data, sizeof(data), in[i], - lci::COMP_NULL_EXPECT_DONE); + lci::status_t status = + lci::post_send(0, &data, sizeof(data), in[i], lci::COMP_NULL); ASSERT_EQ(status.is_done(), true); } for (int i = 0; i < n; ++i) { - lci::status_t status = lci::post_recv(0, &data, sizeof(data), in[i], - lci::COMP_NULL_EXPECT_DONE); + lci::status_t status = + lci::post_recv(0, &data, sizeof(data), in[i], lci::COMP_NULL); ASSERT_EQ(status.is_done(), true); ASSERT_EQ(status.tag, in[i]); } @@ -50,16 +50,15 @@ TEST(MATCHING_POLICY, test_rank_only) uint64_t data = 0xdeadbeef; for (int i = 0; i < n; ++i) { lci::status_t status = - lci::post_send_x(0, &data, sizeof(data), in[i], - lci::COMP_NULL_EXPECT_DONE) + lci::post_send_x(0, &data, sizeof(data), in[i], lci::COMP_NULL) .matching_policy(lci::matching_policy_t::rank_only)(); ASSERT_EQ(status.is_done(), true); } bool flags[n]; memset(flags, false, sizeof(flags)); for (int i = 0; i < n; ++i) { - lci::status_t status = lci::post_recv(0, &data, sizeof(data), lci::ANY_TAG, - lci::COMP_NULL_EXPECT_DONE); + lci::status_t status = + lci::post_recv(0, &data, sizeof(data), lci::ANY_TAG, lci::COMP_NULL); ASSERT_EQ(status.is_done(), true); int idx = status.tag; ASSERT_EQ(idx >= 0 && idx < n, true); @@ -88,17 +87,16 @@ TEST(MATCHING_POLICY, test_none) uint64_t data = 0xdeadbeef; for (int i = 0; i < n; ++i) { - lci::status_t status = lci::post_send_x(0, &data, sizeof(data), in[i], - lci::COMP_NULL_EXPECT_DONE) - .matching_policy(lci::matching_policy_t::none)(); + lci::status_t status = + lci::post_send_x(0, &data, sizeof(data), in[i], lci::COMP_NULL) + .matching_policy(lci::matching_policy_t::none)(); ASSERT_EQ(status.is_done(), true); } bool flags[n]; memset(flags, false, sizeof(flags)); for (int i = 0; i < n; ++i) { - lci::status_t status = - lci::post_recv(lci::ANY_SOURCE, &data, sizeof(data), lci::ANY_TAG, - lci::COMP_NULL_EXPECT_DONE); + lci::status_t status = lci::post_recv(lci::ANY_SOURCE, &data, sizeof(data), + lci::ANY_TAG, lci::COMP_NULL); ASSERT_EQ(status.is_done(), true); ASSERT_EQ(status.rank, 0); int idx = status.tag;