diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a102ae3..e680bc1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: rev: v2.1.0 hooks: - id: codespell - args: ["--skip=docs/imgs/*,csrc/json.hpp,csrc/python/pybind_json/pybind_json.hpp"] + args: ["--skip=docs/imgs/*,csrc/jring.h,csrc/json.hpp,csrc/python/pybind_json/pybind_json.hpp"] # - repo: https://github.com/myint/docformatter # rev: v1.4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ef92c9..7757817 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,8 +11,9 @@ slime_option(BUILD_NVSHMEM "Build NVSHMEM" OFF) slime_option(BUILD_ASCEND_DIRECT "Build Ascend direct transport" OFF) # Slime options for ops -slime_option(BUILD_INTRA_OPS "Build intra LL collective ops" OFF) -slime_option(BUILD_INTER_OPS "Build inter LL collective ops" OFF) +slime_option(BUILD_IBVERBS_OPS "Build ibverbs collective ops" OFF) +slime_option(BUILD_INTRA_OPS "Build intra collective ops" OFF) +slime_option(BUILD_INTER_OPS "Build inter collective ops" OFF) # Slime options for custom python wrapper slime_option(BUILD_PYTHON "Build python wrapper" OFF) @@ -25,6 +26,11 @@ slime_option(BUILD_TORCH_PLUGIN "Build torch plugin" OFF) slime_option(BUILD_BENCH "Build transfer engine benchmark" OFF) slime_option(BUILD_TEST "Build test" OFF) +if(BUILD_IBVERBS_OPS AND NOT BUILD_RDMA) + message(STATUS "BUILD_IBVERBS_OPS requires BUILD_RDMA, enabling RDMA...") + set(BUILD_RDMA ON CACHE BOOL "Build RDMA" FORCE) +endif() + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) @@ -47,7 +53,7 @@ else() set(DLSLIME_INSTALL_PATH "lib") endif() -if (BUILD_TORCH_PLUGIN OR BUILD_INTRA_OPS OR BUILD_INTER_OPS) +if (BUILD_TORCH_PLUGIN OR BUILD_IBVERBS_OPS OR BUILD_INTRA_OPS OR BUILD_INTER_OPS) include(${CMAKE_CURRENT_LIST_DIR}/cmake/torch.cmake) endif() diff --git a/README.md b/README.md index 88b69cd..14ef4be 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ torchrun --nnodes 2 --master-addr 10.130.8.143 --node-rank 1 --nproc-per-node 8 ``` > \[!Note\] -> The intra- and inter- examples example above enables CUDA Graph by default. --eager-mode falls back to eager mode. +> The `intra-` and `inter-` examples above enables CUDA Graph by default. `--eager-mode` falls back to eager mode. ## Install @@ -157,17 +157,18 @@ mkdir -p DLSlime/build && cmake -DFLAG= .. The `FLAG` can be -| Flag | Description | Platform | default | -| :----------------------- | :------------------------------------ | :------- | ------: | -| `BUILD_RDMA` | Build RDMA Transfer Engine | Hetero | ON | -| `BUILD_PYTHON` | Build Python wrapper | Hetero | ON | -| `BUILD_NVLINK` | Build NVLINK Transfer Engine | GPGPU | OFF | -| `BUILD_NVSHMEM` | Build NVShmem Transfer Engine | NVIDIA | OFF | -| `BUILD_ASCEND_DIRECT` | Build Ascend direct transport | ASCEND | OFF | -| `BUILD_TORCH_PLUGIN` | Build DLSlime as a torch backend | Hetero | OFF | -| `USE_GLOO_BACKEND` | Use GLOO RDMA Send/Recv torch backend | Hetero | OFF | -| `BUILD_INTRA_OPS` | Use INTRA Collective OPS | GPGPU | OFF | -| `BUILD_INTER_OPS` | Use INTER Collective OPS (NVSHMEM) | NVIDIA | OFF | +| Flag | Description | Platform | default | +| :-------------------- | :------------------------------------ | :------- | ------: | +| `BUILD_RDMA` | Build RDMA Transfer Engine | Hetero | ON | +| `BUILD_PYTHON` | Build Python wrapper | Hetero | ON | +| `BUILD_NVLINK` | Build NVLINK Transfer Engine | GPGPU | OFF | +| `BUILD_NVSHMEM` | Build NVShmem Transfer Engine | NVIDIA | OFF | +| `BUILD_ASCEND_DIRECT` | Build Ascend direct transport | ASCEND | OFF | +| `BUILD_TORCH_PLUGIN` | Build DLSlime as a torch backend | Hetero | OFF | +| `USE_GLOO_BACKEND` | Use GLOO RDMA Send/Recv torch backend | Hetero | OFF | +| `BUILD_IBVERBS_OPS` | Build IBVERBS Collective OPS | Hetero | OFF | +| `BUILD_INTRA_OPS` | Build INTRA Collective OPS | GPGPU | OFF | +| `BUILD_INTER_OPS` | Build INTER Collective OPS (NVSHMEM) | NVIDIA | OFF | > \[!Note\] > Please enable `USE_MECA` when using DLSlime as a torch backend in Metax platform. diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 52d99f6..96fa1eb 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,8 +1,6 @@ add_subdirectory(engine) -if (BUILD_INTRA_OPS OR BUILD_INTER_OPS) - add_subdirectory(ops) -endif() +add_subdirectory(ops) if (BUILD_PYTHON) add_subdirectory(python) diff --git a/csrc/engine/rdma/rdma_assignment.h b/csrc/engine/rdma/rdma_assignment.h index e443ac0..c59ff29 100644 --- a/csrc/engine/rdma/rdma_assignment.h +++ b/csrc/engine/rdma/rdma_assignment.h @@ -127,8 +127,28 @@ class RDMASchedulerAssignment { rdma_assignment_batch_(std::move(rdma_assignment_batch)) { } + ~RDMASchedulerAssignment(); + int merge(std::shared_ptr assign) { + if (!assign) { + return 0; + } + + int original_size = rdma_assignment_batch_.size(); + + rdma_assignment_batch_.reserve(original_size + assign->rdma_assignment_batch_.size()); + rdma_assignment_batch_.insert( + rdma_assignment_batch_.end(), + assign->rdma_assignment_batch_.begin(), + assign->rdma_assignment_batch_.end() + ); + + assign->rdma_assignment_batch_.clear(); + + return rdma_assignment_batch_.size() - original_size; // 返回合并的元素数量 + } + void query(); void wait(); diff --git a/csrc/engine/rdma/rdma_context.cpp b/csrc/engine/rdma/rdma_context.cpp index f60357e..4b3efe0 100644 --- a/csrc/engine/rdma/rdma_context.cpp +++ b/csrc/engine/rdma/rdma_context.cpp @@ -383,7 +383,7 @@ void RDMAContext::stop_future() cq_thread_.join(); } } - +namespace { void split_assign_by_max_length(OpCode opcode, AssignmentBatch& batch, AssignmentBatch& batch_split_after_max_length, @@ -426,6 +426,7 @@ void nsplit_assign_by_step(OpCode opcode, int step = (bsize + nstep - 1) / nstep; split_assign_by_step(opcode, batch, batch_nsplit, step); } +} // namespace std::shared_ptr RDMAContext::submit(OpCode opcode, AssignmentBatch& batch, callback_fn_t callback, int qpi, int32_t imm_data) @@ -528,8 +529,18 @@ int64_t RDMAContext::post_recv_batch(int qpi, RDMAAssignmentSharedPtr assign) int64_t ret = 0; size_t batch_size = assign->batch_size(); struct ibv_recv_wr* bad_wr = nullptr; - struct ibv_recv_wr* wr = new ibv_recv_wr[batch_size]; - struct ibv_sge* sge = new ibv_sge[batch_size]; + struct ibv_recv_wr* wr; + struct ibv_sge* sge; + if (assign->batch_size() == 0) { + wr = new ibv_recv_wr{.wr_id = (uintptr_t)(new callback_info_with_qpi_t{assign->callback_info_, qpi}), + .next = nullptr, + .sg_list = nullptr, + .num_sge = 0}; + } + else { + wr = new ibv_recv_wr[batch_size]; + sge = new ibv_sge[batch_size]; + } for (size_t i = 0; i < batch_size; ++i) { Assignment& subassign = assign->batch_[i]; diff --git a/csrc/jring.h b/csrc/jring.h new file mode 100644 index 0000000..1cc82b2 --- /dev/null +++ b/csrc/jring.h @@ -0,0 +1,512 @@ +/*- + * SPDX-License-Identifier: BSD-2-Clause-FreeBSD + * + * Copyright (c) 2019 Arm Limited + * Copyright (c) 2010-2017 Intel Corporation* + * Copyright (c) 2007-2009 Kip Macy + * Copyright (c) 2022 Ilias Marinos + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * $FreeBSD$ + * + */ + +#ifndef SRC_EXT_JRING_H_ +#define SRC_EXT_JRING_H_ + +/** + * @file + * Buffer Ring with user defined element size. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include + +#define RING_SZ_MASK (unsigned)(0x0fffffff) /**< Ring slots mask */ +#define CACHE_LINE_SIZE 64 + +#define ISPOWEROF2(x) (((((x)-1) & (x)) == 0) && x) +#define __ROUND_MASK(x, y) ((__typeof__(x))((y)-1)) +#define ALIGN_UP_POW2(x, y) ((((x)-1) | __ROUND_MASK(x, y)) + 1) +#define ROUND_DOWN_POW2(x, y) ((x) & ~__ROUND_MASK(x, y)) + +/** enqueue/dequeue behavior types */ +enum jring_queue_behavior { + /** Enq/Deq a fixed number of items from a ring */ + JRING_QUEUE_FIXED = 0, + /** Enq/Deq as many items as possible from ring */ + JRING_QUEUE_VARIABLE +}; + +/** prod/cons sync types */ +enum jring_sync_type { + JRING_SYNC_MT = 0, /**< multi-thread safe */ + JRING_SYNC_ST = 1, /**< single-thread */ +}; + +struct jring_headtail { + volatile uint32_t head; + volatile uint32_t tail; + enum jring_sync_type sync; +}; + +struct jring { + uint32_t size; /**< Size of ring. */ + uint32_t mask; /**< Mask (size-1) of ring. */ + uint32_t capacity; /**< Usable size of ring */ + uint32_t esize; /**< Size of each element in the ring. */ + uint32_t reserved[2]; + + // Producer head/tail. + struct jring_headtail prod __attribute__((aligned(CACHE_LINE_SIZE))); + + // Empty cache line. + char pad0 __attribute__((aligned(CACHE_LINE_SIZE))); + + // Consumer head/tail. + struct jring_headtail cons __attribute__((aligned(CACHE_LINE_SIZE))); + + // Empty cache line. + char pad1 __attribute__((aligned(CACHE_LINE_SIZE))); + void *ring[0] __attribute__((aligned(CACHE_LINE_SIZE))); +}; +typedef struct jring jring_t; + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#pragma GCC diagnostic ignored "-Wstringop-overflow" +#include "jring_elem_private.h" +#pragma GCC diagnostic pop + +/** + * Calculate the memory size needed for a ring with given element size. + * + * This function returns the number of bytes needed for a ring, given + * the number of elements in it and the size of the element. This value + * is the sum of the size of the structure buf_ring and the size of the + * memory needed for storing the elements. The value is aligned to a cache + * line size. + * + * @param element_size + * The size of ring element, in bytes. It must be a multiple of 4B. + * @param count + * The number of elements in the ring (must be a power of 2). + * @return + * - The memory size needed for the ring on success. + * - (size_t)-1 - element_size is not a multiple of 4 or count + * provided is not a power of 2. + */ +static inline size_t jring_get_buf_ring_size(uint32_t element_size, + uint32_t count) { + if ((element_size % 4) != 0) return -1; + + if (!(ISPOWEROF2(count)) || (count > RING_SZ_MASK)) return -1; + + size_t sz = sizeof(struct jring) + (count * element_size); + + sz = ALIGN_UP_POW2(sz, CACHE_LINE_SIZE); + return sz; +} + +/** + * Function to initialize a ring. + * + * @param r Pointer to the ring structure. + * @param count The number of elements in the ring (must be a power of 2). + * @param esize Element size in bytes. + * @param mp Set to 1 if the ring is to be multi-producer safe. + * @param mc Set to 1 if the ring is to be multi-consumer safe. + * @return 0 on success, -1 on failure. + */ +static inline int jring_init(struct jring *r, uint32_t count, uint32_t esize, + int mp, int mc) { + // The buffer ring needs to be a power of two. + if (!ISPOWEROF2(count)) { + return -1; + } + // Element size needs to be a multiple of 4. + if (esize % 4 != 0) return -1; + + r->size = count; + r->mask = r->size - 1; + r->capacity = r->mask; // Usable size of the ring. + r->esize = esize; + + r->prod.head = 0; + r->prod.tail = 0; + !!mp ? (r->prod.sync = JRING_SYNC_MT) : (r->prod.sync = JRING_SYNC_ST); + r->cons.head = 0; + r->cons.tail = 0; + !!mc ? (r->cons.sync = JRING_SYNC_MT) : (r->cons.sync = JRING_SYNC_ST); + + return 0; +} + +/** + * Enqueue a specific amount of objects on a ring (Single producer only). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param free_space + * if non-NULL, returns the amount of space in the ring after the + * enqueue operation has finished. + * @return + * The number of objects enqueued, either 0 or n + */ +static __attribute__((always_inline)) inline unsigned int jring_sp_enqueue_bulk( + struct jring *r, const void *obj_table, unsigned int n, + unsigned int *free_space) { + return __jring_do_enqueue_elem(r, obj_table, r->esize, n, JRING_QUEUE_FIXED, + JRING_SYNC_ST, free_space); +} + +/** + * Enqueue a specific amount of objects on a ring (multi-producer safe). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param free_space + * if non-NULL, returns the amount of space in the ring after the + * enqueue operation has finished. + * @return + * The number of objects enqueued, either 0 or n + */ +static __attribute__((always_inline)) inline unsigned int jring_mp_enqueue_bulk( + struct jring *r, const void *obj_table, unsigned int n, + unsigned int *free_space) { + return __jring_do_enqueue_elem(r, obj_table, r->esize, n, JRING_QUEUE_FIXED, + JRING_SYNC_MT, free_space); +} + +/** + * Enqueue a specific amount of objects on a ring. This function performs + * a runtime check to find the ring's synchronization type. + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param free_space + * if non-NULL, returns the amount of space in the ring after the + * enqueue operation has finished. + * @return + * The number of objects enqueued, either 0 or n + */ +static __attribute((always_inline)) inline unsigned int jring_enqueue_bulk( + struct jring *r, const void *obj_table, unsigned int n, + unsigned int *free_space) { + return (r->prod.sync == JRING_SYNC_ST) + ? jring_sp_enqueue_bulk(r, obj_table, n, free_space) + : jring_mp_enqueue_bulk(r, obj_table, n, free_space); +} + +/** + * Enqueue up to a specific amount of objects on a ring (Single producer only). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param free_space + * if non-NULL, returns the amount of space in the ring after the + * enqueue operation has finished. + * @return + * The number of objects enqueued, ranging in [0,n]. + */ +static __attribute__((always_inline)) inline unsigned int +jring_sp_enqueue_burst(struct jring *r, const void *obj_table, unsigned int n, + unsigned int *free_space) { + return __jring_do_enqueue_elem(r, obj_table, r->esize, n, + JRING_QUEUE_VARIABLE, JRING_SYNC_ST, + free_space); +} + +/** + * Enqueue up to a specific amount of objects on a ring (multi-producer safe). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param free_space + * if non-NULL, returns the amount of space in the ring after the + * enqueue operation has finished. + * @return + * The number of objects enqueued, ranging in [0,n]. + */ +static __attribute__((always_inline)) inline unsigned int +jring_mp_enqueue_burst(struct jring *r, const void *obj_table, unsigned int n, + unsigned int *free_space) { + return __jring_do_enqueue_elem(r, obj_table, r->esize, n, + JRING_QUEUE_VARIABLE, JRING_SYNC_MT, + free_space); +} + +/** + * Enqueue up to a specific amount of objects on a ring. This function performs + * a runtime check to find the ring's synchronization type. + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param free_space + * if non-NULL, returns the amount of space in the ring after the + * enqueue operation has finished. + * @return + * The number of objects enqueued, ranging in [0,n]. + */ +static __attribute__((always_inline)) inline unsigned int jring_enqueue_burst( + struct jring *r, const void *obj_table, unsigned int n, + unsigned int *free_space) { + return (r->prod.sync == JRING_SYNC_ST) + ? jring_sp_enqueue_burst(r, obj_table, n, free_space) + : jring_mp_enqueue_burst(r, obj_table, n, free_space); +} + +/** + * Dequeue a specific amount of objects from a ring (NOT multi-consumers safe). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of objects that will be filled. + * @param n + * The number of objects to dequeue from the ring to the obj_table, + * must be strictly positive. + * @param available + * If non-NULL, returns the number of remaining ring entries after the + * dequeue has finished. + * @return + * The number of objects dequeued, either 0 or n + */ +static __attribute((always_inline)) inline unsigned int jring_sc_dequeue_bulk( + struct jring *r, void *obj_table, unsigned int n, unsigned int *available) { + return __jring_do_dequeue_elem(r, obj_table, r->esize, n, JRING_QUEUE_FIXED, + JRING_SYNC_ST, available); +} + +/** + * Dequeue a specific amount of objects from a ring (Multi-consumer safe). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of objects that will be filled. + * @param n + * The number of objects to dequeue from the ring to the obj_table, + * must be strictly positive. + * @param available + * If non-NULL, returns the number of remaining ring entries after the + * dequeue has finished. + * @return + * The number of objects dequeued, either 0 or n + */ +static __attribute((always_inline)) inline unsigned int jring_mc_dequeue_bulk( + struct jring *r, void *obj_table, unsigned int n, unsigned int *available) { + return __jring_do_dequeue_elem(r, obj_table, r->esize, n, JRING_QUEUE_FIXED, + JRING_SYNC_MT, available); +} + +/** + * Dequeue a specific amount of objects on a ring. This function performs + * a runtime check to find the ring's synchronization type. + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param available + * If non-NULL, returns the number of remaining ring entries after the + * dequeue has finished. + * @return + * The number of objects dequeued, either 0 or n. + */ +static __attribute((always_inline)) inline unsigned int jring_dequeue_bulk( + struct jring *r, void *obj_table, unsigned int n, unsigned int *available) { + return (r->cons.sync == JRING_SYNC_ST) + ? jring_sc_dequeue_bulk(r, obj_table, n, available) + : jring_mc_dequeue_bulk(r, obj_table, n, available); +} + +/** + * Dequeue up to a specific amount of objects from a ring (NOT multi-consumer + * safe). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of objects that will be filled. + * @param n + * The number of objects to dequeue from the ring to the obj_table, + * must be strictly positive. + * @param available + * If non-NULL, returns the number of remaining ring entries after the + * dequeue has finished. + * @return + * The number of objects dequeued, in the range [0, n] + */ +static __attribute((always_inline)) inline unsigned int jring_sc_dequeue_burst( + struct jring *r, void *obj_table, unsigned int n, unsigned int *available) { + return __jring_do_dequeue_elem(r, obj_table, r->esize, n, + JRING_QUEUE_VARIABLE, JRING_SYNC_ST, + available); +} + +/** + * Dequeue up to a specific amount of objects from a ring (multi-consumer + * safe). + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of objects that will be filled. + * @param n + * The number of objects to dequeue from the ring to the obj_table, + * must be strictly positive. + * @param available + * If non-NULL, returns the number of remaining ring entries after the + * dequeue has finished. + * @return + * The number of objects dequeued, in the range [0, n] + */ +static __attribute((always_inline)) inline unsigned int jring_mc_dequeue_burst( + struct jring *r, void *obj_table, unsigned int n, unsigned int *available) { + return __jring_do_dequeue_elem(r, obj_table, r->esize, n, + JRING_QUEUE_VARIABLE, JRING_SYNC_MT, + available); +} + +/** + * Dequeue up to a specific amount of objects on a ring. This function performs + * a runtime check to find the ring's synchronization type. + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of void * pointers (objects). + * @param n + * The number of objects to add in the ring from the obj_table. + * @param available + * If non-NULL, returns the number of remaining ring entries after the + * dequeue has finished. + * @return + * The number of objects dequeued, ranging in [0,n]. + */ +static __attribute((always_inline)) inline unsigned int jring_dequeue_burst( + struct jring *r, void *obj_table, unsigned int n, unsigned int *available) { + return (r->cons.sync == JRING_SYNC_ST) + ? jring_sc_dequeue_burst(r, obj_table, n, available) + : jring_mc_dequeue_burst(r, obj_table, n, available); +} + +/** + * Return the number of entries in a ring. + * + * @param r + * A pointer to the ring structure. + * @return + * The number of entries in the ring. + */ +static __attribute__((always_inline)) inline unsigned int jring_count( + const struct jring *r) { + uint32_t prod_tail = r->prod.tail; + uint32_t cons_tail = r->cons.tail; + uint32_t count = (prod_tail - cons_tail) & r->mask; + return (count > r->capacity) ? r->capacity : count; +} + +/** + * Return the number of free entries in a ring. + * + * @param r + * A pointer to the ring structure. + * @return + * The number of free entries in the ring. + */ +static __attribute__((always_inline)) inline unsigned int jring_free_count( + const struct jring *r) { + return r->capacity - jring_count(r); +} + +/** + * Test if a ring is full. + * + * @param r + * A pointer to the ring structure. + * @return + * - 1: The ring is full. + * - 0: The ring is not full. + */ +static __attribute__((always_inline)) inline int jring_full( + const struct jring *r) { + return jring_free_count(r) == 0; +} + +/** + * Test if a ring is empty. + * + * @param r + * A pointer to the ring structure. + * @return + * - 1: The ring is empty. + * - 0: The ring is not empty. + */ +static __attribute__((always_inline)) inline int jring_empty( + const struct jring *r) { + uint32_t prod_tail = r->prod.tail; + uint32_t cons_tail = r->cons.tail; + return cons_tail == prod_tail; +} + +#ifdef __cplusplus +} +#endif + +#endif // SRC_EXT_JRING_H_ diff --git a/csrc/jring_elem_private.h b/csrc/jring_elem_private.h new file mode 100644 index 0000000..73b6085 --- /dev/null +++ b/csrc/jring_elem_private.h @@ -0,0 +1,523 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * + * Copyright (c) 2017,2018 HXT-semitech Corporation. + * Copyright (c) 2007-2009 Kip Macy kmacy@freebsd.org + * All rights reserved. + * Derived from FreeBSD's bufring.h + * Used as BSD-3 Licensed with permission from Kip Macy. + */ + +#ifndef SRC_EXT_JRING_ELEM_PRIVATE_H_ +#define SRC_EXT_JRING_ELEM_PRIVATE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#include "pause.h" + +#ifndef likely +#define likely(x) __builtin_expect((x), 1) +#endif +#ifndef unlikely +#define unlikely(x) __builtin_expect((x), 0) +#endif + +typedef struct { + union { + uint64_t val[2]; + }; +} __attribute__((aligned(16))) j_int128_t; + +static inline __attribute__((always_inline)) void __jring_enqueue_elems_32( + struct jring *r, const uint32_t size, uint32_t idx, const void *obj_table, + uint32_t n) { + unsigned int i; + uint32_t *ring = (uint32_t *)&r[1]; + const uint32_t *obj = (const uint32_t *)obj_table; + if (likely(idx + n < size)) { + for (i = 0; i < (n & ~0x7); i += 8, idx += 8) { + ring[idx] = obj[i]; + ring[idx + 1] = obj[i + 1]; + ring[idx + 2] = obj[i + 2]; + ring[idx + 3] = obj[i + 3]; + ring[idx + 4] = obj[i + 4]; + ring[idx + 5] = obj[i + 5]; + ring[idx + 6] = obj[i + 6]; + ring[idx + 7] = obj[i + 7]; + } + switch (n & 0x7) { + case 7: + ring[idx++] = obj[i++]; /* fallthrough */ + case 6: + ring[idx++] = obj[i++]; /* fallthrough */ + case 5: + ring[idx++] = obj[i++]; /* fallthrough */ + case 4: + ring[idx++] = obj[i++]; /* fallthrough */ + case 3: + ring[idx++] = obj[i++]; /* fallthrough */ + case 2: + ring[idx++] = obj[i++]; /* fallthrough */ + case 1: + ring[idx++] = obj[i++]; /* fallthrough */ + } + } else { + for (i = 0; idx < size; i++, idx++) ring[idx] = obj[i]; + /* Start at the beginning */ + for (idx = 0; i < n; i++, idx++) ring[idx] = obj[i]; + } +} + +static __attribute__((always_inline)) inline void __jring_enqueue_elems_64( + struct jring *r, uint32_t prod_head, const void *obj_table, uint32_t n) { + unsigned int i; + const uint32_t size = r->size; + uint32_t idx = prod_head & r->mask; + uint64_t *ring = (uint64_t *)&r[1]; + const uint64_t *obj = (const uint64_t *)obj_table; + if (likely(idx + n < size)) { + for (i = 0; i < (n & ~0x3); i += 4, idx += 4) { + ring[idx] = obj[i]; + ring[idx + 1] = obj[i + 1]; + ring[idx + 2] = obj[i + 2]; + ring[idx + 3] = obj[i + 3]; + } + switch (n & 0x3) { + case 3: + ring[idx++] = obj[i++]; /* fallthrough */ + case 2: + ring[idx++] = obj[i++]; /* fallthrough */ + case 1: + ring[idx++] = obj[i++]; + } + } else { + for (i = 0; idx < size; i++, idx++) ring[idx] = obj[i]; + /* Start at the beginning */ + for (idx = 0; i < n; i++, idx++) ring[idx] = obj[i]; + } +} + +static inline __attribute__((always_inline)) void __jring_enqueue_elems_128( + struct jring *r, uint32_t prod_head, const void *obj_table, uint32_t n) { + unsigned int i; + const uint32_t size = r->size; + uint32_t idx = prod_head & r->mask; + j_int128_t *ring = (j_int128_t *)&r[1]; + const j_int128_t *obj = (const j_int128_t *)obj_table; + if (likely(idx + n < size)) { + for (i = 0; i < (n & ~0x1); i += 2, idx += 2) + memcpy((void *)(ring + idx), (const void *)(obj + i), 32); + switch (n & 0x1) { + case 1: + memcpy((void *)(ring + idx), (const void *)(obj + i), 16); + } + } else { + for (i = 0; idx < size; i++, idx++) + memcpy((void *)(ring + idx), (const void *)(obj + i), 16); + /* Start at the beginning */ + for (idx = 0; i < n; i++, idx++) + memcpy((void *)(ring + idx), (const void *)(obj + i), 16); + } +} + +/* the actual enqueue of elements on the ring. + * Placed here since identical code needed in both + * single and multi producer enqueue functions. + */ +static inline __attribute__((always_inline)) void __jring_enqueue_elems( + struct jring *r, uint32_t prod_head, const void *obj_table, uint32_t esize, + uint32_t num) { + /* 8B and 16B copies implemented individually to retain + * the current performance. + */ + if (esize == 8) { + __jring_enqueue_elems_64(r, prod_head, obj_table, num); + } else if (esize == 16) { + __jring_enqueue_elems_128(r, prod_head, obj_table, num); + } else { + uint32_t idx, scale, nr_idx, nr_num, nr_size; + + /* Normalize to uint32_t */ + scale = esize / sizeof(uint32_t); + nr_num = num * scale; + idx = prod_head & r->mask; + nr_idx = idx * scale; + nr_size = r->size * scale; + __jring_enqueue_elems_32(r, nr_size, nr_idx, obj_table, nr_num); + } +} + +static inline __attribute__((always_inline)) void __jring_dequeue_elems_32( + struct jring *r, const uint32_t size, uint32_t idx, void *obj_table, + uint32_t n) { + unsigned int i; + uint32_t *ring = (uint32_t *)&r[1]; + uint32_t *obj = (uint32_t *)obj_table; + if (likely(idx + n < size)) { + for (i = 0; i < (n & ~0x7); i += 8, idx += 8) { + obj[i] = ring[idx]; + obj[i + 1] = ring[idx + 1]; + obj[i + 2] = ring[idx + 2]; + obj[i + 3] = ring[idx + 3]; + obj[i + 4] = ring[idx + 4]; + obj[i + 5] = ring[idx + 5]; + obj[i + 6] = ring[idx + 6]; + obj[i + 7] = ring[idx + 7]; + } + switch (n & 0x7) { + case 7: + obj[i++] = ring[idx++]; /* fallthrough */ + case 6: + obj[i++] = ring[idx++]; /* fallthrough */ + case 5: + obj[i++] = ring[idx++]; /* fallthrough */ + case 4: + obj[i++] = ring[idx++]; /* fallthrough */ + case 3: + obj[i++] = ring[idx++]; /* fallthrough */ + case 2: + obj[i++] = ring[idx++]; /* fallthrough */ + case 1: + obj[i++] = ring[idx++]; /* fallthrough */ + } + } else { + for (i = 0; idx < size; i++, idx++) obj[i] = ring[idx]; + /* Start at the beginning */ + for (idx = 0; i < n; i++, idx++) obj[i] = ring[idx]; + } +} + +static inline __attribute__((always_inline)) void __jring_dequeue_elems_64( + struct jring *r, uint32_t prod_head, void *obj_table, uint32_t n) { + unsigned int i; + const uint32_t size = r->size; + uint32_t idx = prod_head & r->mask; + uint64_t *ring = (uint64_t *)&r[1]; + uint64_t *obj = (uint64_t *)obj_table; + if (likely(idx + n < size)) { + for (i = 0; i < (n & ~0x3); i += 4, idx += 4) { + obj[i] = ring[idx]; + obj[i + 1] = ring[idx + 1]; + obj[i + 2] = ring[idx + 2]; + obj[i + 3] = ring[idx + 3]; + } + switch (n & 0x3) { + case 3: + obj[i++] = ring[idx++]; /* fallthrough */ + case 2: + obj[i++] = ring[idx++]; /* fallthrough */ + case 1: + obj[i++] = ring[idx++]; /* fallthrough */ + } + } else { + for (i = 0; idx < size; i++, idx++) obj[i] = ring[idx]; + /* Start at the beginning */ + for (idx = 0; i < n; i++, idx++) obj[i] = ring[idx]; + } +} + +static inline __attribute__((always_inline)) void __jring_dequeue_elems_128( + struct jring *r, uint32_t prod_head, void *obj_table, uint32_t n) { + unsigned int i; + const uint32_t size = r->size; + uint32_t idx = prod_head & r->mask; + j_int128_t *ring = (j_int128_t *)&r[1]; + j_int128_t *obj = (j_int128_t *)obj_table; + if (likely(idx + n < size)) { + for (i = 0; i < (n & ~0x1); i += 2, idx += 2) + memcpy((void *)(obj + i), (void *)(ring + idx), 32); + switch (n & 0x1) { + case 1: + memcpy((void *)(obj + i), (void *)(ring + idx), 16); + } + } else { + for (i = 0; idx < size; i++, idx++) + memcpy((void *)(obj + i), (void *)(ring + idx), 16); + /* Start at the beginning */ + for (idx = 0; i < n; i++, idx++) + memcpy((void *)(obj + i), (void *)(ring + idx), 16); + } +} + +/* the actual dequeue of elements from the ring. + * Placed here since identical code needed in both + * single and multi producer enqueue functions. + */ +static inline __attribute__((always_inline)) void __jring_dequeue_elems( + struct jring *r, uint32_t cons_head, void *obj_table, uint32_t esize, + uint32_t num) { + /* 8B and 16B copies implemented individually to retain + * the current performance. + */ + if (esize == 8) { + __jring_dequeue_elems_64(r, cons_head, obj_table, num); + } else if (esize == 16) { + __jring_dequeue_elems_128(r, cons_head, obj_table, num); + } else { + uint32_t idx, scale, nr_idx, nr_num, nr_size; + + /* Normalize to uint32_t */ + scale = esize / sizeof(uint32_t); + nr_num = num * scale; + idx = cons_head & r->mask; + nr_idx = idx * scale; + nr_size = r->size * scale; + __jring_dequeue_elems_32(r, nr_size, nr_idx, obj_table, nr_num); + } +} + +static inline __attribute__((always_inline)) void __jring_wait_until_equal_32( + volatile uint32_t *addr, uint32_t expected, int memorder) { + // assert(memorder == __ATOMIC_ACQUIRE || memorder == __ATOMIC_RELAXED); + + while (__atomic_load_n(addr, memorder) != expected) machnet_pause(); +} + +static inline __attribute__((always_inline)) void __jring_update_tail( + struct jring_headtail *ht, uint32_t old_val, uint32_t new_val, + uint32_t single, __attribute__((unused)) uint32_t enqueue) { + /* + * If there are other enqueues/dequeues in progress that preceded us, + * we need to wait for them to complete + */ + if (!single) + __jring_wait_until_equal_32(&ht->tail, old_val, __ATOMIC_RELAXED); + + __atomic_store_n(&ht->tail, new_val, __ATOMIC_RELEASE); +} + +/** + * @internal This function updates the producer head for enqueue + * + * @param r + * A pointer to the ring structure + * @param is_sp + * Indicates whether multi-producer path is needed or not + * @param n + * The number of elements we will want to enqueue, i.e. how far should the + * head be moved + * @param behavior + * JRING_QUEUE_FIXED: Enqueue a fixed number of items from a ring + * JRING_QUEUE_VARIABLE: Enqueue as many items as possible from ring + * @param old_head + * Returns head value as it was before the move, i.e. where enqueue starts + * @param new_head + * Returns the current/new head value i.e. where enqueue finishes + * @param free_entries + * Returns the amount of free space in the ring BEFORE head was moved + * @return + * Actual number of objects enqueued. + * If behavior == JRING_QUEUE_FIXED, this will be 0 or n only. + */ +static inline __attribute__((always_inline)) unsigned int +__jring_move_prod_head(struct jring *r, unsigned int is_sp, unsigned int n, + enum jring_queue_behavior behavior, uint32_t *old_head, + uint32_t *new_head, uint32_t *free_entries) { + const uint32_t capacity = r->capacity; + uint32_t cons_tail; + unsigned int max = n; + int success; + + *old_head = __atomic_load_n(&r->prod.head, __ATOMIC_RELAXED); + do { + /* Reset n to the initial burst count */ + n = max; + + /* Ensure the head is read before tail */ + __atomic_thread_fence(__ATOMIC_ACQUIRE); + + /* load-acquire synchronize with store-release of ht->tail + * in update_tail. + */ + cons_tail = __atomic_load_n(&r->cons.tail, __ATOMIC_ACQUIRE); + + /* The subtraction is done between two unsigned 32bits value + * (the result is always modulo 32 bits even if we have + * *old_head > cons_tail). So 'free_entries' is always between 0 + * and capacity (which is < size). + */ + *free_entries = (capacity + cons_tail - *old_head); + + /* check that we have enough room in ring */ + if (unlikely(n > *free_entries)) + n = (behavior == JRING_QUEUE_FIXED) ? 0 : *free_entries; + + if (n == 0) return 0; + + *new_head = *old_head + n; + if (is_sp) + r->prod.head = *new_head, success = 1; + else + /* on failure, *old_head is updated */ + success = + __atomic_compare_exchange_n(&r->prod.head, old_head, *new_head, 0, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } while (unlikely(success == 0)); + return n; +} + +/** + * @internal This function updates the consumer head for dequeue + * + * @param r + * A pointer to the ring structure + * @param is_sc + * Indicates whether multi-consumer path is needed or not + * @param n + * The number of elements we will want to dequeue, i.e. how far should the + * head be moved + * @param behavior + * JRING_QUEUE_FIXED: Dequeue a fixed number of items from a ring + * JRING_QUEUE_VARIABLE: Dequeue as many items as possible from ring + * @param old_head + * Returns head value as it was before the move, i.e. where dequeue starts + * @param new_head + * Returns the current/new head value i.e. where dequeue finishes + * @param entries + * Returns the number of entries in the ring BEFORE head was moved + * @return + * - Actual number of objects dequeued. + * If behavior == JRING_QUEUE_FIXED, this will be 0 or n only. + */ +static inline __attribute__((always_inline)) unsigned int +__jring_move_cons_head(struct jring *r, int is_sc, unsigned int n, + enum jring_queue_behavior behavior, uint32_t *old_head, + uint32_t *new_head, uint32_t *entries) { + unsigned int max = n; + uint32_t prod_tail; + int success; + + /* move cons.head atomically */ + *old_head = __atomic_load_n(&r->cons.head, __ATOMIC_RELAXED); + do { + /* Restore n as it may change every loop */ + n = max; + + /* Ensure the head is read before tail */ + __atomic_thread_fence(__ATOMIC_ACQUIRE); + + /* this load-acquire synchronize with store-release of ht->tail + * in update_tail. + */ + prod_tail = __atomic_load_n(&r->prod.tail, __ATOMIC_ACQUIRE); + + /* The subtraction is done between two unsigned 32bits value + * (the result is always modulo 32 bits even if we have + * cons_head > prod_tail). So 'entries' is always between 0 + * and size(ring)-1. + */ + *entries = (prod_tail - *old_head); + + /* Set the actual entries for dequeue */ + if (n > *entries) n = (behavior == JRING_QUEUE_FIXED) ? 0 : *entries; + + if (unlikely(n == 0)) return 0; + + *new_head = *old_head + n; + if (is_sc) + r->cons.head = *new_head, success = 1; + else + /* on failure, *old_head will be updated */ + success = + __atomic_compare_exchange_n(&r->cons.head, old_head, *new_head, 0, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } while (unlikely(success == 0)); + return n; +} + +/** + * @internal Enqueue several objects on the ring + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of objects. + * @param esize + * The size of ring element, in bytes. It must be a multiple of 4. + * This must be the same value used while creating the ring. Otherwise + * the results are undefined. + * @param n + * The number of objects to add in the ring from the obj_table. + * @param behavior + * jRING_QUEUE_FIXED: Enqueue a fixed number of items from a ring + * jRING_QUEUE_VARIABLE: Enqueue as many items as possible from ring + * @param is_sp + * Indicates whether to use single producer or multi-producer head update + * @param free_space + * returns the amount of space after the enqueue operation has finished + * @return + * Actual number of objects enqueued. + * If behavior == jRING_QUEUE_FIXED, this will be 0 or n only. + */ +static inline __attribute__((always_inline)) unsigned int +__jring_do_enqueue_elem(struct jring *r, const void *obj_table, + unsigned int esize, unsigned int n, + enum jring_queue_behavior behavior, unsigned int is_sp, + unsigned int *free_space) { + uint32_t prod_head, prod_next; + uint32_t free_entries; + + n = __jring_move_prod_head(r, is_sp, n, behavior, &prod_head, &prod_next, + &free_entries); + if (n == 0) goto end; + + __jring_enqueue_elems(r, prod_head, obj_table, esize, n); + + __jring_update_tail(&r->prod, prod_head, prod_next, is_sp, 1); +end: + if (free_space != NULL) *free_space = free_entries - n; + return n; +} + +/** + * @internal Dequeue several objects from the ring + * + * @param r + * A pointer to the ring structure. + * @param obj_table + * A pointer to a table of objects. + * @param esize + * The size of ring element, in bytes. It must be a multiple of 4. + * This must be the same value used while creating the ring. Otherwise + * the results are undefined. + * @param n + * The number of objects to pull from the ring. + * @param behavior + * jRING_QUEUE_FIXED: Dequeue a fixed number of items from a ring + * jRING_QUEUE_VARIABLE: Dequeue as many items as possible from ring + * @param is_sc + * Indicates whether to use single consumer or multi-consumer head update + * @param available + * returns the number of remaining ring entries after the dequeue has finished + * @return + * - Actual number of objects dequeued. + * If behavior == jRING_QUEUE_FIXED, this will be 0 or n only. + */ +static inline __attribute__((always_inline)) unsigned int +__jring_do_dequeue_elem(struct jring *r, void *obj_table, unsigned int esize, + unsigned int n, enum jring_queue_behavior behavior, + unsigned int is_sc, unsigned int *available) { + uint32_t cons_head, cons_next; + uint32_t entries; + + n = __jring_move_cons_head(r, (int)is_sc, n, behavior, &cons_head, &cons_next, + &entries); + if (n == 0) goto end; + + __jring_dequeue_elems(r, cons_head, obj_table, esize, n); + + __jring_update_tail(&r->cons, cons_head, cons_next, is_sc, 0); + +end: + if (available != NULL) *available = entries - n; + return n; +} + +#ifdef __cplusplus +} +#endif +#endif // SRC_EXT_JRING_ELEM_PRIVATE_H_ diff --git a/csrc/ops/CMakeLists.txt b/csrc/ops/CMakeLists.txt index 785f6e2..465282a 100644 --- a/csrc/ops/CMakeLists.txt +++ b/csrc/ops/CMakeLists.txt @@ -1,3 +1,5 @@ +if (BUILD_IBVERBS_OPS OR BUILD_INTRA_OPS OR BUILD_INTER_OPS) + set(CMAKE_VERBOSE_MAKEFILE ON) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") @@ -24,11 +26,24 @@ set(OPS_SRC) set(OPS_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}) set(OPS_LIBRARIES ${TORCH_LIBRARIES}) +if (BUILD_IBVERBS_OPS) + set( + IBVERBS_OPS_SRC + ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.cpp + ) + list(APPEND OPS_SRC ${IBVERBS_OPS_SRC}) + set( + IBVERBS_OPS_LIBRARIES + _slime_rdma + ) + list(APPEND OPS_LIBRARIES ${IBVERBS_OPS_LIBRARIES}) +endif() + if (BUILD_INTRA_OPS) set( INTRA_OPS_SRC - intra_ll/all_gather_intra_ll/all_gather_intra_ll.cu - intra_ll/all_gather_intra_ll/all_gather_intra_ll_buffer.cpp + intra/all_gather_intra_ll/all_gather_intra_ll.cu + intra/all_gather_intra_ll/all_gather_intra_ll_buffer.cpp ) list(APPEND OPS_SRC ${INTRA_OPS_SRC}) set( @@ -48,8 +63,8 @@ if (BUILD_INTER_OPS) set( INTER_OPS_SRC - inter_ll/all_gather_inter_ll/all_gather_inter_ll.cu - inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp + inter/all_gather_inter_ll/all_gather_inter_ll.cu + inter/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp nvshmem_runtime.cu ) list(APPEND OPS_SRC ${INTER_OPS_SRC}) @@ -103,3 +118,5 @@ install( _slime_ops LIBRARY DESTINATION ${DLSLIME_INSTALL_PATH} ) + +endif() diff --git a/csrc/ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.cpp b/csrc/ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.cpp new file mode 100644 index 0000000..dcc4c43 --- /dev/null +++ b/csrc/ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.cpp @@ -0,0 +1,166 @@ +#include +#include +#include +#include + +#include "ATen/ops/empty.h" +#include "engine/assignment.h" +#include "engine/rdma/rdma_assignment.h" +#include "jring.h" +#include "json.hpp" +#include "logging.h" + +#include "engine/rdma/rdma_context.h" +#include "engine/rdma/utils.h" + +#include "m2n_ibverbs_rc_ll_buffer.h" +#include "torch/types.h" + +namespace slime { + +M2NIBVerbsRCLLBuffer::M2NIBVerbsRCLLBuffer(int64_t max_bs, + int64_t msg_size, + std::string device, + int64_t role, + int64_t m_world_size, + int64_t n_world_size, + int64_t rank, + int64_t num_concurrency, + int64_t qp_num, + std::string link_type): + max_bs_(max_bs), + msg_size_(msg_size), + device_(device), + role_(role), + m_world_size_(m_world_size), + n_world_size_(n_world_size), + rank_(rank), + num_concurrency_(num_concurrency), + link_type_(link_type) +{ + + // Step 1: Alloc Buffer + allocBuffer(); + + // Step 2. Init all RDMA Context + size_t ctx_size = role == 0 ? m_world_size : n_world_size; + for (int i = 0; i < ctx_size; ++i) { + auto rdma_context = std::make_shared(qp_num); + // TODO (JimyMa): Bynow, set link type to "RoCE", automatically in the future. + auto nics = available_nic(); + auto selected_nic = nics[rank % nics.size()]; + // TODO (JimyMa): Affinity of nic and comp device + rdma_context->init(selected_nic, 1, "RoCE"); + rdma_context->register_memory_region( + "buffer", reinterpret_cast(buffer_.data_ptr()), getBufferSize()); + ctx_.emplace_back(rdma_context); + } + + // Step 3: Init Recv Queue + jring_t* ring = reinterpret_cast(aligned_alloc(sizeof(m2n_recv_task_t), num_concurrency_ + 16)); + jring_init(ring, 16, sizeof(m2n_recv_task_t), 0, 0); + posted_recv_queue_ = ring; + + // Step 4: Prepost Some recv + for (int i = 0; i < num_concurrency_; ++i) + recvQueuePut(); +} + +inline int64_t M2NIBVerbsRCLLBuffer::peerWorldSize() +{ + if (role_ == 0) + return n_world_size_; + else + return m_world_size_; +} + +int M2NIBVerbsRCLLBuffer::recvQueuePut() +{ + // prepost a recv + RDMAAssignmentSharedPtrBatch assign_batch; + + for (int i = 0; i < peerWorldSize(); ++i) { + auto batch = AssignmentBatch{}; + // assign_batch.push_back(ctx_[i]->submit(OpCode::RECV, batch)); + } + + auto rdma_sch_assign = std::make_shared(assign_batch); + + auto task = m2n_recv_task_t{.assign = rdma_sch_assign}; + while (jring_mp_enqueue_bulk(posted_recv_queue_, &task, 1, nullptr) != 1) {} + return 0; +} + +m2n_recv_task_t M2NIBVerbsRCLLBuffer::recvQueueGet() +{ + m2n_recv_task_t task; + jring_sc_dequeue_bulk(posted_recv_queue_, &task, 1, nullptr); + return task; +} + +M2NIBVerbsRCLLBuffer::~M2NIBVerbsRCLLBuffer() +{ + free(posted_recv_queue_); +} + +size_t M2NIBVerbsRCLLBuffer::getBufferSize() +{ + size_t send_buffer_size = max_bs_ * msg_size_ * m_world_size_; + size_t recv_buffer_size = max_bs_ * msg_size_ * n_world_size_; + size_t send_signal_size = n_world_size_; + size_t recv_signal_size = m_world_size_; + return std::max(send_buffer_size, recv_buffer_size) + std::max(send_signal_size, recv_signal_size); +} + +int M2NIBVerbsRCLLBuffer::allocBuffer() +{ + auto options = torch::TensorOptions().dtype(torch::kInt8).device(device_); + buffer_ = torch::empty({static_cast(getBufferSize())}, options); + return 0; +} + +json M2NIBVerbsRCLLBuffer::bufferInfo() +{ + json info{}; + std::vector endpoint_info; + + for (int i = 0; i < ctx_.size(); ++i) { + endpoint_info.emplace_back(ctx_[i]->endpoint_info()); + } + + info["endpoint_info"] = endpoint_info; + + return info; +} + +int M2NIBVerbsRCLLBuffer::connectFullMesh(const std::vector& all_buffer_info) +{ + for (int i = 0; i < ctx_.size(); ++i) { + ctx_[i]->connect(all_buffer_info[i]["endpoint_info"][rank_]); + } + + return 0; +} + +std::shared_ptr M2NIBVerbsRCLLBuffer::M2NRecv(int tag) +{ + m2n_recv_task_t task = recvQueueGet(); + return task.assign; +} + +std::shared_ptr M2NIBVerbsRCLLBuffer::M2NSend(int tag) +{ + + RDMAAssignmentSharedPtrBatch assign_batch; + + for (int i = 0; i < ctx_.size(); ++i) { + AssignmentBatch batch = {Assignment("buffer", 0, 0, msg_size_ * max_bs_)}; + ctx_[i]->submit(OpCode::WRITE, batch, nullptr, -1, tag); + } + + auto rdma_sch_assign = std::make_shared(assign_batch); + + return rdma_sch_assign; +} + +} // namespace slime diff --git a/csrc/ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.h b/csrc/ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.h new file mode 100644 index 0000000..76a7e90 --- /dev/null +++ b/csrc/ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "engine/rdma/rdma_assignment.h" +#include "engine/rdma/rdma_context.h" +#include "engine/rdma/rdma_endpoint.h" +#include "jring.h" +#include "json.hpp" + +namespace slime { + +using json = nlohmann::json; + +typedef struct __attribute__((aligned(64))) M2NRecvTask { + std::shared_ptr assign; +} m2n_recv_task_t; + +class M2NIBVerbsRCLLBuffer { + +public: + M2NIBVerbsRCLLBuffer(int64_t max_bs, + int64_t msg_size, + std::string device, + int64_t role, + int64_t m_world_size, + int64_t n_world_size, + int64_t rank, + int64_t num_concurrency, + int64_t qp_num = 2, + std::string link_type = "RoCE"); + + ~M2NIBVerbsRCLLBuffer(); + + size_t getBufferSize(); + + int allocBuffer(); + + json bufferInfo(); + + int connectFullMesh(const std::vector& all_buffer_info); + + std::shared_ptr M2NRecv(int tag); + + std::shared_ptr M2NSend(int tag); + +private: + int recvQueuePut(); + + m2n_recv_task_t recvQueueGet(); + + inline int64_t peerWorldSize(); + + torch::Tensor buffer_; + + std::vector> ctx_; + + int64_t max_bs_; + int64_t msg_size_; + + std::string device_; + + int64_t role_; + + int64_t m_world_size_; + int64_t n_world_size_; + + int64_t rank_; + + int64_t num_concurrency_; + + std::string link_type_; + + jring_t* posted_recv_queue_; +}; + +} // namespace slime diff --git a/csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll.cu b/csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll.cu similarity index 100% rename from csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll.cu rename to csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll.cu diff --git a/csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll.h b/csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll.h similarity index 100% rename from csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll.h rename to csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll.h diff --git a/csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp b/csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp similarity index 77% rename from csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp rename to csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp index 50bd18d..7a62b4b 100644 --- a/csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp +++ b/csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.cpp @@ -12,30 +12,18 @@ namespace slime { AllGatherInterLLBuffer::AllGatherInterLLBuffer( - int64_t max_bs, int64_t msg_size, torch::Dtype dtype, int64_t world_size, int64_t rank, int64_t num_concurrency): + int64_t max_bs, int64_t msg_size, torch::Dtype dtype, int64_t world_size, int64_t rank, int64_t num_concurrency, bool allow_nvlink): max_bs_(max_bs), msg_size_(msg_size), dtype_(dtype), world_size_(world_size), rank_(rank), - num_concurrency_(num_concurrency) + num_concurrency_(num_concurrency), + allow_nvlink_(allow_nvlink) { - cudaSetDevice(localRank()); SLIME_ASSERT((msg_size * itemsize()) % 16 == 0, "By now, msg size must be divided by 16"); } -AllGatherInterLLBuffer::AllGatherInterLLBuffer(int64_t max_bs, - int64_t msg_size, - torch::Dtype dtype, - int64_t world_size, - int64_t rank, - int64_t num_concurrency, - bool allow_nvlink): - AllGatherInterLLBuffer(max_bs, msg_size, dtype, world_size, rank, num_concurrency) -{ - allow_nvlink_ = allow_nvlink; -} - size_t AllGatherInterLLBuffer::getBufferSize() { size_t buffer_size = static_cast(max_bs_) * msg_size_ * itemsize() * world_size_ * num_concurrency_; @@ -57,7 +45,7 @@ int AllGatherInterLLBuffer::allocSymBuffer() { size_t buffer_size = getBufferSize(); - sym_buffer_ = reinterpret_cast(nvshmem_api::alloc(buffer_size, nvshmem_alignment)); + sym_buffer_ = reinterpret_cast(nvshmem_api::alloc(buffer_size, nvshmem_alignment)); SLIME_ASSERT(sym_buffer_ != NULL, "failure of symbuffer allocation!"); nvshmem_api::barrier(); sym_signal_ = reinterpret_cast(nvshmem_api::alloc(world_size_ * sizeof(int), nvshmem_alignment)); @@ -79,12 +67,13 @@ json AllGatherInterLLBuffer::bufferInfo() int AllGatherInterLLBuffer::connectFullMesh(std::vector all_buffer_info) { - auto unique_ids = all_buffer_info[root_rank]["nvshmem_info"]["unique_id"]; - int nvshmem_rank = nvshmem_api::init(unique_ids, rank_, world_size_); + auto unique_id = all_buffer_info[root_rank]["nvshmem_info"]["unique_id"]; + int nvshmem_rank = nvshmem_api::init(unique_id, rank_, world_size_); nvshmem_api::barrier(); allocSymBuffer(); nvshmem_api::barrier(); - SLIME_ASSERT(nvshmem_rank == rank_, "nvshmem_rank != rank_"); + if (rank_ == 0) + SLIME_ASSERT(nvshmem_rank == rank_, "nvshmem_rank (" << nvshmem_rank << ") != rank_ (" << rank_ << ")"); return 0; } diff --git a/csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.h b/csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.h similarity index 79% rename from csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.h rename to csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.h index b69d474..e01ac9c 100644 --- a/csrc/ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.h +++ b/csrc/ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.h @@ -22,19 +22,13 @@ class AllGatherInterLLBuffer { static constexpr int32_t root_rank = 0; public: - AllGatherInterLLBuffer(int64_t max_bs, - int64_t msg_size, - torch::Dtype dtype, - int64_t world_size, - int64_t rank, - int64_t num_concurrency); AllGatherInterLLBuffer(int64_t max_bs, int64_t msg_size, torch::Dtype dtype, int64_t world_size, int64_t rank, int64_t num_concurrency, - bool allow_nvlink); + bool allow_nvlink = false); size_t getBufferSize(); diff --git a/csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll.cu b/csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll.cu similarity index 100% rename from csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll.cu rename to csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll.cu diff --git a/csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll.h b/csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll.h similarity index 100% rename from csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll.h rename to csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll.h diff --git a/csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll_buffer.cpp b/csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll_buffer.cpp similarity index 100% rename from csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll_buffer.cpp rename to csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll_buffer.cpp diff --git a/csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll_buffer.h b/csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll_buffer.h similarity index 100% rename from csrc/ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll_buffer.h rename to csrc/ops/intra/all_gather_intra_ll/all_gather_intra_ll_buffer.h diff --git a/csrc/pause.h b/csrc/pause.h new file mode 100644 index 0000000..fdf0c4e --- /dev/null +++ b/csrc/pause.h @@ -0,0 +1,34 @@ +#ifndef SRC_INCLUDE_PAUSE_H_ +#define SRC_INCLUDE_PAUSE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__x86_64__) +#include +#elif defined(__aarch64__) || defined(_M_ARM64) +#include +#else +static_assert(false, + "Unsupported architecture, please add the pause intrinsic for " + "the architecture."); +#endif + +static void inline machnet_pause() { +#if defined(__x86_64__) + _mm_pause(); +#elif defined(__aarch64__) || defined(_M_ARM64) + __asm__ volatile("yield" ::: "memory"); +#else + static_assert(false, + "Unsupported architecture, please add the pause intrinsic for " + "the architecture."); +#endif +} + +#ifdef __cplusplus +} +#endif + +#endif // SRC_INCLUDE_PAUSE_H_ diff --git a/csrc/python/CMakeLists.txt b/csrc/python/CMakeLists.txt index b6670d6..565d608 100644 --- a/csrc/python/CMakeLists.txt +++ b/csrc/python/CMakeLists.txt @@ -28,7 +28,10 @@ if (BUILD_NVSHMEM) list(APPEND _slime_c_link_libraries _slime_nvshmem) endif() -if (BUILD_INTRA_OPS OR BUILD_INTER_OPS) +if (BUILD_INTRA_OPS OR BUILD_INTER_OPS OR BUILD_IBVERBS_OPS) + if (BUILD_IBVERBS_OPS) + target_compile_definitions(_slime_c PRIVATE -DBUILD_IBVERBS_OPS) + endif() if (BUILD_INTRA_OPS) target_compile_definitions(_slime_c PRIVATE -DBUILD_INTRA_OPS) endif() diff --git a/csrc/python/bind.cpp b/csrc/python/bind.cpp index 1fbfe87..bd22ce7 100644 --- a/csrc/python/bind.cpp +++ b/csrc/python/bind.cpp @@ -31,15 +31,19 @@ #include "engine/rdma/utils.h" #endif -#if defined(BUILD_INTRA_OPS) || defined(BUILD_INTER_OPS) +#if defined(BUILD_INTRA_OPS) || defined(BUILD_INTER_OPS) || defined(BUILD_IBVERBS_OPS) #include #ifdef BUILD_INTRA_OPS -#include "ops/intra_ll/all_gather_intra_ll/all_gather_intra_ll_buffer.h" +#include "ops/intra/all_gather_intra_ll/all_gather_intra_ll_buffer.h" #endif #ifdef BUILD_INTER_OPS -#include "ops/inter_ll/all_gather_inter_ll/all_gather_inter_ll_buffer.h" +#include "ops/inter/all_gather_inter_ll/all_gather_inter_ll_buffer.h" +#endif + +#ifdef BUILD_IBVERBS_OPS +#include "ops/ibverbs/m2n_ibverbs_rc_ll/m2n_ibverbs_rc_ll_buffer.h" #endif #endif @@ -108,6 +112,12 @@ py::object alloc_dlpack_tensor(slime::NVShmemContext& self, size_t size, size_t #define BUILD_NVLINK_ENABLED false #endif +#ifdef BUILD_IBVERBS_OPS +#define BUILD_IBVERBS_OPS_ENABLED true +#else +#define BUILD_IBVERBS_OPS_ENABLED false +#endif + #ifdef BUILD_INTRA_OPS #define BUILD_INTRA_OPS_ENABLED true #else @@ -120,13 +130,14 @@ py::object alloc_dlpack_tensor(slime::NVShmemContext& self, size_t size, size_t #define BUILD_INTER_OPS_ENABLED false #endif -#define EXPOSE_BUILD_FLAG(m, flag) m.attr("_"#flag) = flag##_ENABLED +#define EXPOSE_BUILD_FLAG(m, flag) m.attr("_" #flag) = flag##_ENABLED PYBIND11_MODULE(_slime_c, m) { EXPOSE_BUILD_FLAG(m, BUILD_RDMA); EXPOSE_BUILD_FLAG(m, BUILD_NVSHMEM); EXPOSE_BUILD_FLAG(m, BUILD_NVLINK); + EXPOSE_BUILD_FLAG(m, BUILD_IBVERBS_OPS); EXPOSE_BUILD_FLAG(m, BUILD_INTRA_OPS); EXPOSE_BUILD_FLAG(m, BUILD_INTER_OPS); @@ -224,6 +235,26 @@ PYBIND11_MODULE(_slime_c, m) .def("read_batch", &slime::NVLinkContext::read_batch); #endif +#ifdef BUILD_IBVERBS_OPS + py::class_(m, "M2NIBVerbsRCLLBuffer") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("connect_full_mesh", &slime::M2NIBVerbsRCLLBuffer::connectFullMesh) + .def("buffer_info", &slime::M2NIBVerbsRCLLBuffer::bufferInfo) + .def("m2n_send", &slime::M2NIBVerbsRCLLBuffer::M2NSend) + .def("m2n_send", &slime::M2NIBVerbsRCLLBuffer::M2NRecv); +#endif + #ifdef BUILD_INTRA_OPS py::class_(m, "AllGatherIntraLLBuffer") .def(py::init()) diff --git a/dlslime/buffer/ibverbs/__init__.py b/dlslime/buffer/ibverbs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dlslime/buffer/ibverbs/m2n_ibverbs_rc_ll_buffer.py b/dlslime/buffer/ibverbs/m2n_ibverbs_rc_ll_buffer.py new file mode 100644 index 0000000..1d9d7d9 --- /dev/null +++ b/dlslime/buffer/ibverbs/m2n_ibverbs_rc_ll_buffer.py @@ -0,0 +1,50 @@ +from dlslime import _slime_c + + +class M2NIBVerbsRCLLBuffer: + def __init__( + self, + max_bs: int, + msg_size: int, + device: str, + role: int, + m_world_size: int, + n_world_size: int, + rank: int, + num_concurrency: int, + qp_num: int, + ): + self.max_bs = max_bs + self.msg_size = msg_size + self.device = device + self.role = role + self.m_world_size = m_world_size + self.n_world_size = n_world_size + self.rank = rank + self.num_concurrency = num_concurrency + self.qp_num = qp_num + + self._buffer = _slime_c.M2NIBVerbsRCLLBuffer( + max_bs, + msg_size, + device, + role, + m_world_size, + n_world_size, + rank, + num_concurrency, + qp_num, + ) + + @property + def buffer_info(self): + return self._buffer.buffer_info() + + def connect_full_mesh(self, peer_all_buffer_info): + return self._buffer.connect_full_mesh(peer_all_buffer_info) + + def send(self): + raise NotImplementedError + + def recv(self): + raise NotImplementedError diff --git a/dlslime/buffer/inter/all_gather_inter_ll_buffer.py b/dlslime/buffer/inter/all_gather_inter_ll_buffer.py index 27403bb..79d0a8f 100644 --- a/dlslime/buffer/inter/all_gather_inter_ll_buffer.py +++ b/dlslime/buffer/inter/all_gather_inter_ll_buffer.py @@ -27,9 +27,9 @@ def __init__( self.allow_nvlink = allow_nvlink - setup_nvshmem_env(qp_num=qp_num) + setup_nvshmem_env(qp_num=qp_num, allow_nvlink=allow_nvlink) - self._buffer = self.buffer = _slime_c.AllGatherInterLLBuffer( + self._buffer = _slime_c.AllGatherInterLLBuffer( self.bs, self.msg_size, self.dtype, diff --git a/dlslime/buffer/inter/init_nvshmem.py b/dlslime/buffer/inter/init_nvshmem.py index c93dcdf..beed143 100644 --- a/dlslime/buffer/inter/init_nvshmem.py +++ b/dlslime/buffer/inter/init_nvshmem.py @@ -1,10 +1,10 @@ import os -def setup_nvshmem_env(qp_num: int = 8): +def setup_nvshmem_env(qp_num: int = 8, allow_nvlink: bool = False): # NVSHMEM ENVS # Adapted from https://github.com/Deepseek-ai/DeepEP.git - os.environ["NVSHMEM_DISABLE_P2P"] = "0" + os.environ["NVSHMEM_DISABLE_P2P"] = "0" if allow_nvlink else "1" os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "1" os.environ["NVSHMEM_IBGDA_NUM_RC_PER_PE"] = str(qp_num) # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check diff --git a/docs/roadmap.md b/docs/roadmap.md index 450b998..bb92aaf 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -6,7 +6,7 @@ DLSlime is dedicated to supporting efficient transmission over a variety of diff ### Transfer Engine -DLSlime provides a flexible and efficient P2P Transfer Engine, enabling AI-workload-aware customized functions such as Prefill-Decode separation and checkpoint transmission. +DLSlime provides a flexible and efficient P2P Transfer Engine, enabling AI-workload-aware customized functions such as Prefill-Decode disaggregation and checkpoint transmission. ### Collective Ops @@ -16,7 +16,7 @@ Referring to [DeepEP](https://github.com/deeplink-org/DeepEP.git), DLSlime provi To meet the heterogeneous requirements of SPMD programs such as heterogeneous pipeline parallel training, a Torch communication backend is provided. -## Transfer Engine Roadmap +## Transfer Engine - IBVerbs Transfer Engine - ✅ SendRecv Endpoint @@ -30,10 +30,12 @@ To meet the heterogeneous requirements of SPMD programs such as heterogeneous pi - CUDA IPC - ✅ support CUDAIPC Read/Write Endpoint - PCIE - - ⏳ High performance Shared Memory transfer engine - - ⏳ High performance data offloading + - ⏳ Shared Memory transfer engine + - ⏳ data offloading - Ascend - ✅ Ascned direct transfer engine +- OpenShmem + - 💭 Planning - NVME-oF - 💭 Planning - UB Mesh @@ -55,6 +57,9 @@ To meet the heterogeneous requirements of SPMD programs such as heterogeneous pi - CUDA IPC - ✅ AllGather - ⚡ High performance AllGather using CUDA Multi-Mem + - ⏳ AllGather + - ⏳ AllReduce + - ⏳ All2All ## Torch Wrapper diff --git a/example/python/m2n_ibverbs_rc_ll.py b/example/python/m2n_ibverbs_rc_ll.py new file mode 100644 index 0000000..dd281cb --- /dev/null +++ b/example/python/m2n_ibverbs_rc_ll.py @@ -0,0 +1,79 @@ +import argparse +import os + +import torch +import torch.distributed as dist + +from dlslime.buffer.ibverbs.m2n_ibverbs_rc_ll_buffer import M2NIBVerbsRCLLBuffer + + +# distributed config +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +local_rank = int(os.environ["LOCAL_RANK"]) +local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) +if rank == 0: + print(f"{rank=}, {world_size=}, {local_rank=}, {local_world_size=}") + +torch.cuda.set_device(local_rank) + +parser = argparse.ArgumentParser() +parser.add_argument("--m-size", type=int, default=None) +parser.add_argument("--n-size", type=int, default=None) +args = parser.parse_args() + +# Peer Size +m_size = args.m_size or world_size // 2 +n_size = args.n_size or world_size // 2 + +assert ( + m_size + n_size == world_size +), f"m_size({m_size}) + n_size({n_size}) != world_size({world_size})" + +max_bs = 128 +msg_size = 2048 +device = f"cuda:{rank}" +role = 0 if rank < m_size else 1 +num_concurrency = 1 +qp_nums_per_rank = 2 + +num_topk = 8 + +m2n_rank = rank if role == 0 else rank - m_size + + +def feed_buffer(): + raise NotImplementedError + + +if __name__ == "__main__": + dist.init_process_group("cuda:nccl,cpu:gloo") + + num_topk = 2 + x = torch.empty([max_bs, msg_size], device=device) + top_k_idx = torch.rand(16, 8, device="cuda").argsort(dim=1)[:, :num_topk] + + # initialize buffer + buffer = M2NIBVerbsRCLLBuffer( + max_bs, + msg_size, + device, + role, + m_size, + n_size, + m2n_rank, + num_concurrency, + qp_nums_per_rank, + ) + buffer_info = buffer.buffer_info + + all_buffer_info = [None for _ in range(world_size)] + dist.all_gather_object(all_buffer_info, buffer_info) + peer_buffer_info = ( + all_buffer_info[:m_size] if role == 1 else all_buffer_info[m_size:] + ) + + buffer.connect_full_mesh(peer_buffer_info) + + dist.barrier() + dist.destroy_process_group()