Skip to content

Commit 4a0b0c3

Browse files
authored
Support setting the cache_size_limit parameter of dynamo in PyTorch 2.0 (#10054)
1 parent a2f33db commit 4a0b0c3

File tree

6 files changed

+115
-2
lines changed

6 files changed

+115
-2
lines changed

docs/en/notes/faq.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,42 @@
22

33
We list some common troubles faced by many users and their corresponding solutions here. Feel free to enrich the list if you find any frequent issues and have ways to help others to solve them. If the contents here do not cover your issue, please create an issue using the [provided templates](https://github.com/open-mmlab/mmdetection/blob/master/.github/ISSUE_TEMPLATE/error-report.md/) and make sure you fill in all required information in the template.
44

5+
## PyTorch 2.0 Support
6+
7+
The vast majority of algorithms in MMDetection now support PyTorch 2.0 and its `torch.compile` function. Users only need to install MMDetection 3.0.0rc7 or later versions to enjoy this feature. If any unsupported algorithms are found during use, please feel free to give us feedback. We also welcome contributions from the community to benchmark the speed improvement brought by using the `torch.compile` function.
8+
9+
To enable the `torch.compile` function, simply add `--cfg-options compile=True` after `train.py` or `test.py`. For example, to enable `torch.compile` for RTMDet, you can use the following command:
10+
11+
```shell
12+
# Single GPU
13+
python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py --cfg-options compile=True
14+
15+
# Single node multiple GPUs
16+
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True
17+
18+
# Single node multiple GPUs + AMP
19+
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True --amp
20+
```
21+
22+
It is important to note that PyTorch 2.0's support for dynamic shapes is not yet fully developed. In most object detection algorithms, not only are the input shapes dynamic, but the loss calculation and post-processing parts are also dynamic. This can lead to slower training speeds when using the `torch.compile` function. Therefore, if you wish to enable the `torch.compile` function, you should follow these principles:
23+
24+
1. Input images to the network are fixed shape, not multi-scale
25+
2. set `torch._dynamo.config.cache_size_limit` parameter. TorchDynamo will convert and cache the Python bytecode, and the compiled functions will be stored in the cache. When the next check finds that the function needs to be recompiled, the function will be recompiled and cached. However, if the number of recompilations exceeds the maximum value set (64), the function will no longer be cached or recompiled. As mentioned above, the loss calculation and post-processing parts of the object detection algorithm are also dynamically calculated, and these functions need to be recompiled every time. Therefore, setting the `torch._dynamo.config.cache_size_limit` parameter to a smaller value can effectively reduce the compilation time
26+
27+
In MMDetection, you can set the `torch._dynamo.config.cache_size_limit` parameter through the environment variable `DYNAMO_CACHE_SIZE_LIMIT`. For example, the command is as follows:
28+
29+
```shell
30+
# Single GPU
31+
export DYNAMO_CACHE_SIZE_LIMIT = 4
32+
python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py --cfg-options compile=True
33+
34+
# Single node multiple GPUs
35+
export DYNAMO_CACHE_SIZE_LIMIT = 4
36+
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True
37+
```
38+
39+
About the common questions about PyTorch 2.0's dynamo, you can refer to [here](https://pytorch.org/docs/stable/dynamo/faq.html)
40+
541
## Installation
642

743
- Compatibility issue between MMCV and MMDetection; "ConvWS is already registered in conv layer"; "AssertionError: MMCV==xxx is used but incompatible. Please install mmcv>=xxx, \<=xxx."

docs/zh_cn/notes/faq.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,42 @@
22

33
我们在这里列出了使用时的一些常见问题及其相应的解决方案。 如果您发现有一些问题被遗漏,请随时提 PR 丰富这个列表。 如果您无法在此获得帮助,请使用 [issue模板](https://github.com/open-mmlab/mmdetection/blob/master/.github/ISSUE_TEMPLATE/error-report.md/)创建问题,但是请在模板中填写所有必填信息,这有助于我们更快定位问题。
44

5+
## PyTorch 2.0 支持
6+
7+
MMDetection 目前绝大部分算法已经支持了 PyTorch 2.0 及其 `torch.compile` 功能, 用户只需要安装 MMDetection 3.0.0rc7 及其以上版本即可。如果你在使用中发现有不支持的算法,欢迎给我们反馈。我们也非常欢迎社区贡献者来 benchmark 对比 `torch.compile` 功能所带来的速度提升。
8+
9+
如果你想启动 `torch.compile` 功能,只需要在 `train.py` 或者 `test.py` 后面加上 `--cfg-options compile=True`。 以 RTMDet 为例,你可以使用以下命令启动 `torch.compile` 功能:
10+
11+
```shell
12+
# 单卡
13+
python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py --cfg-options compile=True
14+
15+
# 单机 8 卡
16+
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True
17+
18+
# 单机 8 卡 + AMP 混合精度训练
19+
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True --amp
20+
```
21+
22+
需要特别注意的是,PyTorch 2.0 对于动态 shape 支持不是非常完善,目标检测算法中大部分不仅输入 shape 是动态的,而且 loss 计算和后处理过程中也是动态的,这会导致在开启 `torch.compile` 功能后训练速度会变慢。基于此,如果你想启动 `torch.compile` 功能,则应该遵循如下原则:
23+
24+
1. 输入到网络的图片是固定 shape 的,而非多尺度的
25+
2. 设置 `torch._dynamo.config.cache_size_limit` 参数。TorchDynamo 会将 Python 字节码转换并缓存,已编译的函数会被存入缓存中。当下一次检查发现需要重新编译时,该函数会被重新编译并缓存。但是如果重编译次数超过预设的最大值(64),则该函数将不再被缓存或重新编译。前面说过目标检测算法中的 loss 计算和后处理部分也是动态计算的,这些函数需要在每次迭代中重新编译。因此将 `torch._dynamo.config.cache_size_limit` 参数设置得更小一些可以有效减少编译时间
26+
27+
在 MMDetection 中可以通过环境变量 `DYNAMO_CACHE_SIZE_LIMIT` 设置 `torch._dynamo.config.cache_size_limit` 参数,以 RTMDet 为例,命令如下所示:
28+
29+
```shell
30+
# 单卡
31+
export DYNAMO_CACHE_SIZE_LIMIT = 4
32+
python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py --cfg-options compile=True
33+
34+
# 单机 8 卡
35+
export DYNAMO_CACHE_SIZE_LIMIT = 4
36+
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True
37+
```
38+
39+
关于 PyTorch 2.0 的 dynamo 常见问题,可以参考 [这里](https://pytorch.org/docs/stable/dynamo/faq.html)
40+
541
## 安装
642

743
- MMCV 与 MMDetection 的兼容问题: "ConvWS is already registered in conv layer"; "AssertionError: MMCV==xxx is used but incompatible. Please install mmcv>=xxx, \<=xxx."

mmdet/utils/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from .misc import (find_latest_checkpoint, get_test_pipeline_cfg,
99
update_data_root)
1010
from .replace_cfg_vals import replace_cfg_vals
11-
from .setup_env import register_all_modules, setup_multi_processes
11+
from .setup_env import (register_all_modules, setup_cache_size_limit_of_dynamo,
12+
setup_multi_processes)
1213
from .split_batch import split_batch
1314
from .typing_utils import (ConfigType, InstanceList, MultiConfig,
1415
OptConfigType, OptInstanceList, OptMultiConfig,
@@ -21,5 +22,6 @@
2122
'AvoidCUDAOOM', 'all_reduce_dict', 'allreduce_grads', 'reduce_mean',
2223
'sync_random_seed', 'ConfigType', 'InstanceList', 'MultiConfig',
2324
'OptConfigType', 'OptInstanceList', 'OptMultiConfig', 'OptPixelList',
24-
'PixelList', 'RangeType', 'get_test_pipeline_cfg'
25+
'PixelList', 'RangeType', 'get_test_pipeline_cfg',
26+
'setup_cache_size_limit_of_dynamo'
2527
]

mmdet/utils/setup_env.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,40 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import datetime
3+
import logging
34
import os
45
import platform
56
import warnings
67

78
import cv2
89
import torch.multiprocessing as mp
910
from mmengine import DefaultScope
11+
from mmengine.logging import print_log
12+
from mmengine.utils import digit_version
13+
14+
15+
def setup_cache_size_limit_of_dynamo():
16+
"""Setup cache size limit of dynamo.
17+
18+
Note: Due to the dynamic shape of the loss calculation and
19+
post-processing parts in the object detection algorithm, these
20+
functions must be compiled every time they are run.
21+
Setting a large value for torch._dynamo.config.cache_size_limit
22+
may result in repeated compilation, which can slow down training
23+
and testing speed. Therefore, we need to set the default value of
24+
cache_size_limit smaller. An empirical value is 4.
25+
"""
26+
27+
import torch
28+
if digit_version(torch.__version__) >= digit_version('2.0.0'):
29+
if 'DYNAMO_CACHE_SIZE_LIMIT' in os.environ:
30+
import torch._dynamo
31+
cache_size_limit = int(os.environ['DYNAMO_CACHE_SIZE_LIMIT'])
32+
torch._dynamo.config.cache_size_limit = cache_size_limit
33+
print_log(
34+
f'torch._dynamo.config.cache_size_limit is force '
35+
f'set to {cache_size_limit}.',
36+
logger='current',
37+
level=logging.WARNING)
1038

1139

1240
def setup_multi_processes(cfg):

tools/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mmdet.engine.hooks.utils import trigger_visualization_hook
1313
from mmdet.evaluation import DumpDetResults
1414
from mmdet.registry import RUNNERS
15+
from mmdet.utils import setup_cache_size_limit_of_dynamo
1516

1617

1718
# TODO: support fuse_conv_bn and format_only
@@ -65,6 +66,10 @@ def parse_args():
6566
def main():
6667
args = parse_args()
6768

69+
# Reduce the number of repeated compilations and improve
70+
# testing speed.
71+
setup_cache_size_limit_of_dynamo()
72+
6873
# load config
6974
cfg = Config.fromfile(args.config)
7075
cfg.launcher = args.launcher

tools/train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from mmengine.registry import RUNNERS
1010
from mmengine.runner import Runner
1111

12+
from mmdet.utils import setup_cache_size_limit_of_dynamo
13+
1214

1315
def parse_args():
1416
parser = argparse.ArgumentParser(description='Train a detector')
@@ -60,6 +62,10 @@ def parse_args():
6062
def main():
6163
args = parse_args()
6264

65+
# Reduce the number of repeated compilations and improve
66+
# training speed.
67+
setup_cache_size_limit_of_dynamo()
68+
6369
# load config
6470
cfg = Config.fromfile(args.config)
6571
cfg.launcher = args.launcher

0 commit comments

Comments
 (0)