diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index b31f6bb745e..144aadbc746 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -23,7 +23,7 @@ #include #include -namespace tensorrt_llm::kernels::moe_a2a +namespace tensorrt_llm::kernels::mnnvl_throughput { #define ENABLE_DEBUG_PRINT 0 @@ -506,7 +506,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads); // Prepare kernel pointers struct - DispatchKernelPointers kernel_ptrs = {}; // Zero-initialize + DispatchKernelPointers kernel_ptrs = {}; // Fill source data pointers and payload sizes for (int i = 0; i < params.num_payloads; i++) @@ -958,4 +958,4 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id); } -} // namespace tensorrt_llm::kernels::moe_a2a +} // namespace tensorrt_llm::kernels::mnnvl_throughput diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h index ad8fae07b44..27b6f926d16 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -19,7 +19,7 @@ #include #include -namespace tensorrt_llm::kernels::moe_a2a +namespace tensorrt_llm::kernels::mnnvl_throughput { // Configuration constants @@ -91,7 +91,7 @@ struct MoeA2ADispatchParams // Token configuration int local_num_tokens; // Number of tokens on this rank - int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation + int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank int top_k; // Number of experts per token // Expert routing information @@ -101,23 +101,22 @@ struct MoeA2ADispatchParams int num_payloads; // Number of different payload types PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors - // Receive buffers and synchronization - void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload + // Local aux data + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* local_token_counter; // Atomic counter for completed tokens on this rank + int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank + int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank + // per k, -1 for duplicates + int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index + // per k, -1 for duplicates - // Synchronization + // Distributed aux data and recv buffers + int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source // rank has signaled the target rank - uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) - - // Communication tracking - int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank - int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters - int* local_token_counter; // Atomic counter for completed tokens on this rank - - // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) - int* topk_target_ranks; // target rank per k, -1 for duplicates - int* topk_send_indices; // dst index per k, -1 for duplicates + void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload + // CUDA stream cudaStream_t stream; }; @@ -137,30 +136,33 @@ struct MoeA2ACombineParams // Token configuration int local_num_tokens; // Number of tokens on this rank - int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation + int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank int top_k; // Number of experts per token - // Expert routing information - int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target - - // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) - int const* topk_target_ranks; // target rank per k, -1 for duplicates - int const* topk_send_indices; // dst index per k, -1 for duplicates + // Prepare-only field: original payload tensor pointer used to stage into workspace + void const* prepare_payload; - // Single payload information - void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) - void* output_data; // Output buffer [local_num_tokens, elements_per_token] - int elements_per_token; // Number of elements per token - nvinfer1::DataType dtype; // Data type for proper summation + // Output tensor + void* output_data; // Output buffer [local_num_tokens, elements_per_token] + // Payload information + int elements_per_token; // Number of elements per token + nvinfer1::DataType dtype; // Data type for proper summation + + // Local aux data + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank + // per k, -1 for duplicates + int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index + // per k, -1 for duplicates + int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target - // Synchronization + // Distributed aux data and recv buffers uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source // rank has signaled the target rank - uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) + // CUDA stream cudaStream_t stream; - // Prepare-only field: original payload tensor pointer used to stage into workspace - void const* prepare_payload; }; // Combine kernels @@ -175,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params); void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id, int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream); -} // namespace tensorrt_llm::kernels::moe_a2a +} // namespace tensorrt_llm::kernels::mnnvl_throughput diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index f22f9ee06a5..b5c9e18391e 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -30,7 +30,7 @@ namespace tensorrt_llm::nanobind::thop void initBindings(nb::module_& m) { // Export MoE A2A constants - for (auto const& kv : torch_ext::getMoeA2AMetaInfoIndexPairs()) + for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs()) { m.attr(kv.first) = kv.second; } diff --git a/cpp/tensorrt_llm/pybind/thop/bindings.cpp b/cpp/tensorrt_llm/pybind/thop/bindings.cpp index 8f2c96ede43..e50e1e6ac0e 100644 --- a/cpp/tensorrt_llm/pybind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/thop/bindings.cpp @@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::thop void initBindings(pybind11::module_& m) { // Export MoE A2A constants - for (auto const& kv : torch_ext::getMoeA2AMetaInfoIndexPairs()) + for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs()) { m.attr(kv.first) = py::int_(kv.second); } diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h b/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h index 5ef2de2d808..4f84cb845fc 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h +++ b/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h @@ -16,15 +16,18 @@ #pragma once +#include #include #include #include namespace torch_ext { +namespace mnnvl_throughput +{ // Enum for indexing into moe_a2a_metainfo tensor -enum MoeA2AMetaInfoIndex +enum MoeA2AMetaInfoIndex : int64_t { FLAG_VAL_OFFSET_INDEX = 0, LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1, @@ -34,21 +37,29 @@ enum MoeA2AMetaInfoIndex DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4, // Combine completion flags offset COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5, - PAYLOAD_DATA_OFFSET_INDEX = 6, - NUM_METAINFO_FIELDS = 7 + TOPK_TARGET_RANKS_OFFSET_INDEX = 6, + TOPK_SEND_INDICES_OFFSET_INDEX = 7, + PAYLOAD_DATA_OFFSET_INDEX = 8, + NUM_METAINFO_FIELDS = 9 }; +using MoeA2ADataOffsets = std::array; + inline std::vector> getMoeA2AMetaInfoIndexPairs() { return { - {"MOE_A2A_FLAG_VAL_OFFSET_INDEX", static_cast(FLAG_VAL_OFFSET_INDEX)}, - {"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", static_cast(LOCAL_TOKEN_COUNTER_OFFSET_INDEX)}, - {"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", static_cast(SEND_COUNTERS_OFFSET_INDEX)}, - {"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", static_cast(RECV_COUNTERS_OFFSET_INDEX)}, - {"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", - static_cast(DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX)}, - {"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", static_cast(COMBINE_COMPLETION_FLAGS_OFFSET_INDEX)}, - {"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", static_cast(PAYLOAD_DATA_OFFSET_INDEX)}, + {"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX}, + {"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX}, + {"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX}, + {"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX}, + {"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX}, + {"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX}, + {"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX}, + {"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX}, + {"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX}, + {"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS}, }; } + +} // namespace mnnvl_throughput } // namespace torch_ext diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index 56823f97adf..d6e5b7465cc 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -28,126 +28,164 @@ namespace torch_ext { -namespace +namespace mnnvl_throughput { +// TODO: Is Alignment necessary?obu guo // Helper function to align offset to specified byte boundary inline size_t alignOffset(size_t offset, size_t alignment) { return (offset + alignment - 1) & ~(alignment - 1); } -// Structure to hold auxiliary data offsets -struct MoeA2ADataOffsets -{ - size_t flag_val_offset; - size_t local_token_counter_offset; - size_t send_counters_offset; - size_t recv_counters_offset; - size_t dispatch_completion_flags_offset; - size_t combine_completion_flags_offset; - size_t topk_target_ranks_offset; - size_t topk_send_indices_offset; - size_t payload_data_offset; -}; - // Calculate auxiliary data offsets -MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokensPerRank) +MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) { + // TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to + // read. constexpr size_t SIZEOF_INT32 = 4; constexpr size_t CACHELINE_ALIGNMENT = 128; - MoeA2ADataOffsets offsets{}; + MoeA2ADataOffsets offsets; size_t offset = 0; // flag_val - offsets.flag_val_offset = offset; + offsets[FLAG_VAL_OFFSET_INDEX] = offset; offset += SIZEOF_INT32; // local_token_counter - offsets.local_token_counter_offset = offset; + offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX] = offset; offset += SIZEOF_INT32; // send_counters - offsets.send_counters_offset = offset; + offsets[SEND_COUNTERS_OFFSET_INDEX] = offset; offset += epSize * SIZEOF_INT32; // recv_counters - offsets.recv_counters_offset = offset; + offsets[RECV_COUNTERS_OFFSET_INDEX] = offset; offset += epSize * SIZEOF_INT32; // dispatch completion flags offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets.dispatch_completion_flags_offset = offset; + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX] = offset; offset += epSize * SIZEOF_INT32; // combine completion flags offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets.combine_completion_flags_offset = offset; + offsets[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX] = offset; offset += epSize * SIZEOF_INT32; - // topk_target_ranks: [maxNumTokensPerRank, kMaxTopK] + // topk_target_ranks: [maxNumTokens, kMaxTopK] offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets.topk_target_ranks_offset = offset; - offset += static_cast(maxNumTokensPerRank) * static_cast(tensorrt_llm::kernels::moe_a2a::kMaxTopK) + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset; + offset += static_cast(maxNumTokens) * static_cast(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK) * SIZEOF_INT32; - // topk_send_indices: [maxNumTokensPerRank, kMaxTopK] + // topk_send_indices: [maxNumTokens, kMaxTopK] offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets.topk_send_indices_offset = offset; - offset += static_cast(maxNumTokensPerRank) * static_cast(tensorrt_llm::kernels::moe_a2a::kMaxTopK) + offsets[TOPK_SEND_INDICES_OFFSET_INDEX] = offset; + offset += static_cast(maxNumTokens) * static_cast(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK) * SIZEOF_INT32; // payload data offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets.payload_data_offset = offset; + offsets[PAYLOAD_DATA_OFFSET_INDEX] = offset; return offsets; } +// Initialize auxiliary data in workspace +// This function sets up the initial values for flag_val and completion_flags +// +// Inputs: +// - workspace: [ep_size, size_per_rank] unified virtual memory workspace +// - epRank: Current expert parallel rank +// - epSize: Total expert parallel size +// - maxNumTokens: Maximum number of tokens supported +// +// Returns: +// - metainfo: Tensor containing offsets for auxiliary data +torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens) +{ + // Validate inputs + CHECK_TH_CUDA(workspace); + CHECK_TYPE(workspace, torch::kUInt8); + TORCH_CHECK(workspace.dim() == 2, "workspace must be a 2D tensor of shape [epSize, sizePerRank]"); + TORCH_CHECK(workspace.size(0) == epSize, "workspace first dimension must equal epSize"); + TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); + + // Initialize workspace to zero + workspace[epRank].zero_(); + + // Calculate auxiliary data offsets + MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokens); + + // Return metainfo as a tensor containing offsets + torch::Tensor metainfo = torch::empty( + {static_cast(NUM_METAINFO_FIELDS)}, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU)); + + for (int i = 0; i < static_cast(NUM_METAINFO_FIELDS); i++) + { + metainfo[i] = static_cast(offsets[i]); + } + + // Synchronize among ranks + cudaDeviceSynchronize(); + tensorrt_llm::mpi::MpiComm::world().barrier(); + + return metainfo; +} + // MoE All-to-All Dispatch Operation // This operation dispatches tokens and their associated payloads to different expert ranks. // // Inputs: // - tokenSelectedExperts: [local_num_tokens, top_k] tensor of expert indices // - inputPayloads: List of tensors with shape [local_num_tokens, ...] containing data to dispatch -// - workspace: [ep_size, size_per_rank] unified virtual memory workspace where -// size_per_rank = sum(ep_size * max_tokens_per_rank * elements_per_token * element_size) for all -// payloads -// - maxTokensPerRank: Maximum number of tokens that can be received per rank +// - workspace: [ep_size, size_per_rank] unified virtual memory workspace where size_per_rank is large enough to store +// all the auxiliary data and recv payloads. +// - metainfo: [NUM_METAINFO_FIELDS] tensor containing offsets for auxiliary data +// - runtimeMaxTokensPerRank: Maximum of the number of tokens of each DP rank's local batch. This is a dynamic value +// during runtime. +// - maxNumTokens: Maximum number of tokens that could be supported. This is a static value that is setup during +// initialization. // - epRank: Current expert parallel rank // - epSize: Total expert parallel size // - topK: Number of experts selected per token // - numExperts: Total number of experts (must be divisible by epSize) // -// Returns: -// - recvBuffers: List of receive buffers, one for each payload -// - sendCounters: [ep_size] tensor tracking tokens sent to each rank (local) -// - recvCounters: [ep_size] tensor tracking tokens received from each rank (all ranks) -// - topkTargetRanks: [local_num_tokens, top_k] compact routing - target EP rank per k -// - topkSendIndices: [local_num_tokens, top_k] compact routing - dst slot per k +// Return values: +// - recvTensors: Vector of receive buffers (one tensor per payload), each [ep_size, runtimeMaxTokensPerRank, +// elements_per_token] +// - combinePayloadOffset: Offset into workspace for the combine payload region, to be used by the combine operation // // Note: token_selected_experts is used for routing but is NOT automatically included as a payload. // If you want to dispatch token_selected_experts, include it explicitly in inputPayloads. -std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, int64_t> -moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, std::vector const& inputPayloads, - torch::Tensor const& workspace, int64_t maxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, - int64_t numExperts) +std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, + std::vector const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo, + int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts) { - using tensorrt_llm::kernels::moe_a2a::PayloadDescriptor; - using tensorrt_llm::kernels::moe_a2a::MoeA2ADispatchParams; - using tensorrt_llm::kernels::moe_a2a::moe_a2a_dispatch_launch; - using tensorrt_llm::kernels::moe_a2a::kMaxTopK; - using tensorrt_llm::kernels::moe_a2a::kMaxPayloads; + using tensorrt_llm::kernels::mnnvl_throughput::PayloadDescriptor; + using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ADispatchParams; + using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_dispatch_launch; + using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK; + using tensorrt_llm::kernels::mnnvl_throughput::kMaxPayloads; // Validate inputs CHECK_INPUT(tokenSelectedExperts, torch::kInt32); TORCH_CHECK(tokenSelectedExperts.dim() == 2, "tokenSelectedExperts must be a 2D tensor"); TORCH_CHECK(tokenSelectedExperts.size(1) == topK, "tokenSelectedExperts must have topK columns"); + CHECK_CPU(metainfo); + CHECK_TYPE(metainfo, torch::kInt64); + TORCH_CHECK(metainfo.dim() == 1, "metainfo must be a 1D tensor"); + TORCH_CHECK(metainfo.size(0) == static_cast(NUM_METAINFO_FIELDS), + "metainfo must have NUM_METAINFO_FIELDS elements"); + MoeA2ADataOffsets const& offsets = *reinterpret_cast(metainfo.data_ptr()); + int64_t localNumTokens = tokenSelectedExperts.size(0); TORCH_CHECK(localNumTokens > 0, "localNumTokens must be positive"); - TORCH_CHECK(maxTokensPerRank > 0, "maxTokensPerRank must be positive"); + TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]"); TORCH_CHECK(!inputPayloads.empty(), "inputPayloads must not be empty"); @@ -164,20 +202,12 @@ moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, std::vector payloadByteSizes; std::vector payloadElementSizes; std::vector payloadElementsPerToken; - for (auto const& payload : inputPayloads) { CHECK_CONTIGUOUS(payload); @@ -189,7 +219,7 @@ moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, std::vector(payload.size(1)); int elementSize = static_cast(payload.dtype().itemsize()); // Each payload buffer stores data from ALL ranks - int64_t bytesPerPayload = epSize * maxTokensPerRank * elementsPerToken * elementSize; + int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize; payloadByteSizes.push_back(bytesPerPayload); payloadElementSizes.push_back(elementSize); @@ -197,95 +227,76 @@ moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, std::vector= requiredSize, "Workspace size per rank insufficient. " "Need at least ", - requiredSize, " bytes (", offsets.payload_data_offset, " for auxiliary data + ", totalBytesNeeded, + requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded, " for payloads), but got ", sizePerRank); - // Setup receive buffer pointers from unified workspace - std::vector payloadDescriptors; - // Get base workspace pointer - uint8_t* workspace_ptr = workspace.data_ptr(); + uint8_t* workspacePtr = workspace.data_ptr(); + uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0); // Setup payload descriptors for source data - for (int i = 0; i < static_cast(inputPayloads.size()); i++) + int num_payloads = static_cast(inputPayloads.size()); + std::vector payloadDescriptors(num_payloads); + for (int i = 0; i < num_payloads; i++) { - PayloadDescriptor desc{}; - desc.src_data = inputPayloads[i].data_ptr(); - desc.element_size = payloadElementSizes[i]; - desc.elements_per_token = payloadElementsPerToken[i]; - payloadDescriptors.push_back(desc); + payloadDescriptors[i].src_data = inputPayloads[i].data_ptr(); + payloadDescriptors[i].element_size = payloadElementSizes[i]; + payloadDescriptors[i].elements_per_token = payloadElementsPerToken[i]; } - // Create tensors for return values (these are views into workspace) - auto options = tokenSelectedExperts.options().dtype(torch::kInt32); - uint8_t* rank_workspace = workspace_ptr + epRank * workspace.stride(0); - - // Create send_counters tensor - view into workspace - // Initialized to 0 in prepare dispatch kernel - torch::Tensor sendCounters = torch::from_blob(rank_workspace + offsets.send_counters_offset, {epSize}, options); - - // Create recv_counters tensor - view into workspace - // No need for initialization - torch::Tensor recvCounters = torch::from_blob(rank_workspace + offsets.recv_counters_offset, {epSize}, options); - - // Create local_token_counter - view into workspace - // Initialized to 0 in prepare dispatch kernel - torch::Tensor localTokenCounter - = torch::from_blob(rank_workspace + offsets.local_token_counter_offset, {1}, options); - - // Allocate compact Top-K routing tensors [localNumTokens, topK] - torch::Tensor topkTargetRanks = torch::empty({localNumTokens, topK}, options); - torch::Tensor topkSendIndices = torch::empty({localNumTokens, topK}, options); - // Setup dispatch parameters MoeA2ADispatchParams params{}; params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); // TODO: Decide this based on the workload + params.ep_size = static_cast(epSize); + params.ep_rank = static_cast(epRank); + params.num_experts_per_rank = static_cast(numExperts) / static_cast(epSize); + params.local_num_tokens = static_cast(localNumTokens); + params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); + params.top_k = static_cast(topK); + params.token_selected_experts = tokenSelectedExperts.data_ptr(); - params.num_payloads = static_cast(payloadDescriptors.size()); + + params.num_payloads = num_payloads; std::copy(payloadDescriptors.begin(), payloadDescriptors.end(), ¶ms.payloads[0]); - params.flag_val = reinterpret_cast(rank_workspace + offsets.flag_val_offset); - // Calculate and store recv buffer pointers directly in params + params.flag_val = reinterpret_cast(rankWorkSpacePtr + offsets[FLAG_VAL_OFFSET_INDEX]); + params.local_token_counter = reinterpret_cast(rankWorkSpacePtr + offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX]); + params.send_counters = reinterpret_cast(rankWorkSpacePtr + offsets[SEND_COUNTERS_OFFSET_INDEX]); + params.topk_target_ranks = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX]); + params.topk_send_indices = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_SEND_INDICES_OFFSET_INDEX]); + for (int target_rank = 0; target_rank < epSize; target_rank++) { - // Each rank gets workspace[target_rank] - calculate base pointer - uint8_t* target_workspace = workspace_ptr + (target_rank * workspace.stride(0)); + uint8_t* targetWorkSpacePtr = workspacePtr + (target_rank * workspace.stride(0)); - params.recv_counters[target_rank] = reinterpret_cast(target_workspace + offsets.recv_counters_offset); + params.recv_counters[target_rank] + = reinterpret_cast(targetWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); params.completion_flags[target_rank] - = reinterpret_cast(target_workspace + offsets.dispatch_completion_flags_offset); + = reinterpret_cast(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]); - int64_t offset = offsets.payload_data_offset; // Start after auxiliary data - - for (int payload_idx = 0; payload_idx < static_cast(inputPayloads.size()); payload_idx++) + size_t offset = static_cast(offsets[PAYLOAD_DATA_OFFSET_INDEX]); + for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { - // Store buffer pointer for kernel - params.recv_buffers[target_rank][payload_idx] = target_workspace + offset; - + // Store pointer for current payload + params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + offset; // Update offset for next payload offset += payloadByteSizes[payload_idx]; } } - params.max_tokens_per_rank = static_cast(maxTokensPerRank); - params.send_counters = sendCounters.data_ptr(); - params.local_token_counter = localTokenCounter.data_ptr(); - params.topk_target_ranks = topkTargetRanks.data_ptr(); - params.topk_send_indices = topkSendIndices.data_ptr(); - params.local_num_tokens = static_cast(localNumTokens); - params.ep_size = static_cast(epSize); - params.ep_rank = static_cast(epRank); - params.top_k = static_cast(topK); - params.num_experts_per_rank = static_cast(numExperts) / static_cast(epSize); + params.stream = at::cuda::getCurrentCUDAStream(); // Prepare for dispatch (zero counters/indices and increment flag_val) @@ -297,28 +308,25 @@ moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, std::vector recvBuffers; - auto* current_rank_workspace = workspace_ptr + (epRank * workspace.stride(0)); - int64_t offset = offsets.payload_data_offset; - - for (int payload_idx = 0; payload_idx < static_cast(inputPayloads.size()); payload_idx++) + std::vector recvTensors; + size_t offset = static_cast(offsets[PAYLOAD_DATA_OFFSET_INDEX]); + for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { auto const& payload = inputPayloads[payload_idx]; - - // Create tensor view for this payload - contains data from ALL ranks - auto recvBuffer = torch::from_blob(current_rank_workspace + offset, - {epSize, maxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options()); - recvBuffers.push_back(recvBuffer); + // Create tensor view for this payload + auto recvTensor = torch::from_blob(rankWorkSpacePtr + offset, + {epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options()); + recvTensors.push_back(recvTensor); // Update offset for next payload offset += payloadByteSizes[payload_idx]; } + // Compute aligned offset after dispatch payloads for combine payload region constexpr size_t CACHELINE_ALIGNMENT = 128; int64_t combinePayloadOffset = static_cast(alignOffset(static_cast(offset), CACHELINE_ALIGNMENT)); - return std::make_tuple( - std::move(recvBuffers), sendCounters, recvCounters, topkTargetRanks, topkSendIndices, combinePayloadOffset); + return std::make_tuple(std::move(recvTensors), combinePayloadOffset); } // MoE All-to-All Combine Operation @@ -332,32 +340,21 @@ moeA2ADispatchOp(torch::Tensor const& tokenSelectedExperts, std::vector 0, "elementsPerToken must be positive"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); @@ -383,40 +380,25 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& topkTargetRanks, torch::Tenso TORCH_CHECK(false, "Unsupported data type for payload"); } - // Create output tensor (local on current rank), no need for initialization - torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options()); - - // Setup combine parameters - MoeA2ACombineParams params{}; - params.one_block_per_token - = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); // TODO: Decide this based on the workload - params.ep_size = static_cast(epSize); - params.ep_rank = static_cast(epRank); - params.local_num_tokens = static_cast(localNumTokens); - params.max_tokens_per_rank = static_cast(maxTokensPerRank); - params.top_k = static_cast(topK); - params.topk_target_ranks = topkTargetRanks.data_ptr(); - params.topk_send_indices = topkSendIndices.data_ptr(); - params.output_data = output.data_ptr(); - params.elements_per_token = static_cast(elementsPerToken); - params.dtype = nvDtype; - params.recv_counters = recvCounters.data_ptr(); - params.stream = at::cuda::getCurrentCUDAStream(); - - MoeA2ADataOffsets offsets = calculateOffsets(static_cast(epSize), static_cast(maxTokensPerRank)); + CHECK_CPU(metainfo); + CHECK_TYPE(metainfo, torch::kInt64); + TORCH_CHECK(metainfo.dim() == 1, "metainfo must be a 1D tensor"); + TORCH_CHECK(metainfo.size(0) == static_cast(NUM_METAINFO_FIELDS), + "metainfo must have NUM_METAINFO_FIELDS elements"); + MoeA2ADataOffsets const& offsets = *reinterpret_cast(metainfo.data_ptr()); // Validate workspace and set synchronization pointers CHECK_TH_CUDA(workspace); CHECK_TYPE(workspace, torch::kUInt8); TORCH_CHECK(workspace.dim() == 2 && workspace.size(0) == epSize, "workspace must be [ep_size, size_per_rank]"); - uint8_t* workspace_ptr = workspace.data_ptr(); + uint8_t* workspacePtr = workspace.data_ptr(); int64_t sizePerRank = workspace.size(1); - uint8_t* workspace_currank_base = workspace_ptr + epRank * workspace.stride(0); + uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0); // If user claims payload is in workspace, ensure payload tensor matches combinePayloadOffset if (payloadInWorkspace) { - TORCH_CHECK(payload.data_ptr() == workspace_currank_base + combinePayloadOffset, + TORCH_CHECK(payload.data_ptr() == rankWorkSpacePtr + combinePayloadOffset, "payload_in_workspace is true but 'payload' dataptr does not match combinePayloadOffset"); } @@ -425,27 +407,42 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& topkTargetRanks, torch::Tenso "workspace does not contain enough space for the payload region for combine. combine payload offset=", combinePayloadOffset, ", payload size needed=", payloadSize, ", workspace size per rank=", sizePerRank); - for (int src_rank = 0; src_rank < epSize; src_rank++) - { - params.recv_buffers[src_rank] = workspace_ptr + src_rank * workspace.stride(0) + combinePayloadOffset; - } + // Create output tensor (local on current rank), no need for initialization + torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options()); - // completion flags for all ranks (combine) - for (int rank = 0; rank < epSize; rank++) + // Setup combine parameters + MoeA2ACombineParams params{}; + params.one_block_per_token + = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); // TODO: Decide this based on the workload + params.ep_size = static_cast(epSize); + params.ep_rank = static_cast(epRank); + params.local_num_tokens = static_cast(localNumTokens); + params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); + params.top_k = static_cast(topK); + // If payload is not in workspace, stage it into current rank's region at prepare phase + if (!payloadInWorkspace) { - uint8_t* rank_base = workspace_ptr + rank * workspace.stride(0); - params.completion_flags[rank] - = reinterpret_cast(rank_base + offsets.combine_completion_flags_offset); + params.prepare_payload = payload.data_ptr(); } - params.flag_val - = reinterpret_cast(workspace_ptr + epRank * workspace.stride(0) + offsets.flag_val_offset); + params.output_data = output.data_ptr(); + params.elements_per_token = static_cast(elementsPerToken); + params.dtype = nvDtype; - // If payload is not already in workspace, stage it into current rank's region - if (!payloadInWorkspace) + params.flag_val = reinterpret_cast(rankWorkSpacePtr + offsets[FLAG_VAL_OFFSET_INDEX]); + params.topk_target_ranks = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX]); + params.topk_send_indices = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_SEND_INDICES_OFFSET_INDEX]); + params.recv_counters = reinterpret_cast(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); + + for (int target_rank = 0; target_rank < epSize; target_rank++) { - params.prepare_payload = payload.data_ptr(); + uint8_t* target_workspace_ptr = workspacePtr + target_rank * workspace.stride(0); + params.completion_flags[target_rank] + = reinterpret_cast(target_workspace_ptr + offsets[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX]); + params.recv_buffers[target_rank] = target_workspace_ptr + combinePayloadOffset; } + params.stream = at::cuda::getCurrentCUDAStream(); + moe_a2a_prepare_combine_launch(params); // Launch the combine kernel @@ -456,90 +453,47 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& topkTargetRanks, torch::Tenso return output; } -// Initialize auxiliary data in workspace -// This function sets up the initial values for flag_val and completion_flags -// -// Inputs: -// - workspace: [ep_size, size_per_rank] unified virtual memory workspace -// - epRank: Current expert parallel rank -// - epSize: Total expert parallel size -// -// Returns: -// - Auxiliary data size (payload_data_offset) in bytes -// -// The function initializes: -// - flag_val to 1 (on current rank) -// - completion_flags to 0 (on all ranks) -torch::Tensor moeA2AInitializeOp( - torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokensPerRank) -{ - // Validate inputs - CHECK_TH_CUDA(workspace); - CHECK_TYPE(workspace, torch::kUInt8); - TORCH_CHECK(workspace.dim() == 2, "workspace must be a 2D tensor of shape [epSize, sizePerRank]"); - TORCH_CHECK(workspace.size(0) == epSize, "workspace first dimension must equal epSize"); - TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); - - // Initialize workspace to zero - workspace[epRank].zero_(); - - // Calculate auxiliary data offsets - MoeA2ADataOffsets offsets = calculateOffsets(epSize, maxNumTokensPerRank); - - // Return moe_a2a_metainfo as a tensor containing offsets - torch::Tensor moe_a2a_metainfo - = torch::zeros({NUM_METAINFO_FIELDS}, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU)); - moe_a2a_metainfo[FLAG_VAL_OFFSET_INDEX] = static_cast(offsets.flag_val_offset); - moe_a2a_metainfo[LOCAL_TOKEN_COUNTER_OFFSET_INDEX] = static_cast(offsets.local_token_counter_offset); - moe_a2a_metainfo[SEND_COUNTERS_OFFSET_INDEX] = static_cast(offsets.send_counters_offset); - moe_a2a_metainfo[RECV_COUNTERS_OFFSET_INDEX] = static_cast(offsets.recv_counters_offset); - moe_a2a_metainfo[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX] - = static_cast(offsets.dispatch_completion_flags_offset); - moe_a2a_metainfo[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX] - = static_cast(offsets.combine_completion_flags_offset); - moe_a2a_metainfo[PAYLOAD_DATA_OFFSET_INDEX] = static_cast(offsets.payload_data_offset); - - // Memset workspace to 0 and synchronize among ranks - workspace[epRank].zero_(); - cudaDeviceSynchronize(); - tensorrt_llm::mpi::MpiComm::world().barrier(); - - return moe_a2a_metainfo; -} - // Op: moe_a2a_sanitize_expert_ids -void moeA2ASanitizeExpertIdsOp(torch::Tensor expert_ids, torch::Tensor recv_counters, int64_t invalid_expert_id) +void moeA2ASanitizeExpertIdsOp(torch::Tensor& expert_ids, torch::Tensor& workspace, torch::Tensor const& metainfo, + int64_t epRank, int64_t invalid_expert_id) { CHECK_INPUT(expert_ids, torch::kInt32); - CHECK_INPUT(recv_counters, torch::kInt32); - TORCH_CHECK(expert_ids.dim() == 3, "expert_ids must be [ep_size, max_tokens_per_rank, top_k]"); - TORCH_CHECK(recv_counters.dim() == 1, "recv_counters must be [ep_size]"); - TORCH_CHECK(expert_ids.size(0) == recv_counters.size(0), "expert_ids and recv_counters must have the same ep_size"); + TORCH_CHECK(expert_ids.dim() == 3, "expert_ids must be [ep_size, runtime_max_tokens_per_rank, top_k]"); int ep_size = static_cast(expert_ids.size(0)); - int max_tokens_per_rank = static_cast(expert_ids.size(1)); + int runtime_max_tokens_per_rank = static_cast(expert_ids.size(1)); int top_k = static_cast(expert_ids.size(2)); - tensorrt_llm::kernels::moe_a2a::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr(), - recv_counters.data_ptr(), static_cast(invalid_expert_id), ep_size, max_tokens_per_rank, top_k, + CHECK_CPU(metainfo); + CHECK_TYPE(metainfo, torch::kInt64); + TORCH_CHECK(metainfo.dim() == 1, "metainfo must be a 1D tensor"); + TORCH_CHECK(metainfo.size(0) == static_cast(NUM_METAINFO_FIELDS), + "metainfo must have NUM_METAINFO_FIELDS elements"); + MoeA2ADataOffsets const& offsets = *reinterpret_cast(metainfo.data_ptr()); + + uint8_t* rankWorkSpacePtr = workspace.data_ptr() + epRank * workspace.stride(0); + int* recv_counters = reinterpret_cast(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); + + tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr(), + recv_counters, static_cast(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k, at::cuda::getCurrentCUDAStream()); } // Return a workspace-backed tensor for combine payload region using from_blob torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, - int64_t maxTokensPerRank, int64_t combinePayloadOffset, c10::ScalarType outDtype, int64_t hiddenSize) + int64_t runtimeMaxTokensPerRank, int64_t combinePayloadOffset, c10::ScalarType outDtype, int64_t hiddenSize) { CHECK_TH_CUDA(workspace); CHECK_TYPE(workspace, torch::kUInt8); TORCH_CHECK(workspace.dim() == 2, "workspace must be [ep_size, size_per_rank_bytes]"); TORCH_CHECK(epRank >= 0 && epRank < workspace.size(0), "epRank out of range"); TORCH_CHECK(epSize == workspace.size(0), "epSize mismatch with workspace"); - TORCH_CHECK(maxTokensPerRank > 0, "maxTokensPerRank must be positive"); + TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive"); TORCH_CHECK(hiddenSize > 0, "hidden must be positive"); int64_t sizePerRank = workspace.size(1); // bytes int64_t elementSize = static_cast(c10::elementSize(outDtype)); - int64_t bytesNeeded = epSize * maxTokensPerRank * hiddenSize * elementSize; + int64_t bytesNeeded = epSize * runtimeMaxTokensPerRank * hiddenSize * elementSize; TORCH_CHECK(combinePayloadOffset >= 0, "combine_payload_offset must be non-negative"); TORCH_CHECK(combinePayloadOffset + bytesNeeded <= sizePerRank, "workspace does not have enough space for combine payload tensor. combine payload offset=", @@ -550,46 +504,45 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in uint8_t* dataPtr = rankBase + combinePayloadOffset; auto options = workspace.options().dtype(outDtype); - torch::Tensor t = torch::from_blob(dataPtr, {epSize * maxTokensPerRank, hiddenSize}, options); + torch::Tensor t = torch::from_blob(dataPtr, {epSize * runtimeMaxTokensPerRank, hiddenSize}, options); return t; } -} // anonymous namespace +} // namespace mnnvl_throughput } // namespace torch_ext // PyTorch bindings TORCH_LIBRARY_FRAGMENT(trtllm, module) { - // Note that we returns recv_buffers as a list of views into workspace, we need to upcast its alias + // Note that we returns recv_tensors as a list of views into workspace, we need to upcast its alias // group to wildcard (a!->*). See // https://github.com/pytorch/pytorch/blob/b1eb6dede556136f9fdcee28415b0358d58ad877/aten/src/ATen/native/README.md#annotations module.def( - "moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, Tensor(a!->*) workspace, int " - "max_tokens_per_rank, " - "int ep_rank, int ep_size, int top_k, int num_experts) -> (Tensor(a!)[], Tensor(a!), Tensor(a!), Tensor, " - "Tensor, " - "int)"); + "moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, " + "Tensor(a!->*) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, " + "int ep_rank, int ep_size, int top_k, int num_experts) -> (Tensor(a!)[], int)"); module.def( - "moe_a2a_combine(Tensor topk_target_ranks, Tensor topk_send_indices, Tensor(a) recv_counters, Tensor(a) " - "payload, " - "Tensor(a!) workspace, int max_tokens_per_rank, int ep_rank, int ep_size, int top_k, int " - "combine_payload_offset, " + "moe_a2a_combine(Tensor(a) payload, int local_num_tokens," + "Tensor(a!) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, " + "int ep_rank, int ep_size, int top_k, int combine_payload_offset, " "bool payload_in_workspace) -> Tensor"); module.def( "moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank) -> Tensor"); - module.def("moe_a2a_sanitize_expert_ids(Tensor(a!) expert_ids, Tensor recv_counters, int invalid_expert_id) -> ()"); module.def( - "moe_a2a_get_combine_payload_tensor(Tensor(a) workspace, int ep_rank, int ep_size, int max_tokens_per_rank, " - "int " - "combine_payload_offset, ScalarType out_dtype, int hidden) -> Tensor(a)"); + "moe_a2a_sanitize_expert_ids(Tensor(a!) expert_ids, Tensor(a!) workspace, Tensor metainfo, int ep_rank, int " + "invalid_expert_id) -> ()"); + module.def( + "moe_a2a_get_combine_payload_tensor(Tensor(a) workspace, int ep_rank, int ep_size, int " + "runtime_max_tokens_per_rank, " + "int combine_payload_offset, ScalarType out_dtype, int hidden_size) -> Tensor(a)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, module) { - module.impl("moe_a2a_dispatch", &torch_ext::moeA2ADispatchOp); - module.impl("moe_a2a_combine", &torch_ext::moeA2ACombineOp); - module.impl("moe_a2a_initialize", &torch_ext::moeA2AInitializeOp); - module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::moeA2ASanitizeExpertIdsOp); - module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::moeA2AGetCombinePayloadTensorOp); + module.impl("moe_a2a_dispatch", &torch_ext::mnnvl_throughput::moeA2ADispatchOp); + module.impl("moe_a2a_combine", &torch_ext::mnnvl_throughput::moeA2ACombineOp); + module.impl("moe_a2a_initialize", &torch_ext::mnnvl_throughput::moeA2AInitializeOp); + module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::mnnvl_throughput::moeA2ASanitizeExpertIdsOp); + module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::mnnvl_throughput::moeA2AGetCombinePayloadTensorOp); } diff --git a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp index 0dd0445e67f..551a6150843 100644 --- a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp @@ -45,7 +45,8 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional std::optional const hidden_size_output, int64_t const local_expert_offset, int64_t const local_num_experts, std::optional const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t moeConfigIndex, - torch::optional const& topk_weights, torch::optional const& topk_ids) + torch::optional const& topk_weights, torch::optional const& topk_ids, + torch::optional const& out_tensor) { TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by MXFP4 block scale MOE"); TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64 @@ -402,9 +403,23 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional output2_scale_scalar->sizes()[0] == local_num_experts, "output2_scales_scalar has incorrect dim 0."); } - // allocate output - at::Tensor output = at::detail::empty_cuda({args.num_tokens, args.hidden_size_output.value()}, - at::ScalarType::BFloat16, hidden_states.device(), std::nullopt); + // allocate or use provided output + at::Tensor output; + if (out_tensor.has_value()) + { + TORCH_CHECK(out_tensor->scalar_type() == at::ScalarType::BFloat16, "out_tensor must be bfloat16."); + TORCH_CHECK(out_tensor->dim() == 2, "out_tensor must be 2D."); + TORCH_CHECK( + out_tensor->sizes()[0] == args.num_tokens && out_tensor->sizes()[1] == args.hidden_size_output.value(), + "out_tensor has incorrect shape."); + TORCH_CHECK(out_tensor->device() == hidden_states.device(), "out_tensor must be on the same device as inputs."); + output = out_tensor.value(); + } + else + { + output = at::detail::empty_cuda({args.num_tokens, args.hidden_size_output.value()}, at::ScalarType::BFloat16, + hidden_states.device(), std::nullopt); + } // setup workspace workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr(); @@ -513,7 +528,8 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, std::nullopt, std::nullopt, std::nullopt, num_experts, top_k, n_group, topk_group, intermediate_size, std::nullopt, local_expert_offset, local_num_experts, routed_scaling_factor, - tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids); + tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids, + /*output=*/torch::nullopt); // TODO: Support user-provided output } private: @@ -574,7 +590,8 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder std::optional const n_group, std::optional const topk_group, int64_t intermediate_size, std::optional const hidden_size_output, int64_t local_expert_offset, int64_t local_num_experts, std::optional routed_scaling_factor, int64_t routing_method_type, std::vector tile_config_pair, - torch::optional const& topk_weights, torch::optional const& topk_ids) + torch::optional const& topk_weights, torch::optional const& topk_ids, + torch::optional const& output) { // tile_config_pair corresponds to pair (tileN, config) auto [tileN, config] = std::tie(tile_config_pair[0], tile_config_pair[1]); @@ -597,7 +614,7 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder gemm2_weights_scale, gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, num_experts, top_k, n_group, topk_group, intermediate_size, hidden_size_output, local_expert_offset, local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config, - topk_weights, topk_ids); + topk_weights, topk_ids, output); } private: diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index ba41196ca9b..11d6f86670b 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -719,6 +719,7 @@ def forward( self, inputs: List[torch.Tensor], tactic: List[int] = [-1, -1], + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert isinstance(tactic, list) @@ -735,7 +736,7 @@ def forward( self.intermediate_size, self.hidden_size_output, self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic, - args.topk_weights, args.topk_ids) + args.topk_weights, args.topk_ids, output) def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, @@ -852,7 +853,8 @@ def mxe4m3_mxe2m1_block_scale_moe_runner( routing_method_type: int, act_type: int, topk_weights: Optional[torch.Tensor] = None, - topk_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + topk_ids: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None) -> torch.Tensor: tuner = AutoTuner.get() kernel_runner = MxE4m3MxE2m1BlockScaleMoERunner( @@ -905,7 +907,8 @@ def mxe4m3_mxe2m1_block_scale_moe_runner( input_tensors[ 0] = routing_logits # replace dummy routing logits with actual routing logits return kernel_runner(input_tensors, - tactic=[-1, -1] if best_tactic == -1 else best_tactic) + tactic=[-1, -1] if best_tactic == -1 else best_tactic, + output=output) @dataclass(frozen=True) @@ -1003,7 +1006,8 @@ def forward( self.n_group, self.topk_group, self.intermediate_size, None, self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic, - args.topk_weights, args.topk_ids) + args.topk_weights, args.topk_ids, None + ) # TODO: Currently user provided output is only supported in w4a8_mxfp4_mxfp8 def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, diff --git a/tensorrt_llm/_torch/distributed/moe_alltoall.py b/tensorrt_llm/_torch/distributed/moe_alltoall.py index 21ec2574c85..fe906600e62 100644 --- a/tensorrt_llm/_torch/distributed/moe_alltoall.py +++ b/tensorrt_llm/_torch/distributed/moe_alltoall.py @@ -5,7 +5,8 @@ with proper workspace management and synchronization. """ -from typing import Optional +from dataclasses import dataclass +from typing import Dict, Optional import torch @@ -15,6 +16,13 @@ from tensorrt_llm.mapping import Mapping +@dataclass +class _A2AState: + phase: str = "idle" # idle | dispatched + local_num_tokens: int | None = None + combine_payload_offset: int | None = None + + class MoeAlltoAll: """ Manages MoE All-to-All operations with proper workspace allocation and synchronization. @@ -23,48 +31,47 @@ class MoeAlltoAll: and auxiliary data structures needed for cross-GPU communication. """ - # Constants from C++ (must match moeAlltoAllKernels.h) - MAX_RANKS = 64 - MAX_TOP_K = 8 - MAX_PAYLOADS = 8 - # Single shared workspace/memory across the process _WORKSPACE: dict | None = None - # MetaInfo indices - initialized from C++ constants - FLAG_VAL_OFFSET_INDEX = None - LOCAL_TOKEN_COUNTER_OFFSET_INDEX = None - SEND_COUNTERS_OFFSET_INDEX = None - RECV_COUNTERS_OFFSET_INDEX = None - DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = None - COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None - PAYLOAD_DATA_OFFSET_INDEX = None + _METAINFO_INDEX: Dict[str, int] | None = None @classmethod def _init_constants(cls): """Initialize constants from C++ if not already done.""" - if cls.FLAG_VAL_OFFSET_INDEX is None: + # TODO: Can we avoid such code duplication? + if cls._METAINFO_INDEX is None: thop = _tllm_internal.thop - cls.FLAG_VAL_OFFSET_INDEX = int(thop.MOE_A2A_FLAG_VAL_OFFSET_INDEX) - cls.LOCAL_TOKEN_COUNTER_OFFSET_INDEX = int( - thop.MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX) - cls.SEND_COUNTERS_OFFSET_INDEX = int( - thop.MOE_A2A_SEND_COUNTERS_OFFSET_INDEX) - cls.RECV_COUNTERS_OFFSET_INDEX = int( - thop.MOE_A2A_RECV_COUNTERS_OFFSET_INDEX) - cls.DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = int( - thop.MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX) - cls.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = int( - thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX) - cls.PAYLOAD_DATA_OFFSET_INDEX = int( - thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX) + cls._METAINFO_INDEX = { + "FLAG_VAL_OFFSET_INDEX": + int(thop.MOE_A2A_FLAG_VAL_OFFSET_INDEX), + "LOCAL_TOKEN_COUNTER_OFFSET_INDEX": + int(thop.MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX), + "SEND_COUNTERS_OFFSET_INDEX": + int(thop.MOE_A2A_SEND_COUNTERS_OFFSET_INDEX), + "RECV_COUNTERS_OFFSET_INDEX": + int(thop.MOE_A2A_RECV_COUNTERS_OFFSET_INDEX), + "DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX": + int(thop.MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX), + "COMBINE_COMPLETION_FLAGS_OFFSET_INDEX": + int(thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX), + "TOPK_TARGET_RANKS_OFFSET_INDEX": + int(thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX), + "TOPK_SEND_INDICES_OFFSET_INDEX": + int(thop.MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX), + "PAYLOAD_DATA_OFFSET_INDEX": + int(thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX), + "NUM_METAINFO_FIELDS": + int(thop.MOE_A2A_NUM_METAINFO_FIELDS), + } def __init__( self, mapping: Mapping, - max_num_tokens_per_rank: int, + max_num_tokens: int, top_k: int, num_experts: int, + # TODO: WE should be able to know the required workspace size if knowing max_num_tokens, ep_size and hidden_size workspace_size_per_rank: int = 256 * 1024 * 1024, ): """ @@ -72,30 +79,30 @@ def __init__( Args: mapping: TensorRT-LLM Mapping object containing rank information - max_num_tokens_per_rank: Maximum number of tokens per rank + max_num_tokens: Maximum number of tokens supported. Should be ModelConfig.max_num_tokens. workspace_size_per_rank: Size of workspace per rank in bytes """ # Initialize constants from C++ self._init_constants() - self.mapping = mapping - self.ep_size = mapping.moe_ep_size # Expert parallel size + # Initialize or reuse workspace + MnnvlMemory.initialize() + + self.workspace_size_per_rank = workspace_size_per_rank + self.max_num_tokens = max_num_tokens + self.ep_size = mapping.moe_ep_size self.ep_rank = mapping.moe_ep_rank - self.max_num_tokens_per_rank = max_num_tokens_per_rank + self.top_k = top_k self.num_experts = num_experts if not isinstance(self.top_k, int) or self.top_k <= 0: raise ValueError("top_k must be a positive int") if not isinstance(self.num_experts, int) or self.num_experts <= 0: raise ValueError("num_experts must be a positive int") - self.workspace_size_per_rank = workspace_size_per_rank - - # Initialize or reuse workspace - MnnvlMemory.initialize() if self._WORKSPACE is None: tllm_logger.info( - f"MoE AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}" + f"MoE AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}" ) mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) @@ -103,11 +110,11 @@ def __init__( workspace, self.ep_rank, self.ep_size, - self.max_num_tokens_per_rank, + self.max_num_tokens, ) MoeAlltoAll._WORKSPACE = { "workspace_size_per_rank": workspace_size_per_rank, - "max_num_tokens_per_rank": self.max_num_tokens_per_rank, + "max_num_tokens": self.max_num_tokens, "ep_rank": self.ep_rank, "ep_size": self.ep_size, "mnnvl_mem": mnnvl_mem, @@ -118,7 +125,7 @@ def __init__( assert self._WORKSPACE[ "workspace_size_per_rank"] == workspace_size_per_rank, "reuse workspace with different workspace_size_per_rank" assert self._WORKSPACE[ - "max_num_tokens_per_rank"] == self.max_num_tokens_per_rank, "reuse workspace with different max_num_tokens_per_rank" + "max_num_tokens"] == self.max_num_tokens, "reuse workspace with different max_num_tokens" assert self._WORKSPACE[ "ep_rank"] == self.ep_rank, "reuse workspace with different ep_rank" assert self._WORKSPACE[ @@ -126,17 +133,14 @@ def __init__( self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] self.workspace = self._WORKSPACE["workspace"] - self.moe_a2a_metainfo = self._WORKSPACE["metainfo"] - self.max_num_tokens_per_rank = self._WORKSPACE[ - "max_num_tokens_per_rank"] - # Internal state and aux data - self.send_counters: torch.Tensor | None = None - self.recv_counters: torch.Tensor | None = None - self._state: str = "idle" # idle | dispatched + self.metainfo = self._WORKSPACE["metainfo"] + # Internal state + self._state: _A2AState = _A2AState() def dispatch(self, token_selected_experts: torch.Tensor, input_payloads: list[torch.Tensor], + runtime_max_tokens_per_rank: int, invalid_token_expert_id: Optional[int] = None, expert_id_payload_index: Optional[int] = None): """ @@ -145,84 +149,87 @@ def dispatch(self, Args: token_selected_experts: [local_num_tokens, top_k] tensor of expert indices input_payloads: List of tensors to dispatch, each has shape [local_num_tokens, payload_num_elements_per_token] + runtime_max_tokens_per_rank: Maximum of the number of tokens of each DP rank's local batch. invalid_token_expert_id: If not None, set the token_selected_experts of the invalid tokens to this expert id. This is used to notify the MoE to skip these tokens for GroupGEMM. expert_id_payload_index: The index of token_selected_experts in the input_payloads. Must be provided if invalid_token_expert_id is not None. Returns: - recv_buffers: List of tensors received, each has shape [ep_size, max_tokens_per_rank, payload_num_elements_per_token] + recv_tensors: List of tensors received, each has shape [ep_size, max_tokens_per_rank, payload_num_elements_per_token] """ - if self._state == "dispatched": - raise RuntimeError( - "dispatch called twice without an intervening combine") - - recv_buffers, send_counters, recv_counters, topk_target_ranks, topk_send_indices, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch( + assert self._state.phase == "idle", "dispatch called twice without an intervening combine" + assert runtime_max_tokens_per_rank <= self.max_num_tokens, "runtime_max_tokens_per_rank must not exceed max_num_tokens" + recv_tensors, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch( token_selected_experts, input_payloads, self.workspace, - self.max_num_tokens_per_rank, self.ep_rank, self.ep_size, - self.top_k, self.num_experts) - self._state = "dispatched" - self.send_counters = send_counters - self.recv_counters = recv_counters - self.topk_target_ranks = topk_target_ranks - self.topk_send_indices = topk_send_indices - self.combine_payload_offset = int(combine_payload_offset) + self.metainfo, runtime_max_tokens_per_rank, self.ep_rank, + self.ep_size, self.top_k, self.num_experts) + # Update state together after successful dispatch + self._state.local_num_tokens = token_selected_experts.size(0) + self._state.combine_payload_offset = combine_payload_offset + self._state.phase = "dispatched" if invalid_token_expert_id is not None: assert expert_id_payload_index is not None, "expert_id_payload_index must be provided if invalid_token_expert_id is not None" - # Sanitize expert IDs for invalid tokens directly on the recv buffer payload - recv_token_selected_experts = recv_buffers[expert_id_payload_index] + # Sanitize expert IDs for invalid tokens directly on the recv tensor payload + recv_token_selected_experts = recv_tensors[expert_id_payload_index] torch.ops.trtllm.moe_a2a_sanitize_expert_ids( recv_token_selected_experts, - self.recv_counters, - int(invalid_token_expert_id), + self.workspace, + self.metainfo, + self.ep_rank, + invalid_token_expert_id, ) - return recv_buffers + return recv_tensors - def combine(self, payload, payload_in_workspace: bool = False): + def combine( + self, + payload, + runtime_max_tokens_per_rank: int, + payload_in_workspace: bool = False, + ): """ Perform MoE all-to-all combine operation. Args: payload: [ep_size, max_tokens_per_rank, num_elements_per_token] tensor to combine. The dtype must be float32, bfloat16 or float16. + runtime_max_tokens_per_rank: Maximum of the number of tokens of each DP rank's local batch. payload_in_workspace: If True, 'payload' is a view into 'workspace' at 'combine_payload_offset' and no staging copy is needed. If False, the op stages 'payload' into the workspace region before combining. Returns: combined_output: [local_num_tokens, num_elements_per_token] tensor of combined results """ - if self._state != "dispatched": - raise RuntimeError("combine called before a successful dispatch") + assert self._state.phase == "dispatched", "combine called before a successful dispatch" + assert runtime_max_tokens_per_rank <= self.max_num_tokens, "runtime_max_tokens_per_rank must not exceed max_num_tokens" output = torch.ops.trtllm.moe_a2a_combine( - self.topk_target_ranks, self.topk_send_indices, self.recv_counters, - payload, self.workspace, self.max_num_tokens_per_rank, self.ep_rank, - self.ep_size, self.top_k, int(self.combine_payload_offset), - bool(payload_in_workspace)) + payload, self._state.local_num_tokens, self.workspace, + self.metainfo, runtime_max_tokens_per_rank, self.ep_rank, + self.ep_size, self.top_k, self._state.combine_payload_offset, + payload_in_workspace) + # Reset state for next round - self._state = "idle" - self.send_counters = None - self.recv_counters = None - self.topk_target_ranks = None - self.topk_send_indices = None - self.combine_payload_offset = None + self._state = _A2AState() + return output def get_combine_payload_tensor_in_workspace( - self, hidden_size: int, dtype: torch.dtype) -> torch.Tensor: + self, runtime_max_tokens_per_rank: int, hidden_size: int, + dtype: torch.dtype) -> torch.Tensor: """ Return the combine payload tensor in the workspace, which could be used as the output of MoE kernel to avoid extra copy. See "payload_in_workspace" in combine method. """ - if self._state != "dispatched": + if self._state.phase != "dispatched": raise RuntimeError( "get_combine_payload_tensor_in_workspace called before a successful dispatch" ) return torch.ops.trtllm.moe_a2a_get_combine_payload_tensor( self.workspace, - int(self.ep_rank), - int(self.ep_size), - int(self.max_num_tokens_per_rank), - int(self.combine_payload_offset), + self.ep_rank, + self.ep_size, + runtime_max_tokens_per_rank, + self._state.combine_payload_offset, dtype, - int(hidden_size), + hidden_size, ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 13325a2b832..fbb487e3423 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -163,7 +163,7 @@ def __init__( os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512")) self.moe_a2a = MoeAlltoAll( mapping=self.mapping, - max_num_tokens_per_rank=model_config.max_num_tokens, + max_num_tokens=model_config.max_num_tokens, top_k=self.routing_method.experts_per_token, num_experts=self.num_experts, workspace_size_per_rank=workspace_mb * 1024 * 1024, @@ -408,7 +408,7 @@ def forward_chunk( assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall" # Prepare alltoall indices top_k = self.routing_method.experts_per_token - max_num_token = max( + runtime_max_tokens_per_rank = max( all_rank_num_tokens) if all_rank_num_tokens else token_count # Handle case where token_final_scales might be None (when apply_router_weight_on_input=True) @@ -420,9 +420,9 @@ def forward_chunk( assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( token_selected_experts, None, - self.alltoall_prepare_workspace, max_num_token, - self.ep_rank, self.ep_size, self.num_experts, - self.num_experts, top_k) + self.alltoall_prepare_workspace, + runtime_max_tokens_per_rank, self.ep_rank, self.ep_size, + self.num_experts, self.num_experts, top_k) if x_sf is not None: x_sf = x_sf.view(x_row, @@ -437,8 +437,9 @@ def forward_chunk( torch.ops.trtllm.memset_expert_ids( token_selected_experts, - alltoall_info.recv_rank_count_cumsum, max_num_token, top_k, - self.num_experts, self.ep_size) + alltoall_info.recv_rank_count_cumsum, + runtime_max_tokens_per_rank, top_k, self.num_experts, + self.ep_size) elif self.moe_alltoall_backend == "mnnvlthroughput": # Python MoeAlltoAll path if x_sf is not None: @@ -456,18 +457,20 @@ def forward_chunk( payloads.append(token_selected_experts) payloads.append(token_final_scales) - recv_buffers = self.moe_a2a.dispatch( + recv_tensors = self.moe_a2a.dispatch( token_selected_experts, payloads, - invalid_token_expert_id=self.num_experts, + runtime_max_tokens_per_rank, + invalid_token_expert_id=self. + num_slots, # Caution: Cutlass MoE uses num_slots as invalid token expert id expert_id_payload_index=expert_id_payload_index, ) if x_sf is not None: - x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_buffers + x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1]) else: - x_recv, token_selected_experts_recv, token_final_scales_recv = recv_buffers + x_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors x = x_recv.view(-1, x_recv.shape[-1]) token_selected_experts = token_selected_experts_recv.view( -1, token_selected_experts_recv.shape[-1]) @@ -498,9 +501,12 @@ def forward_chunk( # Optionally provide an output tensor to fused_moe so it writes directly to our buffer moe_output: Optional[torch.Tensor] = None if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput": - # Retrieve a workspace-backed output tensor + # Retrieve a workspace-backed output tensor sized by runtime tokens + runtime_max_tokens_per_rank = max( + all_rank_num_tokens) if all_rank_num_tokens else x.shape[0] moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace( - self.unpadded_hidden_size, output_dtype) + runtime_max_tokens_per_rank, self.unpadded_hidden_size, + output_dtype) final_hidden_states = torch.ops.trtllm.fused_moe( x, token_selected_experts, @@ -556,11 +562,14 @@ def forward_chunk( use_low_precision_combine, token_count=token_count) elif self.moe_alltoall_backend == "mnnvlthroughput": - hidden = final_hidden_states.shape[-1] + output_hidden_size = final_hidden_states.shape[-1] + runtime_max_tokens_per_rank = max( + all_rank_num_tokens) if all_rank_num_tokens else token_count final_hidden_states = self.moe_a2a.combine( - final_hidden_states.view( - self.ep_size, self.moe_a2a.max_num_tokens_per_rank, - hidden), + final_hidden_states.view(self.ep_size, + runtime_max_tokens_per_rank, + output_hidden_size), + runtime_max_tokens_per_rank, payload_in_workspace=True) else: raise ValueError( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 1aca60ce417..24bd9b0d909 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -6,6 +6,7 @@ from torch import nn from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe +from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll from tensorrt_llm._utils import get_sm_version from tensorrt_llm.logger import logger @@ -122,11 +123,26 @@ def __init__( self.use_low_precision_combine = model_config.use_low_precision_moe_combine if self.alltoall_method_type == AlltoallMethodType.MNNVL: - MnnvlMemory.initialize() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) + if self.moe_alltoall_backend == "mnnvllatency": + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + elif self.moe_alltoall_backend == "mnnvlthroughput": + workspace_mb = int( + os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512")) + self.moe_a2a = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=model_config.max_num_tokens, + top_k=self.routing_method.experts_per_token, + num_experts=self.num_experts, + workspace_size_per_rank=workspace_mb * 1024 * 1024, + ) + else: + raise ValueError( + f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + ) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( "DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet" @@ -173,6 +189,12 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled + @cached_property + def moe_alltoall_backend(self): + # "mnnvllatency" (default) or "mnnvlthroughput" + return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND", + "mnnvllatency").strip().lower() + def _check_configs(self): assert self.has_deepseek_fp8_block_scales \ or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \ @@ -331,7 +353,7 @@ def forward_impl( if self.enable_alltoall: assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall" - max_num_token = max( + runtime_max_tokens_per_rank = max( all_rank_num_tokens) if all_rank_num_tokens else token_count if token_final_scales is None: @@ -340,45 +362,90 @@ def forward_impl( else: token_final_scales = token_final_scales.to(torch.float32) - assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" - alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( - token_selected_experts, - None, - self.alltoall_prepare_workspace, - max_num_token, - self.ep_rank, - self.ep_size, - self.num_experts, - self.num_slots, - top_k, - ) + if self.moe_alltoall_backend == "mnnvllatency": + assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" + alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_experts, + None, + self.alltoall_prepare_workspace, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.num_experts, + self.num_slots, + top_k, + ) - if x_sf is not None: - x_sf = x_sf.view(x_row, ceil_div(x_col, - self.scaling_vector_size)) + if x_sf is not None: + x_sf = x_sf.view(x_row, + ceil_div(x_col, self.scaling_vector_size)) - x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( - [x, x_sf, token_selected_experts, token_final_scales], - alltoall_info, - self.alltoall_workspace, - self.ep_rank, - self.ep_size, - ) + x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( + [x, x_sf, token_selected_experts, token_final_scales], + alltoall_info, + self.alltoall_workspace, + self.ep_rank, + self.ep_size, + ) - torch.ops.trtllm.memset_expert_ids( - token_selected_experts, - alltoall_info.recv_rank_count_cumsum, - max_num_token, - top_k, - -1, # Trtllm Gen uses -1 as invalid expert id - self.ep_size, - ) + torch.ops.trtllm.memset_expert_ids( + token_selected_experts, + alltoall_info.recv_rank_count_cumsum, + runtime_max_tokens_per_rank, + top_k, + -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id + self.ep_size, + ) - if x_sf is not None: - x_sf = x_sf.flatten() + if x_sf is not None: + x_sf = x_sf.flatten() + + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + elif self.moe_alltoall_backend == "mnnvlthroughput": + if x_sf is not None: + x_sf = x_sf.view(x_row, + ceil_div(x_col, self.scaling_vector_size)) + + payloads = [] + payloads.append(x) + if x_sf is not None: + payloads.append(x_sf) + expert_id_payload_index = 2 + else: + expert_id_payload_index = 1 + payloads.append(token_selected_experts) + payloads.append(token_final_scales) + + recv_tensors = self.moe_a2a.dispatch( + token_selected_experts, + payloads, + runtime_max_tokens_per_rank, + invalid_token_expert_id= + -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id + expert_id_payload_index=expert_id_payload_index, + ) - if token_final_scales is not None: - token_final_scales = token_final_scales.to(torch.bfloat16) + if x_sf is not None: + x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors + x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1]) + else: + x_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors + x = x_recv.view(-1, x_recv.shape[-1]) + token_selected_experts = token_selected_experts_recv.view( + -1, token_selected_experts_recv.shape[-1]) + token_final_scales = token_final_scales_recv.view( + -1, token_final_scales_recv.shape[-1]) + + if x_sf is not None: + x_sf = x_sf.flatten() + + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + else: + raise ValueError( + f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + ) elif run_post_quant_allgather: if x_sf is not None: @@ -398,6 +465,13 @@ def forward_impl( router_logits_arg = router_logits if not post_quant_comm else None routing_bias_arg = routing_bias if not post_quant_comm else None + moe_output: Optional[torch.Tensor] = None + use_workspace_output = False + if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput": + moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16) + use_workspace_output = True + # TODO: since routing kernel is integrated into moe_runner for fp8, # here we just route the I/Os for moe_runner if self.has_deepseek_fp8_block_scales: @@ -635,6 +709,7 @@ def forward_impl( 0, # act_type token_final_scales, token_selected_experts, + output=moe_output, ) else: raise NotImplementedError( @@ -642,17 +717,45 @@ def forward_impl( ) # Combine results if using alltoall - if self.enable_alltoall and alltoall_info is not None: - final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( - final_hidden_states, - alltoall_info, - self.alltoall_workspace, - ep_rank=self.ep_rank, - ep_size=self.ep_size, - top_k=top_k, - use_low_precision_combine=self.use_low_precision_combine, - token_count=token_count, - ) + if self.enable_alltoall: + if self.moe_alltoall_backend == "mnnvllatency": + if alltoall_info is not None: + final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( + final_hidden_states, + alltoall_info, + self.alltoall_workspace, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + top_k=top_k, + use_low_precision_combine=self. + use_low_precision_combine, + token_count=token_count, + ) + elif self.moe_alltoall_backend == "mnnvlthroughput": + # If use_workspace_output=True, the MoE result is already in workspace + # Otherwise, we need to reshape and pass it + if use_workspace_output: + # Workspace payload is returned as 2D [ep_size * max_tokens, hidden]; reshape to 3D. + hidden = final_hidden_states.shape[-1] + payload = moe_output.view(self.ep_size, + runtime_max_tokens_per_rank, + hidden) + final_hidden_states = self.moe_a2a.combine( + payload, + runtime_max_tokens_per_rank, + payload_in_workspace=True) + else: + hidden = final_hidden_states.shape[-1] + payload = final_hidden_states.view( + self.ep_size, runtime_max_tokens_per_rank, hidden) + final_hidden_states = self.moe_a2a.combine( + payload, + runtime_max_tokens_per_rank, + payload_in_workspace=False) + else: + raise ValueError( + f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + ) final_hidden_states = self.reducescatter_or_allreduce( final_hidden_states, diff --git a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py index 4ad94b5f5ef..7defc0dae72 100644 --- a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py +++ b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py @@ -228,7 +228,7 @@ def make_bfloat16_payloads( def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, workspace_size_per_rank, num_experts_per_rank, hidden_size, - max_tokens_per_rank): + invalid_token_expert_id): """Worker function for MPIPoolExecutor.""" rank = tllm.mpi_rank() torch.cuda.set_device(rank) @@ -242,7 +242,9 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, ) # Create MoeAlltoAll manager - moe_a2a = MoeAlltoAll(mapping, max_tokens_per_rank, top_k, + max_num_tokens = max(all_num_tokens) + + moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k, ep_size * num_experts_per_rank, workspace_size_per_rank) @@ -255,20 +257,21 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, payloads, expert_id_payload_index = make_nvfp4_payloads( rank_local_tokens, hidden_size, top_k, rank, token_selected_experts) - recv_buffers = moe_a2a.dispatch( + recv_tensors = moe_a2a.dispatch( token_selected_experts, payloads, - invalid_token_expert_id=-1, + max_num_tokens, + invalid_token_expert_id=invalid_token_expert_id, expert_id_payload_index=expert_id_payload_index) # Verify completion flags after dispatch - completion_flags_offset = moe_a2a.moe_a2a_metainfo[ - MoeAlltoAll.DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX].item() + completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ + "DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX"]].item() completion_flags = moe_a2a.workspace[ rank, completion_flags_offset:completion_flags_offset + ep_size * 4].view(torch.int32).cpu() - flag_val_offset = moe_a2a.moe_a2a_metainfo[ - MoeAlltoAll.FLAG_VAL_OFFSET_INDEX].item() + flag_val_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["FLAG_VAL_OFFSET_INDEX"]].item() expected_flag_val = moe_a2a.workspace[rank, flag_val_offset:flag_val_offset + 4].view(torch.int32).cpu() @@ -277,24 +280,49 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, f"Rank {rank} completion flags: {completion_flags}, expected flag val: {expected_flag_val}" ) + # Read counters and compact routing tensors from workspace + send_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"]].item() + recv_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"]].item() + topk_target_ranks_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ + "TOPK_TARGET_RANKS_OFFSET_INDEX"]].item() + topk_send_indices_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ + "TOPK_SEND_INDICES_OFFSET_INDEX"]].item() + + send_counters = moe_a2a.workspace[ + rank, send_counters_offset:send_counters_offset + ep_size * 4].view( + torch.int32).cpu() + recv_counters = moe_a2a.workspace[ + rank, recv_counters_offset:recv_counters_offset + ep_size * 4].view( + torch.int32).cpu() + topk_target_ranks = moe_a2a.workspace[ + rank, topk_target_ranks_offset:topk_target_ranks_offset + + max_num_tokens * top_k * 4].view(torch.int32).view( + max_num_tokens, top_k).cpu() + topk_send_indices = moe_a2a.workspace[ + rank, topk_send_indices_offset:topk_send_indices_offset + + max_num_tokens * top_k * 4].view(torch.int32).view( + max_num_tokens, top_k).cpu() + # Return results to be collected (move to CPU for MPI transfer) - return (token_selected_experts.cpu(), - [p.cpu() for p in payloads], [rb.cpu() for rb in recv_buffers], - moe_a2a.send_counters.cpu(), moe_a2a.topk_send_indices.cpu(), - moe_a2a.topk_target_ranks.cpu(), moe_a2a.recv_counters.cpu(), - expert_id_payload_index) + return (token_selected_experts.cpu(), [p.cpu() for p in payloads], + [rt.cpu() + for rt in recv_tensors], send_counters, topk_send_indices, + topk_target_ranks, recv_counters, expert_id_payload_index) except Exception: traceback.print_exc() raise -def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_buffers, +def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_tensors, all_send_counters, all_topk_send_indices, all_topk_target_ranks, all_recv_counters, ep_size, - all_num_tokens, top_k, max_tokens_per_rank, - num_experts_per_rank, expert_id_payload_index: int): + all_num_tokens, top_k, num_experts_per_rank, + expert_id_payload_index, invalid_token_expert_id): """Verify dispatch results including actual content verification""" + max_num_tokens = max(all_num_tokens) # Verify dimensions and dtypes for send_rank in range(ep_size): local_num_tokens = all_num_tokens[send_rank] @@ -309,65 +337,52 @@ def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_buffers, 1] == top_k, "token_selected_experts.shape[1] should be top_k" payloads = all_payloads[send_rank] - recv_buffers = all_recv_buffers[send_rank] + recv_tensors = all_recv_tensors[send_rank] num_payloads = len(payloads) assert len( - recv_buffers - ) == num_payloads, "recv_buffers should have the same number of payloads as payloads" + recv_tensors + ) == num_payloads, "recv_tensors should have the same number of payloads as payloads" for i in range(num_payloads): payload = payloads[i] assert len(payload.shape) == 2, "payload should be a 2D tensor" assert payload.shape[ 0] == local_num_tokens, "payload.shape[0] should be local_num_tokens" - recv_buffer = recv_buffers[i] + recv_tensor = recv_tensors[i] assert len( - recv_buffer.shape) == 3, "recv_buffer should be a 3D tensor" - assert recv_buffer.shape[ - 0] == ep_size, "recv_buffer.shape[0] should be ep_size" - assert recv_buffer.shape[ - 1] == max_tokens_per_rank, "recv_buffer.shape[1] should be max_tokens_per_rank" - assert recv_buffer.shape[2] == payload.shape[ - 1], "recv_buffer.shape[2] should be payload.shape[1]" - assert recv_buffer.dtype == payload.dtype, "recv_buffer.dtype should be payload.dtype" - + recv_tensor.shape) == 3, "recv_tensor should be a 3D tensor" + assert recv_tensor.shape[ + 0] == ep_size, "recv_tensor.shape[0] should be ep_size" + assert recv_tensor.shape[ + 1] == max_num_tokens, "recv_tensor.shape[1] should be max_num_tokens" + assert recv_tensor.shape[2] == payload.shape[ + 1], "recv_tensor.shape[2] should be payload.shape[1]" + assert recv_tensor.dtype == payload.dtype, "recv_tensor.dtype should be payload.dtype" + + # Verify counters and compact routing tensors send_counters = all_send_counters[send_rank] assert len( send_counters.shape) == 1, "send_counters should be a 1D tensor" - assert send_counters.shape[ - 0] == ep_size, "send_counters.shape[0] should be ep_size" - assert send_counters.dtype == torch.int32, "send_counters.dtype should be torch.int32" + assert send_counters.shape[0] == ep_size + assert send_counters.dtype == torch.int32 recv_counters = all_recv_counters[send_rank] assert len( recv_counters.shape) == 1, "recv_counters should be a 1D tensor" - assert recv_counters.shape[ - 0] == ep_size, "recv_counters.shape[0] should be ep_size" - assert recv_counters.dtype == torch.int32, "recv_counters.dtype should be torch.int32" + assert recv_counters.shape[0] == ep_size + assert recv_counters.dtype == torch.int32 topk_send_indices = all_topk_send_indices[send_rank] - assert len(topk_send_indices.shape - ) == 2, "topk_send_indices should be a 2D tensor" - assert topk_send_indices.shape[ - 0] == local_num_tokens, "topk_send_indices.shape[0] should be local_num_tokens" - assert topk_send_indices.shape[ - 1] == top_k, "topk_send_indices.shape[1] should be top_k" - assert topk_send_indices.dtype == torch.int32, "topk_send_indices.dtype should be torch.int32" - topk_target_ranks = all_topk_target_ranks[send_rank] - assert len(topk_target_ranks.shape - ) == 2, "topk_target_ranks should be a 2D tensor" - assert topk_target_ranks.shape[ - 0] == local_num_tokens, "topk_target_ranks.shape[0] should be local_num_tokens" - assert topk_target_ranks.shape[ - 1] == top_k, "topk_target_ranks.shape[1] should be top_k" - assert topk_target_ranks.dtype == torch.int32, "topk_target_ranks.dtype should be torch.int32" - - # Verify send_counters + assert topk_send_indices.shape == (max_num_tokens, + top_k), "topk_send_indices shape" + assert topk_target_ranks.shape == (max_num_tokens, + top_k), "topk_target_ranks shape" + assert topk_send_indices.dtype == torch.int32 + assert topk_target_ranks.dtype == torch.int32 + + # Verify send_counters per (send_rank -> target_rank) for send_rank in range(ep_size): - send_counters = all_send_counters[send_rank] - - # Count expected sends to each target expected_sends = {} token_experts = all_token_selected_experts[send_rank] sent_to_rank = set() @@ -377,7 +392,6 @@ def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_buffers, target_ranks = compute_target_rank_id(experts, num_experts_per_rank) sent_to_rank.clear() - # Due to deduplication, each token is sent to each unique target rank only once for target_rank in target_ranks.tolist(): if target_rank not in sent_to_rank: if target_rank not in expected_sends: @@ -385,39 +399,34 @@ def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_buffers, expected_sends[target_rank] += 1 sent_to_rank.add(target_rank) - # Verify send counters for each target rank for target_rank in range(ep_size): expected_to_rank = expected_sends.get(target_rank, 0) - actual_to_rank = send_counters[target_rank].item() - assert actual_to_rank == expected_to_rank, \ - f"Rank {send_rank} sent {actual_to_rank} tokens to rank {target_rank}, " \ - f"expected {expected_to_rank}" + actual_to_rank = all_send_counters[send_rank][target_rank].item() + assert actual_to_rank == expected_to_rank, ( + f"Rank {send_rank} sent {actual_to_rank} tokens to rank {target_rank}, expected {expected_to_rank}" + ) - # Verify recv_counters + # Verify recv_counters match send_counters for recv_rank in range(ep_size): - recv_counters = all_recv_counters[recv_rank] - for send_rank in range(ep_size): expected_recv = all_send_counters[send_rank][recv_rank].item() - actual_recv = recv_counters[send_rank].item() - assert actual_recv == expected_recv, \ - f"Rank {recv_rank} received {actual_recv} tokens from rank {send_rank}, " \ - f"expected {expected_recv} (based on send_counters)" + actual_recv = all_recv_counters[recv_rank][send_rank].item() + assert actual_recv == expected_recv, ( + f"Rank {recv_rank} received {actual_recv} tokens from rank {send_rank}, expected {expected_recv}" + ) - # Verify payloads using topk_send_indices and topk_target_ranks + # Verify payload content using topk_send_indices and topk_target_ranks for send_rank in range(ep_size): - topk_send_indices = all_topk_send_indices[send_rank] - topk_target_ranks = all_topk_target_ranks[send_rank] token_selected_experts = all_token_selected_experts[send_rank] payloads = all_payloads[send_rank] - + topk_send_indices = all_topk_send_indices[send_rank] + topk_target_ranks = all_topk_target_ranks[send_rank] local_num_tokens = all_num_tokens[send_rank] - # For each source token on this send rank for token_idx in range(local_num_tokens): experts = token_selected_experts[token_idx] target_ranks = compute_target_rank_id(experts, num_experts_per_rank) - # Deduplicate target ranks per token: keep first occurrence, set duplicates to -1 + # Deduplicate target ranks per token topk_target_ranks_ref = target_ranks.clone() seen = set() for kk in range(top_k): @@ -427,57 +436,35 @@ def verify_dispatch(all_token_selected_experts, all_payloads, all_recv_buffers, else: seen.add(tr) - assert topk_target_ranks[token_idx, :].tolist() == topk_target_ranks_ref.tolist(), \ - f"topk_target_ranks[token_idx, :] should match deduplicated target_ranks: {topk_target_ranks_ref.tolist()}" + assert topk_target_ranks[ + token_idx, :].tolist() == topk_target_ranks_ref.tolist() - # For each top_k expert for k in range(top_k): dst_pos = topk_send_indices[token_idx, k].item() target_rank = topk_target_ranks[token_idx, k].item() - if dst_pos == -1: - assert target_rank == -1, \ - f"target_rank should be -1: dst_pos={dst_pos} target_rank={target_rank}" + assert target_rank == -1 continue - - # Verify actual payload content was copied correctly - recv_buffers = all_recv_buffers[target_rank] + recv_tensors = all_recv_tensors[target_rank] for payload_idx, payload in enumerate(payloads): - recv_buffer = recv_buffers[payload_idx] - + recv_tensor = recv_tensors[payload_idx] source_data = payload[token_idx] - received_data = recv_buffer[send_rank, dst_pos] - # Compare source and received data - torch.testing.assert_close( - received_data, - source_data, - atol= - 0, # Dispatch is pure copy, should expact exactly the same - rtol=0, - msg= - f"Content mismatch: received_data={received_data} source_data={source_data} send_rank={send_rank} token_idx={token_idx} experts={experts.tolist()} target_rank={target_rank}, topk_send_indices[token_idx]={topk_send_indices[token_idx].tolist()}" - ) + received_data = recv_tensor[send_rank, dst_pos] + torch.testing.assert_close(received_data, + source_data, + atol=0, + rtol=0) # Verify token_selected_experts of invalid tokens are correctly sanitized for recv_rank in range(ep_size): - recv_counters = all_recv_counters[recv_rank] - # expert ids received on recv_rank for all sources - try: - expert_ids_recv = all_recv_buffers[recv_rank][ - expert_id_payload_index] - except Exception: - # Not present in this variant of the test - continue - - # expert_ids_recv: [ep_size, max_tokens_per_rank, top_k] + expert_ids_recv = all_recv_tensors[recv_rank][expert_id_payload_index] for source_rank in range(ep_size): - valid = int(recv_counters[source_rank].item()) - for token_idx in range(max_tokens_per_rank): + valid = int(all_recv_counters[recv_rank][source_rank].item()) + for token_idx in range(max_num_tokens): token_expert_ids = expert_ids_recv[source_rank, token_idx] if token_idx >= valid: - assert torch.all(token_expert_ids == -1), ( - f"recv_rank={recv_rank} src={source_rank} token={token_idx} should be sanitized to -1" - ) + assert torch.all( + token_expert_ids == invalid_token_expert_id) class TestMoEAlltoAll: @@ -525,20 +512,19 @@ def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k): hidden_size = 1024 num_experts_per_rank = 8 - max_tokens_per_rank = max(all_num_tokens) - # Calculate workspace size for all payloads - # workspace_size_per_rank = compute_nvfp4_workspace_size( - # ep_size, max_tokens_per_rank, hidden_size, top_k) - workspace_size_per_rank = 512 * 1024 * 1024 # Large enough workspace + # Large enough workspace + workspace_size_per_rank = 512 * 1024 * 1024 + + invalid_token_expert_id = -1 # Run dispatch on workers - each worker executes the same logic as single-GPU # but on separate GPUs with MNNVL memory instead of regular CUDA memory results = mpi_pool_executor.map( run_moe_a2a_dispatch_single_rank, *zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank, - num_experts_per_rank, hidden_size, max_tokens_per_rank)] * - ep_size), + num_experts_per_rank, hidden_size, + invalid_token_expert_id)] * ep_size), ) # Collect results from all ranks (same as single-GPU collecting from emulated ranks) @@ -547,7 +533,7 @@ def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k): # Extract results in same format as single-GPU test all_token_selected_experts = [r[0] for r in all_results] all_payloads = [r[1] for r in all_results] - all_recv_buffers = [r[2] for r in all_results] + all_recv_tensors = [r[2] for r in all_results] all_send_counters = [r[3] for r in all_results] all_topk_send_indices = [r[4] for r in all_results] all_topk_target_ranks = [r[5] for r in all_results] @@ -561,28 +547,28 @@ def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k): # Verify dispatch results with content verification verify_dispatch(all_token_selected_experts, all_payloads, - all_recv_buffers, all_send_counters, + all_recv_tensors, all_send_counters, all_topk_send_indices, all_topk_target_ranks, all_recv_counters, ep_size, all_num_tokens, top_k, - max_tokens_per_rank, num_experts_per_rank, - expert_id_payload_index) + num_experts_per_rank, expert_id_payload_index, + invalid_token_expert_id) @pytest.mark.skipif(torch.cuda.device_count() < 8, reason='needs at least 8 GPUs to run multi-GPU test') @pytest.mark.threadleak(enabled=False) @pytest.mark.parametrize( - "mpi_pool_executor,all_num_tokens,top_k,dtype", + "mpi_pool_executor,all_num_tokens,top_k", [ - # (num_workers, all_num_tokens, top_k, dtype) - (4, [32, 32, 32, 32], 2, torch.float32), - (4, [16, 32, 64, 48], 2, torch.float32), - (2, [100, 50], 2, torch.float16), - (4, [32, 32, 32, 32], 4, torch.float32), - (4, [1, 1, 1, 1], 2, torch.float32), - (8, [640, 640, 640, 640, 640, 640, 640, 640], 4, torch.bfloat16), + # (num_workers, all_num_tokens, top_k) + (4, [32, 32, 32, 32], 2), + (4, [16, 32, 64, 48], 2), + (2, [100, 50], 2), + (4, [32, 32, 32, 32], 4), + (4, [1, 1, 1, 1], 2), + (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), ], indirect=["mpi_pool_executor"]) - def test_combine(self, mpi_pool_executor, all_num_tokens, top_k, dtype): + def test_combine(self, mpi_pool_executor, all_num_tokens, top_k): """Test MoE A2A combine with MNNVL across multiple GPUs""" try: @@ -600,20 +586,18 @@ def test_combine(self, mpi_pool_executor, all_num_tokens, top_k, dtype): hidden_size = 2880 # gpt-oss num_experts_per_rank = 8 - max_tokens_per_rank = max(all_num_tokens) - # Calculate workspace size - # workspace_size_per_rank = compute_nvfp4_workspace_size( - # ep_size, max_tokens_per_rank, hidden_size, top_k) - workspace_size_per_rank = 512 * 1024 * 1024 # Large enough workspace + # Large enough workspace + workspace_size_per_rank = 512 * 1024 * 1024 # Run dispatch and combine on workers print("Starting dispatch and combine on workers...") + invalid_token_expert_id = -1 results = mpi_pool_executor.map( run_moe_a2a_dispatch_moe_combine_single_rank, *zip(*[(ep_size, all_num_tokens, top_k, workspace_size_per_rank, - num_experts_per_rank, hidden_size, max_tokens_per_rank, - dtype)] * ep_size), + num_experts_per_rank, hidden_size, + invalid_token_expert_id)] * ep_size), ) # Collect results @@ -631,19 +615,19 @@ def test_combine(self, mpi_pool_executor, all_num_tokens, top_k, dtype): # Verify combine results print("Starting verification...") verify_combine_results(all_results, ep_size, all_num_tokens, top_k, - hidden_size, num_experts_per_rank, - max_tokens_per_rank) + hidden_size, num_experts_per_rank) def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, workspace_size_per_rank, num_experts_per_rank, hidden_size, - max_tokens_per_rank, dtype): + invalid_token_expert_id): """Worker function for dispatch and combine test.""" rank = tllm.mpi_rank() torch.cuda.set_device(rank) device = torch.cuda.current_device() + max_num_tokens = max(all_num_tokens) try: mapping = Mapping(rank=rank, @@ -652,7 +636,7 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, world_size=ep_size) # Create MoeAlltoAll manager - moe_a2a = MoeAlltoAll(mapping, max_tokens_per_rank, top_k, + moe_a2a = MoeAlltoAll(mapping, max_num_tokens, top_k, ep_size * num_experts_per_rank, workspace_size_per_rank) @@ -667,21 +651,19 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, # Run dispatch with torch.cuda.profiler.profile(): - recv_buffers = moe_a2a.dispatch( + recv_tensors = moe_a2a.dispatch( token_selected_experts, payloads, - invalid_token_expert_id=-1, + max_num_tokens, + invalid_token_expert_id=invalid_token_expert_id, expert_id_payload_index=expert_id_payload_index) - hidden_states_recv = recv_buffers[ - 0] # [ep_size, max_tokens_per_rank, hidden_size] - token_selected_experts_recv = recv_buffers[ - 1] # [ep_size, max_tokens_per_rank, top_k] - token_final_scales_recv = recv_buffers[ - 2] # [ep_size, max_tokens_per_rank, top_k] - - ep_size = hidden_states_recv.shape[0] - max_tokens_per_rank = hidden_states_recv.shape[1] + hidden_states_recv = recv_tensors[ + 0] # [ep_size, max_num_tokens, hidden_size] + token_selected_experts_recv = recv_tensors[ + 1] # [ep_size, max_num_tokens, top_k] + token_final_scales_recv = recv_tensors[ + 2] # [ep_size, max_num_tokens, top_k] # emulate MoE computation on the received data # Create experts for this rank @@ -692,30 +674,31 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, dtype=torch.bfloat16) hidden_states_recv = fake_moe( - hidden_states_recv.view(ep_size * max_tokens_per_rank, + hidden_states_recv.view(ep_size * max_num_tokens, hidden_states_recv.shape[-1]), token_selected_experts_recv.view( - ep_size * max_tokens_per_rank, + ep_size * max_num_tokens, token_selected_experts_recv.shape[-1]), - token_final_scales_recv.view(ep_size * max_tokens_per_rank, + token_final_scales_recv.view(ep_size * max_num_tokens, token_final_scales_recv.shape[-1]), rank_experts, # experts for current rank is_ep=True, ep_rank=rank, num_experts_per_rank=num_experts_per_rank).view( - ep_size, max_tokens_per_rank, hidden_states_recv.shape[-1]) + ep_size, max_num_tokens, hidden_states_recv.shape[-1]) with torch.cuda.profiler.profile(): - combined_output = moe_a2a.combine(hidden_states_recv) + combined_output = moe_a2a.combine(hidden_states_recv, + max_num_tokens) # Verify completion flags after combine - completion_flags_offset = moe_a2a.moe_a2a_metainfo[ - MoeAlltoAll.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX].item() + completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ + "COMBINE_COMPLETION_FLAGS_OFFSET_INDEX"]].item() completion_flags_ptr = moe_a2a.workspace[ rank, completion_flags_offset:completion_flags_offset + ep_size * 4] completion_flags = completion_flags_ptr.view(torch.int32).cpu() - flag_val_offset = moe_a2a.moe_a2a_metainfo[ - MoeAlltoAll.FLAG_VAL_OFFSET_INDEX].item() + flag_val_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["FLAG_VAL_OFFSET_INDEX"]].item() expected_flag_val = moe_a2a.workspace[rank, flag_val_offset:flag_val_offset + 4].view(torch.int32).cpu() @@ -736,8 +719,7 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, def verify_combine_results(all_results, ep_size, all_num_tokens, top_k, - hidden_size, num_experts_per_rank, - max_tokens_per_rank): + hidden_size, num_experts_per_rank): """Verify that combine correctly sums the dispatched tokens.""" # Extract results