diff --git a/csrc/nv_internal/cpp/common/envUtils.cpp b/csrc/nv_internal/cpp/common/envUtils.cpp index e2ee31261c..2f60d0778b 100644 --- a/csrc/nv_internal/cpp/common/envUtils.cpp +++ b/csrc/nv_internal/cpp/common/envUtils.cpp @@ -222,11 +222,6 @@ bool getEnvDisaggLayerwise() { return disaggLayerwise; } -bool getEnvParallelCacheSend() { - static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); - return parallelCacheSend; -} - bool getEnvRequestKVCacheConcurrent() { static bool const requestKVCacheConcurrent = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT"); return requestKVCacheConcurrent; @@ -277,7 +272,7 @@ size_t getEnvAllReduceWorkspaceSize() { return workspaceSize; } -std::string getEnvKVCacheTransferOutputPath() { +std::string const& getEnvKVCacheTimeOutputPath() { static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or(""); return outputPath; } @@ -328,4 +323,37 @@ uint16_t getEnvNixlPort() { bool getEnvDisaggBenchmarkGenOnly() { return getBoolEnv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"); } +bool getEnvMoeA2AOneBlockPerToken() { + // Default true; return false only if env set to "0" + static std::optional const val = getIntEnv("TLLM_MOE_A2A_ONE_BLOCK_PER_TOKEN"); + if (!val.has_value()) { + return true; + } + return val.value() != 0; +} + +static int sanitizeBlockSize(std::optional const& val) { + // Default 256 when not set or invalid + int block = val.value_or(256); + // Clamp to sane CUDA bounds and warp multiples + if (block <= 0) block = 256; + if (block > 1024) block = 1024; + // Round to nearest multiple of 32 (warp size) + block = (block + 31) / 32 * 32; + if (block == 0) block = 256; + return block; +} + +int getEnvMoeA2ADispatchBlockSize() { + static int const kBlock = sanitizeBlockSize(getIntEnv("TLLM_MOE_A2A_DISPATCH_BLOCK_SIZE")); + return kBlock; +} + +int getEnvMoeA2ACombineBlockSize() { + static int const kBlock = sanitizeBlockSize(getIntEnv("TLLM_MOE_A2A_COMBINE_BLOCK_SIZE")); + return kBlock; +} + +bool getEnvEplbForceGdrcopy() { return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY"); } + } // namespace tensorrt_llm::common diff --git a/csrc/nv_internal/tensorrt_llm/common/envUtils.h b/csrc/nv_internal/tensorrt_llm/common/envUtils.h index 887162e786..cdbdd8c414 100644 --- a/csrc/nv_internal/tensorrt_llm/common/envUtils.h +++ b/csrc/nv_internal/tensorrt_llm/common/envUtils.h @@ -64,7 +64,7 @@ bool getEnvDisableKVCacheTransferOverlap(); bool getEnvEnableReceiveKVCacheParallel(); -std::string getEnvKVCacheTransferOutputPath(); +std::string const& getEnvKVCacheTimeOutputPath(); bool getEnvTryZCopyForKVCacheTransfer(); @@ -92,4 +92,13 @@ size_t getEnvKVCacheSendMaxConcurrenceNum(); size_t getEnvMemSizeForKVCacheTransferBuffer(); +// Whether to use one block per token for MoE A2A kernels (default true). +bool getEnvMoeA2AOneBlockPerToken(); + +// TODO: For DEV purpose temporarily. +// Block size (threads per block) for MoE A2A Dispatch kernels (default 256 if unset or invalid) +int getEnvMoeA2ADispatchBlockSize(); +// Block size (threads per block) for MoE A2A Combine kernels (default 256 if unset or invalid) +int getEnvMoeA2ACombineBlockSize(); + } // namespace tensorrt_llm::common diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu new file mode 100644 index 0000000000..e46c0e9e63 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -0,0 +1,847 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include "flashinfer/exception.h" +#include "flashinfer/utils.cuh" +#include "flashinfer/vec_dtypes.cuh" +#include "tensorrt_llm/common/dataType.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h" + +namespace tensorrt_llm::kernels::mnnvl_throughput { + +#define ENABLE_DEBUG_PRINT 0 +#define DISABLE_SYNC_FOR_PROFILING 0 + +// Helper function for ceiling division +template +__host__ __device__ inline T ceilDiv(T m, T n) { + return (m + n - 1) / n; +} + +// Macros for concise launch-time specialization +#define SWITCH_BOOL(flag, NAME, ...) \ + if (flag) { \ + constexpr bool NAME = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool NAME = false; \ + __VA_ARGS__ \ + } + +#define SWITCH_TOP_K(top_k, TOP_K, ...) \ + switch (top_k) { \ + case 8: { \ + constexpr int TOP_K = 8; \ + __VA_ARGS__; \ + break; \ + } \ + case 4: { \ + constexpr int TOP_K = 4; \ + __VA_ARGS__; \ + break; \ + } \ + case 2: { \ + constexpr int TOP_K = 2; \ + __VA_ARGS__; \ + break; \ + } \ + case 1: { \ + constexpr int TOP_K = 1; \ + __VA_ARGS__; \ + break; \ + } \ + default: { \ + FLASHINFER_CHECK(false, "Unsupported top_k"); \ + } \ + } + +#define SWITCH_DTYPE(dtype, TYPE, ...) \ + switch (dtype) { \ + case nvinfer1::DataType::kHALF: { \ + using TYPE = half; \ + __VA_ARGS__; \ + break; \ + } \ + case nvinfer1::DataType::kBF16: { \ + using TYPE = __nv_bfloat16; \ + __VA_ARGS__; \ + break; \ + } \ + case nvinfer1::DataType::kFLOAT: { \ + using TYPE = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: { \ + FLASHINFER_CHECK(false, "Unsupported dtype for moe_a2a_combine"); \ + } \ + } + +#define SWITCH_POLICY(one_block_per_token, POLICY, ...) \ + if (one_block_per_token) { \ + using POLICY = BlockPolicy; \ + __VA_ARGS__ \ + } else { \ + using POLICY = WarpPolicy; \ + __VA_ARGS__ \ + } + +// ============================================================================ +// Helper Functions for Expert-to-Rank Mapping +// ============================================================================ + +__device__ int compute_target_rank_id(int expert_id, int num_experts_per_rank) { + // Compute which rank owns a given expert using contiguous partitioning + // Experts are divided evenly across EP ranks: + // - Rank 0 gets experts [0, num_experts_per_rank) + // - Rank 1 gets experts [num_experts_per_rank, 2*num_experts_per_rank) + // - etc. + // Example: 32 experts, 4 ranks -> 8 experts per rank + // - Rank 0: experts 0-7 + // - Rank 1: experts 8-15 + // - Rank 2: experts 16-23 + // - Rank 3: experts 24-31 + return expert_id / num_experts_per_rank; +} + +// ============================================================================ +// Helper Functions for Vectorized Memory Operations +// ============================================================================ + +struct WarpPolicy { + __device__ static int stride() { return warpSize; } + + __device__ static int offset() { return (threadIdx.x % warpSize); } + + __device__ static int token_idx() { return (blockIdx.x * blockDim.x + threadIdx.x) / warpSize; } + + __device__ static void sync() { __syncwarp(); } +}; + +struct BlockPolicy { + __device__ static int stride() { return blockDim.x; } + + __device__ static int offset() { return threadIdx.x; } + + __device__ static int token_idx() { return blockIdx.x; } + + __device__ static void sync() { __syncthreads(); } +}; + +template +__device__ void vectorized_copy_impl(void* dst, void const* src, int size) { + using flashinfer::vec_t; + + uint8_t* dst_ptr = static_cast(dst); + uint8_t const* src_ptr = static_cast(src); + + int const stride = ThreadingPolicy::stride() * VEC_SIZE; + + for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < size; offset += stride) { + vec_t v; + v.load(src_ptr + offset); + v.store(dst_ptr + offset); + } +} + +template +__device__ void vectorized_copy(void* dst, void const* src, int size) { + if (size % 16 == 0) { + vectorized_copy_impl<16, ThreadingPolicy>(dst, src, size); + } else if (size % 8 == 0) { + vectorized_copy_impl<8, ThreadingPolicy>(dst, src, size); + } else if (size % 4 == 0) { + vectorized_copy_impl<4, ThreadingPolicy>(dst, src, size); + } else if (size % 2 == 0) { + vectorized_copy_impl<2, ThreadingPolicy>(dst, src, size); + } else { + vectorized_copy_impl<1, ThreadingPolicy>(dst, src, size); + } +} + +// Vectorized dispatch: load one vec from source and write to up to TOP_K destinations +template +__device__ void vectorized_dispatch_impl(uint8_t const* src_ptr, int bytes_per_token, int rank_id, + int max_tokens_per_rank, int payload_idx, + DispatchKernelPointers const& ptrs, + int const* topk_target_ranks, + int const* topk_send_indices) { + using flashinfer::vec_t; + + // Precompute destination base pointers per k + uint8_t* dst_base_k[TOP_K]; +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + int dst_idx_k = topk_send_indices[k]; + int target_rank_k = topk_target_ranks[k]; + if (dst_idx_k < 0) { + dst_base_k[k] = nullptr; + continue; + } + uint8_t* dst_data = static_cast(ptrs.recv_buffers[target_rank_k][payload_idx]); + size_t base_source_rank = + static_cast(rank_id) * static_cast(max_tokens_per_rank) + + static_cast(dst_idx_k); + size_t base_token = base_source_rank * static_cast(bytes_per_token); + dst_base_k[k] = dst_data + base_token; + } + + // TODO: process all payloads. index could be reused. + int const stride = ThreadingPolicy::stride() * VEC_SIZE; + for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < bytes_per_token; + offset += stride) { + vec_t v; + v.load(src_ptr + offset); + +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + uint8_t* dst_base = dst_base_k[k]; + if (dst_base == nullptr) { + continue; + } + v.store(dst_base + offset); + } + } +} + +template +__device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token, int rank_id, + int max_tokens_per_rank, int payload_idx, + DispatchKernelPointers const& ptrs, + int const* topk_target_ranks, int const* topk_send_indices) { + if (bytes_per_token % 16 == 0) { + vectorized_dispatch_impl<16, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else if (bytes_per_token % 8 == 0) { + vectorized_dispatch_impl<8, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else if (bytes_per_token % 4 == 0) { + vectorized_dispatch_impl<4, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else if (bytes_per_token % 2 == 0) { + vectorized_dispatch_impl<2, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else { + vectorized_dispatch_impl<1, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } +} + +__global__ void moeA2APrepareDispatchKernel(int* send_counters, int* local_token_counter, + int ep_size, uint32_t* flag_val_ptr) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Zero send_counters + if (idx < ep_size) { + send_counters[idx] = 0; + } + // Zero local_token_counter and increment flag_val + if (idx == 0) { + *local_token_counter = 0; + // Increment flag_val for this dispatch round + *flag_val_ptr = *flag_val_ptr + 1; + } +} + +// ============================================================================ +// Generic Dispatch Kernel Implementation +// One warp per token design: +// - Each CTA has 256 threads = 8 warps +// - Each warp independently processes one token and all its payloads +// - Better GPU utilization and reduced synchronization overhead +// ============================================================================ + +template +__global__ void moeA2ADispatchKernel( + int32_t const* token_selected_experts, // [local_num_tokens, TOP_K] + const DispatchKernelPointers ptrs, // Struct containing all kernel pointers + int num_payloads, // Number of payloads + int max_tokens_per_rank, // Maximum tokens per rank + int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank) { + int thread_idx = ThreadingPolicy::offset(); + int local_token_idx = ThreadingPolicy::token_idx(); + + if (local_token_idx >= local_num_tokens) { + return; + } + + // Prepare per-policy shared-memory tiles for this token + extern __shared__ int smem[]; + int* smem_topk_target_ranks; + int* smem_topk_send_indices; + int warps_per_block = blockDim.x / warpSize; + if constexpr (std::is_same::value) { + int lane_id = threadIdx.x / warpSize; + smem_topk_target_ranks = smem + lane_id * TOP_K; + smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; + } else { + smem_topk_target_ranks = smem; + smem_topk_send_indices = smem + TOP_K; + } + + uint64_t already_copied = 0; + for (int k = 0; k < TOP_K; k++) { + int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; + // Use contiguous partitioning to determine target rank + int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); + + if (already_copied & (1ULL << target_rank)) { + if (thread_idx == 0) { + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + // Mirror to shared memory immediately + smem_topk_target_ranks[k] = -1; + smem_topk_send_indices[k] = -1; + } + continue; + } + + // Only one thread per warp should increment the counter + int dst_token_idx; + if (thread_idx == 0) { + dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); + + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; + // Mirror to shared memory immediately + smem_topk_target_ranks[k] = target_rank; + smem_topk_send_indices[k] = dst_token_idx; + } + already_copied |= 1ULL << target_rank; + } + // Sync before dispatching data + ThreadingPolicy::sync(); + + // Read staged routing once into registers per thread + int topk_target_ranks[TOP_K]; + int topk_send_indices[TOP_K]; +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + topk_target_ranks[k] = smem_topk_target_ranks[k]; + topk_send_indices[k] = smem_topk_send_indices[k]; + } + + // Perform a single source load and TOP_K fanout per payload + for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { + uint8_t const* src_data = static_cast(ptrs.src_data_ptrs[payload_idx]); + int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; + uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; + + vectorized_dispatch(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } + + ThreadingPolicy::sync(); + + bool is_first_warp = threadIdx.x / warpSize == 0; + if (is_first_warp) { + int lane_id = threadIdx.x % warpSize; + + bool is_last_token = false; + if (lane_id == 0) { + int cnt = atomicAdd(ptrs.local_token_counter, 1); + is_last_token = cnt + 1 == local_num_tokens; + } + is_last_token = __shfl_sync(0xffffffff, is_last_token, 0); + + if (is_last_token) { +// Store send_counters to recv_counters +#pragma unroll 1 // No unroll as one iter is typically enough + for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { + int send_count = ptrs.send_counters[target_rank]; + ptrs.recv_counters[target_rank][rank_id] = send_count; + } + +#if !DISABLE_SYNC_FOR_PROFILING + uint32_t expected_value = *ptrs.flag_val; + + asm volatile("fence.release.sys;"); +#pragma unroll 1 // No unroll as one iter is typically enough + for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { + uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id]; + asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); + +#if ENABLE_DEBUG_PRINT + printf("dispatch: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, + expected_value, target_rank); +#endif + } + +#pragma unroll 1 // No unroll + for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + bool flag_set = false; + do { + uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; + uint32_t flag_value; + // Acquire load to ensure visibility of peer's release-store + asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); +#if ENABLE_DEBUG_PRINT + printf( + "dispatch: ---Rank %d received completion flag from rank %d, flag_value: %d, " + "expected_value: " + "%d, address: %p\n", + rank_id, peer_rank, flag_value, expected_value, flag_ptr); +#endif + flag_set = flag_value == expected_value; + } while (!flag_set); + } + // asm volatile("fence.acquire.sys;"); +#endif + } + } +} + +void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params) { + moeA2APrepareDispatchKernel<<<1, params.ep_size, 0, params.stream>>>( + params.send_counters, params.local_token_counter, params.ep_size, params.flag_val); +} + +// ============================================================================ +// Launch Functions +// ============================================================================ + +void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) { + // Validate parameters + TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); + TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); + TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads); + + // Prepare kernel pointers struct + DispatchKernelPointers kernel_ptrs = {}; + + // Fill source data pointers and payload sizes + for (int i = 0; i < params.num_payloads; i++) { + kernel_ptrs.src_data_ptrs[i] = params.payloads[i].src_data; + kernel_ptrs.payload_bytes_per_token[i] = + params.payloads[i].element_size * params.payloads[i].elements_per_token; + } + + // Fill receive buffer pointers + for (int target_rank = 0; target_rank < params.ep_size; target_rank++) { + kernel_ptrs.recv_counters[target_rank] = params.recv_counters[target_rank]; + for (int payload = 0; payload < params.num_payloads; payload++) { + kernel_ptrs.recv_buffers[target_rank][payload] = params.recv_buffers[target_rank][payload]; + } + } + + // Copy completion flag pointers + for (int i = 0; i < params.ep_size; i++) { + kernel_ptrs.completion_flags[i] = params.completion_flags[i]; + } + kernel_ptrs.flag_val = params.flag_val; + + // Copy communication tracking pointers + kernel_ptrs.send_counters = params.send_counters; + kernel_ptrs.local_token_counter = params.local_token_counter; + kernel_ptrs.topk_target_ranks = params.topk_target_ranks; + kernel_ptrs.topk_send_indices = params.topk_send_indices; + + int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize(); + constexpr int kWarpSize = 32; + int const kWarpsPerBlock = kBlockSize / kWarpSize; + + // Configure kernel launch + if (params.one_block_per_token) { + int grid_size = params.local_num_tokens; + int shared_bytes = 2 * params.top_k * (int)sizeof(int); + SWITCH_TOP_K(params.top_k, TOP_K, + moeA2ADispatchKernel + <<>>( + params.token_selected_experts, kernel_ptrs, params.num_payloads, + params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, + params.ep_size, params.num_experts_per_rank)) + } else { + int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int)sizeof(int); + SWITCH_TOP_K(params.top_k, TOP_K, + moeA2ADispatchKernel + <<>>( + params.token_selected_experts, kernel_ptrs, params.num_payloads, + params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, + params.ep_size, params.num_experts_per_rank)) + } +} + +// ============================================================================ +// Combine kernels +// ============================================================================ + +// Accumulate across all valid ranks into registers, then store once per segment +template +__device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, int rank_id, + int max_tokens_per_rank, + CombineKernelPointers const& ptrs) { + constexpr int elems_per_vec = VEC_SIZE / sizeof(T); + using flashinfer::vec_t; + + uint8_t* dst_bytes = reinterpret_cast(dst_typed_base); + + int const stride = ThreadingPolicy::stride() * VEC_SIZE; + int const local_token_idx = ThreadingPolicy::token_idx(); + + for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < size_per_token; + offset += stride) { + vec_t acc[TOP_K]; + +// Unrolled K accumulation using compact top-k lists +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k]; + int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k]; + if (dst_idx < 0) { + acc[k].fill(0); + continue; + } + + uint8_t const* recv_buffer = static_cast(ptrs.recv_buffers[target_rank][0]); + size_t base_source_rank = + static_cast(rank_id) * static_cast(max_tokens_per_rank) + + static_cast(dst_idx); + size_t base_token = base_source_rank * static_cast(size_per_token); + + // Load directly into the per-k accumulator; reduce across k below + acc[k].load(recv_buffer + base_token + offset); + } + + // Reduce acc[TOP_K] into acc[0] + if constexpr (TOP_K == 8) { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); + T* a2 = reinterpret_cast(&acc[2]); + T* a3 = reinterpret_cast(&acc[3]); + T* a4 = reinterpret_cast(&acc[4]); + T* a5 = reinterpret_cast(&acc[5]); + T* a6 = reinterpret_cast(&acc[6]); + T* a7 = reinterpret_cast(&acc[7]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a1[j]; + a2[j] += a3[j]; + a4[j] += a5[j]; + a6[j] += a7[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a2[j]; + a4[j] += a6[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a4[j]; + } + } else if constexpr (TOP_K == 4) { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); + T* a2 = reinterpret_cast(&acc[2]); + T* a3 = reinterpret_cast(&acc[3]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a1[j]; + a2[j] += a3[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a2[j]; + } + } else if constexpr (TOP_K == 2) { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a1[j]; + } + } else if constexpr (TOP_K == 1) { + // nothing to do + } else { + // Generic fallback: accumulate all into acc[0] + T* a0 = reinterpret_cast(&acc[0]); +#pragma unroll + for (int k = 1; k < TOP_K; ++k) { + T* ak = reinterpret_cast(&acc[k]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += ak[j]; + } + } + } + + acc[0].store(dst_bytes + offset); + } +} + +// Wrapper that selects vector width based on size_per_token alignment +template +__device__ void vectorized_combine(T* dst_typed_base, int size_per_token, int rank_id, + int max_tokens_per_rank, CombineKernelPointers const& ptrs) { + if (size_per_token % 16 == 0) { + vectorized_combine_impl<16, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else if (size_per_token % 8 == 0) { + vectorized_combine_impl<8, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else if (size_per_token % 4 == 0) { + vectorized_combine_impl<4, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else if (size_per_token % 2 == 0) { + vectorized_combine_impl<2, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else { + vectorized_combine_impl<1, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } +} + +// Copy payload to recv buffer using vectorized copy; supports warp/block token mapping +template +__global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t const* payload_bytes, + int bytes_per_token, int ep_size, + int max_tokens_per_rank, uint32_t* flag_val_ptr, + int const* recv_counters) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + // Increment flag_val for this combine round + *flag_val_ptr = *flag_val_ptr + 1; + } + + if (payload_bytes == nullptr) return; + + int slot_idx = ThreadingPolicy::token_idx(); + + int total_slots = ep_size * max_tokens_per_rank; + if (slot_idx >= total_slots) return; + + // Map global token to (source_rank, token_idx) + int source_rank = slot_idx / max_tokens_per_rank; + int token_idx = slot_idx % max_tokens_per_rank; + + // Skip invalid tokens beyond per-source recv count + if (token_idx >= recv_counters[source_rank]) return; + + // Calculate source and destination pointers for this token + size_t slot_offset = static_cast(slot_idx) * bytes_per_token; + uint8_t* dst_ptr = recv_buffer_bytes + slot_offset; + uint8_t const* src_ptr = payload_bytes + slot_offset; + + // Copy one token's data using vectorized copy with policy + vectorized_copy(dst_ptr, src_ptr, bytes_per_token); +} + +// ============================================================================ +// Generic Combine Kernel Implementation (Templated by data type) +// ============================================================================ + +template +__global__ void moeA2ACombineKernel( + const CombineKernelPointers ptrs, // Combine-specific struct, src_data_ptrs[0] is output + int max_tokens_per_rank, int elements_per_token, int local_num_tokens, int rank_id, + int ep_size) { + int local_token_idx = ThreadingPolicy::token_idx(); + int const size_per_token = elements_per_token * sizeof(T); + + if (local_token_idx >= local_num_tokens) { + return; + } + +#if !DISABLE_SYNC_FOR_PROFILING + // In-kernel readiness synchronization at start of combine: + // - One warp signals readiness to all peers with current flag_val. + // - The first warp of each block waits for all peers' readiness (equality), then __syncthreads. + bool is_first_warp = threadIdx.x / warpSize == 0; + if (is_first_warp) { + int lane_id = threadIdx.x % warpSize; + uint32_t expected_value = *ptrs.flag_val; + + if (blockIdx.x == 0) { + // asm volatile("fence.release.sys;"); +#pragma unroll 1 // No unroll + for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + uint32_t* flag_addr = &ptrs.completion_flags[peer_rank][rank_id]; + asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); +#if ENABLE_DEBUG_PRINT + printf("combine: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, + expected_value, peer_rank); +#endif + } + } + +#pragma unroll 1 // No unroll + for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + bool flag_set = false; + do { + uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; + uint32_t flag_value; + // Acquire load to ensure visibility of peer's release-store + asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); +#if ENABLE_DEBUG_PRINT + printf( + "combine: ---Rank %d received completion flag from rank %d, flag_value: %d, " + "expected_value: %d, " + "address: %p\n", + rank_id, peer_rank, flag_value, expected_value, flag_ptr); +#endif + flag_set = flag_value == expected_value; + } while (!flag_set); + } + asm volatile("fence.acquire.sys;"); + } + __syncthreads(); +#endif + + // Get output location for this token (using src_data_ptrs[0] as output) + T* token_output = static_cast(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token; + + // Accumulate across ranks in registers, then store once per segment + vectorized_combine(token_output, size_per_token, rank_id, + max_tokens_per_rank, ptrs); +} + +void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; // 8 warps per block + + // Calculate bytes per token based on dtype + int element_size; + switch (params.dtype) { + case nvinfer1::DataType::kHALF: + element_size = sizeof(half); + break; + case nvinfer1::DataType::kBF16: + element_size = sizeof(__nv_bfloat16); + break; + case nvinfer1::DataType::kFLOAT: + element_size = sizeof(float); + break; + default: + FLASHINFER_CHECK(false, "Unsupported dtype for combine prepare"); + return; + } + + int bytes_per_token = params.elements_per_token * element_size; + int total_slots = + params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank; + int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock); + int grid_size_block = total_slots; // one block per token + + if (params.one_block_per_token) { + moeA2APrepareCombineKernel<<>>( + static_cast(const_cast(params.recv_buffers[params.ep_rank])), + static_cast(params.prepare_payload), bytes_per_token, params.ep_size, + params.max_tokens_per_rank, params.flag_val, params.recv_counters); + } else { + moeA2APrepareCombineKernel<<>>( + static_cast(const_cast(params.recv_buffers[params.ep_rank])), + static_cast(params.prepare_payload), bytes_per_token, params.ep_size, + params.max_tokens_per_rank, params.flag_val, params.recv_counters); + } +} + +// ============================================================================ +// Combine Launch Function +// ============================================================================ + +void moe_a2a_combine_launch(MoeA2ACombineParams const& params) { + // Validate parameters + TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); + TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); + TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.elements_per_token > 0); + + // Configure kernel launch + int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize(); + int const kWarpsPerBlock = kBlockSize / 32; // warpSize + int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + int grid_size_block = params.local_num_tokens; + + // Prepare kernel pointers struct for combine + CombineKernelPointers kernel_ptrs = {}; // Zero-initialize + + // Set output data pointer in src_data_ptrs[0] + kernel_ptrs.src_data_ptrs[0] = params.output_data; + + // Fill recv buffer pointers + for (int rank = 0; rank < params.ep_size; rank++) { + kernel_ptrs.recv_buffers[rank][0] = params.recv_buffers[rank]; + } + + // Copy completion flag pointers + for (int i = 0; i < params.ep_size; i++) { + kernel_ptrs.completion_flags[i] = params.completion_flags[i]; + } + kernel_ptrs.flag_val = params.flag_val; + + // Copy communication tracking pointers + kernel_ptrs.topk_target_ranks = params.topk_target_ranks; + kernel_ptrs.topk_send_indices = params.topk_send_indices; + + // Launch appropriate kernel with compact macros + SWITCH_DTYPE(params.dtype, TKernelType, { + SWITCH_POLICY(params.one_block_per_token, Policy, { + SWITCH_TOP_K(params.top_k, TOP_K, { + auto launch = [&](int grid_blocks, int block_threads) { + moeA2ACombineKernel + <<>>( + kernel_ptrs, params.max_tokens_per_rank, params.elements_per_token, + params.local_num_tokens, params.ep_rank, params.ep_size); + }; + int grid = params.one_block_per_token ? grid_size_block : grid_size_warp; + int cta = kBlockSize; + launch(grid, cta); + }); + }); + }); +} + +// Kernel to sanitize expert ids for invalid tokens +__global__ void moeA2ASanitizeExpertIdsKernel(int32_t* expert_ids_ptr, + int32_t const* recv_counters_ptr, int ep_size, + int max_tokens_per_rank, int top_k, + int32_t invalid_id) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total_tokens = ep_size * max_tokens_per_rank; + if (tid >= total_tokens) return; + + int source_rank = tid / max_tokens_per_rank; + int token_idx = tid % max_tokens_per_rank; + + if (token_idx >= recv_counters_ptr[source_rank]) { + int32_t* token_expert_ids = expert_ids_ptr + tid * top_k; + for (int k = 0; k < top_k; ++k) { + token_expert_ids[k] = invalid_id; + } + } +} + +void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, + int32_t invalid_id, int ep_size, int max_tokens_per_rank, + int top_k, cudaStream_t stream) { + constexpr int kBlockSize = 256; + int total_tokens = ep_size * max_tokens_per_rank; + int grid = ceilDiv(total_tokens, kBlockSize); + moeA2ASanitizeExpertIdsKernel<<>>( + expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id); +} + +} // namespace tensorrt_llm::kernels::mnnvl_throughput diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h new file mode 100644 index 0000000000..0e8dfd9b7c --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace tensorrt_llm::kernels::mnnvl_throughput { + +// Configuration constants +static constexpr int kMaxExperts = 256; // Maximum number of experts per rank +static constexpr int kMaxTopK = 8; // Maximum top-k experts per token +static constexpr int kMaxPayloads = 8; // Maximum number of different payload types +static constexpr int kMaxRanks = 64; // Maximum supported EP size + +// Describes a single payload type to be communicated +struct PayloadDescriptor { + void const* src_data; // Source data pointer [local_num_tokens, elements_per_token] + int element_size; // Size of each element in bytes + int elements_per_token; // Number of elements per token (e.g., hidden_size, top_k) +}; + +// Kernel pointers packed into a struct for device access +// Dispatch kernel pointers - const source data +struct DispatchKernelPointers { + // Payload pointers + void const* src_data_ptrs[kMaxPayloads]; // Array of source data pointers + void* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers + int payload_bytes_per_token[kMaxPayloads]; // Bytes per token for each payload + + // Completion flags for synchronization + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + + // Local aux data pointers + int* send_counters; // [ep_size] How many tokens have been sent to each target rank + int* recv_counters[kMaxRanks]; // How many tokens have been received from each source rank. Each + // rank has [ep_size] counters + int* local_token_counter; // Atomic counter for completed tokens + + // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) + int* topk_target_ranks; // target rank per k, -1 for duplicates + int* topk_send_indices; // dst index per k, -1 for duplicates +}; + +// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers +struct CombineKernelPointers { + // Payload pointers + void* src_data_ptrs[kMaxPayloads]; // src_data_ptrs[0] is output + void const* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers (const) + + // Completion flags for synchronization + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + + // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) + int const* topk_target_ranks; // target rank per k, -1 for duplicates + int const* topk_send_indices; // dst index per k, -1 for duplicates +}; + +// Dispatch phase parameters +struct MoeA2ADispatchParams { + bool one_block_per_token; // True: one block per token, False: one warp per token + + // Threading policy + // EP configuration + int ep_size; // Number of EP ranks + int ep_rank; // Current EP rank + int num_experts_per_rank; // Number of experts per rank (num_experts / ep_size) + + // Token configuration + int local_num_tokens; // Number of tokens on this rank + int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to + // runtime_max_tokens_per_rank + int top_k; // Number of experts per token + + // Expert routing information + int32_t const* token_selected_experts; // [local_num_tokens, top_k] + + // Generic payloads + int num_payloads; // Number of different payload types + PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors + + // Local aux data + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* local_token_counter; // Atomic counter for completed tokens on this rank + int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank + int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), target rank per k, -1 for duplicates + int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), dst index per k, -1 for duplicates + + // Distributed aux data and recv buffers + int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has + // [ep_size] counters + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload + + // CUDA stream + cudaStream_t stream; +}; + +// Dispatch kernels +void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params); +// Prepare for dispatch: zero send_counters, local_token_counter and increment flag_val +void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params); + +// Combine phase parameters +struct MoeA2ACombineParams { + bool one_block_per_token; // True: one block per token, False: one warp per token + + // EP configuration + int ep_size; // Number of EP ranks + int ep_rank; // Current EP rank + + // Token configuration + int local_num_tokens; // Number of tokens on this rank + int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to + // runtime_max_tokens_per_rank + int top_k; // Number of experts per token + + // Prepare-only field: original payload tensor pointer used to stage into workspace + void const* prepare_payload; + + // Output tensor + void* output_data; // Output buffer [local_num_tokens, elements_per_token] + // Payload information + int elements_per_token; // Number of elements per token + nvinfer1::DataType dtype; // Data type for proper summation + + // Local aux data + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), target rank per k, -1 for duplicates + int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), dst index per k, -1 for duplicates + int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target + + // Distributed aux data and recv buffers + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) + + // CUDA stream + cudaStream_t stream; +}; + +// Combine kernels +void moe_a2a_combine_launch(MoeA2ACombineParams const& params); + +void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params); + +// Sanitize expert IDs for invalid tokens +// expert_ids: [ep_size, max_tokens_per_rank, top_k] (int32) +// recv_counters: [ep_size] (int32), number of valid tokens per source +// invalid_id: value to fill for invalid tokens' expert ids +void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, + int32_t invalid_id, int ep_size, int max_tokens_per_rank, + int top_k, cudaStream_t stream); + +} // namespace tensorrt_llm::kernels::mnnvl_throughput diff --git a/csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h b/csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h new file mode 100644 index 0000000000..354365c1ac --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace torch_ext { +namespace mnnvl_throughput { + +// Enum for indexing into moe_a2a_metainfo tensor +enum MoeA2AMetaInfoIndex : int64_t { + FLAG_VAL_OFFSET_INDEX = 0, + LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1, + SEND_COUNTERS_OFFSET_INDEX = 2, + RECV_COUNTERS_OFFSET_INDEX = 3, + // Dispatch completion flags offset + DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4, + // Combine completion flags offset + COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5, + TOPK_TARGET_RANKS_OFFSET_INDEX = 6, + TOPK_SEND_INDICES_OFFSET_INDEX = 7, + PAYLOAD_DATA_OFFSET_INDEX = 8, + NUM_METAINFO_FIELDS = 9 +}; + +using MoeA2ADataOffsets = std::array; + +inline std::vector> getMoeA2AMetaInfoIndexPairs() { + return { + {"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX}, + {"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX}, + {"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX}, + {"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX}, + {"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX}, + {"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX}, + {"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX}, + {"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX}, + {"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX}, + {"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS}, + }; +} + +} // namespace mnnvl_throughput +} // namespace torch_ext diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu new file mode 100644 index 0000000000..0407eb201a --- /dev/null +++ b/csrc/trtllm_moe_a2a.cu @@ -0,0 +1,399 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +#include "flashinfer/utils.cuh" +#include "tensorrt_llm/common/dataType.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h" +#include "tensorrt_llm/thop/moeAlltoAllMeta.h" +#include "tvm_ffi_utils.h" + +using tvm::ffi::Array; +using tvm::ffi::Shape; +using tvm::ffi::String; +using tvm::ffi::Tensor; +using tvm::ffi::TensorView; +using tvm::ffi::Tuple; + +namespace { + +namespace tl_throughput = tensorrt_llm::kernels::mnnvl_throughput; +namespace fi_throughput = torch_ext::mnnvl_throughput; + +constexpr size_t kCachelineAlignment = 128; +constexpr size_t kInt32Bytes = sizeof(int32_t); + +inline size_t alignOffset(size_t offset, size_t alignment = kCachelineAlignment) { + return (offset + alignment - 1) & ~(alignment - 1); +} + +fi_throughput::MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) { + fi_throughput::MoeA2ADataOffsets offsets{}; + size_t offset = 0; + + offsets[fi_throughput::FLAG_VAL_OFFSET_INDEX] = offset; + offset += kInt32Bytes; + + offsets[fi_throughput::LOCAL_TOKEN_COUNTER_OFFSET_INDEX] = offset; + offset += kInt32Bytes; + + offsets[fi_throughput::SEND_COUNTERS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::COMBINE_COMPLETION_FLAGS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::TOPK_TARGET_RANKS_OFFSET_INDEX] = offset; + offset += static_cast(maxNumTokens) * tl_throughput::kMaxTopK * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::TOPK_SEND_INDICES_OFFSET_INDEX] = offset; + offset += static_cast(maxNumTokens) * tl_throughput::kMaxTopK * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX] = offset; + return offsets; +} + +Tensor moeA2AInitializeOp(TensorView workspace, int64_t epRank, int64_t epSize, + int64_t maxNumTokens) { + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2) << "workspace must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize) << "workspace first dim must equal ep_size"; + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize) << "epRank out of range"; + + auto stream = get_current_stream(); + auto* basePtr = static_cast(workspace.data_ptr()); + auto* rankPtr = basePtr + epRank * workspace.stride(0); + auto result = cudaMemsetAsync(rankPtr, 0, workspace.size(1), stream); + TVM_FFI_ICHECK(result == cudaSuccess) << "cudaMemsetAsync failed"; + + auto offsets = calculateOffsets(static_cast(epSize), static_cast(maxNumTokens)); + Tensor metainfo = alloc_tensor({fi_throughput::NUM_METAINFO_FIELDS}, dl_int64, cpu); + auto* metaPtr = static_cast(metainfo.data_ptr()); + std::copy(offsets.begin(), offsets.end(), metaPtr); + + auto err = cudaStreamSynchronize(stream); + TVM_FFI_ICHECK(err == cudaSuccess) << "cudaStreamSynchronize failed: " << cudaGetErrorString(err); + + return metainfo; +} + +Tuple, Array, int64_t> moeA2ADispatchOp( + TensorView tokenSelectedExperts, Array inputPayloads, TensorView workspace, + TensorView metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, + int64_t topK, int64_t numExperts) { + using tl_throughput::PayloadDescriptor; + fflush(stdout); + + CHECK_INPUT(tokenSelectedExperts); + CHECK_INPUT_TYPE(tokenSelectedExperts, dl_int32); + TVM_FFI_ICHECK_EQ(tokenSelectedExperts.ndim(), 2) << "token_selected_experts must be 2D"; + TVM_FFI_ICHECK_EQ(tokenSelectedExperts.size(1), topK) << "token_selected_experts shape mismatch"; + + int numPayloads = static_cast(inputPayloads.size()); + TVM_FFI_ICHECK(numPayloads > 0) << "At least one payload is required"; + TVM_FFI_ICHECK(numPayloads <= tl_throughput::kMaxPayloads) + << "Too many payloads: " << numPayloads << " > " << tl_throughput::kMaxPayloads; + + auto localNumTokens = static_cast(tokenSelectedExperts.size(0)); + TVM_FFI_ICHECK(localNumTokens > 0) << "local_num_tokens must be positive"; + + // Validate all payloads and calculate sizes + for (int i = 0; i < numPayloads; ++i) { + auto const& payload = inputPayloads[i]; + CHECK_INPUT(payload); + TVM_FFI_ICHECK_EQ(payload.ndim(), 2) << "payload " << i << " must be 2D"; + TVM_FFI_ICHECK_EQ(payload.size(0), localNumTokens) + << "payload " << i << " first dimension must match local_num_tokens"; + } + + CHECK_CPU(metainfo); + CHECK_INPUT_TYPE(metainfo, dl_int64); + TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); + TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); + auto const* offsetsPtr = static_cast(metainfo.data_ptr()); + fi_throughput::MoeA2ADataOffsets offsets{}; + std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); + + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); + TVM_FFI_ICHECK(runtimeMaxTokensPerRank > 0); + TVM_FFI_ICHECK(numExperts >= epSize && numExperts % epSize == 0) + << "num_experts must be divisible by ep_size"; + TVM_FFI_ICHECK(topK > 0 && topK <= tl_throughput::kMaxTopK); + + // Calculate payload descriptors and sizes from input tensors + std::vector payloadDescriptors(numPayloads); + std::vector payloadByteSizes(numPayloads); + int64_t totalBytesNeeded = 0; + + for (int i = 0; i < numPayloads; ++i) { + auto const& payload = inputPayloads[i]; + int elementsPerToken = static_cast(payload.size(1)); + int elementSize = static_cast(get_element_size(payload)); + + payloadDescriptors[i].src_data = payload.data_ptr(); + payloadDescriptors[i].element_size = elementSize; + payloadDescriptors[i].elements_per_token = elementsPerToken; + + int64_t bytesPerPayload = + static_cast(epSize) * runtimeMaxTokensPerRank * elementsPerToken * elementSize; + payloadByteSizes[i] = bytesPerPayload; + totalBytesNeeded += bytesPerPayload; + } + + auto* workspaceBase = static_cast(workspace.data_ptr()); + auto strideBytes = workspace.stride(0); + size_t rankWorkspaceOffset = epRank * strideBytes; + auto* rankWorkspacePtr = workspaceBase + rankWorkspaceOffset; + int64_t sizePerRank = workspace.size(1); + + int64_t requiredSize = offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded; + TVM_FFI_ICHECK(sizePerRank >= requiredSize) << "workspace size per rank insufficient, need " + << requiredSize << " bytes but has " << sizePerRank; + + tl_throughput::MoeA2ADispatchParams params{}; + params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); + params.ep_size = static_cast(epSize); + params.ep_rank = static_cast(epRank); + params.num_experts_per_rank = static_cast(numExperts / epSize); + params.local_num_tokens = localNumTokens; + params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); + params.top_k = static_cast(topK); + params.token_selected_experts = static_cast(tokenSelectedExperts.data_ptr()); + params.num_payloads = numPayloads; + std::copy(payloadDescriptors.begin(), payloadDescriptors.end(), params.payloads); + + params.flag_val = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::FLAG_VAL_OFFSET_INDEX]); + params.local_token_counter = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::LOCAL_TOKEN_COUNTER_OFFSET_INDEX]); + params.send_counters = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::SEND_COUNTERS_OFFSET_INDEX]); + params.topk_target_ranks = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_TARGET_RANKS_OFFSET_INDEX]); + params.topk_send_indices = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_SEND_INDICES_OFFSET_INDEX]); + + for (int targetRank = 0; targetRank < epSize; ++targetRank) { + auto* targetWorkspacePtr = workspaceBase + targetRank * strideBytes; + params.recv_counters[targetRank] = reinterpret_cast( + targetWorkspacePtr + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX]); + params.completion_flags[targetRank] = reinterpret_cast( + targetWorkspacePtr + offsets[fi_throughput::DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]); + + size_t offset = static_cast(offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]); + for (int payloadIdx = 0; payloadIdx < numPayloads; ++payloadIdx) { + params.recv_buffers[targetRank][payloadIdx] = targetWorkspacePtr + offset; + offset += payloadByteSizes[payloadIdx]; + } + } + + params.stream = get_current_stream(); + + tl_throughput::moe_a2a_prepare_dispatch_launch(params); + tl_throughput::moe_a2a_dispatch_launch(params); + auto launchErr = cudaGetLastError(); + TVM_FFI_ICHECK(launchErr == cudaSuccess) + << "moe_a2a_dispatch launch failed: " << cudaGetErrorString(launchErr); + + Array recvOffsets; + Array recvByteSizes; + recvOffsets.reserve(numPayloads); + size_t localOffset = static_cast(offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]); + for (auto payloadByteSize : payloadByteSizes) { + recvOffsets.push_back(rankWorkspaceOffset + localOffset); + recvByteSizes.push_back(payloadByteSize); + localOffset += payloadByteSize; + } + + int64_t combinePayloadOffset = static_cast(alignOffset(localOffset)); + return Tuple(recvOffsets, recvByteSizes, combinePayloadOffset); +} + +nvinfer1::DataType toNvDataType(DLDataType dtype) { + auto code = encode_dlpack_dtype(dtype); + if (code == float16_code) { + return nvinfer1::DataType::kHALF; + } + if (code == bfloat16_code) { + return nvinfer1::DataType::kBF16; + } + if (code == float32_code) { + return nvinfer1::DataType::kFLOAT; + } + TVM_FFI_LOG_AND_THROW(TypeError) << "Unsupported dtype for MoE combine"; + return nvinfer1::DataType::kFLOAT; +} + +Tensor moeA2ACombineOp(TensorView payload, int64_t localNumTokens, TensorView workspace, + TensorView metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, + int64_t epSize, int64_t topK, int64_t combinePayloadOffset, + bool payloadInWorkspace) { + using tl_throughput::MoeA2ACombineParams; + CHECK_INPUT(payload); + TVM_FFI_ICHECK_EQ(payload.ndim(), 3) + << "payload must be [ep_size, runtime_max_tokens_per_rank, hidden]"; + TVM_FFI_ICHECK_EQ(payload.size(0), epSize); + TVM_FFI_ICHECK_EQ(payload.size(1), runtimeMaxTokensPerRank); + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); + TVM_FFI_ICHECK(topK > 0 && topK <= tl_throughput::kMaxTopK); + TVM_FFI_ICHECK(localNumTokens > 0); + + CHECK_CPU(metainfo); + CHECK_INPUT_TYPE(metainfo, dl_int64); + TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); + TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); + auto const* offsetsPtr = static_cast(metainfo.data_ptr()); + fi_throughput::MoeA2ADataOffsets offsets{}; + std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); + + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); + auto* workspaceBase = static_cast(workspace.data_ptr()); + auto strideBytes = workspace.stride(0); + auto* rankWorkspacePtr = workspaceBase + epRank * strideBytes; + int64_t sizePerRank = workspace.size(1); + + int64_t elementsPerToken = payload.size(2); + int64_t payloadBytes = + payload.numel() * + get_element_size(payload); // includes all ranks * runtime_max_tokens_per_rank + TVM_FFI_ICHECK(combinePayloadOffset >= 0 && combinePayloadOffset + payloadBytes <= sizePerRank) + << "workspace insufficient for combine payload region"; + + if (payloadInWorkspace) { + auto* expectedPtr = rankWorkspacePtr + combinePayloadOffset; + TVM_FFI_ICHECK(payload.data_ptr() == expectedPtr) + << "payload_in_workspace is True but tensor pointer mismatch"; + } + + Tensor output = + alloc_tensor({localNumTokens, elementsPerToken}, payload.dtype(), payload.device()); + + MoeA2ACombineParams params{}; + params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); + params.ep_size = static_cast(epSize); + params.ep_rank = static_cast(epRank); + params.local_num_tokens = static_cast(localNumTokens); + params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); + params.top_k = static_cast(topK); + params.prepare_payload = payloadInWorkspace ? nullptr : payload.data_ptr(); + params.output_data = output.data_ptr(); + params.elements_per_token = static_cast(elementsPerToken); + params.dtype = toNvDataType(payload.dtype()); + + params.flag_val = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::FLAG_VAL_OFFSET_INDEX]); + params.topk_target_ranks = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_TARGET_RANKS_OFFSET_INDEX]); + params.topk_send_indices = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_SEND_INDICES_OFFSET_INDEX]); + params.recv_counters = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX]); + + for (int targetRank = 0; targetRank < epSize; ++targetRank) { + auto* targetWorkspacePtr = workspaceBase + targetRank * strideBytes; + params.completion_flags[targetRank] = reinterpret_cast( + targetWorkspacePtr + offsets[fi_throughput::COMBINE_COMPLETION_FLAGS_OFFSET_INDEX]); + params.recv_buffers[targetRank] = targetWorkspacePtr + combinePayloadOffset; + } + params.stream = get_current_stream(); + + tl_throughput::moe_a2a_prepare_combine_launch(params); + tl_throughput::moe_a2a_combine_launch(params); + auto err = cudaGetLastError(); + TVM_FFI_ICHECK(err == cudaSuccess) + << "moe_a2a_combine launch failed: " << cudaGetErrorString(err); + return output; +} + +void moeA2ASanitizeExpertIdsOp(TensorView expertIds, TensorView workspace, TensorView metainfo, + int64_t epRank, int64_t invalidExpertId) { + CHECK_INPUT(expertIds); + CHECK_INPUT_TYPE(expertIds, dl_int32); + TVM_FFI_ICHECK_EQ(expertIds.ndim(), 3); + int64_t epSize = expertIds.size(0); + int64_t runtimeMaxTokensPerRank = expertIds.size(1); + int64_t topK = expertIds.size(2); + + CHECK_CPU(metainfo); + CHECK_INPUT_TYPE(metainfo, dl_int64); + TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); + TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); + auto const* offsetsPtr = static_cast(metainfo.data_ptr()); + fi_throughput::MoeA2ADataOffsets offsets{}; + std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); + + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + auto* workspaceBase = static_cast(workspace.data_ptr()); + auto* rankWorkspacePtr = workspaceBase + epRank * workspace.stride(0); + auto* recvCounters = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX]); + + tl_throughput::moe_a2a_sanitize_expert_ids_launch( + static_cast(expertIds.data_ptr()), recvCounters, + static_cast(invalidExpertId), static_cast(epSize), + static_cast(runtimeMaxTokensPerRank), static_cast(topK), get_current_stream()); +} + +// Expose metainfo index constants for Python access +// Returns a tuple of (names, values) for all metainfo constants +Tuple, Array> getMoeA2AMetaInfoIndexPairs() { + auto pairs = fi_throughput::getMoeA2AMetaInfoIndexPairs(); + + Array names; + Array values; + + for (const auto& pair : pairs) { + names.push_back(pair.first); + values.push_back(pair.second); + } + + return Tuple{names, values}; +} + +} // namespace + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_initialize, moeA2AInitializeOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_dispatch, moeA2ADispatchOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_combine, moeA2ACombineOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_sanitize_expert_ids, moeA2ASanitizeExpertIdsOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_get_metainfo_index_pairs, getMoeA2AMetaInfoIndexPairs); diff --git a/docs/api/comm.rst b/docs/api/comm.rst index 735f9a4294..f852073ae4 100644 --- a/docs/api/comm.rst +++ b/docs/api/comm.rst @@ -128,3 +128,17 @@ TensorRT-LLM MNNVL AllReduce trtllm_mnnvl_all_reduce trtllm_mnnvl_fused_allreduce_rmsnorm mpi_barrier + +MNNVL A2A (Throughput Backend) +------------------------------- + +.. currentmodule:: flashinfer.comm + +.. autosummary:: + :toctree: ../generated + + MoeAlltoAll + moe_a2a_initialize + moe_a2a_dispatch + moe_a2a_combine + moe_a2a_sanitize_expert_ids diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 609e1bcbcf..84bc0ca199 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -512,12 +512,14 @@ def gen_all_modules( from .jit.comm import gen_nvshmem_module from .jit.comm import gen_comm_alltoall_module from .jit.comm import gen_trtllm_mnnvl_comm_module + from .jit.comm import gen_mnnvl_a2a_module jit_specs.append(gen_nvshmem_module()) jit_specs.append(gen_comm_alltoall_module()) if has_sm100: jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_trtllm_mnnvl_comm_module()) + jit_specs.append(gen_mnnvl_a2a_module()) jit_specs.append(gen_vllm_comm_module()) if add_misc: diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index f7ae3754ac..7aa1ff1cfe 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -39,4 +39,13 @@ from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers +# MNNVL A2A (Throughput Backend) +from .trtllm_moe_a2a import MoeAlltoAll as MoeAlltoAll +from .trtllm_moe_a2a import moe_a2a_combine as moe_a2a_combine +from .trtllm_moe_a2a import moe_a2a_dispatch as moe_a2a_dispatch +from .trtllm_moe_a2a import moe_a2a_initialize as moe_a2a_initialize +from .trtllm_moe_a2a import ( + moe_a2a_sanitize_expert_ids as moe_a2a_sanitize_expert_ids, +) + # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py new file mode 100644 index 0000000000..58f0a1834a --- /dev/null +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -0,0 +1,556 @@ +""" +MoE All-to-All Operations (Throughput Backend) + +This module provides the throughput-optimized all-to-all backend for MoE expert parallelism, +supporting multiple payloads per collective operation. +""" + +# TODO Review + +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional + +import torch +import functools + +from .mnnvl import MnnvlMemory +from .mapping import Mapping +from ..jit.comm import gen_mnnvl_a2a_module +from ..utils import register_custom_op + + +@dataclass +class _A2AState: + """Internal state tracking for MoeAlltoAll operations.""" + + phase: str = "idle" # idle | dispatched + local_num_tokens: Optional[int] = None + combine_payload_offset: Optional[int] = None + + +@functools.cache +def get_mnnvl_a2a_module(): + """Get or build the MNNVL A2A JIT module.""" + module = gen_mnnvl_a2a_module().build_and_load() + + @register_custom_op( + "flashinfer::moe_a2a_initialize", + mutates_args=[], + ) + def moe_a2a_initialize( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + max_num_tokens: int, + ): + return module.moe_a2a_initialize(workspace, ep_rank, ep_size, max_num_tokens) + + @register_custom_op( + "flashinfer::moe_a2a_dispatch", + mutates_args=[], + ) + def moe_a2a_dispatch( + token_selected_experts: torch.Tensor, + input_payloads: list[torch.Tensor], + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + num_experts: int, + ): + """ + Dispatch tokens and payloads to expert ranks. + + Args: + token_selected_experts: [local_num_tokens, top_k] int32 tensor + input_payloads: List of [local_num_tokens, *] tensors to dispatch + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + num_experts: Total number of experts + + Returns: + recv_offsets: List of offsets for each payload in the workspace + recv_sizes: List of sizes for each payload in the workspace + combine_payload_offset: Offset for combine payload region + """ + return module.moe_a2a_dispatch( + token_selected_experts, + input_payloads, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + num_experts, + ) + + @register_custom_op( + "flashinfer::moe_a2a_combine", + mutates_args=[], + ) + def moe_a2a_combine( + payload: torch.Tensor, + local_num_tokens: int, + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + combine_payload_offset: int, + payload_in_workspace: bool = False, + ) -> torch.Tensor: + """ + Combine expert outputs back to originating tokens. + + Args: + payload: [ep_size, max_tokens, elements_per_token] tensor + local_num_tokens: Number of tokens on this rank + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + combine_payload_offset: Offset from dispatch + payload_in_workspace: If True, payload is workspace-backed + + Returns: + output: [local_num_tokens, elements_per_token] tensor + """ + return module.moe_a2a_combine( + payload, + local_num_tokens, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + combine_payload_offset, + payload_in_workspace, + ) + + @register_custom_op( + "flashinfer::moe_a2a_sanitize_expert_ids", + mutates_args=[], + ) + def moe_a2a_sanitize_expert_ids( + expert_ids: torch.Tensor, + workspace: torch.Tensor, + metainfo: torch.Tensor, + ep_rank: int, + invalid_expert_id: int, + ): + return module.moe_a2a_sanitize_expert_ids( + expert_ids, workspace, metainfo, ep_rank, invalid_expert_id + ) + + @register_custom_op( + "flashinfer::moe_a2a_get_metainfo_index_pairs", + mutates_args=[], + ) + def moe_a2a_get_metainfo_index_pairs(): + """ + Get all metainfo index constants from C++. + + Returns: + Tuple of (names, values) where names is a list of constant names + and values is a list of their corresponding integer values + """ + return module.moe_a2a_get_metainfo_index_pairs() + + return SimpleNamespace( + moe_a2a_initialize=moe_a2a_initialize, + moe_a2a_dispatch=moe_a2a_dispatch, + moe_a2a_combine=moe_a2a_combine, + moe_a2a_sanitize_expert_ids=moe_a2a_sanitize_expert_ids, + moe_a2a_get_metainfo_index_pairs=moe_a2a_get_metainfo_index_pairs, + ) + + +def moe_a2a_initialize( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + max_num_tokens: int, +): + return get_mnnvl_a2a_module().moe_a2a_initialize( + workspace, ep_rank, ep_size, max_num_tokens + ) + + +def moe_a2a_wrap_payload_tensor_in_workspace( + workspace: torch.Tensor, + leading_shape: list[int], + slice_start: int, + slice_end: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Wrap an offset in the workspace into a tensor. + + Args: + workspace: [ep_size, size_per_rank] workspace tensor + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + runtime_max_tokens_per_rank: Max tokens per rank in this batch + total_size: Total size of the payload + offset: Offset from dispatch + dtype: Data type for the tensor + + Returns: + tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + """ + workspace_base = workspace.flatten().view(dtype=torch.uint8) + result = ( + workspace_base[slice_start:slice_end] + .view(leading_shape + [-1]) + .view(dtype=dtype) + ) + return result + + +def moe_a2a_dispatch( + token_selected_experts: torch.Tensor, + input_payloads: list[torch.Tensor], + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + num_experts: int, +): + recv_offsets, recv_sizes, combine_payload_offset = ( + get_mnnvl_a2a_module().moe_a2a_dispatch( + token_selected_experts, + input_payloads, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + num_experts, + ) + ) + + output_payloads = [] + for input_payload, offset, size in zip( + input_payloads, recv_offsets, recv_sizes, strict=True + ): + # This uses absolute offsets in the workspace, so skip indexing into the workspace + output_payloads.append( + moe_a2a_wrap_payload_tensor_in_workspace( + workspace, + [ep_size, runtime_max_tokens_per_rank], + offset, + offset + size, + input_payload.dtype, + ) + ) + + return output_payloads, combine_payload_offset + + +def moe_a2a_combine( + payload: torch.Tensor, + local_num_tokens: int, + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + combine_payload_offset: int, + payload_in_workspace: bool = False, +) -> torch.Tensor: + return get_mnnvl_a2a_module().moe_a2a_combine( + payload, + local_num_tokens, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + combine_payload_offset, + payload_in_workspace, + ) + + +def moe_a2a_sanitize_expert_ids( + expert_ids: torch.Tensor, + workspace: torch.Tensor, + metainfo: torch.Tensor, + ep_rank: int, + invalid_expert_id: int, +): + return get_mnnvl_a2a_module().moe_a2a_sanitize_expert_ids( + expert_ids, workspace, metainfo, ep_rank, invalid_expert_id + ) + + +class MoeAlltoAll: + """ + Manages MoE All-to-All operations with proper workspace allocation and synchronization. + + This class provides the throughput-optimized backend that supports multiple payloads + per collective operation, explicit dispatch/combine phases, and workspace-backed tensors. + + Example: + >>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8) + >>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size) + >>> output = moe_a2a.combine(processed, batch_size) + """ + + # Single shared workspace across the process + _WORKSPACE: Optional[dict] = None + + # Metainfo index constants (loaded dynamically from C++) + # These offsets allow accessing internal workspace data for testing/debugging + _METAINFO_INDEX: Optional[dict] = None + + @classmethod + def _init_constants(cls): + """Initialize constants from C++ if not already done.""" + if cls._METAINFO_INDEX is None: + module = get_mnnvl_a2a_module() + names, values = module.moe_a2a_get_metainfo_index_pairs() + + # Convert TVM arrays to Python and build dictionary + # Strip "MOE_A2A_" prefix from names for cleaner API + cls._METAINFO_INDEX = {} + for name, value in zip(names, values, strict=True): + # Convert from "MOE_A2A_SEND_COUNTERS_OFFSET_INDEX" to "SEND_COUNTERS_OFFSET_INDEX" + clean_name = ( + name.replace("MOE_A2A_", "") + if name.startswith("MOE_A2A_") + else name + ) + cls._METAINFO_INDEX[clean_name] = int(value) + + def __init__( + self, + mapping: Mapping, + max_num_tokens: int, + top_k: int, + num_experts: int, + workspace_size_per_rank: int = 512 * 1024 * 1024, + ): + """ + Initialize MoeAlltoAll with workspace allocation. + + Args: + mapping: Mapping object containing rank information + max_num_tokens: Maximum number of tokens supported + top_k: Number of experts per token + num_experts: Total number of experts + workspace_size_per_rank: Size of workspace per rank in bytes (default: 512MB) + """ + # Initialize constants from C++ + self._init_constants() + + # Initialize MNNVL memory system + MnnvlMemory.initialize() + + self.workspace_size_per_rank = workspace_size_per_rank + self.max_num_tokens = max_num_tokens + self.ep_size = mapping.moe_ep_size + self.ep_rank = mapping.moe_ep_rank + self.top_k = top_k + self.num_experts = num_experts + + if not isinstance(self.top_k, int) or self.top_k <= 0: + raise ValueError("top_k must be a positive int") + if not isinstance(self.num_experts, int) or self.num_experts <= 0: + raise ValueError("num_experts must be a positive int") + + # Allocate or reuse workspace + if self._WORKSPACE is None: + mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) + metainfo = moe_a2a_initialize( + workspace, + self.ep_rank, + self.ep_size, + self.max_num_tokens, + ) + MoeAlltoAll._WORKSPACE = { + "workspace_size_per_rank": workspace_size_per_rank, + "max_num_tokens": self.max_num_tokens, + "ep_rank": self.ep_rank, + "ep_size": self.ep_size, + "mnnvl_mem": mnnvl_mem, + "workspace": workspace, + "metainfo": metainfo, + } + else: + # Validate workspace compatibility + assert ( + self._WORKSPACE["workspace_size_per_rank"] == workspace_size_per_rank + ), "Workspace size mismatch" + assert self._WORKSPACE["max_num_tokens"] == self.max_num_tokens, ( + "Max tokens mismatch" + ) + assert self._WORKSPACE["ep_rank"] == self.ep_rank, "EP rank mismatch" + assert self._WORKSPACE["ep_size"] == self.ep_size, "EP size mismatch" + + self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] + self.workspace = self._WORKSPACE["workspace"] + self.metainfo = self._WORKSPACE["metainfo"] + self._state = _A2AState() + + def dispatch( + self, + token_selected_experts: torch.Tensor, + input_payloads: list[torch.Tensor], + runtime_max_tokens_per_rank: int, + invalid_token_expert_id: Optional[int] = None, + expert_id_payload_index: Optional[int] = None, + ) -> list[torch.Tensor]: + """ + Perform MoE all-to-all dispatch operation. + + Args: + token_selected_experts: [local_num_tokens, top_k] expert indices + input_payloads: List of [local_num_tokens, *] tensors to dispatch + runtime_max_tokens_per_rank: Max tokens per rank in this batch + invalid_token_expert_id: If set, sanitize invalid tokens to this ID + expert_id_payload_index: Index of expert IDs in input_payloads (required if invalid_token_expert_id is set) + + Returns: + recv_tensors: List of [ep_size, max_tokens, *] tensors + """ + assert self._state.phase == "idle", "dispatch called twice without combine" + assert runtime_max_tokens_per_rank <= self.max_num_tokens, ( + "runtime_max_tokens_per_rank exceeds max_num_tokens" + ) + + recv_tensors, combine_payload_offset = moe_a2a_dispatch( + token_selected_experts, + input_payloads, + self.workspace, + self.metainfo, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.top_k, + self.num_experts, + ) + + # Update state + self._state.local_num_tokens = token_selected_experts.size(0) + self._state.combine_payload_offset = combine_payload_offset + self._state.phase = "dispatched" + + # Sanitize invalid tokens if requested + if invalid_token_expert_id is not None: + assert expert_id_payload_index is not None, ( + "expert_id_payload_index required when invalid_token_expert_id is set" + ) + recv_expert_ids = recv_tensors[expert_id_payload_index] + moe_a2a_sanitize_expert_ids( + recv_expert_ids, + self.workspace, + self.metainfo, + self.ep_rank, + invalid_token_expert_id, + ) + + return recv_tensors + + def combine( + self, + payload: torch.Tensor, + runtime_max_tokens_per_rank: int, + payload_in_workspace: bool = False, + ) -> torch.Tensor: + """ + Perform MoE all-to-all combine operation. + + Args: + payload: [ep_size, max_tokens, elements_per_token] tensor + runtime_max_tokens_per_rank: Max tokens per rank in this batch + payload_in_workspace: If True, payload is workspace-backed (skip staging) + + Returns: + output: [local_num_tokens, elements_per_token] tensor + """ + assert self._state.phase == "dispatched", ( + "combine called before successful dispatch" + ) + assert runtime_max_tokens_per_rank <= self.max_num_tokens, ( + "runtime_max_tokens_per_rank exceeds max_num_tokens" + ) + + output = moe_a2a_combine( + payload, + self._state.local_num_tokens, + self.workspace, + self.metainfo, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.top_k, + self._state.combine_payload_offset, + payload_in_workspace, + ) + + # Reset state for next round + self._state = _A2AState() + + return output + + def get_combine_payload_tensor_in_workspace( + self, + runtime_max_tokens_per_rank: int, + hidden_size: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Get combine payload tensor backed by workspace (zero-copy). + + This tensor can be written to directly by expert processing, avoiding + a staging copy in the combine operation. + + Args: + runtime_max_tokens_per_rank: Max tokens per rank in this batch + hidden_size: Hidden dimension size + dtype: Data type for the tensor + + Returns: + tensor: [ep_size, max_tokens, hidden_size] workspace-backed tensor + """ + if self._state.phase != "dispatched": + raise RuntimeError( + "get_combine_payload_tensor_in_workspace called before successful dispatch" + ) + + element_size = torch.tensor([], dtype=dtype).element_size() + return moe_a2a_wrap_payload_tensor_in_workspace( + self.workspace[self.ep_rank, :], + [self.ep_size, runtime_max_tokens_per_rank], + self._state.combine_payload_offset, + self._state.combine_payload_offset + + self.ep_size * runtime_max_tokens_per_rank * hidden_size * element_size, + dtype, + ) + + +__all__ = [ + "MoeAlltoAll", + "moe_a2a_initialize", + "moe_a2a_dispatch", + "moe_a2a_combine", + "moe_a2a_sanitize_expert_ids", +] diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 0bacf2d28b..bd1934ff62 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -77,6 +77,7 @@ from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module from .comm import gen_vllm_comm_module as gen_vllm_comm_module from .comm import gen_nvshmem_module as gen_nvshmem_module +from .comm import gen_mnnvl_a2a_module as gen_mnnvl_a2a_module from .dsv3_optimizations import ( gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, ) diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 27661b1fe2..4c350ddf22 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -78,3 +78,32 @@ def gen_vllm_comm_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "vllm_custom_all_reduce.cu", ], ) + + +def gen_mnnvl_a2a_module() -> JitSpec: + return gen_jit_spec( + "mnnvl_a2a", + [ + jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_a2a.cu", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "communicationKernels" + / "moeAlltoAllKernels.cu", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "cpp" + / "common" + / "envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "cpp" + / "common" + / "tllmException.cpp", + ], + extra_include_paths=[ + str(jit_env.FLASHINFER_CSRC_DIR / "nv_internal"), + str(jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include"), + ], + ) diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py new file mode 100644 index 0000000000..8a1e84c94e --- /dev/null +++ b/tests/comm/test_mnnvl_a2a.py @@ -0,0 +1,787 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback + +import pytest +import torch +from mpi4py import MPI + +from flashinfer.comm import MoeAlltoAll +from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import MnnvlMemory + + +class MPIExit(Exception): + pass + + +def check_any_rank_failed(): + comm = MPI.COMM_WORLD + if any(comm.allgather(False)): + raise MPIExit("Another rank failed") + + +def safe_run(func, *args, **kwargs): + comm = MPI.COMM_WORLD + try: + func(*args, **kwargs) + except MPIExit as e: + raise e + except Exception as e: + traceback.print_exc() + comm.allgather(True) + raise e + + +@pytest.fixture(autouse=True) +def setup_test(): + torch.manual_seed(0x1234) + + +def compute_target_rank_id(expert_id, num_experts_per_rank): + """Compute the rank that owns a given expert using contiguous partitioning. + Experts are divided evenly across ranks: + - Rank 0: experts [0, num_experts_per_rank) + - Rank 1: experts [num_experts_per_rank, 2 * num_experts_per_rank) + - ... + For example, with 32 experts and 4 ranks (8 experts per rank): + - Rank 0: experts 0-7 + - Rank 1: experts 8-15 + - Rank 2: experts 16-23 + - Rank 3: experts 24-31 + """ + return expert_id // num_experts_per_rank + + +def generate_token_selected_experts( + local_num_tokens: int, ep_size: int, num_experts_per_rank: int, top_k: int +) -> torch.Tensor: + """Generate global expert IDs tensor, aligned with single-GPU test semantics.""" + return torch.randint( + 0, + ep_size * num_experts_per_rank, + (local_num_tokens, top_k), + dtype=torch.int32, + device="cuda", + ) + + +def create_experts( + num_experts_per_rank, hidden_size, ep_rank, device, dtype=torch.bfloat16 +): + """ + Create a 3D tensor of expert weights for a given rank. + + Args: + num_experts_per_rank: Number of experts on this rank + hidden_size: Hidden dimension size + ep_rank: EP rank ID + device: Device to create experts on + + Returns: + experts: Tensor of shape [num_experts_per_rank, hidden_size, hidden_size] + """ + # For reproducibility, set the seed based on rank + experts = torch.empty( + (num_experts_per_rank, hidden_size, hidden_size), dtype=dtype, device=device + ) + for i in range(num_experts_per_rank): + torch.manual_seed(ep_rank * 1000 + i) + # Xavier uniform initialization for each expert + torch.nn.init.xavier_uniform_(experts[i]) + return experts + + +def fake_moe( + hidden_states, + token_selected_experts, + token_final_scales, + experts, + is_ep=False, + ep_rank=None, + num_experts_per_rank=None, +): + """ + Emulate MoE computation by scaling tokens based on which experts belong to this rank. + + Args: + hidden_states: [num_tokens, hidden_size] - input hidden states + token_selected_experts: [num_tokens, top_k] - selected expert indices + token_final_scales: [num_tokens, top_k] - scaling factors for each expert + experts: [num_experts_per_rank, hidden_size, hidden_size] if is_ep, otherwise [num_experts, hidden_size, hidden_size] - expert weights + is_ep: If true, emulate MoE on a EP rank; otherwise, emulate MoE with all experts + ep_rank: EP rank ID + num_experts_per_rank: Number of experts per rank + + Returns: + processed_states: [num_tokens, hidden_size] - processed hidden states + """ + num_tokens, _ = hidden_states.shape + _, top_k = token_selected_experts.shape + + if is_ep: + assert ep_rank is not None and num_experts_per_rank is not None + + # Initialize output + processed_states = torch.zeros_like(hidden_states) + + # Process each token + for token_idx in range(num_tokens): + # For each expert selected for this token/ + for k in range(top_k): + expert_id = token_selected_experts[token_idx, k].item() + if is_ep: + if not ( + expert_id >= ep_rank * num_experts_per_rank + and expert_id < (ep_rank + 1) * num_experts_per_rank + ): + continue + # Convert global expert ID to local expert ID for this rank + local_expert_id = expert_id - ep_rank * num_experts_per_rank + expert = experts[local_expert_id] + else: + expert = experts[expert_id] + + scale = token_final_scales[token_idx, k] + processed_states[token_idx] += hidden_states[token_idx] @ expert * scale + + return processed_states + + +def make_nvfp4_payloads( + local_num_tokens: int, + hidden_size: int, + top_k: int, + rank: int, + token_selected_experts: torch.Tensor, +) -> tuple[list, int]: + """Create the four NV FP4 payloads exactly as in single-GPU test.""" + payloads = [] + # Payload 0: Packed FP4 tokens (uint8) + packed_hidden_size = hidden_size // 2 + packed_hidden_states = torch.randint( + 0, 256, (local_num_tokens, packed_hidden_size), dtype=torch.uint8, device="cuda" + ) + payloads.append(packed_hidden_states) + + # Payload 1: Scaling factors (fp8) + num_elts_per_sf = 16 + num_scaling_factors = hidden_size // num_elts_per_sf + scaling_factors = torch.randn( + local_num_tokens, num_scaling_factors, dtype=torch.float32, device="cuda" + ) # .to(torch.float8_e4m3fn) TODO: Test failed. + scaling_factors += rank + payloads.append(scaling_factors) + + # Payload 2: token_selected_experts + payloads.append(token_selected_experts) + + # Payload 3: token_final_scales (bfloat16) + token_final_scales = torch.rand( + local_num_tokens, top_k, dtype=torch.bfloat16, device="cuda" + ) + + # Construct the data to contain info about send rank and local_token_idx, which is used for debugging + # token_final_scales[:, 0] = rank + # token_final_scales[:, 1] = torch.linspace(0, local_num_tokens - 1, local_num_tokens, dtype=torch.bfloat16, device='cuda') + + payloads.append(token_final_scales) + return payloads, 2 + + +def make_bfloat16_payloads( + local_num_tokens: int, + hidden_size: int, + top_k: int, + rank: int, + token_selected_experts: torch.Tensor, +) -> tuple[list, int]: + """Create bfloat16 test payloads matching nvfp4 structure but without scaling factors.""" + payloads = [] + + # Payload 0: Hidden states (bfloat16) + hidden_states = torch.randn( + local_num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + # Add rank-specific pattern for verification + hidden_states += rank + payloads.append(hidden_states) + + # Payload 1: token_selected_experts + payloads.append(token_selected_experts) + + # Payload 2: token_final_scales (bfloat16) - similar to nvfp4's payload 4 + token_final_scales = torch.rand( + local_num_tokens, top_k, dtype=torch.bfloat16, device="cuda" + ) + + # Optional: Construct the data that is easier to debug + # token_final_scales[:, 0] = rank + # token_final_scales[:, 1] = torch.linspace(0, local_num_tokens - 1, local_num_tokens, dtype=torch.bfloat16, device='cuda') + + payloads.append(token_final_scales) + + return payloads, 1 + + +def run_moe_a2a_dispatch_single_rank( + ep_size, + all_num_tokens, + top_k, + workspace_size_per_rank, + num_experts_per_rank, + hidden_size, + invalid_token_expert_id, +): + """Worker function for MPI testing.""" + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + torch.cuda.set_device(rank) + + check_any_rank_failed() + + mapping = Mapping( + rank=rank, + tp_size=ep_size, + moe_ep_size=ep_size, + world_size=ep_size, + gpus_per_node=ep_size, + pp_size=1, + cp_size=1, + ) + + # Create MoeAlltoAll manager + max_num_tokens = max(all_num_tokens) + + MoeAlltoAll._WORKSPACE = None + moe_a2a = MoeAlltoAll( + mapping, + max_num_tokens, + top_k, + ep_size * num_experts_per_rank, + workspace_size_per_rank, + ) + + # Get the number of tokens for this specific rank (same as single-GPU) + rank_local_tokens = all_num_tokens[rank] + + # Generate data using helper functions + token_selected_experts = generate_token_selected_experts( + rank_local_tokens, ep_size, num_experts_per_rank, top_k + ) + payloads, expert_id_payload_index = make_nvfp4_payloads( + rank_local_tokens, hidden_size, top_k, rank, token_selected_experts + ) + + check_any_rank_failed() + + recv_tensors = moe_a2a.dispatch( + token_selected_experts, + payloads, + max_num_tokens, + invalid_token_expert_id=invalid_token_expert_id, + expert_id_payload_index=expert_id_payload_index, + ) + + # Read counters and compact routing tensors from workspace + send_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"] + ].item() + recv_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"] + ].item() + topk_target_ranks_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["TOPK_TARGET_RANKS_OFFSET_INDEX"] + ].item() + topk_send_indices_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["TOPK_SEND_INDICES_OFFSET_INDEX"] + ].item() + + send_counters = ( + moe_a2a.workspace[ + rank, send_counters_offset : send_counters_offset + ep_size * 4 + ] + .view(torch.int32) + .cpu() + ) + recv_counters = ( + moe_a2a.workspace[ + rank, recv_counters_offset : recv_counters_offset + ep_size * 4 + ] + .view(torch.int32) + .cpu() + ) + topk_target_ranks = ( + moe_a2a.workspace[ + rank, + topk_target_ranks_offset : topk_target_ranks_offset + + max_num_tokens * top_k * 4, + ] + .view(torch.int32) + .view(max_num_tokens, top_k) + .cpu() + ) + topk_send_indices = ( + moe_a2a.workspace[ + rank, + topk_send_indices_offset : topk_send_indices_offset + + max_num_tokens * top_k * 4, + ] + .view(torch.int32) + .view(max_num_tokens, top_k) + .cpu() + ) + + # Return results to be collected (move to CPU for MPI transfer) + return ( + token_selected_experts.cpu(), + [p.cpu() for p in payloads], + [rt.cpu() for rt in recv_tensors], + send_counters, + topk_send_indices, + topk_target_ranks, + recv_counters, + expert_id_payload_index, + ) + + +def verify_dispatch( + all_token_selected_experts, + all_payloads, + all_recv_tensors, + all_send_counters, + all_topk_send_indices, + all_topk_target_ranks, + all_recv_counters, + ep_size, + all_num_tokens, + top_k, + num_experts_per_rank, + expert_id_payload_index, + invalid_token_expert_id, +): + """Verify dispatch results including actual content verification""" + + max_num_tokens = max(all_num_tokens) + # Verify dimensions and dtypes + for send_rank in range(ep_size): + local_num_tokens = all_num_tokens[send_rank] + + token_selected_experts = all_token_selected_experts[send_rank] + assert len(token_selected_experts.shape) == 2, ( + "token_selected_experts should be a 2D tensor" + ) + assert token_selected_experts.dtype == torch.int32, ( + "token_selected_experts should be a 32-bit integer tensor" + ) + assert token_selected_experts.shape[0] == local_num_tokens, ( + "token_selected_experts.shape[0] should be local_num_tokens" + ) + assert token_selected_experts.shape[1] == top_k, ( + "token_selected_experts.shape[1] should be top_k" + ) + + payloads = all_payloads[send_rank] + recv_tensors = all_recv_tensors[send_rank] + num_payloads = len(payloads) + assert len(recv_tensors) == num_payloads, ( + "recv_tensors should have the same number of payloads as payloads" + ) + for i in range(num_payloads): + payload = payloads[i] + assert len(payload.shape) == 2, "payload should be a 2D tensor" + assert payload.shape[0] == local_num_tokens, ( + "payload.shape[0] should be local_num_tokens" + ) + + recv_tensor = recv_tensors[i] + assert len(recv_tensor.shape) == 3, "recv_tensor should be a 3D tensor" + assert recv_tensor.shape[0] == ep_size, ( + "recv_tensor.shape[0] should be ep_size" + ) + assert recv_tensor.shape[1] == max_num_tokens, ( + "recv_tensor.shape[1] should be max_num_tokens" + ) + assert recv_tensor.shape[2] == payload.shape[1], ( + "recv_tensor.shape[2] should be payload.shape[1]" + ) + assert recv_tensor.dtype == payload.dtype, ( + "recv_tensor.dtype should be payload.dtype" + ) + + # Verify counters and compact routing tensors + send_counters = all_send_counters[send_rank] + assert len(send_counters.shape) == 1, "send_counters should be a 1D tensor" + assert send_counters.shape[0] == ep_size + assert send_counters.dtype == torch.int32 + + recv_counters = all_recv_counters[send_rank] + assert len(recv_counters.shape) == 1, "recv_counters should be a 1D tensor" + assert recv_counters.shape[0] == ep_size + assert recv_counters.dtype == torch.int32 + + topk_send_indices = all_topk_send_indices[send_rank] + topk_target_ranks = all_topk_target_ranks[send_rank] + assert topk_send_indices.shape == (max_num_tokens, top_k), ( + "topk_send_indices shape" + ) + assert topk_target_ranks.shape == (max_num_tokens, top_k), ( + "topk_target_ranks shape" + ) + assert topk_send_indices.dtype == torch.int32 + assert topk_target_ranks.dtype == torch.int32 + + # Verify send_counters per (send_rank -> target_rank) + for send_rank in range(ep_size): + expected_sends = {} + token_experts = all_token_selected_experts[send_rank] + sent_to_rank = set() + + for token_idx in range(token_experts.shape[0]): + experts = token_experts[token_idx] + target_ranks = compute_target_rank_id(experts, num_experts_per_rank) + sent_to_rank.clear() + + for target_rank in target_ranks.tolist(): + if target_rank not in sent_to_rank: + if target_rank not in expected_sends: + expected_sends[target_rank] = 0 + expected_sends[target_rank] += 1 + sent_to_rank.add(target_rank) + + for target_rank in range(ep_size): + expected_to_rank = expected_sends.get(target_rank, 0) + actual_to_rank = all_send_counters[send_rank][target_rank].item() + assert actual_to_rank == expected_to_rank, ( + f"Rank {send_rank} sent {actual_to_rank} tokens to rank {target_rank}, expected {expected_to_rank}" + ) + + # Verify recv_counters match send_counters + for recv_rank in range(ep_size): + for send_rank in range(ep_size): + expected_recv = all_send_counters[send_rank][recv_rank].item() + actual_recv = all_recv_counters[recv_rank][send_rank].item() + assert actual_recv == expected_recv, ( + f"Rank {recv_rank} received {actual_recv} tokens from rank {send_rank}, expected {expected_recv}" + ) + + # Verify payload content using topk_send_indices and topk_target_ranks + for send_rank in range(ep_size): + token_selected_experts = all_token_selected_experts[send_rank] + payloads = all_payloads[send_rank] + topk_send_indices = all_topk_send_indices[send_rank] + topk_target_ranks = all_topk_target_ranks[send_rank] + local_num_tokens = all_num_tokens[send_rank] + + for token_idx in range(local_num_tokens): + experts = token_selected_experts[token_idx] + target_ranks = compute_target_rank_id(experts, num_experts_per_rank) + # Deduplicate target ranks per token + topk_target_ranks_ref = target_ranks.clone() + seen = set() + for kk in range(top_k): + tr = int(topk_target_ranks_ref[kk].item()) + if tr in seen: + topk_target_ranks_ref[kk] = -1 + else: + seen.add(tr) + + assert ( + topk_target_ranks[token_idx, :].tolist() + == topk_target_ranks_ref.tolist() + ) + + for k in range(top_k): + dst_pos = topk_send_indices[token_idx, k].item() + target_rank = topk_target_ranks[token_idx, k].item() + if dst_pos == -1: + assert target_rank == -1 + continue + recv_tensors = all_recv_tensors[target_rank] + for payload_idx, payload in enumerate(payloads): + recv_tensor = recv_tensors[payload_idx] + source_data = payload[token_idx] + received_data = recv_tensor[send_rank, dst_pos] + torch.testing.assert_close( + received_data, source_data, atol=0, rtol=0 + ) + + # Verify token_selected_experts of invalid tokens are correctly sanitized + for recv_rank in range(ep_size): + expert_ids_recv = all_recv_tensors[recv_rank][expert_id_payload_index] + for source_rank in range(ep_size): + valid = int(all_recv_counters[recv_rank][source_rank].item()) + for token_idx in range(max_num_tokens): + token_expert_ids = expert_ids_recv[source_rank, token_idx] + if token_idx >= valid: + assert torch.all(token_expert_ids == invalid_token_expert_id) + + +def test_moe_a2a_dispatch_impl(ep_size, all_num_tokens, top_k): + """Test MoE A2A dispatch operation.""" + if len(all_num_tokens) != ep_size: + pytest.skip( + f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" + ) + + comm = MPI.COMM_WORLD + # rank = comm.Get_rank() + world_size = comm.Get_size() + + if world_size != ep_size: + pytest.skip(f"Test requires exactly {ep_size} ranks") + + try: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): + pytest.skip("MNNVL not supported on this system") + except Exception: + pytest.skip("MNNVL not supported on this system") + + hidden_size = 1024 + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + invalid_token_expert_id = -1 + + check_any_rank_failed() + + # Run dispatch on this rank + result = run_moe_a2a_dispatch_single_rank( + ep_size, + all_num_tokens, + top_k, + workspace_size_per_rank, + num_experts_per_rank, + hidden_size, + invalid_token_expert_id, + ) + + check_any_rank_failed() + + # Gather results from all ranks + all_results = comm.allgather(result) + + # Extract results + all_token_selected_experts = [r[0] for r in all_results] + all_payloads = [r[1] for r in all_results] + all_recv_tensors = [r[2] for r in all_results] + all_send_counters = [r[3] for r in all_results] + all_topk_send_indices = [r[4] for r in all_results] + all_topk_target_ranks = [r[5] for r in all_results] + all_recv_counters = [r[6] for r in all_results] + all_expert_id_payload_index = [r[7] for r in all_results] + expert_id_payload_index = all_expert_id_payload_index[0] + + assert all(i == expert_id_payload_index for i in all_expert_id_payload_index), ( + "all_expert_id_payload_index should be the same" + ) + + # Verify dispatch results with full counter verification + verify_dispatch( + all_token_selected_experts, + all_payloads, + all_recv_tensors, + all_send_counters, + all_topk_send_indices, + all_topk_target_ranks, + all_recv_counters, + ep_size, + all_num_tokens, + top_k, + num_experts_per_rank, + expert_id_payload_index, + invalid_token_expert_id, + ) + + +@pytest.mark.parametrize( + "ep_size,all_num_tokens,top_k", + [ + # Basic configurations + (4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution + (4, [16, 32, 64, 48], 2), # Four ranks with non-uniform distribution + (2, [100, 50], 2), # Two ranks with different loads + (8, [10, 20, 30, 40, 50, 60, 70, 80], 2), # Eight ranks with increasing load + # Different top_k values + (4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4 + (4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8 + # Edge cases + (4, [1, 1, 1, 1], 2), # Four ranks with single token per rank + ], +) +def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): + """Test MoE A2A dispatch operation.""" + safe_run(test_moe_a2a_dispatch_impl, ep_size, all_num_tokens, top_k) + + +def test_moe_a2a_dispatch_moe_combine_impl(ep_size, all_num_tokens, top_k): + """Test full MoE A2A dispatch + expert processing + combine cycle.""" + if len(all_num_tokens) != ep_size: + pytest.skip( + f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" + ) + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + if world_size != ep_size: + pytest.skip(f"Test requires exactly {ep_size} ranks") + + try: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): + pytest.skip("MNNVL not supported on this system") + except Exception: + pytest.skip("MNNVL not supported on this system") + + torch.cuda.set_device(rank) + + check_any_rank_failed() + + hidden_size = 2880 # gpt-oss + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + mapping = Mapping( + rank=rank, + moe_ep_size=world_size, + tp_size=world_size, + world_size=world_size, + ) + + local_num_tokens = all_num_tokens[rank] + max_num_tokens = max(all_num_tokens) + + # Generate inputs + token_selected_experts = generate_token_selected_experts( + local_num_tokens, ep_size, num_experts_per_rank, top_k + ) + + payloads, expert_id_payload_index = make_bfloat16_payloads( + local_num_tokens, hidden_size, top_k, rank, token_selected_experts + ) + + hidden_states = payloads[0] + token_final_scales = payloads[2] + + # Compute reference (single-GPU MoE) + all_experts = torch.cat( + [ + create_experts( + num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16 + ) + for r in range(ep_size) + ], + dim=0, + ) + + rank_experts = create_experts( + num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 + ) + + reference_output = fake_moe( + hidden_states, + token_selected_experts, + token_final_scales, + all_experts, + is_ep=False, + ) + + # Initialize MoeAlltoAll + MoeAlltoAll._WORKSPACE = None + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=ep_size * num_experts_per_rank, + workspace_size_per_rank=workspace_size_per_rank, + ) + + check_any_rank_failed() + + # Dispatch + recv_tensors = moe_a2a.dispatch( + token_selected_experts=token_selected_experts, + input_payloads=payloads, + runtime_max_tokens_per_rank=max_num_tokens, + ) + + # Unpack received tensors + hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size] + token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k] + token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k] + + # Get workspace-backed tensor for output + moe_output = moe_a2a.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank=max_num_tokens, + hidden_size=hidden_size, + dtype=torch.bfloat16, + ) + moe_output.zero_() + + # Process each rank's tokens with local experts + moe_output.copy_( + fake_moe( + hidden_states_recv.view( + ep_size * max_num_tokens, hidden_states_recv.shape[-1] + ), + token_selected_experts_recv.view( + ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] + ), + token_final_scales_recv.view( + ep_size * max_num_tokens, token_final_scales_recv.shape[-1] + ), + rank_experts, # experts for current rank + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts_per_rank, + ).view(ep_size, max_num_tokens, hidden_size) + ) + + check_any_rank_failed() + + # Combine + combined_output = moe_a2a.combine( + payload=moe_output, + runtime_max_tokens_per_rank=max_num_tokens, + payload_in_workspace=True, + ) + + # Verify against reference + torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) + + check_any_rank_failed() + + +@pytest.mark.parametrize( + "ep_size,all_num_tokens,top_k", + [ + (4, [32, 32, 32, 32], 2), + (4, [16, 32, 64, 48], 2), + (2, [100, 50], 2), + (4, [32, 32, 32, 32], 4), + (4, [1, 1, 1, 1], 2), + (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + ], +) +def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): + """Test full MoE A2A dispatch + expert processing + combine cycle.""" + safe_run(test_moe_a2a_dispatch_moe_combine_impl, ep_size, all_num_tokens, top_k) + + +if __name__ == "__main__": + # Run with: mpirun -n 2 python -m pytest tests/comm/test_mnnvl_a2a.py -v + pytest.main([__file__, "-v", "-s"])