Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo
{
if (laneIdx < params.mTopK)
{
int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx];
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx];
if (expertIdx != -1)
{
int offset = warpIdx * MaxNumExperts + expertIdx;
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
}
else
{
params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1};
}
}
}
}
Expand Down
30 changes: 15 additions & 15 deletions cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32)
/*numExperts=*/32, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -228,7 +228,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72)
/*numExperts=*/72, /*topK=*/6,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -239,7 +239,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384)
/*numExperts=*/384, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -250,7 +250,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
/*numExperts=*/256, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -261,7 +261,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
/*numExperts=*/256, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -272,7 +272,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
/*numExperts=*/384, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -283,7 +283,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
/*numExperts=*/256, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -294,7 +294,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
/*numExperts=*/256, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
Expand All @@ -305,7 +305,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384)
/*numExperts=*/384, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
Expand All @@ -316,7 +316,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
/*numExperts=*/256, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
Expand All @@ -327,7 +327,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384)
/*numExperts=*/384, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
Expand All @@ -338,7 +338,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
/*numExperts=*/256, /*topK=*/2,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -349,7 +349,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
/*numExperts=*/256, /*topK=*/2,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -360,7 +360,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2)
/*numExperts=*/256, /*topK=*/2,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
Expand All @@ -371,7 +371,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8)
/*numExperts=*/32, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
Expand Down
74 changes: 62 additions & 12 deletions cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>
// convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked
for (int ie = 0; ie < param.topK; ++ie)
{
// Set invalid topk indices for the first half of the topk
if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1)
{
expIdx[ie].idx = -1;
}

PackedType si{static_cast<T>(expIdx[ie].score), expIdx[ie].idx};
reinterpret_cast<PackedType*>(bufferCast<int8_t>(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si;
if (param.useTopKAsInput)
Expand Down Expand Up @@ -198,7 +204,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization)
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -209,7 +215,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertPara
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -220,7 +237,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -231,7 +248,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -242,7 +270,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -264,7 +292,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4)
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200,
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -275,7 +314,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/8,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -286,7 +325,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -297,7 +336,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
this->runTest(param);
};
Expand All @@ -308,7 +347,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN)
/*numExperts=*/512, /*topK=*/10,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -319,7 +358,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN)
/*numExperts=*/512, /*topK=*/10,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
Expand All @@ -330,7 +369,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN)
/*numExperts=*/512, /*topK=*/10,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
this->runTest(param);
};

TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000,
/*numExperts=*/512, /*topK=*/10,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
this->runTest(param);
};
Expand Down
Loading