Skip to content

Commit 1700599

Browse files
ngimelpytorchmergebot
authored andcommitted
Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (pytorch#150129)
Per title, we want to be able to use it even if inputs are not registered. Separate copy would add latency, and one-shot is all about the lowest possible latency. Pull Request resolved: pytorch#150129 Approved by: https://github.com/xw285cornell
1 parent 414b9ae commit 1700599

File tree

3 files changed

+106
-23
lines changed

3 files changed

+106
-23
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -861,22 +861,32 @@ def test_multimem_one_shot_all_reduce(
861861

862862
@skipIfRocm
863863
@skip_if_lt_x_gpu(4)
864-
@parametrize("dtype", [torch.float, torch.bfloat16])
865-
@parametrize("align_bytes", [4, 8, 16])
866-
@parametrize("size_bytes", [4, 8192, 8196])
867-
def test_one_shot_all_reduce(
868-
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
869-
) -> None:
864+
def test_one_shot_all_reduce(self) -> None:
870865
self._init_process()
871866
group_name = dist.group.WORLD.group_name
872867

873-
inp = symm_mem.empty(
874-
size_bytes // dtype.itemsize, dtype=dtype, device=self.device
875-
).normal_()
876-
symm_mem.rendezvous(inp, group=group_name)
877-
878-
res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name)
879-
self._verify_all_reduce_result(inp, res)
868+
for dtype, size_bytes, align_bytes, copy, offset in itertools.product(
869+
[torch.float, torch.bfloat16],
870+
[4, 8192, 8196],
871+
[4, 8, 16],
872+
[True, False],
873+
[0, 16],
874+
):
875+
inp = symm_mem.empty(
876+
size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device
877+
)
878+
symm_mem.rendezvous(inp, group=group_name)
879+
if not copy:
880+
inp.normal_()
881+
res = torch.ops.symm_mem.one_shot_all_reduce(
882+
inp[offset:], "sum", group_name
883+
)
884+
if copy:
885+
local_inp = torch.randn_like(inp[offset:])
886+
res = torch.ops.symm_mem.one_shot_all_reduce_copy(
887+
inp[offset:], local_inp, "sum", group_name
888+
)
889+
self._verify_all_reduce_result(local_inp if copy else inp[offset:], res)
880890

881891
dist.destroy_process_group()
882892

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -397,27 +397,34 @@ at::Tensor multimem_all_gather_out(
397397
// One-shot all-reduce is register-intensive because it stages values loaded
398398
// from peers in registers before performing reduction. Setting the thread
399399
// count to 512 to prevent/alleviate register spill.
400-
constexpr size_t one_shot_all_reduce_max_num_blocks = 8;
400+
constexpr size_t one_shot_all_reduce_max_num_blocks = 24;
401401
constexpr size_t one_shot_all_reduce_max_num_threads = 512;
402402
403403
template <typename T, int alignment, int k_world_size>
404404
static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
405405
void one_shot_all_reduce_kernel(
406406
T** input_ptrs,
407407
T* output_ptr,
408+
T* input_ptr,
408409
size_t input_offset,
409410
size_t numel,
410411
uint32_t** signal_pads,
411412
size_t rank,
412413
size_t world_size) {
413414
static_assert(alignment % sizeof(T) == 0);
414415
constexpr size_t numel_per_thread = alignment / sizeof(T);
415-
416-
sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
417-
__syncthreads();
418-
416+
// copy input to shared ptr
419417
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
420418
auto stride = blockDim.x * gridDim.x * numel_per_thread;
419+
if (input_ptr) {
420+
for (size_t i = offset; i < numel; i += stride) {
421+
Vec<alignment> vec_st = ld_vec<alignment>(input_ptr + i);
422+
st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
423+
}
424+
}
425+
// TODO make it sync with one block for no-copy case
426+
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
427+
__syncthreads();
421428
422429
for (size_t i = offset; i < numel; i += stride) {
423430
auto vec = load_and_reduce<T, alignment, k_world_size>(
@@ -426,11 +433,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
426433
}
427434
428435
__syncthreads();
429-
sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
436+
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
430437
}
431438
432-
at::Tensor one_shot_all_reduce_out(
439+
at::Tensor one_shot_all_reduce_out_impl(
433440
const at::Tensor& input,
441+
const std::optional<at::Tensor>& local_input,
434442
std::string reduce_op,
435443
std::string group_name,
436444
at::Tensor out) {
@@ -440,18 +448,35 @@ at::Tensor one_shot_all_reduce_out(
440448
out.is_contiguous(), "one_shot_all_reduce: output must be contiguous.");
441449
TORCH_CHECK(
442450
out.sizes() == input.sizes(),
443-
"one_shot_all_reduce: input/output size mismatch.");
451+
"one_shot_all_reduce: input/output size mismatch, input.sizes(): ",
452+
input.sizes(),
453+
", output.sizes(): ",
454+
out.sizes());
444455
TORCH_CHECK(
445456
reduce_op == "sum",
446457
"one_shot_all_reduce: only sum is supported for now.");
447-
458+
if (local_input.has_value()) {
459+
TORCH_CHECK(
460+
local_input->is_contiguous(),
461+
"one_shot_all_reduce: local input must be contiguous.");
462+
TORCH_CHECK(
463+
local_input->numel() <= input.numel(),
464+
"one_shot_all_reduce: local input size must be smaller than symm buffer size.");
465+
}
448466
auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name);
449467
TORCH_CHECK(
450468
symm_mem != nullptr,
451469
"one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
452470
453471
const size_t alignment =
454472
get_and_verify_alignment(input, "one_shot_all_reduce");
473+
if (local_input.has_value()) {
474+
const size_t local_alignment =
475+
get_and_verify_alignment(*local_input, "one_shot_all_reduce");
476+
TORCH_CHECK(
477+
alignment == local_alignment,
478+
"one_shot_all_reduce: local input and symm buffer must have the same alignment.");
479+
}
455480
456481
int num_blocks = 0, num_threads = 0;
457482
init_elementwise_launch_config(
@@ -476,6 +501,8 @@ at::Tensor one_shot_all_reduce_out(
476501
reinterpret_cast<scalar_t**>(
477502
symm_mem->get_buffer_ptrs_dev()),
478503
out.data_ptr<scalar_t>(),
504+
local_input.has_value() ? local_input->data_ptr<scalar_t>()
505+
: nullptr,
479506
input.storage_offset(),
480507
input.numel(),
481508
reinterpret_cast<uint32_t**>(
@@ -489,12 +516,42 @@ at::Tensor one_shot_all_reduce_out(
489516
return out;
490517
}
491518
519+
at::Tensor one_shot_all_reduce_out(
520+
const at::Tensor& input,
521+
std::string reduce_op,
522+
std::string group_name,
523+
at::Tensor out) {
524+
return one_shot_all_reduce_out_impl(
525+
input, std::nullopt, reduce_op, group_name, out);
526+
}
527+
528+
at::Tensor one_shot_all_reduce_copy_out(
529+
const at::Tensor& input,
530+
const at::Tensor& local_input,
531+
std::string reduce_op,
532+
std::string group_name,
533+
at::Tensor out) {
534+
return one_shot_all_reduce_out_impl(
535+
input, local_input, reduce_op, group_name, out);
536+
}
537+
492538
at::Tensor one_shot_all_reduce(
493539
const at::Tensor& input,
494540
std::string reduce_op,
495541
std::string group_name) {
496542
auto out = at::empty_like(input);
497-
return one_shot_all_reduce_out(input, reduce_op, group_name, out);
543+
return one_shot_all_reduce_out_impl(
544+
input, std::nullopt, reduce_op, group_name, out);
545+
}
546+
547+
at::Tensor one_shot_all_reduce_copy(
548+
const at::Tensor& input,
549+
const at::Tensor& local_input,
550+
std::string reduce_op,
551+
std::string group_name) {
552+
auto out = at::empty_like(local_input);
553+
return one_shot_all_reduce_out_impl(
554+
input, local_input, reduce_op, group_name, out);
498555
}
499556
500557
constexpr size_t two_shot_all_reduce_max_num_blocks = 24;
@@ -838,6 +895,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
838895
m.impl("multimem_all_gather_out", ::multimem_all_gather_out);
839896
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
840897
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
898+
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
899+
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
841900
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
842901
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
843902

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,14 @@ at::Tensor one_shot_all_reduce_meta(
217217
return at::empty_like(input);
218218
}
219219

220+
at::Tensor one_shot_all_reduce_copy_meta(
221+
const at::Tensor& symm_buffer,
222+
const at::Tensor& local_input,
223+
std::string reduce_op,
224+
std::string group_name) {
225+
return at::empty_like(local_input);
226+
}
227+
220228
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
221229
m.def(
222230
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
@@ -230,6 +238,11 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
230238
"one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor");
231239
m.def(
232240
"one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)");
241+
m.def(
242+
"one_shot_all_reduce_copy(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name) -> Tensor");
243+
m.def(
244+
"one_shot_all_reduce_copy_out(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)");
245+
233246
m.def(
234247
"two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
235248

@@ -256,6 +269,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
256269

257270
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
258271
m.impl("one_shot_all_reduce", one_shot_all_reduce_meta);
272+
m.impl("one_shot_all_reduce_copy", one_shot_all_reduce_copy_meta);
259273
}
260274

261275
} // namespace

0 commit comments

Comments
 (0)