Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] No backend type associated with device type cpu #2477

Open
rballeba opened this issue Mar 27, 2024 · 11 comments
Open

[Bug] No backend type associated with device type cpu #2477

rballeba opened this issue Mar 27, 2024 · 11 comments
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x

Comments

@rballeba
Copy link

rballeba commented Mar 27, 2024

🐛 Bug

Metrics (predefined in library and custom implementations) using concatenation dist_reduce_fx="cat" and CPU computation compute_on_cpu=True raise an error when training in multiple GPUs (ddp). The concrete error is RuntimeError: No backend type associated with device type cpu.

To Reproduce

Code sample:

import torch
from lightning import Trainer, LightningModule
from torch.utils.data import DataLoader
from torchmetrics import AUROC


class LitModel(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.layer = torch.nn.Linear(1, 1)
        self.auroc = AUROC(task="binary", compute_on_cpu=True)

    def training_step(self, x):
        preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]).cuda()
        target = torch.tensor([0, 0, 1, 1, 1]).cuda()
        self.auroc(preds, target)
        self.log("train_auroc", self.auroc, on_step=True, on_epoch=True)
        loss = self.layer(x).mean()
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(torch.randn(32, 1), batch_size=1)

Stacktrace

Traceback (most recent call last):
  File "/home/ruben/Documents/PhD/Research/Topological Deep Learning/lightning/pythonProject/main.py", line 35, in <module>
    main()
  File "/home/ruben/Documents/PhD/Research/Topological Deep Learning/lightning/pythonProject/main.py", line 31, in main
    trainer.fit(model)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self.fit_loop.run()
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 206, in run
    self.on_advance_end()
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 376, in on_advance_end
    call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=False)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/callbacks/progress/tqdm_progress.py", line 281, in on_train_epoch_end
    self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/callbacks/progress/progress_bar.py", line 198, in get_metrics
    pbar_metrics = trainer.progress_bar_metrics
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1651, in progress_bar_metrics
    return self._logger_connector.progress_bar_metrics
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 253, in progress_bar_metrics
    metrics = self.metrics["pbar"]
              ^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 234, in metrics
    return self.trainer._results.metrics(on_step)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 483, in metrics
    value = self._get_cache(result_metric, on_step)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 447, in _get_cache
    result_metric.compute()
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 289, in wrapped_func
    self._computed = compute(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 254, in compute
    return self.value.compute()
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 611, in wrapped_func
    with self.sync_context(
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/contextlib.py", line 137, in __enter__
    return next(self.gen)
           ^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 582, in sync_context
    self.sync(
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 531, in sync
    self._sync_dist(dist_sync_fn, process_group=process_group)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 435, in _sync_dist
    output_dict = apply_to_collection(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 72, in apply_to_collection
    return _apply_to_collection_slow(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 125, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 96, in _apply_to_collection_slow
    return function(data, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/utilities/distributed.py", line 127, in gather_all_tensors
    torch.distributed.all_gather(local_sizes, local_size, group=group)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2808, in all_gather
    work = group.allgather([tensor_list], [tensor])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: No backend type associated with device type cpu

Expected behavior

Metric is computed properly merging the different lists in the differents processes in multi GPU training scenarios.

Environment

  • TorchMetrics version 1.3.2. Installed using pip
  • Python & PyTorch Version: 3.11 and 2.1.2, respectively.
  • Any other relevant information such as OS (e.g., Linux): Ubuntu 23.10

Additional context

Related bug in PyTorch Lightning

Lightning-AI/pytorch-lightning#18803

@rballeba rballeba added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 27, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda added the v1.3.x label Mar 28, 2024
@HGGshiwo
Copy link

I meet the same bug.

@SangbumChoi
Copy link

I also met the same bug

@Rbrq03
Copy link

Rbrq03 commented Jul 9, 2024

Hello there, any update in this issue?

@SkafteNicki
Copy link
Member

Hi all, thanks for reporting this issue. I am currently looking into what can be done on to solve this issue. The compute_on_cpu argument was sadly never tested for multi-gpu setups, only single GPU.

@xiuqhou
Copy link

xiuqhou commented Aug 19, 2024

Hi, @SkafteNicki
I also met the same bug when using MeanAveragePrecision, the error occurs in the default dist_sync_fn gather_all_tensors. It cannot successfully gather tensors when evaluation with multi-gpu setups. I hope the following function may be helpful:

def all_gather(data, group=None):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    # obtain Tensor size of each rank
    local_size = torch.tensor([tensor.numel()], device="cuda")
    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size, group)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    # we pad the tensor because torch all_gather does not support
    # gathering tensors of different shapes
    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
    if local_size != max_size:
        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
        tensor = torch.cat((tensor, padding), dim=0)
    dist.all_gather(tensor_list, tensor, group)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list

The function comes from the training reference from torchvision: https://github.com/pytorch/vision/blob/main/references/detection/utils.py

When I add the following codes to my script, I found it works in multi-gpus evaluation with compute_on_cpu=True. However, compute_on_cpu=False failed. Maybe gather_all_tensors can be extended with all_gather to apply to single gpu and multi-gpu, "computing_on_cpu" or not.

coco_evaluator = MeanAveragePrecision(iou_type=args.iou_type, backend=args.backend)
coco_evaluator.dist_sync_fn = utils.all_gather if args.evaluate_on_cpu else None

@Borda
Copy link
Member

Borda commented Aug 21, 2024

Hi all, thanks for reporting this issue. I am currently looking into what can be done on to solve this issue. The compute_on_cpu argument was sadly never tested for multi-gpu setups, only single GPU.

@SkafteNicki lets add the first test for this multi-GPU so we can reproduce and prevent it in the future?

@crazyboy9103
Copy link

crazyboy9103 commented Aug 21, 2024

I've looked into the problem and found out that the main reason for this error is that default distributed backend for lightning is nccl. If compute_on_cpu=True, it gives the error since all_gather operation is not supported on cpu. One way to resolve this is by using gloo backend, which allows all_gather on cpu.

@sandychoii
Copy link

I have also encountered the same issue when trying to use compute_on_cpu=True in a ddp setup. I tried to initialize the metric with compute_on_cpu and process_group, within on_fit_start function in lightning trainer:

def on_fit_start(self):
    cpu_comm = torch.distributed.new_group(backend="gloo") 
    self.metric = SomeMetric(..., compute_on_cpu=True, process_group=cpu_comm)

But this did not resolve the problem. The root cause seems to be that a duplicated instance of the metric class _ResultMetric is initialized during the evaluation loop, but the metadata does not include the compute_on_cpu or process_group arguments. So in the _ResultMetric class, gloo process group is not properly passed by so it defaults to "nccl" (torch.distributed.group.WORLD).

This seems to be where the _ResultMetric is initialized:
https://github.com/Lightning-AI/pytorch-lightning/blob/32e7d32956e1685d36f2ab0ca3770baa2f76ce10/pytorch_lightning/trainer/connectors/logger_connector/result.py#L503

@Holer90
Copy link

Holer90 commented Sep 17, 2024

This is a copy-paste of my reply to this issue: Lightning-AI/pytorch-lightning#18803

I was having the same error message when using MeanAveragePrecision() on Databricks.

For me it worked adding the following three kwargs when the metric was initialized:

  • compute_on_cpu=False
  • sync_on_compute=False
  • dist_sync_on_step=True

All three arguments are needed to solve it in my case.

My code now looks like:

metric = MeanAveragePrecision(
          iou_type="segm", 
          class_metrics=True, 
          compute_on_cpu=False, 
          sync_on_compute=False,
          dist_sync_on_step=True, 
)

@mdifatta
Copy link

Thank you @Holer90 for sharing. Unfortunately your solution doesn't seem to work for me. It'd be useful to know a bit more about your configuration. In particular, what's your Trainer flags configuration, e.g. devices, strategy etc.?
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x
Projects
None yet
Development

No branches or pull requests

12 participants