Skip to content

improve collective communication #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/distributed_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class distributed_array_t
lci::status_t status;
T value;
do {
status = lci::post_get(target_rank, &value, sizeof(T),
lci::COMP_NULL_EXPECT_DONE_OR_RETRY,
local_index * sizeof(T), m_rmrs[target_rank]);
status =
lci::post_get(target_rank, &value, sizeof(T), lci::COMP_NULL_RETRY,
local_index * sizeof(T), m_rmrs[target_rank]);
lci::progress();
} while (status.is_retry());
assert(status.is_done());
Expand All @@ -70,7 +70,7 @@ class distributed_array_t
do {
status = lci::post_put_x(target_rank,
static_cast<void*>(const_cast<int*>(&value)),
sizeof(T), lci::COMP_NULL_EXPECT_DONE_OR_RETRY,
sizeof(T), lci::COMP_NULL_RETRY,
local_index * sizeof(T), m_rmrs[target_rank])
.comp_semantic(lci::comp_semantic_t::network)();
lci::progress();
Expand Down
6 changes: 2 additions & 4 deletions examples/pingpong_am_mt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ void worker(int thread_id)
// sender
for (int i = 0; i < nmsgs; i++) {
// send a message
lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL_EXPECT_DONE,
rcomp)
lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL, rcomp)
.device(device)
.tag(thread_id)();
// wait for an incoming message
Expand Down Expand Up @@ -85,8 +84,7 @@ void worker(int thread_id)
}
free(recv_buf.base);
// send a message
lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL_EXPECT_DONE,
rcomp)
lci::post_am_x(peer_rank, send_buf, msg_size, lci::COMP_NULL, rcomp)
.device(device)
.tag(thread_id)();
}
Expand Down
9 changes: 8 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ target_sources_relative(
runtime/runtime.cpp
core/communicate.cpp
core/progress.cpp
collective/collective.cpp)
collective/collective.cpp
collective/alltoall.cpp
collective/barrier.cpp
collective/broadcast.cpp
collective/gather.cpp
collective/reduce_scatter.cpp
collective/allreduce.cpp
collective/reduce.cpp)

if(LCI_BACKEND_ENABLE_OFI)
target_sources_relative(LCI PRIVATE network/ofi/backend_ofi.cpp)
Expand Down
83 changes: 81 additions & 2 deletions src/api/lci.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,60 @@ enum class net_opcode_t {
*/
const char* get_net_opcode_str(net_opcode_t opcode);

/**
* @ingroup LCI_BASIC
* @brief The type of broadcast algorithm.
*/
enum class broadcast_algorithm_t {
none, /**< automatically select the best algorithm */
direct, /**< direct algorithm */
tree, /**< binomial tree algorithm */
ring, /**< ring algorithm */
};

/**
* @brief Get the string representation of a collective algorithm.
* @param opcode The collective algorithm.
* @return The string representation of the collective algorithm.
*/
const char* get_broadcast_algorithm_str(broadcast_algorithm_t algorithm);

/**
* @ingroup LCI_BASIC
* @brief The type of reduce scatter algorithm.
*/
enum class reduce_scatter_algorithm_t {
none, /**< automatically select the best algorithm */
direct, /**< direct algorithm */
tree, /**< reduce followed by broadcast */
ring, /**< ring algorithm */
};

/**
* @brief Get the string representation of a collective algorithm.
* @param opcode The collective algorithm.
* @return The string representation of the collective algorithm.
*/
const char* get_reduce_scatter_algorithm_str(broadcast_algorithm_t algorithm);

/**
* @ingroup LCI_BASIC
* @brief The type of allreduce algorithm.
*/
enum class allreduce_algorithm_t {
none, /**< automatically select the best algorithm */
direct, /**< direct algorithm */
tree, /**< reduce followed by broadcast */
ring, /**< ring algorithm */
};

/**
* @brief Get the string representation of a collective algorithm.
* @param opcode The collective algorithm.
* @return The string representation of the collective algorithm.
*/
const char* get_allreduce_algorithm_str(broadcast_algorithm_t algorithm);

/**
* @ingroup LCI_BASIC
* @brief The type of network-layer immediate data field.
Expand Down Expand Up @@ -581,16 +635,33 @@ struct status_t {
* @ingroup LCI_BASIC
* @brief Special completion object setting `allow_posted` to false.
*/
const comp_t COMP_NULL = comp_t(reinterpret_cast<comp_impl_t*>(0x0));

/**
* @ingroup LCI_BASIC
* @brief Deprecated. Same as COMP_NULL.
*/
const comp_t COMP_NULL_EXPECT_DONE =
comp_t(reinterpret_cast<comp_impl_t*>(0x1));
comp_t(reinterpret_cast<comp_impl_t*>(0x0));

/**
* @ingroup LCI_BASIC
* @brief Special completion object setting `allow_posted` and `allow_retry` to
* false.
*/
const comp_t COMP_NULL_RETRY = comp_t(reinterpret_cast<comp_impl_t*>(0x1));

/**
* @ingroup LCI_BASIC
* @brief Deprecated. Same as COMP_NULL_RETRY.
*/
const comp_t COMP_NULL_EXPECT_DONE_OR_RETRY =
comp_t(reinterpret_cast<comp_impl_t*>(0x2));
comp_t(reinterpret_cast<comp_impl_t*>(0x1));

inline bool comp_t::is_empty() const
{
return reinterpret_cast<uintptr_t>(p_impl) <= 1;
}

/**
* @ingroup LCI_BASIC
Expand Down Expand Up @@ -641,6 +712,14 @@ const graph_node_t GRAPH_END = reinterpret_cast<graph_node_t>(0x2);
*/
using graph_node_run_cb_t = status_t (*)(void* value);

/**
* @ingroup LCI_BASIC
* @brief A dummy callback function for a graph node.
* @details This function can be used as a placeholder for a graph node that
* does not perform any operation.
*/
const graph_node_run_cb_t GRAPH_NODE_DUMMY_CB = nullptr;

/**
* @ingroup LCI_BASIC
* @brief The function signature for a callback that will be triggered when the
Expand Down
51 changes: 49 additions & 2 deletions src/binding/input/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."),
optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."),
optional_arg("comp_semantic_t", "comp_semantic", "comp_semantic_t::buffer", comment="The completion semantic."),
optional_arg("comp_t", "comp", "comp_t()", comment="The completion to signal when the operation completes."),
optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."),
],
doc = {
"in_group": "LCI_COLL",
Expand All @@ -31,10 +31,13 @@
optional_arg("device_t", "device", "runtime.get_impl()->default_device", comment="The device to use."),
optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."),
optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."),
optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."),
optional_arg("broadcast_algorithm_t", "algorithm", "broadcast_algorithm_t::none", comment="The collective algorithm to use."),
optional_arg("int", "ring_nsteps", "get_rank_n() - 1", comment="The number of steps in the ring algorithm."),
],
doc = {
"in_group": "LCI_COLL",
"brief": "A blocking broadcast operation.",
"brief": "A broadcast operation.",
}
),
operation(
Expand All @@ -56,6 +59,50 @@
"brief": "A blocking reduce operation.",
}
),
operation(
"reduce_scatter",
[
optional_runtime_args,
positional_arg("const void*", "sendbuf", comment="The local buffer base address to send."),
positional_arg("void*", "recvbuf", comment="The local buffer base address to recv."),
positional_arg("size_t", "recvcount", comment="The number of data items to receive one each rank."),
positional_arg("size_t", "item_size", comment="The size of each data item."),
positional_arg("reduce_op_t", "op", comment="The reduction operation."),
optional_arg("device_t", "device", "runtime.get_impl()->default_device", comment="The device to use."),
optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."),
optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."),
optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."),
optional_arg("reduce_scatter_algorithm_t", "algorithm", "reduce_scatter_algorithm_t::none", comment="The collective algorithm to use."),
optional_arg("int", "ring_nsteps", "get_rank_n() - 1", comment="The number of steps in the ring algorithm."),
],
doc = {
"in_group": "LCI_COLL",
"brief": "A reduce scatter operation.",
"details": "This operation assumes the send count is equal to `recvcount * item_size` and "
"`sendbuf` is of size at least `recvcount * item_size * get_rank_n()`.",
}
),
operation(
"allreduce",
[
optional_runtime_args,
positional_arg("const void*", "sendbuf", comment="The local buffer base address to send."),
positional_arg("void*", "recvbuf", comment="The local buffer base address to recv."),
positional_arg("size_t", "count", comment="The number of data items in the buffer."),
positional_arg("size_t", "item_size", comment="The size of each data item."),
positional_arg("reduce_op_t", "op", comment="The reduction operation."),
optional_arg("device_t", "device", "runtime.get_impl()->default_device", comment="The device to use."),
optional_arg("endpoint_t", "endpoint", "device.get_impl()->default_endpoint", comment="The endpoint to use."),
optional_arg("matching_engine_t", "matching_engine", "runtime.get_impl()->default_coll_matching_engine", comment="The matching engine to use."),
optional_arg("comp_t", "comp", "COMP_NULL", comment="The completion to signal when the operation completes."),
optional_arg("allreduce_algorithm_t", "algorithm", "allreduce_algorithm_t::none", comment="The collective algorithm to use."),
optional_arg("int", "ring_nsteps", "get_rank_n() - 1", comment="The number of steps in the ring algorithm."),
],
doc = {
"in_group": "LCI_COLL",
"brief": "An allreduce operation.",
}
),
operation(
"allgather",
[
Expand Down
5 changes: 3 additions & 2 deletions src/binding/input/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
attr_enum("cq_type", enum_options=["array_atomic", "lcrq"], default_value="lcrq", comment="The completion object type."),
attr("int", "cq_default_length", default_value=65536, comment="The default length of the completion queue."),
],
custom_is_empty_method=True,
doc = {
"in_group": "LCI_COMPLETION",
"brief": "The completion object resource.",
Expand Down Expand Up @@ -160,7 +161,7 @@
operation(
"alloc_graph",
[
optional_arg("comp_t", "comp", "comp_t()", comment="Another completion object to signal when the graph is completed. The graph will be automatically destroyed afterwards."),
optional_arg("comp_t", "comp", "COMP_NULL", comment="Another completion object to signal when the graph is completed. The graph will be automatically destroyed afterwards."),
optional_arg("void*", "user_context", "nullptr", comment="The arbitrary user-defined context associated with this completion object."),
optional_runtime_args,
return_val("comp_t", "comp", comment="The allocated completion handler."),
Expand Down Expand Up @@ -234,7 +235,7 @@
"brief": "Test a graph.",
"details": "Successful test will reset the graph to the state that is ready to be started again.",
}
)
),
]

def get_input():
Expand Down
Loading
Loading