diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 7ee9e7d8a66f..50a85b92dbb7 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -103,29 +103,6 @@ def test_reduce_scatter(self, pin_layout): for ordinal, value in results.items(): np.testing.assert_array_equal(value, [-ordinal]) - @staticmethod - def _scatter(): - dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() - world_size = xr.world_size() - tensors = None - if xr.global_ordinal() == 0: - tensors = [ - torch.tensor([i], device=device, dtype=torch.float) - for i in range(world_size) - ] - - output_tensor = torch.tensor([-1], dtype=torch.float, device=device) - dist.scatter(output_tensor, tensors, src=0) - return output_tensor.cpu() - - def test_scatter(self): - """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] - on device 0, then scatters it. Device i should therefore receive [i].""" - results = pjrt.run_multiprocess(self._scatter) - for ordinal, value in results.items(): - np.testing.assert_array_equal(value, [ordinal]) - @staticmethod def _all_to_all(pin_layout): device = torch_xla.device() @@ -359,6 +336,49 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + @staticmethod + def _scatter(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + tensors = None + if xr.global_ordinal() == 0: + tensors = [ + torch.tensor([i], device=device, dtype=torch.float) + for i in range(world_size) + ] + + output_tensor = torch.tensor([-1], dtype=torch.float, device=device) + dist.scatter(output_tensor, tensors, src=0) + return output_tensor.cpu() + + def test_scatter(self): + """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] + on device 0, then scatters it. Device i should therefore receive [i].""" + results = pjrt.run_multiprocess(self._scatter) + for ordinal, value in results.items(): + np.testing.assert_array_equal(value, [ordinal]) + + @staticmethod + def _reduce(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + input = torch.tensor([xr.global_ordinal()], + dtype=torch.float, + device=device) + dist.reduce(input, dst=0, op=dist.ReduceOp.SUM) + + return input.cpu() + + def test_reduce(self): + results = pjrt.run_multiprocess(self._reduce) + for ordinal, value in results.items(): + if ordinal == 0: + expected = sum(range(tpu.num_expected_global_devices())) + else: + expected = ordinal + np.testing.assert_array_equal(value, [expected]) + if __name__ == '__main__': absltest.main() diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 99b721a4fa16..e0c5af31bf2a 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -356,7 +356,6 @@ def test_barrier(self): dist.barrier() @parameterized.parameters( - 'reduce', 'allreduce_coalesced', 'alltoall', 'gather', diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index daef50c243dc..7c14f0f23f69 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -1,11 +1,12 @@ import torch import torch.distributed as dist +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import rendezvous import logging import os -from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -224,11 +225,24 @@ def _reduce_scatter_base(self, output_tensor, input_tensor, opts): def barrier(self, opts): return _ret_work([]) - # Call site: - # https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L1417 - # `reduce` is not needed by DeepSpeed for now. - def reduce(self, *args): - raise NotImplementedError + # Called by torch.distributed.reduce. Call site example: + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L2925 + # Tensors are reduced but result is only saved on dst device. + # Input tensor is unchanged on all other devices. + # This is an inefficient operation. In order to avoid XLA deadlocks it + # performs redundant reductions on all devices and materializes the result. + def reduce(self, tensors: list[torch.Tensor], opts: ReduceOptions): + rank = xr.global_ordinal() + dst = opts.rootRank + reduce_type = self._get_reduce_type(opts.reduceOp) + for tensor in tensors: + result = xm.all_reduce(reduce_type, inputs=tensor) + torch_xla.sync() + + if rank == dst: + tensor.copy_(result) + + return _ret_work(tensors) def allreduce_coalesced(self, *args): raise NotImplementedError