@@ -362,6 +362,18 @@ void grouped_topk_kernel_impl(
362
362
e_score_correction_bias.data_ptr<float >(), \
363
363
routed_scaling_factor.data_ptr<float >());
364
364
365
+ #define LAUNCH_GROUPED_TOPK_KERNEL_FP16 (NE ) \
366
+ grouped_topk_kernel_impl<at::Half, NE>( \
367
+ topk_weights.data_ptr<float >(), \
368
+ topk_ids.data_ptr<int32_t >(), \
369
+ gating_output.data_ptr<at::Half>(), \
370
+ num_tokens, \
371
+ topk, \
372
+ num_expert_group, \
373
+ topk_group, \
374
+ renormalize, \
375
+ e_score_correction_bias.data_ptr<float >(), \
376
+ routed_scaling_factor.data_ptr<float >());
365
377
//
366
378
std::tuple<at::Tensor, at::Tensor> grouped_topk (
367
379
at::Tensor& hidden_states,
@@ -380,36 +392,70 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk(
380
392
TORCH_CHECK (gating_output.size (0 ) == num_tokens, " Number of tokens mismatch" );
381
393
auto topk_weights = at::empty ({num_tokens, topk}, at::kFloat );
382
394
auto topk_ids = at::empty_like (topk_weights, at::kInt );
383
- switch (num_experts) {
384
- case 1 :
385
- LAUNCH_GROUPED_TOPK_KERNEL (1 );
386
- break ;
387
- case 2 :
388
- LAUNCH_GROUPED_TOPK_KERNEL (2 );
389
- break ;
390
- case 4 :
391
- LAUNCH_GROUPED_TOPK_KERNEL (4 );
392
- break ;
393
- case 8 :
394
- LAUNCH_GROUPED_TOPK_KERNEL (8 );
395
- break ;
396
- case 16 :
397
- LAUNCH_GROUPED_TOPK_KERNEL (16 );
398
- break ;
399
- case 32 :
400
- LAUNCH_GROUPED_TOPK_KERNEL (32 );
401
- break ;
402
- case 64 :
403
- LAUNCH_GROUPED_TOPK_KERNEL (64 );
404
- break ;
405
- case 128 :
406
- LAUNCH_GROUPED_TOPK_KERNEL (128 );
407
- break ;
408
- case 256 :
409
- LAUNCH_GROUPED_TOPK_KERNEL (256 );
410
- break ;
411
- default :
412
- TORCH_CHECK (false , " Unexpected num_experts: " , num_experts);
395
+ if (st == at::kBFloat16 ) {
396
+ switch (num_experts) {
397
+ case 1 :
398
+ LAUNCH_GROUPED_TOPK_KERNEL (1 );
399
+ break ;
400
+ case 2 :
401
+ LAUNCH_GROUPED_TOPK_KERNEL (2 );
402
+ break ;
403
+ case 4 :
404
+ LAUNCH_GROUPED_TOPK_KERNEL (4 );
405
+ break ;
406
+ case 8 :
407
+ LAUNCH_GROUPED_TOPK_KERNEL (8 );
408
+ break ;
409
+ case 16 :
410
+ LAUNCH_GROUPED_TOPK_KERNEL (16 );
411
+ break ;
412
+ case 32 :
413
+ LAUNCH_GROUPED_TOPK_KERNEL (32 );
414
+ break ;
415
+ case 64 :
416
+ LAUNCH_GROUPED_TOPK_KERNEL (64 );
417
+ break ;
418
+ case 128 :
419
+ LAUNCH_GROUPED_TOPK_KERNEL (128 );
420
+ break ;
421
+ case 256 :
422
+ LAUNCH_GROUPED_TOPK_KERNEL (256 );
423
+ break ;
424
+ default :
425
+ TORCH_CHECK (false , " Unexpected num_experts: " , num_experts);
426
+ }
427
+ } else if (st == at::kHalf ) {
428
+ switch (num_experts) {
429
+ case 1 :
430
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (1 );
431
+ break ;
432
+ case 2 :
433
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (2 );
434
+ break ;
435
+ case 4 :
436
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (4 );
437
+ break ;
438
+ case 8 :
439
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (8 );
440
+ break ;
441
+ case 16 :
442
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (16 );
443
+ break ;
444
+ case 32 :
445
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (32 );
446
+ break ;
447
+ case 64 :
448
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (64 );
449
+ break ;
450
+ case 128 :
451
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (128 );
452
+ break ;
453
+ case 256 :
454
+ LAUNCH_GROUPED_TOPK_KERNEL_FP16 (256 );
455
+ break ;
456
+ default :
457
+ TORCH_CHECK (false , " Unexpected num_experts: " , num_experts);
458
+ }
413
459
}
414
460
return std::make_tuple (topk_ids, topk_weights);
415
461
}
0 commit comments