39
39
} \
40
40
}
41
41
42
+ #define DISPATCH_WORLD_SIZES_NO_DEFAULT (world_size, ...) \
43
+ switch (world_size) { \
44
+ INT_SWITCH_CASE (k_world_size, 8 , __VA_ARGS__); \
45
+ INT_SWITCH_CASE (k_world_size, 4 , __VA_ARGS__); \
46
+ INT_SWITCH_CASE (k_world_size, 2 , __VA_ARGS__); \
47
+ default : { \
48
+ TORCH_CHECK (false , " Not implemented for world_size=" , world_size); \
49
+ } \
50
+ }
51
+
42
52
#define DISPATCH_ALIGNMENTS_16_8_4 (alignment, ...) \
43
53
switch (alignment) { \
44
54
INT_SWITCH_CASE (k_alignment, 16 , __VA_ARGS__); \
@@ -493,6 +503,70 @@ constexpr size_t two_shot_all_reduce_max_num_threads = 512;
493
503
template <typename T, int alignment, int k_world_size>
494
504
static __launch_bounds__ (two_shot_all_reduce_max_num_threads) __global__
495
505
void two_shot_all_reduce_kernel(
506
+ T** input_ptrs,
507
+ T* output_ptr,
508
+ size_t input_offset,
509
+ size_t numel,
510
+ uint32_t ** signal_pads,
511
+ size_t rank,
512
+ size_t world_size) {
513
+ static_assert (alignment % sizeof (T) == 0 );
514
+ constexpr size_t numel_per_thread = alignment / sizeof (T);
515
+
516
+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
517
+ __syncthreads ();
518
+
519
+ const size_t numel_per_rank =
520
+ at::round_up (numel, alignment * world_size) / world_size;
521
+ const size_t start = numel_per_rank * rank;
522
+
523
+ auto offset = (blockDim .x * blockIdx .x + threadIdx .x ) * numel_per_thread;
524
+ auto stride = blockDim .x * gridDim .x * numel_per_thread;
525
+ for (size_t i = offset; i < numel_per_rank; i += stride) {
526
+ if (start + i >= numel) {
527
+ continue ;
528
+ }
529
+ auto vec = load_and_reduce<T, alignment, k_world_size>(
530
+ input_ptrs, rank, world_size, input_offset + start + i);
531
+ // store to local buffer
532
+ st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
533
+ }
534
+
535
+ __syncthreads ();
536
+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
537
+ __syncthreads ();
538
+ for (size_t i = offset; i < numel_per_rank; i += stride) {
539
+ Vec<alignment> tmp[k_world_size];
540
+ #pragma unroll k_world_size
541
+ for (size_t step = 0 ; step < k_world_size; ++step) {
542
+ size_t remote_rank = (rank + step) % k_world_size;
543
+ size_t remote_start = numel_per_rank * remote_rank;
544
+ if (remote_start + i >= numel) {
545
+ continue ;
546
+ }
547
+ tmp[step] = ld_vec<alignment>(
548
+ input_ptrs[remote_rank] + input_offset + remote_start + i);
549
+ }
550
+ #pragma unroll k_world_size
551
+ for (size_t step = 0 ; step < k_world_size; ++step) {
552
+ size_t remote_rank = (rank + step) % k_world_size;
553
+ size_t remote_start = numel_per_rank * remote_rank;
554
+ if (remote_start + i >= numel) {
555
+ continue ;
556
+ }
557
+ st_vec<alignment>(
558
+ output_ptr + remote_start + i, tmp[step]);
559
+ }
560
+ }
561
+ // need to make sure all blocks exit simultaneously so that the data
562
+ // is not corrupted by the subsequent kernels
563
+ __syncthreads ();
564
+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
565
+ }
566
+
567
+ template <typename T, int alignment, int k_world_size>
568
+ static __launch_bounds__ (two_shot_all_reduce_max_num_threads) __global__
569
+ void two_shot_all_reduce_kernel_inplace(
496
570
T** input_ptrs,
497
571
size_t input_offset,
498
572
size_t numel,
@@ -528,8 +602,9 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
528
602
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
529
603
}
530
604
531
- at::Tensor two_shot_all_reduce_ (
605
+ at::Tensor two_shot_all_reduce_impl (
532
606
at::Tensor input,
607
+ std::optional<at::Tensor> output,
533
608
std::string reduce_op,
534
609
std::string group_name) {
535
610
TORCH_CHECK (
@@ -546,6 +621,14 @@ at::Tensor two_shot_all_reduce_(
546
621
const size_t alignment =
547
622
get_and_verify_alignment (input, " two_shot_all_reduce" );
548
623
624
+ if (output.has_value ()) {
625
+ const size_t output_alignment =
626
+ get_and_verify_alignment (*output, " two_shot_all_reduce" );
627
+ TORCH_CHECK (
628
+ alignment <= output_alignment,
629
+ " two_shot_all_reduce: output alignment must be equal to or larger than input." );
630
+ }
631
+
549
632
int num_blocks = 0 , num_threads = 0 ;
550
633
init_elementwise_launch_config (
551
634
input.numel (),
@@ -557,30 +640,73 @@ at::Tensor two_shot_all_reduce_(
557
640
num_blocks,
558
641
num_threads);
559
642
560
- AT_DISPATCH_FLOAT_AND_BFLOAT16 (
561
- input.scalar_type (), " two_shot_all_reduce" , [&]() {
562
- DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
563
- DISPATCH_WORLD_SIZES (symm_mem->get_world_size (), [&]() {
564
- two_shot_all_reduce_kernel<scalar_t , k_alignment, k_world_size>
565
- <<<num_blocks,
566
- num_threads,
567
- 0 ,
568
- at::cuda::getCurrentCUDAStream ()>>>(
569
- reinterpret_cast <scalar_t **>(
570
- symm_mem->get_buffer_ptrs_dev ()),
571
- input.storage_offset(),
572
- input.numel(),
573
- reinterpret_cast<uint32_t**>(
574
- symm_mem->get_signal_pad_ptrs_dev ()),
575
- symm_mem->get_rank(),
576
- symm_mem->get_world_size());
577
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
643
+ if (!output.has_value ()) {
644
+ AT_DISPATCH_FLOAT_AND_BFLOAT16 (
645
+ input.scalar_type (), " two_shot_all_reduce" , [&]() {
646
+ DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
647
+ DISPATCH_WORLD_SIZES (symm_mem->get_world_size (), [&]() {
648
+ two_shot_all_reduce_kernel_inplace<
649
+ scalar_t ,
650
+ k_alignment,
651
+ k_world_size>
652
+ <<<num_blocks,
653
+ num_threads,
654
+ 0 ,
655
+ at::cuda::getCurrentCUDAStream ()>>>(
656
+ reinterpret_cast <scalar_t **>(
657
+ symm_mem->get_buffer_ptrs_dev ()),
658
+ input.storage_offset(),
659
+ input.numel(),
660
+ reinterpret_cast<uint32_t**>(
661
+ symm_mem->get_signal_pad_ptrs_dev ()),
662
+ symm_mem->get_rank(),
663
+ symm_mem->get_world_size());
664
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
665
+ });
578
666
});
579
667
});
580
- });
581
- return input;
668
+ return input;
669
+ } else {
670
+ AT_DISPATCH_FLOAT_AND_BFLOAT16 (
671
+ input.scalar_type (), " two_shot_all_reduce" , [&]() {
672
+ DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
673
+ DISPATCH_WORLD_SIZES_NO_DEFAULT (symm_mem->get_world_size (), [&]() {
674
+ two_shot_all_reduce_kernel<scalar_t , k_alignment, k_world_size>
675
+ <<<num_blocks,
676
+ num_threads,
677
+ 0 ,
678
+ at::cuda::getCurrentCUDAStream ()>>>(
679
+ reinterpret_cast <scalar_t **>(
680
+ symm_mem->get_buffer_ptrs_dev ()),
681
+ output->data_ptr<scalar_t>(),
682
+ input.storage_offset(),
683
+ input.numel(),
684
+ reinterpret_cast<uint32_t**>(
685
+ symm_mem->get_signal_pad_ptrs_dev ()),
686
+ symm_mem->get_rank(),
687
+ symm_mem->get_world_size());
688
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
689
+ });
690
+ });
691
+ });
692
+ return *output;
693
+ }
694
+ }
695
+
696
+ at::Tensor two_shot_all_reduce_ (
697
+ at::Tensor input,
698
+ std::string reduce_op,
699
+ std::string group_name) {
700
+ return two_shot_all_reduce_impl (input, std::nullopt , reduce_op, group_name);
582
701
}
583
702
703
+ at::Tensor two_shot_all_reduce_out (
704
+ at::Tensor input,
705
+ std::string reduce_op,
706
+ std::string group_name,
707
+ at::Tensor output) {
708
+ return two_shot_all_reduce_impl (input, output, reduce_op, group_name);
709
+ }
584
710
} // namespace
585
711
#endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
586
712
@@ -713,6 +839,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
713
839
m.impl (" one_shot_all_reduce" , ::one_shot_all_reduce);
714
840
m.impl (" one_shot_all_reduce_out" , ::one_shot_all_reduce_out);
715
841
m.impl (" two_shot_all_reduce_" , ::two_shot_all_reduce_);
842
+ m.impl (" two_shot_all_reduce_out" , ::two_shot_all_reduce_out);
843
+
716
844
m.impl (" _async_input_mm" , c10d::cuda::detail::async_input_mm);
717
845
#endif
718
846
m.impl (" stream_write_value32_" , ::stream_write_value32_);
0 commit comments