From d18f5e8202cd2d8480eff9263a6a4456a11d90cf Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Mon, 21 Mar 2022 15:04:12 +0800 Subject: [PATCH 1/2] feat(model): support hub load --- hubconf.py | 19 +++++++++ yolox/models/__init__.py | 1 + yolox/models/build.py | 91 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 hubconf.py create mode 100644 yolox/models/build.py diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 000000000..9c0e56f8f --- /dev/null +++ b/hubconf.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +""" +Usage example: + import torch + model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s") +""" +dependencies = ["torch"] + +from yolox.models import ( # noqa: F401, E402 + yolox_tiny, + yolox_nano, + yolox_s, + yolox_m, + yolox_l, + yolox_x, + yolov3, +) diff --git a/yolox/models/__init__.py b/yolox/models/__init__.py index f2f26603a..c74fd3064 100644 --- a/yolox/models/__init__.py +++ b/yolox/models/__init__.py @@ -2,6 +2,7 @@ # -*- coding:utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. +from .build import * from .darknet import CSPDarknet, Darknet from .losses import IOUloss from .yolo_fpn import YOLOFPN diff --git a/yolox/models/build.py b/yolox/models/build.py new file mode 100644 index 000000000..1195bb870 --- /dev/null +++ b/yolox/models/build.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import torch +from torch import nn +from torch.hub import load_state_dict_from_url + +__all__ = [ + "create_yolox_model", + "yolox_nano", + "yolox_tiny", + "yolox_s", + "yolox_m", + "yolox_l", + "yolox_x", + "yolov3", +] + +_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download" +_CKPT_FULL_PATH = { + "yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth", + "yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth", + "yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth", + "yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth", + "yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth", + "yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth", + "yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth", +} + + +def create_yolox_model( + name: str, pretrained: bool = True, num_classes: int = 80, device=None +) -> nn.Module: + """creates and loads a YOLOX model + + Args: + name (str): name of model. for example, "yolox-s", "yolox-tiny". + pretrained (bool): load pretrained weights into the model. Default to True. + num_classes (int): number of model classes. Defalut to 80. + device (str): default device to for model. Defalut to None. + + Returns: + YOLOX model (nn.Module) + """ + from yolox.exp import get_exp, Exp + + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}" + exp: Exp = get_exp(exp_name=name) + exp.num_classes = num_classes + yolox_model = exp.get_model() + if pretrained and num_classes == 80: + weights_url = _CKPT_FULL_PATH[name] + ckpt = load_state_dict_from_url(weights_url, map_location="cpu") + if "model" in ckpt: + ckpt = ckpt["model"] + yolox_model.load_state_dict(ckpt) + + yolox_model.to(device) + return yolox_model + + +def yolox_nano(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-nano", pretrained, num_classes, device) + + +def yolox_tiny(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-tiny", pretrained, num_classes, device) + + +def yolox_s(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-s", pretrained, num_classes, device) + + +def yolox_m(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-m", pretrained, num_classes, device) + + +def yolox_l(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-l", pretrained, num_classes, device) + + +def yolox_x(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-x", pretrained, num_classes, device) + + +def yolov3(pretrained=True, num_classes=80, device=None): + return create_yolox_model("yolox-tiny", pretrained, num_classes, device) From e685457fa5dc1d82ce91ca43ac7e6f1b41a95a60 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Mon, 21 Mar 2022 15:07:23 +0800 Subject: [PATCH 2/2] fix(model): compatible meshgrid and CUDA OOM error --- hubconf.py | 2 +- yolox/models/yolo_head.py | 12 ++++++++---- yolox/utils/__init__.py | 1 + yolox/utils/compat.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 yolox/utils/compat.py diff --git a/hubconf.py b/hubconf.py index 9c0e56f8f..d6736478e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -8,7 +8,7 @@ """ dependencies = ["torch"] -from yolox.models import ( # noqa: F401, E402 +from yolox.models import ( # isort:skip # noqa: F401, E402 yolox_tiny, yolox_nano, yolox_s, diff --git a/yolox/models/yolo_head.py b/yolox/models/yolo_head.py index d3ca5e613..d67abd1a0 100644 --- a/yolox/models/yolo_head.py +++ b/yolox/models/yolo_head.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from yolox.utils import bboxes_iou +from yolox.utils import bboxes_iou, meshgrid from .losses import IOUloss from .network_blocks import BaseConv, DWConv @@ -220,7 +220,7 @@ def get_output_and_grid(self, output, k, stride, dtype): n_ch = 5 + self.num_classes hsize, wsize = output.shape[-2:] if grid.shape[2:4] != output.shape[2:4]: - yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij") + yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) self.grids[k] = grid @@ -237,7 +237,7 @@ def decode_outputs(self, outputs, dtype): grids = [] strides = [] for (hsize, wsize), stride in zip(self.hw, self.strides): - yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij") + yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, -1, 2) grids.append(grid) shape = grid.shape[:2] @@ -321,7 +321,11 @@ def get_losses( labels, imgs, ) - except RuntimeError: + except RuntimeError as e: + # TODO: the string might change, consider a better way + if "CUDA out of memory. " not in str(e): + raise # RuntimeError might not caused by CUDA OOM + logger.error( "OOM RuntimeError is raised due to the huge memory cost during label assignment. \ CPU mode is applied in this batch. If you want to avoid this issue, \ diff --git a/yolox/utils/__init__.py b/yolox/utils/__init__.py index ff8db0dbc..15426396e 100644 --- a/yolox/utils/__init__.py +++ b/yolox/utils/__init__.py @@ -5,6 +5,7 @@ from .allreduce_norm import * from .boxes import * from .checkpoint import load_ckpt, save_checkpoint +from .compat import meshgrid from .demo_utils import * from .dist import * from .ema import * diff --git a/yolox/utils/compat.py b/yolox/utils/compat.py new file mode 100644 index 000000000..1324077e6 --- /dev/null +++ b/yolox/utils/compat.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import torch + +_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]] + +__all__ = ["meshgrid"] + + +def meshgrid(*tensors): + if _TORCH_VER >= [1, 10]: + return torch.meshgrid(*tensors, indexing="ij") + else: + return torch.meshgrid(*tensors)