Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more communication kernels #221

Merged
merged 52 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d8b3bc2
WIP
Binyang2014 Jul 3, 2024
c6b3583
WIP
Binyang2014 Jul 3, 2024
a3edacb
WIP
Binyang2014 Jul 4, 2024
1281769
Merge branch 'main' into binyli/comm
Binyang2014 Jul 4, 2024
16c67a4
WIP
Binyang2014 Jul 4, 2024
2143bf6
WIP
Binyang2014 Jul 5, 2024
4914c31
WIP
Binyang2014 Jul 5, 2024
dc72b38
packet work
Binyang2014 Jul 5, 2024
d1d109a
WIP
Binyang2014 Jul 5, 2024
496ebbc
done for today
Binyang2014 Jul 5, 2024
b23a7b1
WIP
Binyang2014 Jul 7, 2024
ced9b15
compile pass
Binyang2014 Jul 8, 2024
0ac0d3a
recv_reduce_write packet works
Binyang2014 Jul 8, 2024
5f36018
WIP
Binyang2014 Jul 10, 2024
06479bb
WIP
Binyang2014 Jul 10, 2024
b3e85bd
WIP
Binyang2014 Jul 11, 2024
cb4693b
add more tests
Binyang2014 Jul 12, 2024
1c3245e
WIP
Binyang2014 Jul 13, 2024
55894c0
Merge branch 'main' into binyli/comm
Binyang2014 Jul 13, 2024
6d72055
enable test
Binyang2014 Jul 13, 2024
4dac64b
Merge branch 'binyli/comm' of https://github.com/microsoft/ark into b…
Binyang2014 Jul 13, 2024
bb70b4b
add allreduce test
Binyang2014 Jul 15, 2024
bf37570
Fix UT
Binyang2014 Jul 15, 2024
745973e
For rocm
Binyang2014 Jul 15, 2024
6c693f1
Add sm algo
Binyang2014 Jul 15, 2024
e75901d
fix
Binyang2014 Jul 16, 2024
d87f63b
Merge branch 'main' into binyli/comm
Binyang2014 Jul 16, 2024
a589691
Expose C++ exceptions to the Python module
chhwang Aug 2, 2024
cb3af8f
Refine error types
chhwang Aug 3, 2024
bae5ec1
unit test
chhwang Aug 3, 2024
fad4876
Merge branch 'chhwang/errors' into binyli/comm
chhwang Aug 3, 2024
cf96f0e
error handling fixes & unit tests
chhwang Aug 3, 2024
cd99eee
Merge branch 'chhwang/errors' into binyli/comm
chhwang Aug 3, 2024
b442297
Merge branch 'main' into binyli/comm
chhwang Aug 3, 2024
144435b
minor changes
chhwang Aug 3, 2024
69beb1a
context manager
chhwang Aug 3, 2024
dbbdda5
minor changes
chhwang Aug 5, 2024
304fa59
changed ModelContextManager inferface
chhwang Aug 5, 2024
24c5cc6
context interface
chhwang Aug 5, 2024
6f7b184
rename method & docs
chhwang Aug 5, 2024
d9bdcb6
internal interface
chhwang Aug 5, 2024
7e1a31c
planner context
chhwang Aug 5, 2024
e151513
rename DefaultPlanner to Planner
chhwang Aug 6, 2024
3930c1b
python interface
chhwang Aug 6, 2024
adf6a64
Add plan class & arch check
chhwang Aug 6, 2024
eb3215e
minor fixes & improve coverage
chhwang Aug 6, 2024
2abb518
update workflow
chhwang Aug 6, 2024
3efaf33
Merge branch 'chhwang/plan_context' into binyli/comm
chhwang Aug 6, 2024
b94d61f
interface updates
chhwang Aug 6, 2024
8eed342
Merge branch 'main' into binyli/comm
chhwang Aug 6, 2024
f768b2e
add plan file
Binyang2014 Aug 6, 2024
707ab13
lint
Binyang2014 Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions ark/include/ark/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ class Model : public ModelGraph {
/// @p strides should be greater than or equal to the padded shape. If the
/// @p strides are not provided, they are set to the padded shape. If the
/// padded shape is not provided, it is set to the @p shape.
/// @param rank Rank of the tensor. -1 means the rank of this model.
/// @param name Name of the tensor.
/// @return Pointer to a tensor object.
///
Tensor tensor(const Dims &shape, const DataType &data_type,
const Dims &strides = {}, const Dims &offsets = {},
const Dims &padded_shape = {}, const std::string &name = "");
const Dims &padded_shape = {}, int rank = -1,
const std::string &name = "");

Tensor refer(Tensor input, const Dims &shape = {}, const Dims &strides = {},
const Dims &offsets = {}, const Dims &padded_shape = {},
Expand Down Expand Up @@ -196,10 +198,23 @@ class Model : public ModelGraph {
// operator is completed.
Tensor recv(Tensor output, int remote_rank, int tag,
const std::string &name = "");
//
Tensor put_packet(Tensor input, Tensor local_tmp_buf, Tensor recv_buf,
int id, int rank, int dst_rank, size_t dst_offset,
int flag, const std::string &name = "");
Tensor send_packet(Tensor input, int remote_rank, int tag, int flag,
Tensor output = NullTensor,
const std::string &name = "");
Tensor recv_packet(Tensor output, int remote_rank, int tag, int flag,
Tensor scratch = NullTensor,
const std::string &name = "");
Tensor recv_reduce_send_packet(
Tensor input, const std::vector<int> &remote_ranks, int recv_tag,
int output_tag, unsigned int flag, Tensor output = NullTensor,
std::vector<Tensor> peer_outputs = {}, Tensor scratch = NullTensor,
const std::string &name = "");
Tensor recv_reduce_send(Tensor input, const std::vector<int> &remote_ranks,
int recv_tag, int output_tag,
Tensor output = NullTensor,
std::vector<Tensor> peer_outputs = {},
Tensor scratch = NullTensor,
const std::string &name = "");
// Performs an all-reduce operator across all ranks, aggregating the input
// tensors. Takes the `input` tensor, the current GPU's rank, and the
// total number of ranks `rank_num`.
Expand All @@ -220,7 +235,8 @@ class Model : public ModelGraph {
Tensor output = NullTensor, const std::string &name = "");

// sync across multi devices
Tensor device_sync(Tensor input, int npeers, const std::string &name = "");
Tensor device_sync(Tensor input, int rank, int rank_num,
const std::string &name = "");

// local reduce scatter
Tensor local_reduce_scatter(Tensor input, int gpu_id, int ngpus_per_node,
Expand All @@ -238,18 +254,7 @@ class Model : public ModelGraph {

Tensor local_all_reduce(Tensor input, int gpu_id, int gpu_num,
const std::string &name = "");
Tensor local_all_reduce_packet(Tensor input, int gpu_id, int gpu_num,
const std::string &name = "");

Tensor reduce_and_write_packet(Tensor input, Tensor scratch, Tensor output,
const std::vector<Tensor> &remote_peer_bufs,
int id, int rank, int npeers,
size_t elems_per_rank, size_t scratch_offset,
size_t remote_dst_offset, int flag,
const std::string &name = "");
Tensor get_packet(Tensor input, Tensor output, size_t src_offset,
size_t dst_offset, size_t npackets, int flag,
const std::string &name = "");
};

} // namespace ark
Expand Down
Loading
Loading