@@ -127,6 +127,12 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>
127127 // convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked
128128 for (int ie = 0 ; ie < param.topK ; ++ie)
129129 {
130+ // Set invalid topk indices for the first half of the topk
131+ if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1 )
132+ {
133+ expIdx[ie].idx = -1 ;
134+ }
135+
130136 PackedType si{static_cast <T>(expIdx[ie].score ), expIdx[ie].idx };
131137 reinterpret_cast <PackedType*>(bufferCast<int8_t >(*this ->mPtrTopKPackedHost ))[it * param.topK + ie] = si;
132138 if (param.useTopKAsInput )
@@ -198,7 +204,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization)
198204 /* numExperts=*/ 128 , /* topK=*/ 8 ,
199205 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
200206 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
201- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
207+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
202208 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
203209 this ->runTest (param);
204210};
@@ -209,7 +215,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertPara
209215 /* numExperts=*/ 128 , /* topK=*/ 8 ,
210216 /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 192 ,
211217 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
212- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
218+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
219+ /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
220+ this ->runTest (param);
221+ };
222+
223+ TYPED_TEST (RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput)
224+ {
225+ RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 4 ,
226+ /* numExperts=*/ 128 , /* topK=*/ 8 ,
227+ /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 192 ,
228+ /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
229+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
213230 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
214231 this ->runTest (param);
215232};
@@ -220,7 +237,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
220237 /* numExperts=*/ 128 , /* topK=*/ 8 ,
221238 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 192 ,
222239 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
223- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
240+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
224241 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
225242 this ->runTest (param);
226243};
@@ -231,7 +248,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
231248 /* numExperts=*/ 128 , /* topK=*/ 8 ,
232249 /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 256 ,
233250 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
234- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
251+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
252+ /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
253+ this ->runTest (param);
254+ };
255+
256+ TYPED_TEST (RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput)
257+ {
258+ RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 100 ,
259+ /* numExperts=*/ 128 , /* topK=*/ 8 ,
260+ /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 256 ,
261+ /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
262+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
235263 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
236264 this ->runTest (param);
237265};
@@ -242,7 +270,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
242270 /* numExperts=*/ 128 , /* topK=*/ 8 ,
243271 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
244272 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
245- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
273+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
246274 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
247275 this ->runTest (param);
248276};
@@ -264,7 +292,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4)
264292 /* numExperts=*/ 128 , /* topK=*/ 4 ,
265293 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 8 ,
266294 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
267- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
295+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
296+ /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
297+ this ->runTest (param);
298+ };
299+
300+ TYPED_TEST (RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4)
301+ {
302+ RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 200 ,
303+ /* numExperts=*/ 128 , /* topK=*/ 4 ,
304+ /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 8 ,
305+ /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
306+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
268307 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
269308 this ->runTest (param);
270309};
@@ -275,7 +314,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
275314 /* numExperts=*/ 128 , /* topK=*/ 4 ,
276315 /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 8 ,
277316 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
278- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
317+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
279318 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
280319 this ->runTest (param);
281320};
@@ -286,7 +325,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
286325 /* numExperts=*/ 128 , /* topK=*/ 4 ,
287326 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 8 ,
288327 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
289- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
328+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
290329 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
291330 this ->runTest (param);
292331};
@@ -297,7 +336,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
297336 /* numExperts=*/ 128 , /* topK=*/ 4 ,
298337 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
299338 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
300- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
339+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput= */ true ,
301340 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 8 );
302341 this ->runTest (param);
303342};
@@ -308,7 +347,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN)
308347 /* numExperts=*/ 512 , /* topK=*/ 10 ,
309348 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
310349 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
311- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
350+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
312351 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
313352 this ->runTest (param);
314353};
@@ -319,7 +358,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN)
319358 /* numExperts=*/ 512 , /* topK=*/ 10 ,
320359 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
321360 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
322- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
361+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput= */ false ,
323362 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 9 );
324363 this ->runTest (param);
325364};
@@ -330,7 +369,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN)
330369 /* numExperts=*/ 512 , /* topK=*/ 10 ,
331370 /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
332371 /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
333- /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
372+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
373+ /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 8 );
374+ this ->runTest (param);
375+ };
376+
377+ TYPED_TEST (RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput)
378+ {
379+ RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 1000 ,
380+ /* numExperts=*/ 512 , /* topK=*/ 10 ,
381+ /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
382+ /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
383+ /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
334384 /* nGroup*/ 0 , /* topkGroup*/ 0 , /* routedScalingFactor*/ 1 .0f , /* requiredComputeCapability*/ 8 );
335385 this ->runTest (param);
336386};
0 commit comments