Skip to content

Commit

Permalink
fix(model): compatible meshgrid and CUDA OOM error
Browse files Browse the repository at this point in the history
  • Loading branch information
FateScript committed Mar 21, 2022
1 parent d18f5e8 commit e685457
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 5 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
dependencies = ["torch"]

from yolox.models import ( # noqa: F401, E402
from yolox.models import ( # isort:skip # noqa: F401, E402
yolox_tiny,
yolox_nano,
yolox_s,
Expand Down
12 changes: 8 additions & 4 deletions yolox/models/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from yolox.utils import bboxes_iou
from yolox.utils import bboxes_iou, meshgrid

from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_output_and_grid(self, output, k, stride, dtype):
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid

Expand All @@ -237,7 +237,7 @@ def decode_outputs(self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
Expand Down Expand Up @@ -321,7 +321,11 @@ def get_losses(
labels,
imgs,
)
except RuntimeError:
except RuntimeError as e:
# TODO: the string might change, consider a better way
if "CUDA out of memory. " not in str(e):
raise # RuntimeError might not caused by CUDA OOM

logger.error(
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
CPU mode is applied in this batch. If you want to avoid this issue, \
Expand Down
1 change: 1 addition & 0 deletions yolox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .allreduce_norm import *
from .boxes import *
from .checkpoint import load_ckpt, save_checkpoint
from .compat import meshgrid
from .demo_utils import *
from .dist import *
from .ema import *
Expand Down
15 changes: 15 additions & 0 deletions yolox/utils/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import torch

_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]

__all__ = ["meshgrid"]


def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing="ij")
else:
return torch.meshgrid(*tensors)

0 comments on commit e685457

Please sign in to comment.