@@ -397,27 +397,34 @@ at::Tensor multimem_all_gather_out(
397
397
// One-shot all-reduce is register-intensive because it stages values loaded
398
398
// from peers in registers before performing reduction. Setting the thread
399
399
// count to 512 to prevent/alleviate register spill.
400
- constexpr size_t one_shot_all_reduce_max_num_blocks = 8 ;
400
+ constexpr size_t one_shot_all_reduce_max_num_blocks = 24 ;
401
401
constexpr size_t one_shot_all_reduce_max_num_threads = 512 ;
402
402
403
403
template <typename T, int alignment, int k_world_size>
404
404
static __launch_bounds__ (one_shot_all_reduce_max_num_threads) __global__
405
405
void one_shot_all_reduce_kernel(
406
406
T** input_ptrs,
407
407
T* output_ptr,
408
+ T* input_ptr,
408
409
size_t input_offset,
409
410
size_t numel,
410
411
uint32_t ** signal_pads,
411
412
size_t rank,
412
413
size_t world_size) {
413
414
static_assert (alignment % sizeof (T) == 0 );
414
415
constexpr size_t numel_per_thread = alignment / sizeof (T);
415
-
416
- sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
417
- __syncthreads ();
418
-
416
+ // copy input to shared ptr
419
417
auto offset = (blockDim .x * blockIdx .x + threadIdx .x ) * numel_per_thread;
420
418
auto stride = blockDim .x * gridDim .x * numel_per_thread;
419
+ if (input_ptr) {
420
+ for (size_t i = offset; i < numel; i += stride) {
421
+ Vec<alignment> vec_st = ld_vec<alignment>(input_ptr + i);
422
+ st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
423
+ }
424
+ }
425
+ // TODO make it sync with one block for no-copy case
426
+ sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
427
+ __syncthreads ();
421
428
422
429
for (size_t i = offset; i < numel; i += stride) {
423
430
auto vec = load_and_reduce<T, alignment, k_world_size>(
@@ -426,11 +433,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
426
433
}
427
434
428
435
__syncthreads ();
429
- sync_remote_blocks<std::memory_order_relaxed >(signal_pads, rank, world_size);
436
+ sync_remote_blocks<std::memory_order_acq_rel >(signal_pads, rank, world_size);
430
437
}
431
438
432
- at::Tensor one_shot_all_reduce_out (
439
+ at::Tensor one_shot_all_reduce_out_impl (
433
440
const at::Tensor& input,
441
+ const std::optional<at::Tensor>& local_input,
434
442
std::string reduce_op,
435
443
std::string group_name,
436
444
at::Tensor out) {
@@ -440,18 +448,35 @@ at::Tensor one_shot_all_reduce_out(
440
448
out.is_contiguous (), " one_shot_all_reduce: output must be contiguous." );
441
449
TORCH_CHECK (
442
450
out.sizes () == input.sizes (),
443
- " one_shot_all_reduce: input/output size mismatch." );
451
+ " one_shot_all_reduce: input/output size mismatch, input.sizes(): " ,
452
+ input.sizes (),
453
+ " , output.sizes(): " ,
454
+ out.sizes ());
444
455
TORCH_CHECK (
445
456
reduce_op == " sum" ,
446
457
" one_shot_all_reduce: only sum is supported for now." );
447
-
458
+ if (local_input.has_value ()) {
459
+ TORCH_CHECK (
460
+ local_input->is_contiguous (),
461
+ " one_shot_all_reduce: local input must be contiguous." );
462
+ TORCH_CHECK (
463
+ local_input->numel () <= input.numel (),
464
+ " one_shot_all_reduce: local input size must be smaller than symm buffer size." );
465
+ }
448
466
auto symm_mem = c10d::symmetric_memory::rendezvous (input, group_name);
449
467
TORCH_CHECK (
450
468
symm_mem != nullptr ,
451
469
" one_shot_all_reduce: input must be allocated with empty_strided_p2p()." );
452
470
453
471
const size_t alignment =
454
472
get_and_verify_alignment (input, " one_shot_all_reduce" );
473
+ if (local_input.has_value ()) {
474
+ const size_t local_alignment =
475
+ get_and_verify_alignment (*local_input, " one_shot_all_reduce" );
476
+ TORCH_CHECK (
477
+ alignment == local_alignment,
478
+ " one_shot_all_reduce: local input and symm buffer must have the same alignment." );
479
+ }
455
480
456
481
int num_blocks = 0 , num_threads = 0 ;
457
482
init_elementwise_launch_config (
@@ -476,6 +501,8 @@ at::Tensor one_shot_all_reduce_out(
476
501
reinterpret_cast <scalar_t **>(
477
502
symm_mem->get_buffer_ptrs_dev ()),
478
503
out.data_ptr<scalar_t>(),
504
+ local_input.has_value() ? local_input->data_ptr<scalar_t>()
505
+ : nullptr,
479
506
input.storage_offset(),
480
507
input.numel(),
481
508
reinterpret_cast<uint32_t**>(
@@ -489,12 +516,42 @@ at::Tensor one_shot_all_reduce_out(
489
516
return out;
490
517
}
491
518
519
+ at::Tensor one_shot_all_reduce_out (
520
+ const at::Tensor& input,
521
+ std::string reduce_op,
522
+ std::string group_name,
523
+ at::Tensor out) {
524
+ return one_shot_all_reduce_out_impl (
525
+ input, std::nullopt , reduce_op, group_name, out);
526
+ }
527
+
528
+ at::Tensor one_shot_all_reduce_copy_out (
529
+ const at::Tensor& input,
530
+ const at::Tensor& local_input,
531
+ std::string reduce_op,
532
+ std::string group_name,
533
+ at::Tensor out) {
534
+ return one_shot_all_reduce_out_impl (
535
+ input, local_input, reduce_op, group_name, out);
536
+ }
537
+
492
538
at::Tensor one_shot_all_reduce (
493
539
const at::Tensor& input,
494
540
std::string reduce_op,
495
541
std::string group_name) {
496
542
auto out = at::empty_like (input);
497
- return one_shot_all_reduce_out (input, reduce_op, group_name, out);
543
+ return one_shot_all_reduce_out_impl (
544
+ input, std::nullopt , reduce_op, group_name, out);
545
+ }
546
+
547
+ at::Tensor one_shot_all_reduce_copy (
548
+ const at::Tensor& input,
549
+ const at::Tensor& local_input,
550
+ std::string reduce_op,
551
+ std::string group_name) {
552
+ auto out = at::empty_like (local_input);
553
+ return one_shot_all_reduce_out_impl (
554
+ input, local_input, reduce_op, group_name, out);
498
555
}
499
556
500
557
constexpr size_t two_shot_all_reduce_max_num_blocks = 24 ;
@@ -838,6 +895,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
838
895
m.impl (" multimem_all_gather_out" , ::multimem_all_gather_out);
839
896
m.impl (" one_shot_all_reduce" , ::one_shot_all_reduce);
840
897
m.impl (" one_shot_all_reduce_out" , ::one_shot_all_reduce_out);
898
+ m.impl (" one_shot_all_reduce_copy" , ::one_shot_all_reduce_copy);
899
+ m.impl (" one_shot_all_reduce_copy_out" , ::one_shot_all_reduce_copy_out);
841
900
m.impl (" two_shot_all_reduce_" , ::two_shot_all_reduce_);
842
901
m.impl (" two_shot_all_reduce_out" , ::two_shot_all_reduce_out);
843
902
0 commit comments