Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
// Compute the runtime config for projections
// Whether or not an expert is local is taken into account when smemExpertCount is computed
// so we do not need to take it into account here.
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);

int32_t numCta;
if constexpr (KernelParams::isPow2) {
numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
} else {
numCta = divUpTileN<int32_t>(count, params.mTileTokensDim);
}

int32_t ctaOffset;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
Expand All @@ -401,14 +408,31 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
const int32_t localExpertIdx =
(threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] =
min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
int32_t mnLimit1;
int32_t mnLimit2;
if constexpr (KernelParams::isPow2) {
mnLimit1 = mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2);
mnLimit2 = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count;
} else {
mnLimit1 = mulTileN<int32_t>(ctaOffset + cta + 1, params.mTileTokensDim);
mnLimit2 = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim) + count;
}
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2);
}

// get the padded offset associated with this expert
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
int32_t offset;
if constexpr (KernelParams::isPow2) {
offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
} else {
offset = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim);
}
int32_t permutedIdxSize;
if constexpr (KernelParams::isPow2) {
permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
} else {
permutedIdxSize = mulTileN<int32_t>(numNonExitingCtas, params.mTileTokensDim);
}

// write out padded count
if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) {
Expand Down Expand Up @@ -542,8 +566,6 @@ void runImpl(Data& data, void* stream) {
}
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d",
data.mPaddingLog2);

int const numBlocks = data.mNumTokens;
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
Expand Down
54 changes: 42 additions & 12 deletions csrc/trtllm_fused_moe_routing_llama4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,13 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
#pragma unroll
for (int ii = 0; ii < ExpertsPerThread; ++ii) {
auto count = getBits(expertCount, ii);
numCta += divUpLog2<int32_t>(count, params.mPaddingLog2);
int32_t num;
if constexpr (KernelParams::isPow2) {
num = divUpLog2<int32_t>(count, params.mPaddingLog2);
} else {
num = divUpTileN<int32_t>(count, params.mTileTokensDim);
}
numCta += num;
}
// second, we perform the exclusive sum across the warp
int32_t ctaOffset;
Expand All @@ -202,22 +208,39 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
#pragma unroll
for (int ii = 0; ii < ExpertsPerThread; ++ii) {
auto count = getBits(expertCount, ii);
auto finalNumCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
int32_t finalNumCta;
if constexpr (KernelParams::isPow2) {
finalNumCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
} else {
finalNumCta = divUpTileN<int32_t>(count, params.mTileTokensDim);
}
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
// during the scan for expert offsets, we can already write out
// both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit`
for (int cta = 0; cta < finalNumCta; ++cta) {
params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] =
min(mulLog2<int32_t>(ctaOffsetExp + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffsetExp, params.mPaddingLog2) + count);
int32_t mnLimit1;
int32_t mnLimit2;
if constexpr (KernelParams::isPow2) {
mnLimit1 = mulLog2<int32_t>(ctaOffsetExp + cta + 1, params.mPaddingLog2);
mnLimit2 = mulLog2<int32_t>(ctaOffsetExp, params.mPaddingLog2) + count;
} else {
mnLimit1 = mulTileN<int32_t>(ctaOffsetExp + cta + 1, params.mTileTokensDim);
mnLimit2 = mulTileN<int32_t>(ctaOffsetExp, params.mTileTokensDim) + count;
}
params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mnLimit1, mnLimit2);
}
ctaOffsetExp += finalNumCta;
}

// at this point, we can write out padded count from the warp-aggregate
if (cute::elect_one_sync()) {
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
int32_t permutedIdxSize;
if constexpr (KernelParams::isPow2) {
permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
} else {
permutedIdxSize = mulTileN<int32_t>(numNonExitingCtas, params.mTileTokensDim);
}
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
}
Expand All @@ -236,12 +259,21 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
// of registers
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
int32_t finalExpertOffset[ExpertsPerThread];
finalExpertOffset[0] = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
int32_t finalExpertOffset0;
if constexpr (KernelParams::isPow2) {
finalExpertOffset0 = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
} else {
finalExpertOffset0 = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim);
}
#pragma unroll
for (int ii = 1; ii < ExpertsPerThread; ++ii) {
finalExpertOffset[ii] =
finalExpertOffset[ii - 1] +
divUpMulLog2<int32_t>(getBits(expertCount, ii - 1), params.mPaddingLog2);
int32_t tmp;
if constexpr (KernelParams::isPow2) {
tmp = divUpMulLog2<int32_t>(getBits(expertCount, ii - 1), params.mPaddingLog2);
} else {
tmp = divUpMulTileN<int32_t>(getBits(expertCount, ii - 1), params.mTileTokensDim);
}
finalExpertOffset[ii] = finalExpertOffset[ii - 1] + tmp;
}

#pragma unroll
Expand Down Expand Up @@ -455,8 +487,6 @@ void runImpl(Data const& data, void* stream) {
NumExpertsLimit);
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d",
data.mPaddingLog2);

bool const useSingleWarp =
(data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) ||
Expand Down
38 changes: 29 additions & 9 deletions csrc/trtllm_fused_moe_routing_renormalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,30 +165,52 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
}
__syncthreads();
// Get the number of CTAs and the offset for each CTA
const int32_t numCta = divUpLog2<int32_t>(accExpertCount, params.mPaddingLog2);
int32_t numCta;
if constexpr (KernelParams::isPow2) {
numCta = divUpLog2<int32_t>(accExpertCount, params.mPaddingLog2);
} else {
numCta = divUpTileN<int32_t>(accExpertCount, params.mTileTokensDim);
}
int32_t ctaOffset = 0;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);

int32_t expertScanCounts = 0;
Scan(tempStorage)
.ExclusiveSum(divUpMulLog2(accExpertCount, params.mPaddingLog2), expertScanCounts);
int32_t tmpCount;
if constexpr (KernelParams::isPow2) {
tmpCount = divUpMulLog2<int32_t>(accExpertCount, params.mPaddingLog2);
} else {
tmpCount = divUpMulTileN<int32_t>(accExpertCount, params.mTileTokensDim);
}
Scan(tempStorage).ExclusiveSum(tmpCount, expertScanCounts);
__syncthreads();

if (isLocalExpert) {
for (int cta = 0; cta < numCta; ++cta) {
const int32_t localExpertIdx =
(expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] =
min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + accExpertCount);
int32_t mnLimit1;
int32_t mnLimit2;
if constexpr (KernelParams::isPow2) {
mnLimit1 = mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2);
mnLimit2 = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + accExpertCount;
} else {
mnLimit1 = mulTileN<int32_t>(ctaOffset + cta + 1, params.mTileTokensDim);
mnLimit2 = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim) + accExpertCount;
}
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2);
}
}

// at this point, we can write out padded count
if (threadIdx.x == 0) {
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
int32_t permutedIdxSize;
if constexpr (KernelParams::isPow2) {
permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
} else {
permutedIdxSize = mulTileN<int32_t>(numNonExitingCtas, params.mTileTokensDim);
}
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
}
Expand Down Expand Up @@ -399,8 +421,6 @@ void run(Data const& data, void* stream) {
<< NumExpertsLimit << ".";
TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0)
<< "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4.";
TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8)
<< "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2;

bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens;

Expand Down
8 changes: 6 additions & 2 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "flashinfer/trtllm/fused_moe/DevKernel.h"
#include "flashinfer/trtllm/fused_moe/RoutingKernel.h"
#include "flashinfer/trtllm/fused_moe/runner.h"
// #include <tensorrt_llm/common/assert.h>

namespace tensorrt_llm {
namespace kernels {
Expand All @@ -39,7 +38,9 @@ inline int32_t computeLog2(int32_t val, std::string const& name = "") {
while (n >>= 1) {
++out;
}
FLASHINFER_CHECK((1 << out) == val, "Expected ", name, " to be a power of 2, got ", val);
if ((1 << out) != val) {
out = -1;
}
return out;
}
} // namespace
Expand Down Expand Up @@ -90,6 +91,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNumLimitedGroups = topkGroup;
routingData.mTopK = topK;
routingData.mPaddingLog2 = computeLog2(mTileTokensDim);
routingData.mTileTokensDim = mTileTokensDim;
routingData.mLocalExpertsStartIdx = localExpertOffset;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = localNumExperts;
Expand Down Expand Up @@ -124,6 +126,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNumExperts = numExperts;
routingData.mTopK = topK;
routingData.mPaddingLog2 = computeLog2(mTileTokensDim);
routingData.mTileTokensDim = mTileTokensDim;
routingData.mLocalExpertsStartIdx = localExpertOffset;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = localNumExperts;
Expand Down Expand Up @@ -170,6 +173,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNumExperts = numExperts;
routingData.mTopK = topK;
routingData.mPaddingLog2 = computeLog2(mTileTokensDim);
routingData.mTileTokensDim = mTileTokensDim;
routingData.mLocalExpertsStartIdx = localExpertOffset;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = localNumExperts;
Expand Down
Loading