Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <cstdint>
#include <type_traits>

namespace tensorrt_llm::kernels::moe_a2a
namespace tensorrt_llm::kernels::mnnvl_throughput
{

#define ENABLE_DEBUG_PRINT 0
Expand Down Expand Up @@ -506,7 +506,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);

// Prepare kernel pointers struct
DispatchKernelPointers kernel_ptrs = {}; // Zero-initialize
DispatchKernelPointers kernel_ptrs = {};

// Fill source data pointers and payload sizes
for (int i = 0; i < params.num_payloads; i++)
Expand Down Expand Up @@ -958,4 +958,4 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv
expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
}

} // namespace tensorrt_llm::kernels::moe_a2a
} // namespace tensorrt_llm::kernels::mnnvl_throughput
66 changes: 34 additions & 32 deletions cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace tensorrt_llm::kernels::moe_a2a
namespace tensorrt_llm::kernels::mnnvl_throughput
{

// Configuration constants
Expand Down Expand Up @@ -91,7 +91,7 @@ struct MoeA2ADispatchParams

// Token configuration
int local_num_tokens; // Number of tokens on this rank
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
int top_k; // Number of experts per token

// Expert routing information
Expand All @@ -101,23 +101,22 @@ struct MoeA2ADispatchParams
int num_payloads; // Number of different payload types
PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors

// Receive buffers and synchronization
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
// Local aux data
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
int* local_token_counter; // Atomic counter for completed tokens on this rank
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
// per k, -1 for duplicates
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
// per k, -1 for duplicates

// Synchronization
// Distributed aux data and recv buffers
int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)

// Communication tracking
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
int* local_token_counter; // Atomic counter for completed tokens on this rank

// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int* topk_target_ranks; // target rank per k, -1 for duplicates
int* topk_send_indices; // dst index per k, -1 for duplicates
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload

// CUDA stream
cudaStream_t stream;
};

Expand All @@ -137,30 +136,33 @@ struct MoeA2ACombineParams

// Token configuration
int local_num_tokens; // Number of tokens on this rank
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
int top_k; // Number of experts per token

// Expert routing information
int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target

// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int const* topk_target_ranks; // target rank per k, -1 for duplicates
int const* topk_send_indices; // dst index per k, -1 for duplicates
// Prepare-only field: original payload tensor pointer used to stage into workspace
void const* prepare_payload;

// Single payload information
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
void* output_data; // Output buffer [local_num_tokens, elements_per_token]
int elements_per_token; // Number of elements per token
nvinfer1::DataType dtype; // Data type for proper summation
// Output tensor
void* output_data; // Output buffer [local_num_tokens, elements_per_token]
// Payload information
int elements_per_token; // Number of elements per token
nvinfer1::DataType dtype; // Data type for proper summation

// Local aux data
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
// per k, -1 for duplicates
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
// per k, -1 for duplicates
int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target

// Synchronization
// Distributed aux data and recv buffers
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)

// CUDA stream
cudaStream_t stream;
// Prepare-only field: original payload tensor pointer used to stage into workspace
void const* prepare_payload;
};

// Combine kernels
Expand All @@ -175,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id,
int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);

} // namespace tensorrt_llm::kernels::moe_a2a
} // namespace tensorrt_llm::kernels::mnnvl_throughput
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tensorrt_llm::nanobind::thop
void initBindings(nb::module_& m)
{
// Export MoE A2A constants
for (auto const& kv : torch_ext::getMoeA2AMetaInfoIndexPairs())
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
{
m.attr(kv.first) = kv.second;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/pybind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::thop
void initBindings(pybind11::module_& m)
{
// Export MoE A2A constants
for (auto const& kv : torch_ext::getMoeA2AMetaInfoIndexPairs())
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
{
m.attr(kv.first) = py::int_(kv.second);
}
Expand Down
33 changes: 22 additions & 11 deletions cpp/tensorrt_llm/thop/moeAlltoAllMeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@

#pragma once

#include <array>
#include <cstdint>
#include <utility>
#include <vector>

namespace torch_ext
{
namespace mnnvl_throughput
{

// Enum for indexing into moe_a2a_metainfo tensor
enum MoeA2AMetaInfoIndex
enum MoeA2AMetaInfoIndex : int64_t
{
FLAG_VAL_OFFSET_INDEX = 0,
LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1,
Expand All @@ -34,21 +37,29 @@ enum MoeA2AMetaInfoIndex
DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4,
// Combine completion flags offset
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5,
PAYLOAD_DATA_OFFSET_INDEX = 6,
NUM_METAINFO_FIELDS = 7
TOPK_TARGET_RANKS_OFFSET_INDEX = 6,
TOPK_SEND_INDICES_OFFSET_INDEX = 7,
PAYLOAD_DATA_OFFSET_INDEX = 8,
NUM_METAINFO_FIELDS = 9
};

using MoeA2ADataOffsets = std::array<int64_t, NUM_METAINFO_FIELDS>;

inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs()
{
return {
{"MOE_A2A_FLAG_VAL_OFFSET_INDEX", static_cast<int64_t>(FLAG_VAL_OFFSET_INDEX)},
{"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", static_cast<int64_t>(LOCAL_TOKEN_COUNTER_OFFSET_INDEX)},
{"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", static_cast<int64_t>(SEND_COUNTERS_OFFSET_INDEX)},
{"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", static_cast<int64_t>(RECV_COUNTERS_OFFSET_INDEX)},
{"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX",
static_cast<int64_t>(DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX)},
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", static_cast<int64_t>(COMBINE_COMPLETION_FLAGS_OFFSET_INDEX)},
{"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", static_cast<int64_t>(PAYLOAD_DATA_OFFSET_INDEX)},
{"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX},
{"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX},
{"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX},
{"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX},
{"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX},
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX},
{"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX},
{"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX},
{"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX},
{"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS},
};
}

} // namespace mnnvl_throughput
} // namespace torch_ext
Loading
Loading