From e685457fa5dc1d82ce91ca43ac7e6f1b41a95a60 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Mon, 21 Mar 2022 15:07:23 +0800 Subject: [PATCH] fix(model): compatible meshgrid and CUDA OOM error --- hubconf.py | 2 +- yolox/models/yolo_head.py | 12 ++++++++---- yolox/utils/__init__.py | 1 + yolox/utils/compat.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 yolox/utils/compat.py diff --git a/hubconf.py b/hubconf.py index 9c0e56f8f..d6736478e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -8,7 +8,7 @@ """ dependencies = ["torch"] -from yolox.models import ( # noqa: F401, E402 +from yolox.models import ( # isort:skip # noqa: F401, E402 yolox_tiny, yolox_nano, yolox_s, diff --git a/yolox/models/yolo_head.py b/yolox/models/yolo_head.py index d3ca5e613..d67abd1a0 100644 --- a/yolox/models/yolo_head.py +++ b/yolox/models/yolo_head.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from yolox.utils import bboxes_iou +from yolox.utils import bboxes_iou, meshgrid from .losses import IOUloss from .network_blocks import BaseConv, DWConv @@ -220,7 +220,7 @@ def get_output_and_grid(self, output, k, stride, dtype): n_ch = 5 + self.num_classes hsize, wsize = output.shape[-2:] if grid.shape[2:4] != output.shape[2:4]: - yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij") + yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) self.grids[k] = grid @@ -237,7 +237,7 @@ def decode_outputs(self, outputs, dtype): grids = [] strides = [] for (hsize, wsize), stride in zip(self.hw, self.strides): - yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij") + yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, -1, 2) grids.append(grid) shape = grid.shape[:2] @@ -321,7 +321,11 @@ def get_losses( labels, imgs, ) - except RuntimeError: + except RuntimeError as e: + # TODO: the string might change, consider a better way + if "CUDA out of memory. " not in str(e): + raise # RuntimeError might not caused by CUDA OOM + logger.error( "OOM RuntimeError is raised due to the huge memory cost during label assignment. \ CPU mode is applied in this batch. If you want to avoid this issue, \ diff --git a/yolox/utils/__init__.py b/yolox/utils/__init__.py index ff8db0dbc..15426396e 100644 --- a/yolox/utils/__init__.py +++ b/yolox/utils/__init__.py @@ -5,6 +5,7 @@ from .allreduce_norm import * from .boxes import * from .checkpoint import load_ckpt, save_checkpoint +from .compat import meshgrid from .demo_utils import * from .dist import * from .ema import * diff --git a/yolox/utils/compat.py b/yolox/utils/compat.py new file mode 100644 index 000000000..1324077e6 --- /dev/null +++ b/yolox/utils/compat.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import torch + +_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]] + +__all__ = ["meshgrid"] + + +def meshgrid(*tensors): + if _TORCH_VER >= [1, 10]: + return torch.meshgrid(*tensors, indexing="ij") + else: + return torch.meshgrid(*tensors)