diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index 3b05f06a71..49a8f79695 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -186,9 +186,6 @@ def gather(data: Tensor, Calling ``gather`` in non-distributed environment dose nothing and just returns a list containing :attr:`data` itself. - Note: - ``NCCL`` backend does not support ``gather``. - Note: Unlike PyTorch ``torch.distributed.gather``, :meth:`gather` in MMEngine does not pass in an empty list ``gather_list`` and returns @@ -251,7 +248,15 @@ def gather(data: Tensor, else: gather_list = [] - torch_dist.gather(data, gather_list, dst, group) + # Check if the backend is NCCL + if get_backend(group) == torch_dist.Backend.NCCL: + if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): + torch_dist.gather(data, gather_list, dst, group) + else: + if get_rank(group) == dst: + gather_list = all_gather(data, group) + else: + torch_dist.gather(data, gather_list, dst, group) if get_rank(group) == dst: return cast_data_device(gather_list, input_device) # type: ignore diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index d89f5eb878..d2c4e02525 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -415,6 +415,23 @@ def test_all_gather(self): torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])) + def test_gather(self): + self._init_dist_env(self.rank, self.world_size) + for device_type in ('cpu', 'cuda'): + data = torch.tensor([self.rank, self.rank + 1]).to(device_type) + dst = 0 + expected = [ + torch.tensor([0, 1]).to(device_type), + torch.tensor([1, 2]).to(device_type), + ] + gather_list = dist.gather(data, dst=dst, group=self.group) + if self.rank == dst: + for i in range(self.world_size): + self.assertTrue( + torch.allclose(gather_list[i], expected[i])) + else: + self.assertEqual(gather_list, []) + def test_broadcast_dist(self): self._init_dist_env(self.rank, self.world_size) for device_type in ('cpu', 'cuda'):