Skip to content

Commit 414b9ae

Browse files
ngimelpytorchmergebot
authored andcommitted
enable out variant of 2-shot reduction (pytorch#150153)
Per title, this version uses symm mem input both as input source and as a work buffer, so input is modified after the end (similar to what fbgemm car reduction does). It is intended to be wrapped in an op that would first copy the real inputs to symm mem buffers that wouldn't be exposed. Pull Request resolved: pytorch#150153 Approved by: https://github.com/xw285cornell
1 parent 7e7e569 commit 414b9ae

File tree

3 files changed

+182
-45
lines changed

3 files changed

+182
-45
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: c10d"]
22

3+
import itertools
34
import os
45
from unittest import skipIf
56

@@ -881,34 +882,38 @@ def test_one_shot_all_reduce(
881882

882883
@skipIfRocm
883884
@skip_if_lt_x_gpu(4)
884-
@parametrize("dtype", [torch.float, torch.bfloat16])
885-
@parametrize("align_bytes", [4, 8, 16])
886-
@parametrize("size_bytes", [4, 8192, 8196])
887-
def test_two_shot_all_reduce(
888-
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
889-
) -> None:
885+
def test_two_shot_all_reduce(self) -> None:
890886
self._init_process()
891887
group_name = dist.group.WORLD.group_name
892888

893-
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
894-
symm_mem.rendezvous(t, group=group_name)
895-
896-
self.assertTrue(t.data_ptr() % 16 == 0)
897-
self.assertTrue(align_bytes % t.element_size() == 0)
898-
self.assertTrue(size_bytes % t.element_size() == 0)
899-
900-
shift = align_bytes // t.element_size()
901-
numel = size_bytes // t.element_size()
902-
res = t[shift : shift + numel]
903-
res.normal_()
904-
inp = res.clone()
905-
906-
torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name)
889+
for dtype, size_bytes, align_bytes, inplace in itertools.product(
890+
[torch.float, torch.bfloat16],
891+
[4, 8192, 8196],
892+
[4, 8, 16],
893+
[True, False],
894+
):
895+
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
896+
symm_mem.rendezvous(t, group=group_name)
897+
898+
self.assertTrue(t.data_ptr() % 16 == 0)
899+
self.assertTrue(align_bytes % t.element_size() == 0)
900+
self.assertTrue(size_bytes % t.element_size() == 0)
901+
902+
shift = align_bytes // t.element_size()
903+
numel = size_bytes // t.element_size()
904+
res = t[shift : shift + numel]
905+
res.normal_().fill_(1)
906+
inp = res.clone()
907+
if not inplace:
908+
out = torch.empty_like(inp)
909+
torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out)
910+
else:
911+
torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name)
907912

908-
# Head and tail should not be written
909-
self.assertTrue(t[:shift].eq(0).all().item())
910-
self.assertTrue(t[shift + numel :].eq(0).all().item())
911-
self._verify_all_reduce_result(inp, res)
913+
# Head and tail should not be written
914+
self.assertTrue(t[:shift].eq(0).all().item())
915+
self.assertTrue(t[shift + numel :].eq(0).all().item())
916+
self._verify_all_reduce_result(inp, res if inplace else out)
912917

913918
dist.destroy_process_group()
914919

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@
3939
} \
4040
}
4141

42+
#define DISPATCH_WORLD_SIZES_NO_DEFAULT(world_size, ...) \
43+
switch (world_size) { \
44+
INT_SWITCH_CASE(k_world_size, 8, __VA_ARGS__); \
45+
INT_SWITCH_CASE(k_world_size, 4, __VA_ARGS__); \
46+
INT_SWITCH_CASE(k_world_size, 2, __VA_ARGS__); \
47+
default: { \
48+
TORCH_CHECK(false, "Not implemented for world_size=", world_size); \
49+
} \
50+
}
51+
4252
#define DISPATCH_ALIGNMENTS_16_8_4(alignment, ...) \
4353
switch (alignment) { \
4454
INT_SWITCH_CASE(k_alignment, 16, __VA_ARGS__); \
@@ -493,6 +503,70 @@ constexpr size_t two_shot_all_reduce_max_num_threads = 512;
493503
template <typename T, int alignment, int k_world_size>
494504
static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
495505
void two_shot_all_reduce_kernel(
506+
T** input_ptrs,
507+
T* output_ptr,
508+
size_t input_offset,
509+
size_t numel,
510+
uint32_t** signal_pads,
511+
size_t rank,
512+
size_t world_size) {
513+
static_assert(alignment % sizeof(T) == 0);
514+
constexpr size_t numel_per_thread = alignment / sizeof(T);
515+
516+
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
517+
__syncthreads();
518+
519+
const size_t numel_per_rank =
520+
at::round_up(numel, alignment * world_size) / world_size;
521+
const size_t start = numel_per_rank * rank;
522+
523+
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
524+
auto stride = blockDim.x * gridDim.x * numel_per_thread;
525+
for (size_t i = offset; i < numel_per_rank; i += stride) {
526+
if (start + i >= numel) {
527+
continue;
528+
}
529+
auto vec = load_and_reduce<T, alignment, k_world_size>(
530+
input_ptrs, rank, world_size, input_offset + start + i);
531+
// store to local buffer
532+
st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
533+
}
534+
535+
__syncthreads();
536+
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
537+
__syncthreads();
538+
for (size_t i = offset; i < numel_per_rank; i += stride) {
539+
Vec<alignment> tmp[k_world_size];
540+
#pragma unroll k_world_size
541+
for (size_t step = 0; step < k_world_size; ++step) {
542+
size_t remote_rank = (rank + step) % k_world_size;
543+
size_t remote_start = numel_per_rank * remote_rank;
544+
if (remote_start + i >= numel) {
545+
continue;
546+
}
547+
tmp[step] = ld_vec<alignment>(
548+
input_ptrs[remote_rank] + input_offset + remote_start + i);
549+
}
550+
#pragma unroll k_world_size
551+
for (size_t step = 0; step < k_world_size; ++step) {
552+
size_t remote_rank = (rank + step) % k_world_size;
553+
size_t remote_start = numel_per_rank * remote_rank;
554+
if (remote_start + i >= numel) {
555+
continue;
556+
}
557+
st_vec<alignment>(
558+
output_ptr + remote_start + i, tmp[step]);
559+
}
560+
}
561+
// need to make sure all blocks exit simultaneously so that the data
562+
// is not corrupted by the subsequent kernels
563+
__syncthreads();
564+
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
565+
}
566+
567+
template <typename T, int alignment, int k_world_size>
568+
static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
569+
void two_shot_all_reduce_kernel_inplace(
496570
T** input_ptrs,
497571
size_t input_offset,
498572
size_t numel,
@@ -528,8 +602,9 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
528602
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
529603
}
530604
531-
at::Tensor two_shot_all_reduce_(
605+
at::Tensor two_shot_all_reduce_impl(
532606
at::Tensor input,
607+
std::optional<at::Tensor> output,
533608
std::string reduce_op,
534609
std::string group_name) {
535610
TORCH_CHECK(
@@ -546,6 +621,14 @@ at::Tensor two_shot_all_reduce_(
546621
const size_t alignment =
547622
get_and_verify_alignment(input, "two_shot_all_reduce");
548623
624+
if (output.has_value()) {
625+
const size_t output_alignment =
626+
get_and_verify_alignment(*output, "two_shot_all_reduce");
627+
TORCH_CHECK(
628+
alignment <= output_alignment,
629+
"two_shot_all_reduce: output alignment must be equal to or larger than input.");
630+
}
631+
549632
int num_blocks = 0, num_threads = 0;
550633
init_elementwise_launch_config(
551634
input.numel(),
@@ -557,30 +640,73 @@ at::Tensor two_shot_all_reduce_(
557640
num_blocks,
558641
num_threads);
559642
560-
AT_DISPATCH_FLOAT_AND_BFLOAT16(
561-
input.scalar_type(), "two_shot_all_reduce", [&]() {
562-
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
563-
DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() {
564-
two_shot_all_reduce_kernel<scalar_t, k_alignment, k_world_size>
565-
<<<num_blocks,
566-
num_threads,
567-
0,
568-
at::cuda::getCurrentCUDAStream()>>>(
569-
reinterpret_cast<scalar_t**>(
570-
symm_mem->get_buffer_ptrs_dev()),
571-
input.storage_offset(),
572-
input.numel(),
573-
reinterpret_cast<uint32_t**>(
574-
symm_mem->get_signal_pad_ptrs_dev()),
575-
symm_mem->get_rank(),
576-
symm_mem->get_world_size());
577-
C10_CUDA_KERNEL_LAUNCH_CHECK();
643+
if (!output.has_value()) {
644+
AT_DISPATCH_FLOAT_AND_BFLOAT16(
645+
input.scalar_type(), "two_shot_all_reduce", [&]() {
646+
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
647+
DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() {
648+
two_shot_all_reduce_kernel_inplace<
649+
scalar_t,
650+
k_alignment,
651+
k_world_size>
652+
<<<num_blocks,
653+
num_threads,
654+
0,
655+
at::cuda::getCurrentCUDAStream()>>>(
656+
reinterpret_cast<scalar_t**>(
657+
symm_mem->get_buffer_ptrs_dev()),
658+
input.storage_offset(),
659+
input.numel(),
660+
reinterpret_cast<uint32_t**>(
661+
symm_mem->get_signal_pad_ptrs_dev()),
662+
symm_mem->get_rank(),
663+
symm_mem->get_world_size());
664+
C10_CUDA_KERNEL_LAUNCH_CHECK();
665+
});
578666
});
579667
});
580-
});
581-
return input;
668+
return input;
669+
} else {
670+
AT_DISPATCH_FLOAT_AND_BFLOAT16(
671+
input.scalar_type(), "two_shot_all_reduce", [&]() {
672+
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
673+
DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() {
674+
two_shot_all_reduce_kernel<scalar_t, k_alignment, k_world_size>
675+
<<<num_blocks,
676+
num_threads,
677+
0,
678+
at::cuda::getCurrentCUDAStream()>>>(
679+
reinterpret_cast<scalar_t**>(
680+
symm_mem->get_buffer_ptrs_dev()),
681+
output->data_ptr<scalar_t>(),
682+
input.storage_offset(),
683+
input.numel(),
684+
reinterpret_cast<uint32_t**>(
685+
symm_mem->get_signal_pad_ptrs_dev()),
686+
symm_mem->get_rank(),
687+
symm_mem->get_world_size());
688+
C10_CUDA_KERNEL_LAUNCH_CHECK();
689+
});
690+
});
691+
});
692+
return *output;
693+
}
694+
}
695+
696+
at::Tensor two_shot_all_reduce_(
697+
at::Tensor input,
698+
std::string reduce_op,
699+
std::string group_name) {
700+
return two_shot_all_reduce_impl(input, std::nullopt, reduce_op, group_name);
582701
}
583702
703+
at::Tensor two_shot_all_reduce_out(
704+
at::Tensor input,
705+
std::string reduce_op,
706+
std::string group_name,
707+
at::Tensor output) {
708+
return two_shot_all_reduce_impl(input, output, reduce_op, group_name);
709+
}
584710
} // namespace
585711
#endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
586712
@@ -713,6 +839,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
713839
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
714840
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
715841
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
842+
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
843+
716844
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
717845
#endif
718846
m.impl("stream_write_value32_", ::stream_write_value32_);

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
233233
m.def(
234234
"two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
235235

236+
// note this implementation also modified the input tensor
237+
m.def(
238+
"two_shot_all_reduce_out(Tensor(a!) input, str reduce_op, str group_name, Tensor(b!) output) -> Tensor(b!)");
239+
236240
// An mm that supports consuming asynchronous input. It guarantees the
237241
// following rasterization order, and that the corresponding signal arrives
238242
// before an input chunk is consumed.

0 commit comments

Comments
 (0)