From 209554bfef1bb62a6cd105e5b4e2fd58eb943c07 Mon Sep 17 00:00:00 2001 From: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Date: Mon, 27 Oct 2025 12:38:08 +0000 Subject: [PATCH] Add unit tests and revisement in block_level kernel for invalid input Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> --- .../blockScaleMoe/RoutingRenormalize.cu | 12 ++- .../kernels/routing/routingDeepSeekTest.cpp | 30 ++++---- .../routing/routingRenormalizeTest.cpp | 74 ++++++++++++++++--- .../kernels/routing/routingTest.cpp | 9 ++- .../unit_tests/kernels/routing/routingTest.h | 16 +++- .../unittest/_torch/thop/parallel/test_moe.py | 21 +++++- 6 files changed, 123 insertions(+), 39 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu index 3959acee4da..a2988863937 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu @@ -130,8 +130,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo { if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; - smemKIdx[offset] = static_cast(laneIdx); + auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; + if (expertIdx != -1) + { + int offset = warpIdx * MaxNumExperts + expertIdx; + smemKIdx[offset] = static_cast(laneIdx); + } + else + { + params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1}; + } } } } diff --git a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp index 4effc318d7c..ad2e0401274 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp @@ -217,7 +217,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32) /*numExperts=*/32, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -228,7 +228,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72) /*numExperts=*/72, /*topK=*/6, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -239,7 +239,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384) /*numExperts=*/384, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -250,7 +250,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization) /*numExperts=*/256, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -261,7 +261,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput /*numExperts=*/256, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -272,7 +272,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput /*numExperts=*/384, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false, /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -283,7 +283,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal /*numExperts=*/256, /*topK=*/8, /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -294,7 +294,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization) /*numExperts=*/256, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; @@ -305,7 +305,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384) /*numExperts=*/384, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; @@ -316,7 +316,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization) /*numExperts=*/256, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; @@ -327,7 +327,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384) /*numExperts=*/384, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; @@ -338,7 +338,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2) /*numExperts=*/256, /*topK=*/2, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -349,7 +349,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal /*numExperts=*/256, /*topK=*/2, /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -360,7 +360,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2) /*numExperts=*/256, /*topK=*/2, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; @@ -371,7 +371,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8) /*numExperts=*/32, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp index 23e3a6e3d96..c77b384a7c4 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp @@ -127,6 +127,12 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest // convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked for (int ie = 0; ie < param.topK; ++ie) { + // Set invalid topk indices for the first half of the topk + if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1) + { + expIdx[ie].idx = -1; + } + PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; if (param.useTopKAsInput) @@ -198,7 +204,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization) /*numExperts=*/128, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -209,7 +215,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertPara /*numExperts=*/128, /*topK=*/8, /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + this->runTest(param); +}; + +TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput) +{ + RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, + /*numExperts=*/128, /*topK=*/8, + /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -220,7 +237,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization) /*numExperts=*/128, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -231,7 +248,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa /*numExperts=*/128, /*topK=*/8, /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + this->runTest(param); +}; + +TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput) +{ + RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, + /*numExperts=*/128, /*topK=*/8, + /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -242,7 +270,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal /*numExperts=*/128, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -264,7 +292,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4) /*numExperts=*/128, /*topK=*/4, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + this->runTest(param); +}; + +TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4) +{ + RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200, + /*numExperts=*/128, /*topK=*/4, + /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -275,7 +314,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa /*numExperts=*/128, /*topK=*/4, /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/8, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -286,7 +325,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal /*numExperts=*/128, /*topK=*/4, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -297,7 +336,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4) /*numExperts=*/128, /*topK=*/4, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); this->runTest(param); }; @@ -308,7 +347,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN) /*numExperts=*/512, /*topK=*/10, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -319,7 +358,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN) /*numExperts=*/512, /*topK=*/10, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -330,7 +369,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN) /*numExperts=*/512, /*topK=*/10, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); + this->runTest(param); +}; + +TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput) +{ + RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, + /*numExperts=*/512, /*topK=*/10, + /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); this->runTest(param); }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp index 144ff4a8c73..b8f387f4b41 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp @@ -159,10 +159,13 @@ void RoutingKernelTest::computePermutation(RoutingKernelTestParam const& para int32_t index = expIdxHostPtr[it * param.topK + ie].idx; tokenToExpertHostPtr[it * param.topK + ie] = index; - auto localExpertIdx = index - param.localExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < param.numLocalExperts + int32_t localExpertIdx = index - param.localExpertsStartIdx; + bool isLocalExpert = localExpertIdx >= 0 && localExpertIdx < param.numLocalExperts && (localExpertIdx & param.localExpertsStrideLog2) == 0; - tokenToIdxInExpertHostPtr[it * param.topK + ie] = expertCountsHostPtr[index]; + if (index >= 0) + { + tokenToIdxInExpertHostPtr[it * param.topK + ie] = expertCountsHostPtr[index]; + } if (isLocalExpert) { expertCountsHostPtr[index]++; diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.h b/cpp/tests/unit_tests/kernels/routing/routingTest.h index 1e917e2b7c0..e6955d68a0d 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.h +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.h @@ -245,6 +245,7 @@ struct RoutingKernelTestParam // Check the input parameters bool useTopKAsInput{false}; + bool hasInvalidTopKInput{false}; // Special for renormalize routing method bool doSoftmaxBeforeTopK{false}; @@ -270,8 +271,8 @@ struct RoutingKernelTestParam RoutingKernelTestParam(RoutingMethodType routingMethod, int32_t numTokens, int32_t numExperts, uint32_t topK, int32_t expertParallelization = 1, int32_t expertParallelizationId = 0, int32_t tileTokensDim = 1, int32_t paddingLog2 = 3, int32_t localExpertsStrideLog2 = 0, bool usePdl = true, bool getExpWeights = true, - bool useTopKAsInput = false, int32_t nGroup = 1, int32_t topkGroup = 1, float routedScalingFactor = 1.0f, - int requiredComputeCapability = 9) + bool useTopKAsInput = false, bool hasInvalidTopKInput = false, int32_t nGroup = 1, int32_t topkGroup = 1, + float routedScalingFactor = 1.0f, int requiredComputeCapability = 9) : routingMethod(routingMethod) , numTokens(numTokens) , numExperts(numExperts) @@ -282,6 +283,7 @@ struct RoutingKernelTestParam , usePdl(usePdl) , getExpWeights(getExpWeights) , useTopKAsInput(useTopKAsInput) + , hasInvalidTopKInput(hasInvalidTopKInput) , nGroup(nGroup) , topkGroup(topkGroup) , routedScalingFactor(routedScalingFactor) @@ -319,6 +321,11 @@ struct RoutingKernelTestParam { singleClusterTokenNum = 256; } + + if (hasInvalidTopKInput && !useTopKAsInput) + { + throw std::invalid_argument("hasInvalidTopKInput is only supported when useTopKAsInput is true"); + } } // Copy constructor @@ -340,9 +347,10 @@ struct RoutingKernelTestParam { return tensorrt_llm::common::fmtstr( "RoutingKernelTestParam[num_tokens=%d, num_experts=%d, topK=%u, doSoftmaxBeforeTopK=%d, normTopkProb=%d, " - "localExpertsStartIdx=%d, localExpertsStrideLog2=%d, numLocalExperts=%d, usePdl=%d]", + "localExpertsStartIdx=%d, localExpertsStrideLog2=%d, numLocalExperts=%d, usePdl=%d, useTopKAsInput=%d, " + "hasInvalidTopKInput=%d]", numTokens, numExperts, topK, doSoftmaxBeforeTopK, normTopkProb, localExpertsStartIdx, - localExpertsStrideLog2, numLocalExperts, usePdl); + localExpertsStrideLog2, numLocalExperts, usePdl, useTopKAsInput, hasInvalidTopKInput); } }; diff --git a/tests/unittest/_torch/thop/parallel/test_moe.py b/tests/unittest/_torch/thop/parallel/test_moe.py index ec54da2f05b..00112202d38 100644 --- a/tests/unittest/_torch/thop/parallel/test_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_moe.py @@ -1895,7 +1895,7 @@ def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size, # # Data Generation # - + test_invalid_topk_input = False act_type = ActType.SwiGlu num_experts = routing_info["num_experts"] top_k = routing_info["top_k"] @@ -1931,10 +1931,12 @@ def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size, pytest.skip( "512 experts is tested only with no autotune, mxfp8, SwiGlu") if use_topk_as_input: - if dtype_activation != "mxfp8" or top_k != 4 or act_type_str != "SwiGlu" or not use_autotune or num_tokens != 1: + if dtype_activation != "mxfp8" or top_k != 4 or act_type_str != "SwiGlu" or not use_autotune: pytest.skip( - "use_topk_as_input is tested only with mxfp8, topk=4, SwiGlu, autotune, and num_tokens=1" + "use_topk_as_input is tested only with mxfp8, topk=4, SwiGlu and not use_autotune" ) + else: + test_invalid_topk_input = True assert top_k <= num_experts assert top_k <= 10 @@ -2037,6 +2039,17 @@ def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size, topk_ids = permute_info["topKIndices"].to(torch.int32) topk_weights = permute_info["topKLogits"] expert_logits = None + if test_invalid_topk_input: + extra_col = torch.full((num_tokens, 1), + -1, + dtype=topk_ids.dtype, + device=topk_ids.device) + topk_ids = torch.cat([topk_ids, extra_col], dim=1) + extra_col = torch.full((num_tokens, 1), + -1, + dtype=topk_weights.dtype, + device=topk_weights.device) + topk_weights = torch.cat([topk_weights, extra_col], dim=1) else: topk_ids = None topk_weights = None @@ -2182,6 +2195,8 @@ def test_moe_mxe2m1_weights(num_tokens, hidden_size, intermediate_size, # Run the TRT-LLM kernel # unpadded_hidden_size = hidden_size + if test_invalid_topk_input: + top_k = top_k + 1 AutoTuner.get().clear_cache() with autotune(use_autotune): if dtype_activation == "mxfp8":