diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 02e009fd74..80ffc77039 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -4,18 +4,6 @@ repos: rev: v4.0.0 hooks: - id: validate_manifest - - repo: https://github.com/PyCQA/flake8 - rev: 7.1.1 - hooks: - - id: flake8 - - repo: https://gitee.com/openmmlab/mirrors-isort - rev: 5.11.5 - hooks: - - id: isort - - repo: https://gitee.com/openmmlab/mirrors-yapf - rev: v0.32.0 - hooks: - - id: yapf - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks rev: v5.0.0 hooks: diff --git a/docs/en/conf.py b/docs/en/conf.py index c2b4961477..2f4c3c5255 100644 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -15,18 +15,19 @@ import pytorch_sphinx_theme -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'mmengine' -copyright = '2022, mmengine contributors' -author = 'mmengine contributors' +project = "mmengine" +copyright = "2022, mmengine contributors" +author = "mmengine contributors" -version_file = '../../mmengine/version.py' +version_file = "../../mmengine/version.py" with open(version_file) as f: - exec(compile(f.read(), version_file, 'exec')) -__version__ = locals()['__version__'] + exec(compile(f.read(), version_file, "exec")) +__version__ = locals()["__version__"] # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags @@ -49,52 +50,49 @@ 'sphinx.ext.autodoc.typehints', 'sphinx_tabs.tabs', ] # yapf: disable -autodoc_typehints = 'description' +autodoc_typehints = "description" myst_heading_anchors = 4 -myst_enable_extensions = ['colon_fence'] +myst_enable_extensions = ["colon_fence"] # Configuration for intersphinx intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'mmcv': ('https://mmcv.readthedocs.io/en/2.x/', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "mmcv": ("https://mmcv.readthedocs.io/en/2.x/", None), } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_options = { - 'menu': [ - { - 'name': 'GitHub', - 'url': 'https://github.com/open-mmlab/mmengine' - }, + "menu": [ + {"name": "GitHub", "url": "https://github.com/open-mmlab/mmengine"}, ], # Specify the language of shared menu - 'menu_lang': 'en', + "menu_lang": "en", } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_css_files = ['css/readthedocs.css'] +html_static_path = ["_static"] +html_css_files = ["css/readthedocs.css"] # -- Extension configuration ------------------------------------------------- # Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True diff --git a/docs/resources/config/config_sgd.py b/docs/resources/config/config_sgd.py index 9afaf8e54e..c14c943471 100644 --- a/docs/resources/config/config_sgd.py +++ b/docs/resources/config/config_sgd.py @@ -1 +1 @@ -optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) +optimizer = dict(type="SGD", lr=0.1, momentum=0.9, weight_decay=0.0001) diff --git a/docs/resources/config/cross_repo.py b/docs/resources/config/cross_repo.py index 9472668956..9961671a45 100644 --- a/docs/resources/config/cross_repo.py +++ b/docs/resources/config/cross_repo.py @@ -1,6 +1,6 @@ _base_ = [ - 'mmdet::_base_/schedules/schedule_1x.py', - 'mmdet::_base_/datasets/coco_instance.py', - 'mmdet::_base_/default_runtime.py', - 'mmdet::_base_/models/faster-rcnn_r50_fpn.py', + "mmdet::_base_/schedules/schedule_1x.py", + "mmdet::_base_/datasets/coco_instance.py", + "mmdet::_base_/default_runtime.py", + "mmdet::_base_/models/faster-rcnn_r50_fpn.py", ] diff --git a/docs/resources/config/custom_imports.py b/docs/resources/config/custom_imports.py index adb5d0489a..555edf1176 100644 --- a/docs/resources/config/custom_imports.py +++ b/docs/resources/config/custom_imports.py @@ -1,2 +1,2 @@ -custom_imports = dict(imports=['my_module'], allow_failed_imports=False) -optimizer = dict(type='CustomOptim') +custom_imports = dict(imports=["my_module"], allow_failed_imports=False) +optimizer = dict(type="CustomOptim") diff --git a/docs/resources/config/demo_train.py b/docs/resources/config/demo_train.py index 411ef6a98d..510bcb5635 100644 --- a/docs/resources/config/demo_train.py +++ b/docs/resources/config/demo_train.py @@ -4,18 +4,19 @@ def parse_args(): - parser = argparse.ArgumentParser(description='Train a model') - parser.add_argument('config', help='train config file path') + parser = argparse.ArgumentParser(description="Train a model") + parser.add_argument("config", help="train config file path") parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) args = parser.parse_args() return args @@ -29,5 +30,5 @@ def main(): print(cfg) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/docs/resources/config/example.py b/docs/resources/config/example.py index d3df48c6a6..096485466e 100644 --- a/docs/resources/config/example.py +++ b/docs/resources/config/example.py @@ -1,2 +1,2 @@ -model = dict(type='CustomModel', in_channels=[1, 2, 3]) -optimizer = dict(type='SGD', lr=0.01) +model = dict(type="CustomModel", in_channels=[1, 2, 3]) +optimizer = dict(type="SGD", lr=0.01) diff --git a/docs/resources/config/learn_read_config.py b/docs/resources/config/learn_read_config.py index 822aa66257..19080a28e4 100644 --- a/docs/resources/config/learn_read_config.py +++ b/docs/resources/config/learn_read_config.py @@ -1,3 +1,3 @@ test_int = 1 test_list = [1, 2, 3] -test_dict = dict(key1='value1', key2=0.1) +test_dict = dict(key1="value1", key2=0.1) diff --git a/docs/resources/config/optimizer_cfg.py b/docs/resources/config/optimizer_cfg.py index b6f55cd3c5..a0a0f7af15 100644 --- a/docs/resources/config/optimizer_cfg.py +++ b/docs/resources/config/optimizer_cfg.py @@ -1 +1 @@ -optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer = dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001) diff --git a/docs/resources/config/predefined_var.py b/docs/resources/config/predefined_var.py index f072068c05..452ac53411 100644 --- a/docs/resources/config/predefined_var.py +++ b/docs/resources/config/predefined_var.py @@ -1 +1 @@ -work_dir = './work_dir/{{fileBasenameNoExtension}}' +work_dir = "./work_dir/{{fileBasenameNoExtension}}" diff --git a/docs/resources/config/refer_base_var.py b/docs/resources/config/refer_base_var.py index 1e98a1ef08..8a4e1d4d3f 100644 --- a/docs/resources/config/refer_base_var.py +++ b/docs/resources/config/refer_base_var.py @@ -1,2 +1,2 @@ -_base_ = ['resnet50.py'] +_base_ = ["resnet50.py"] a = {{_base_.model}} diff --git a/docs/resources/config/replace_data_root.py b/docs/resources/config/replace_data_root.py index 17c01e99ad..098af136ac 100644 --- a/docs/resources/config/replace_data_root.py +++ b/docs/resources/config/replace_data_root.py @@ -1,3 +1,3 @@ -dataset_type = 'CocoDataset' -data_root = '{{$DATASET:/data/coco/}}' -dataset = dict(ann_file=data_root + 'train.json') +dataset_type = "CocoDataset" +data_root = "{{$DATASET:/data/coco/}}" +dataset = dict(ann_file=data_root + "train.json") diff --git a/docs/resources/config/replace_num_classes.py b/docs/resources/config/replace_num_classes.py index 76dc436508..35c5c1b091 100644 --- a/docs/resources/config/replace_num_classes.py +++ b/docs/resources/config/replace_num_classes.py @@ -1 +1 @@ -model = dict(bbox_head=dict(num_classes={{'$NUM_CLASSES:80'}})) +model = dict(bbox_head=dict(num_classes={{"$NUM_CLASSES:80"}})) diff --git a/docs/resources/config/resnet50.py b/docs/resources/config/resnet50.py index 43f0187d9a..bc0559c805 100644 --- a/docs/resources/config/resnet50.py +++ b/docs/resources/config/resnet50.py @@ -1,2 +1,2 @@ -_base_ = ['optimizer_cfg.py'] -model = dict(type='ResNet', depth=50) +_base_ = ["optimizer_cfg.py"] +model = dict(type="ResNet", depth=50) diff --git a/docs/resources/config/resnet50_delete_key.py b/docs/resources/config/resnet50_delete_key.py index b7ada136a7..a778575ac7 100644 --- a/docs/resources/config/resnet50_delete_key.py +++ b/docs/resources/config/resnet50_delete_key.py @@ -1,3 +1,3 @@ -_base_ = ['optimizer_cfg.py', 'runtime_cfg.py'] -model = dict(type='ResNet', depth=50) -optimizer = dict(_delete_=True, type='SGD', lr=0.01) +_base_ = ["optimizer_cfg.py", "runtime_cfg.py"] +model = dict(type="ResNet", depth=50) +optimizer = dict(_delete_=True, type="SGD", lr=0.01) diff --git a/docs/resources/config/resnet50_lr0.01.py b/docs/resources/config/resnet50_lr0.01.py index 47a83994b6..1fdc5cf490 100644 --- a/docs/resources/config/resnet50_lr0.01.py +++ b/docs/resources/config/resnet50_lr0.01.py @@ -1,3 +1,3 @@ -_base_ = ['optimizer_cfg.py', 'runtime_cfg.py'] -model = dict(type='ResNet', depth=50) +_base_ = ["optimizer_cfg.py", "runtime_cfg.py"] +model = dict(type="ResNet", depth=50) optimizer = dict(lr=0.01) diff --git a/docs/resources/config/resnet50_runtime.py b/docs/resources/config/resnet50_runtime.py index 4b7c8e1a43..ad1032e22c 100644 --- a/docs/resources/config/resnet50_runtime.py +++ b/docs/resources/config/resnet50_runtime.py @@ -1,2 +1,2 @@ -_base_ = ['optimizer_cfg.py', 'runtime_cfg.py'] -model = dict(type='ResNet', depth=50) +_base_ = ["optimizer_cfg.py", "runtime_cfg.py"] +model = dict(type="ResNet", depth=50) diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py index ad611187f9..2aefab3839 100644 --- a/docs/zh_cn/conf.py +++ b/docs/zh_cn/conf.py @@ -16,18 +16,19 @@ import pytorch_sphinx_theme -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'mmengine' -copyright = '2022, mmengine contributors' -author = 'mmengine contributors' +project = "mmengine" +copyright = "2022, mmengine contributors" +author = "mmengine contributors" -version_file = '../../mmengine/version.py' +version_file = "../../mmengine/version.py" with open(version_file) as f: - exec(compile(f.read(), version_file, 'exec')) -__version__ = locals()['__version__'] + exec(compile(f.read(), version_file, "exec")) +__version__ = locals()["__version__"] # The short X.Y version version = __version__ # The full version, including alpha/beta/rc tags @@ -37,7 +38,7 @@ # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = 'zh_CN' +language = "zh_CN" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -54,60 +55,57 @@ 'sphinx.ext.autodoc.typehints', 'sphinx_tabs.tabs', ] # yapf: disable -autodoc_typehints = 'description' +autodoc_typehints = "description" myst_heading_anchors = 4 -myst_enable_extensions = ['colon_fence'] +myst_enable_extensions = ["colon_fence"] # Configuration for intersphinx intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'mmcv': ('https://mmcv.readthedocs.io/zh_CN/2.x/', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "mmcv": ("https://mmcv.readthedocs.io/zh_CN/2.x/", None), } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_options = { - 'menu': [ - { - 'name': 'GitHub', - 'url': 'https://github.com/open-mmlab/mmengine' - }, + "menu": [ + {"name": "GitHub", "url": "https://github.com/open-mmlab/mmengine"}, ], # Specify the language of shared menu - 'menu_lang': 'cn', + "menu_lang": "cn", } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_css_files = ['css/readthedocs.css'] +html_static_path = ["_static"] +html_css_files = ["css/readthedocs.css"] # -- Extension configuration ------------------------------------------------- # Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True def builder_inited_handler(app): - subprocess.run(['./cp_origin_docs.sh']) + subprocess.run(["./cp_origin_docs.sh"]) def setup(app): - app.connect('builder-inited', builder_inited_handler) + app.connect("builder-inited", builder_inited_handler) diff --git a/examples/distributed_training.py b/examples/distributed_training.py index 236bee234c..32c88dedf6 100644 --- a/examples/distributed_training.py +++ b/examples/distributed_training.py @@ -12,42 +12,38 @@ class MMResNet50(BaseModel): - def __init__(self): super().__init__() self.resnet = torchvision.models.resnet50() def forward(self, imgs, labels, mode): x = self.resnet(imgs) - if mode == 'loss': - return {'loss': F.cross_entropy(x, labels)} - elif mode == 'predict': + if mode == "loss": + return {"loss": F.cross_entropy(x, labels)} + elif mode == "predict": return x, labels class Accuracy(BaseMetric): - def process(self, data_batch, data_samples): score, gt = data_samples - self.results.append({ - 'batch_size': len(gt), - 'correct': (score.argmax(dim=1) == gt).sum().cpu(), - }) + self.results.append( + { + "batch_size": len(gt), + "correct": (score.argmax(dim=1) == gt).sum().cpu(), + } + ) def compute_metrics(self, results): - total_correct = sum(item['correct'] for item in results) - total_size = sum(item['batch_size'] for item in results) + total_correct = sum(item["correct"] for item in results) + total_size = sum(item["batch_size"] for item in results) return dict(accuracy=100 * total_correct / total_size) def parse_args(): - parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + parser = argparse.ArgumentParser(description="Distributed Training") + parser.add_argument("--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="none", help="job launcher") + parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() return args @@ -57,35 +53,39 @@ def main(): args = parse_args() norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) train_set = torchvision.datasets.CIFAR10( - 'data/cifar10', + "data/cifar10", train=True, download=True, - transform=transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(**norm_cfg) - ])) + transform=transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg), + ] + ), + ) valid_set = torchvision.datasets.CIFAR10( - 'data/cifar10', + "data/cifar10", train=False, download=True, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(**norm_cfg)])) + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(**norm_cfg)]), + ) train_dataloader = dict( batch_size=32, dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) + sampler=dict(type="DefaultSampler", shuffle=True), + collate_fn=dict(type="default_collate"), + ) val_dataloader = dict( batch_size=32, dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + sampler=dict(type="DefaultSampler", shuffle=False), + collate_fn=dict(type="default_collate"), + ) runner = Runner( model=MMResNet50(), - work_dir='./work_dirs', + work_dir="./work_dirs", train_dataloader=train_dataloader, optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), @@ -97,5 +97,5 @@ def main(): runner.train() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/distributed_training_with_flexible_runner.py b/examples/distributed_training_with_flexible_runner.py index 99d2cf257d..d55afab1e4 100644 --- a/examples/distributed_training_with_flexible_runner.py +++ b/examples/distributed_training_with_flexible_runner.py @@ -12,40 +12,40 @@ class MMResNet50(BaseModel): - def __init__(self): super().__init__() self.resnet = torchvision.models.resnet50() def forward(self, imgs, labels, mode): x = self.resnet(imgs) - if mode == 'loss': - return {'loss': F.cross_entropy(x, labels)} - elif mode == 'predict': + if mode == "loss": + return {"loss": F.cross_entropy(x, labels)} + elif mode == "predict": return x, labels class Accuracy(BaseMetric): - def process(self, data_batch, data_samples): score, gt = data_samples - self.results.append({ - 'batch_size': len(gt), - 'correct': (score.argmax(dim=1) == gt).sum().cpu(), - }) + self.results.append( + { + "batch_size": len(gt), + "correct": (score.argmax(dim=1) == gt).sum().cpu(), + } + ) def compute_metrics(self, results): - total_correct = sum(item['correct'] for item in results) - total_size = sum(item['batch_size'] for item in results) + total_correct = sum(item["correct"] for item in results) + total_size = sum(item["batch_size"] for item in results) return dict(accuracy=100 * total_correct / total_size) def parse_args(): - parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument('--local_rank', '--local-rank', type=int, default=0) - parser.add_argument('--use-fsdp', action='store_true') - parser.add_argument('--use-deepspeed', action='store_true') - parser.add_argument('--use-colossalai', action='store_true') + parser = argparse.ArgumentParser(description="Distributed Training") + parser.add_argument("--local_rank", "--local-rank", type=int, default=0) + parser.add_argument("--use-fsdp", action="store_true") + parser.add_argument("--use-deepspeed", action="store_true") + parser.add_argument("--use-colossalai", action="store_true") args = parser.parse_args() return args @@ -54,36 +54,40 @@ def main(): args = parse_args() norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) train_set = torchvision.datasets.CIFAR10( - 'data/cifar10', + "data/cifar10", train=True, download=True, - transform=transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(**norm_cfg) - ])) + transform=transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg), + ] + ), + ) valid_set = torchvision.datasets.CIFAR10( - 'data/cifar10', + "data/cifar10", train=False, download=True, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(**norm_cfg)])) + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(**norm_cfg)]), + ) train_dataloader = dict( batch_size=128, dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) + sampler=dict(type="DefaultSampler", shuffle=True), + collate_fn=dict(type="default_collate"), + ) val_dataloader = dict( batch_size=128, dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + sampler=dict(type="DefaultSampler", shuffle=False), + collate_fn=dict(type="default_collate"), + ) if args.use_deepspeed: strategy = dict( - type='DeepSpeedStrategy', + type="DeepSpeedStrategy", fp16=dict( enabled=True, fp16_master_weights_and_grads=False, @@ -105,22 +109,18 @@ def main(): reduce_bucket_size=50000000, overlap_comm=True, contiguous_gradients=True, - cpu_offload=False), + cpu_offload=False, + ), ) - optim_wrapper = dict( - type='DeepSpeedOptimWrapper', - optimizer=dict(type='AdamW', lr=1e-3)) + optim_wrapper = dict(type="DeepSpeedOptimWrapper", optimizer=dict(type="AdamW", lr=1e-3)) elif args.use_fsdp: from functools import partial from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy - size_based_auto_wrap_policy = partial( - size_based_auto_wrap_policy, min_num_params=1e7) - strategy = dict( - type='FSDPStrategy', - model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy)) - optim_wrapper = dict( - type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) + + size_based_auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7) + strategy = dict(type="FSDPStrategy", model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy)) + optim_wrapper = dict(type="AmpOptimWrapper", optimizer=dict(type="AdamW", lr=1e-3)) elif args.use_colossalai: from colossalai.tensor.op_wrapper import colo_op_impl @@ -138,28 +138,28 @@ def main(): # since PyTorch consider the custom op before it could not handle the # backward graph modification colo_op_impl(torch.Tensor.add_)(torch.add) - strategy = dict(type='ColossalAIStrategy') - optim_wrapper = dict(optimizer=dict(type='HybridAdam', lr=1e-3)) + strategy = dict(type="ColossalAIStrategy") + optim_wrapper = dict(optimizer=dict(type="HybridAdam", lr=1e-3)) else: strategy = None - optim_wrapper = dict( - type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) + optim_wrapper = dict(type="AmpOptimWrapper", optimizer=dict(type="AdamW", lr=1e-3)) runner = FlexibleRunner( model=MMResNet50(), - work_dir='./work_dirs', + work_dir="./work_dirs", strategy=strategy, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, - param_scheduler=dict(type='LinearLR'), + param_scheduler=dict(type="LinearLR"), train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), val_dataloader=val_dataloader, val_cfg=dict(), - val_evaluator=dict(type=Accuracy)) + val_evaluator=dict(type=Accuracy), + ) runner.train() -if __name__ == '__main__': +if __name__ == "__main__": # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-fsdp # noqa: 501 # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-deepspeed # noqa: 501 # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py diff --git a/examples/llama2/fsdp_finetune.py b/examples/llama2/fsdp_finetune.py index 0d7e2751b7..4403d9d534 100644 --- a/examples/llama2/fsdp_finetune.py +++ b/examples/llama2/fsdp_finetune.py @@ -18,24 +18,26 @@ from mmengine.utils import apply_to from mmengine.visualization import Visualizer, WandbVisBackend + ORI_BATCH_SIZE = 4 PROMPT_DICT = { - 'prompt_input': - ('Below is an instruction that describes a task, paired with an input ' - 'that provides further context. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{instruction}\n\n' - '### Input:\n{input}\n\n### Response:'), - 'prompt_no_input': - ('Below is an instruction that describes a task. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{instruction}\n\n### Response:'), + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input " + "that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n" + "### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), } # Modified from https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/alpaca_dataset.py # noqa: E501 class AlpacaDataset(Dataset): - def __init__(self, data_path, tokenizer, max_words=224): self.ann = load(data_path) self.max_words = max_words @@ -46,23 +48,22 @@ def __len__(self): def __getitem__(self, index): ann = self.ann[index] - if ann.get('input', '') == '': - prompt = PROMPT_DICT['prompt_no_input'].format_map(ann) + if ann.get("input", "") == "": + prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) else: - prompt = PROMPT_DICT['prompt_input'].format_map(ann) - example = prompt + ann['output'] + prompt = PROMPT_DICT["prompt_input"].format_map(ann) + example = prompt + ann["output"] prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64) example = self.tokenizer.encode(example) example.append(self.tokenizer.eos_token_id) example = torch.tensor(example, dtype=torch.int64) padding = self.max_words - example.shape[0] if padding > 0: - example = torch.cat( - (example, torch.zeros(padding, dtype=torch.int64) - 1)) + example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) elif padding < 0: - example = example[:self.max_words] + example = example[: self.max_words] labels = copy.deepcopy(example) - labels[:len(prompt)] = -1 + labels[: len(prompt)] = -1 example_mask = example.ge(0) label_mask = labels.ge(0) example[~example_mask] = 0 @@ -71,20 +72,20 @@ def __getitem__(self, index): label_mask = label_mask.float() return { - 'input_ids': example, - 'labels': labels, - 'attention_mask': example_mask, + "input_ids": example, + "labels": labels, + "attention_mask": example_mask, } def parse_args(): - parser = argparse.ArgumentParser(description='Train alpaca with llama2') - parser.add_argument('data_root', type=str) - parser.add_argument('checkpoint', type=str) - parser.add_argument('--output-dir', type=str, default='work_dirs') - parser.add_argument('--max-epoch', type=int, default=3) - parser.add_argument('--batch-size', type=int, default=4) - parser.add_argument('--save-interval', type=int, default=500) + parser = argparse.ArgumentParser(description="Train alpaca with llama2") + parser.add_argument("data_root", type=str) + parser.add_argument("checkpoint", type=str) + parser.add_argument("--output-dir", type=str, default="work_dirs") + parser.add_argument("--max-epoch", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--save-interval", type=int, default=500) args = parser.parse_args() return args @@ -94,69 +95,68 @@ def train(): # Setup distributed related component in Strategy. strategy = FSDPStrategy( model_wrapper=dict( - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={LlamaDecoderLayer})), - state_dict_cfg='full', - env_kwargs=dict(randomness=dict(seed=42))) - visualizer = Visualizer( - name='mmengine', - save_dir=args.output_dir, - vis_backends=[dict(type=WandbVisBackend)]) + auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer}) + ), + state_dict_cfg="full", + env_kwargs=dict(randomness=dict(seed=42)), + ) + visualizer = Visualizer(name="mmengine", save_dir=args.output_dir, vis_backends=[dict(type=WandbVisBackend)]) # Prepare model tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint) - tokenizer.add_special_tokens({'pad_token': ''}) + tokenizer.add_special_tokens({"pad_token": ""}) model = LlamaForCausalLM.from_pretrained(args.checkpoint) model.to(torch.bfloat16) model.train() # Prepare dataset - train_dataset = AlpacaDataset( - tokenizer=tokenizer, data_path=args.data_root) + train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.data_root) train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, sampler=DefaultSampler(train_dataset, seed=0), collate_fn=default_data_collator, - drop_last=True) + drop_last=True, + ) # Get the prepared model, scheduler and optimizer from strategy epoch_length = len(train_dataloader) max_iters = epoch_length * args.max_epoch optim_cfg = dict( - optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), - accumulative_counts=ORI_BATCH_SIZE / args.batch_size) + optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), accumulative_counts=ORI_BATCH_SIZE / args.batch_size + ) scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)] model, optimizer, schedulers = strategy.prepare( model, optim_wrapper=optim_cfg, param_scheduler=scheduler_cfgs, - dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch)) + dispatch_kwargs=dict(max_iters=max_iters, max_epochs=args.max_epoch), + ) for epoch in range(args.max_epoch): for idx, inputs in enumerate(train_dataloader): # Convert inputs to target device. - inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor), - lambda m: m.cuda()) + inputs = apply_to(inputs, lambda m: isinstance(m, torch.Tensor), lambda m: m.cuda()) loss = model(**inputs).loss optimizer.update_params(loss) max_memory = torch.cuda.max_memory_allocated() - strategy.logger.info(f'Epoch: {epoch+1}/{args.max_epoch}, ' - f'Iter: {idx+1}/{epoch_length}, ' - f'Loss: {loss.item():.3f}, ' - f'Lr: {optimizer.get_lr()["lr"][0]:.6f} ' - f'Memory: {max_memory/1e9:.3f}G') - visualizer.add_scalars({'loss': loss.item()}) + strategy.logger.info( + f"Epoch: {epoch + 1}/{args.max_epoch}, " + f"Iter: {idx + 1}/{epoch_length}, " + f"Loss: {loss.item():.3f}, " + f"Lr: {optimizer.get_lr()['lr'][0]:.6f} " + f"Memory: {max_memory / 1e9:.3f}G" + ) + visualizer.add_scalars({"loss": loss.item()}) torch.cuda.reset_peak_memory_stats() for scheduler in schedulers: scheduler.step() - save_dir = f'{args.output_dir}/epoch_{epoch+1}' + save_dir = f"{args.output_dir}/epoch_{epoch + 1}" state_dict = model.state_dict() if is_main_process(): @@ -164,5 +164,5 @@ def train(): tokenizer.save_pretrained(save_dir) -if __name__ == '__main__': +if __name__ == "__main__": train() diff --git a/examples/llama2/generate.py b/examples/llama2/generate.py index 85635c37ae..51720763ff 100644 --- a/examples/llama2/generate.py +++ b/examples/llama2/generate.py @@ -14,23 +14,19 @@ def parse_args(): - parser = argparse.ArgumentParser(description='llama2 inference') - parser.add_argument('checkpoint', type=str) + parser = argparse.ArgumentParser(description="llama2 inference") + parser.add_argument("checkpoint", type=str) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() tokenizer = AutoTokenizer.from_pretrained(args.checkpoint) model = LlamaForCausalLM.from_pretrained(args.checkpoint).half().cuda() model.eval() - inputs = tokenizer(prompt, return_tensors='pt') + inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): generate_ids = model.generate(inputs.input_ids.cuda(), max_length=300) - print( - tokenizer.batch_decode( - generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False)[0]) + print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) diff --git a/examples/segmentation/train.ipynb b/examples/segmentation/train.ipynb index d644ff71de..7dc3372299 100644 --- a/examples/segmentation/train.ipynb +++ b/examples/segmentation/train.ipynb @@ -54,48 +54,39 @@ "metadata": {}, "outputs": [], "source": [ + "import csv\n", "import os\n", + "\n", "import numpy as np\n", - "from torchvision.datasets import VisionDataset\n", "from PIL import Image\n", - "import csv\n", + "from torchvision.datasets import VisionDataset\n", "\n", "\n", "def create_palette(csv_filepath):\n", " color_to_class = {}\n", - " with open(csv_filepath, newline='') as csvfile:\n", + " with open(csv_filepath, newline=\"\") as csvfile:\n", " reader = csv.DictReader(csvfile)\n", " for idx, row in enumerate(reader):\n", - " r, g, b = int(row['r']), int(row['g']), int(row['b'])\n", + " r, g, b = int(row[\"r\"]), int(row[\"g\"]), int(row[\"b\"])\n", " color_to_class[(r, g, b)] = idx\n", " return color_to_class\n", "\n", - "class CamVid(VisionDataset):\n", "\n", - " def __init__(self,\n", - " root,\n", - " img_folder,\n", - " mask_folder,\n", - " transform=None,\n", - " target_transform=None):\n", - " super().__init__(\n", - " root, transform=transform, target_transform=target_transform)\n", + "class CamVid(VisionDataset):\n", + " def __init__(self, root, img_folder, mask_folder, transform=None, target_transform=None):\n", + " super().__init__(root, transform=transform, target_transform=target_transform)\n", " self.img_folder = img_folder\n", " self.mask_folder = mask_folder\n", - " self.images = list(\n", - " sorted(os.listdir(os.path.join(self.root, img_folder))))\n", - " self.masks = list(\n", - " sorted(os.listdir(os.path.join(self.root, mask_folder))))\n", - " self.color_to_class = create_palette(\n", - " os.path.join(self.root, 'class_dict.csv'))\n", + " self.images = sorted(os.listdir(os.path.join(self.root, img_folder)))\n", + " self.masks = sorted(os.listdir(os.path.join(self.root, mask_folder)))\n", + " self.color_to_class = create_palette(os.path.join(self.root, \"class_dict.csv\"))\n", "\n", " def __getitem__(self, index):\n", " img_path = os.path.join(self.root, self.img_folder, self.images[index])\n", - " mask_path = os.path.join(self.root, self.mask_folder,\n", - " self.masks[index])\n", + " mask_path = os.path.join(self.root, self.mask_folder, self.masks[index])\n", "\n", - " img = Image.open(img_path).convert('RGB')\n", - " mask = Image.open(mask_path).convert('RGB') # Convert to RGB\n", + " img = Image.open(img_path).convert(\"RGB\")\n", + " mask = Image.open(mask_path).convert(\"RGB\") # Convert to RGB\n", "\n", " if self.transform is not None:\n", " img = self.transform(img)\n", @@ -110,12 +101,11 @@ "\n", " if self.target_transform is not None:\n", " labels = self.target_transform(labels)\n", - " data_samples = dict(\n", - " labels=labels, img_path=img_path, mask_path=mask_path)\n", + " data_samples = dict(labels=labels, img_path=img_path, mask_path=mask_path)\n", " return img, data_samples\n", "\n", " def __len__(self):\n", - " return len(self.images)\n" + " return len(self.images)" ] }, { @@ -134,39 +124,37 @@ "import torch\n", "import torchvision.transforms as transforms\n", "\n", + "\n", "norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", - "transform = transforms.Compose(\n", - " [transforms.ToTensor(),\n", - " transforms.Normalize(**norm_cfg)])\n", + "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(**norm_cfg)])\n", "\n", - "target_transform = transforms.Lambda(\n", - " lambda x: torch.tensor(np.array(x), dtype=torch.long))\n", + "target_transform = transforms.Lambda(lambda x: torch.tensor(np.array(x), dtype=torch.long))\n", "\n", "train_set = CamVid(\n", - " 'data/CamVid',\n", - " img_folder='train',\n", - " mask_folder='train_labels',\n", + " \"data/CamVid\",\n", + " img_folder=\"train\",\n", + " mask_folder=\"train_labels\",\n", " transform=transform,\n", - " target_transform=target_transform)\n", + " target_transform=target_transform,\n", + ")\n", "\n", "valid_set = CamVid(\n", - " 'data/CamVid',\n", - " img_folder='val',\n", - " mask_folder='val_labels',\n", - " transform=transform,\n", - " target_transform=target_transform)\n", + " \"data/CamVid\", img_folder=\"val\", mask_folder=\"val_labels\", transform=transform, target_transform=target_transform\n", + ")\n", "\n", "train_dataloader = dict(\n", " batch_size=3,\n", " dataset=train_set,\n", - " sampler=dict(type='DefaultSampler', shuffle=True),\n", - " collate_fn=dict(type='default_collate'))\n", + " sampler=dict(type=\"DefaultSampler\", shuffle=True),\n", + " collate_fn=dict(type=\"default_collate\"),\n", + ")\n", "\n", "val_dataloader = dict(\n", " batch_size=3,\n", " dataset=valid_set,\n", - " sampler=dict(type='DefaultSampler', shuffle=False),\n", - " collate_fn=dict(type='default_collate'))" + " sampler=dict(type=\"DefaultSampler\", shuffle=False),\n", + " collate_fn=dict(type=\"default_collate\"),\n", + ")" ] }, { @@ -186,22 +174,22 @@ "metadata": {}, "outputs": [], "source": [ - "from mmengine.model import BaseModel\n", - "from torchvision.models.segmentation import deeplabv3_resnet50\n", "import torch.nn.functional as F\n", + "from torchvision.models.segmentation import deeplabv3_resnet50\n", + "\n", + "from mmengine.model import BaseModel\n", "\n", "\n", "class MMDeeplabV3(BaseModel):\n", - "\n", " def __init__(self, num_classes):\n", " super().__init__()\n", " self.deeplab = deeplabv3_resnet50(num_classes=num_classes)\n", "\n", - " def forward(self, imgs, data_samples=None, mode='tensor'):\n", - " x = self.deeplab(imgs)['out']\n", - " if mode == 'loss':\n", - " return {'loss': F.cross_entropy(x, data_samples['labels'])}\n", - " elif mode == 'predict':\n", + " def forward(self, imgs, data_samples=None, mode=\"tensor\"):\n", + " x = self.deeplab(imgs)[\"out\"]\n", + " if mode == \"loss\":\n", + " return {\"loss\": F.cross_entropy(x, data_samples[\"labels\"])}\n", + " elif mode == \"predict\":\n", " return x, data_samples" ] }, @@ -222,20 +210,19 @@ "source": [ "from mmengine.evaluator import BaseMetric\n", "\n", - "class IoU(BaseMetric):\n", "\n", + "class IoU(BaseMetric):\n", " def process(self, data_batch, data_samples):\n", - " preds, labels = data_samples[0], data_samples[1]['labels']\n", + " preds, labels = data_samples[0], data_samples[1][\"labels\"]\n", " preds = torch.argmax(preds, dim=1)\n", " intersect = (labels == preds).sum()\n", " union = (torch.logical_or(preds, labels)).sum()\n", " iou = (intersect / union).cpu()\n", - " self.results.append(\n", - " dict(batch_size=len(labels), iou=iou * len(labels)))\n", + " self.results.append(dict(batch_size=len(labels), iou=iou * len(labels)))\n", "\n", " def compute_metrics(self, results):\n", - " total_iou = sum(result['iou'] for result in self.results)\n", - " num_samples = sum(result['batch_size'] for result in self.results)\n", + " total_iou = sum(result[\"iou\"] for result in self.results)\n", + " num_samples = sum(result[\"batch_size\"] for result in self.results)\n", " return dict(iou=total_iou / num_samples)" ] }, @@ -252,33 +239,29 @@ "metadata": {}, "outputs": [], "source": [ - "from mmengine.hooks import Hook\n", + "import os.path as osp\n", "import shutil\n", + "\n", "import cv2\n", - "import os.path as osp\n", "\n", + "from mmengine.hooks import Hook\n", "\n", - "class SegVisHook(Hook):\n", "\n", + "class SegVisHook(Hook):\n", " def __init__(self, data_root, vis_num=1) -> None:\n", " super().__init__()\n", " self.vis_num = vis_num\n", - " self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))\n", + " self.palette = create_palette(osp.join(data_root, \"class_dict.csv\"))\n", "\n", - " def after_val_iter(self,\n", - " runner,\n", - " batch_idx: int,\n", - " data_batch=None,\n", - " outputs=None) -> None:\n", + " def after_val_iter(self, runner, batch_idx: int, data_batch=None, outputs=None) -> None:\n", " if batch_idx > self.vis_num:\n", " return\n", " preds, data_samples = outputs\n", - " img_paths = data_samples['img_path']\n", - " mask_paths = data_samples['mask_path']\n", + " img_paths = data_samples[\"img_path\"]\n", + " mask_paths = data_samples[\"mask_path\"]\n", " _, C, H, W = preds.shape\n", " preds = torch.argmax(preds, dim=1)\n", - " for idx, (pred, img_path,\n", - " mask_path) in enumerate(zip(preds, img_paths, mask_paths)):\n", + " for idx, (pred, img_path, mask_path) in enumerate(zip(preds, img_paths, mask_paths, strict=False)):\n", " pred_mask = np.zeros((H, W, 3), dtype=np.uint8)\n", " runner.visualizer.set_image(pred_mask)\n", " for color, class_id in self.palette.items():\n", @@ -289,16 +272,12 @@ " )\n", " # Convert RGB to BGR\n", " pred_mask = runner.visualizer.get_image()[..., ::-1]\n", - " saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))\n", + " saved_dir = osp.join(runner.log_dir, \"vis_data\", str(idx))\n", " os.makedirs(saved_dir, exist_ok=True)\n", "\n", - " shutil.copyfile(img_path,\n", - " osp.join(saved_dir, osp.basename(img_path)))\n", - " shutil.copyfile(mask_path,\n", - " osp.join(saved_dir, osp.basename(mask_path)))\n", - " cv2.imwrite(\n", - " osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),\n", - " pred_mask)" + " shutil.copyfile(img_path, osp.join(saved_dir, osp.basename(img_path)))\n", + " shutil.copyfile(mask_path, osp.join(saved_dir, osp.basename(mask_path)))\n", + " cv2.imwrite(osp.join(saved_dir, f\"pred_{osp.basename(img_path)}\"), pred_mask)" ] }, { @@ -315,6 +294,7 @@ "outputs": [], "source": [ "from torch.optim import AdamW\n", + "\n", "from mmengine.optim import AmpOptimWrapper\n", "from mmengine.runner import Runner\n", "\n", @@ -323,16 +303,15 @@ "\n", "runner = Runner(\n", " model=MMDeeplabV3(num_classes),\n", - " work_dir='./work_dir',\n", + " work_dir=\"./work_dir\",\n", " train_dataloader=train_dataloader,\n", - " optim_wrapper=dict(\n", - " type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),\n", + " optim_wrapper=dict(type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),\n", " train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),\n", " val_dataloader=val_dataloader,\n", " val_cfg=dict(),\n", " val_evaluator=dict(type=IoU),\n", - " custom_hooks=[SegVisHook('data/CamVid')],\n", - " default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),\n", + " custom_hooks=[SegVisHook(\"data/CamVid\")],\n", + " default_hooks=dict(checkpoint=dict(type=\"CheckpointHook\", interval=1)),\n", ")\n", "runner.train()" ] diff --git a/examples/segmentation/train.py b/examples/segmentation/train.py index dc045a18b9..c55e034cf4 100644 --- a/examples/segmentation/train.py +++ b/examples/segmentation/train.py @@ -24,40 +24,29 @@ def create_palette(csv_filepath): color_to_class = {} - with open(csv_filepath, newline='') as csvfile: + with open(csv_filepath, newline="") as csvfile: reader = csv.DictReader(csvfile) for idx, row in enumerate(reader): - r, g, b = int(row['r']), int(row['g']), int(row['b']) + r, g, b = int(row["r"]), int(row["g"]), int(row["b"]) color_to_class[(r, g, b)] = idx return color_to_class class CamVid(VisionDataset): - - def __init__(self, - root, - img_folder, - mask_folder, - transform=None, - target_transform=None): - super().__init__( - root, transform=transform, target_transform=target_transform) + def __init__(self, root, img_folder, mask_folder, transform=None, target_transform=None): + super().__init__(root, transform=transform, target_transform=target_transform) self.img_folder = img_folder self.mask_folder = mask_folder - self.images = list( - sorted(os.listdir(os.path.join(self.root, img_folder)))) - self.masks = list( - sorted(os.listdir(os.path.join(self.root, mask_folder)))) - self.color_to_class = create_palette( - os.path.join(self.root, 'class_dict.csv')) + self.images = sorted(os.listdir(os.path.join(self.root, img_folder))) + self.masks = sorted(os.listdir(os.path.join(self.root, mask_folder))) + self.color_to_class = create_palette(os.path.join(self.root, "class_dict.csv")) def __getitem__(self, index): img_path = os.path.join(self.root, self.img_folder, self.images[index]) - mask_path = os.path.join(self.root, self.mask_folder, - self.masks[index]) + mask_path = os.path.join(self.root, self.mask_folder, self.masks[index]) - img = Image.open(img_path).convert('RGB') - mask = Image.open(mask_path).convert('RGB') # Convert to RGB + img = Image.open(img_path).convert("RGB") + mask = Image.open(mask_path).convert("RGB") # Convert to RGB if self.transform is not None: img = self.transform(img) @@ -72,8 +61,7 @@ def __getitem__(self, index): if self.target_transform is not None: labels = self.target_transform(labels) - data_samples = dict( - labels=labels, img_path=img_path, mask_path=mask_path) + data_samples = dict(labels=labels, img_path=img_path, mask_path=mask_path) return img, data_samples def __len__(self): @@ -81,59 +69,50 @@ def __len__(self): class MMDeeplabV3(BaseModel): - def __init__(self, num_classes): super().__init__() self.deeplab = deeplabv3_resnet50(num_classes=num_classes) - def forward(self, imgs, data_samples=None, mode='tensor'): - x = self.deeplab(imgs)['out'] - if mode == 'loss': - return {'loss': F.cross_entropy(x, data_samples['labels'])} - elif mode == 'predict': + def forward(self, imgs, data_samples=None, mode="tensor"): + x = self.deeplab(imgs)["out"] + if mode == "loss": + return {"loss": F.cross_entropy(x, data_samples["labels"])} + elif mode == "predict": return x, data_samples class IoU(BaseMetric): - def process(self, data_batch, data_samples): - preds, labels = data_samples[0], data_samples[1]['labels'] + preds, labels = data_samples[0], data_samples[1]["labels"] preds = torch.argmax(preds, dim=1) intersect = (labels == preds).sum() union = (torch.logical_or(preds, labels)).sum() iou = (intersect / union).cpu() - self.results.append( - dict(batch_size=len(labels), iou=iou * len(labels))) + self.results.append(dict(batch_size=len(labels), iou=iou * len(labels))) def compute_metrics(self, results): - total_iou = sum(result['iou'] for result in self.results) - num_samples = sum(result['batch_size'] for result in self.results) + total_iou = sum(result["iou"] for result in self.results) + num_samples = sum(result["batch_size"] for result in self.results) return dict(iou=total_iou / num_samples) class SegVisHook(Hook): - def __init__(self, data_root, vis_num=1) -> None: super().__init__() self.vis_num = vis_num - self.palette = create_palette(osp.join(data_root, 'class_dict.csv')) + self.palette = create_palette(osp.join(data_root, "class_dict.csv")) @master_only - def after_val_iter(self, - runner, - batch_idx: int, - data_batch=None, - outputs=None) -> None: + def after_val_iter(self, runner, batch_idx: int, data_batch=None, outputs=None) -> None: if batch_idx > self.vis_num: return preds, data_samples = outputs - img_paths = data_samples['img_path'] - mask_paths = data_samples['mask_path'] + img_paths = data_samples["img_path"] + mask_paths = data_samples["mask_path"] _, C, H, W = preds.shape preds = torch.argmax(preds, dim=1) - for idx, (pred, img_path, - mask_path) in enumerate(zip(preds, img_paths, mask_paths)): + for idx, (pred, img_path, mask_path) in enumerate(zip(preds, img_paths, mask_paths, strict=False)): pred_mask = np.zeros((H, W, 3), dtype=np.uint8) runner.visualizer.set_image(pred_mask) for color, class_id in self.palette.items(): @@ -144,26 +123,18 @@ def after_val_iter(self, ) # Convert RGB to BGR pred_mask = runner.visualizer.get_image()[..., ::-1] - saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx)) + saved_dir = osp.join(runner.log_dir, "vis_data", str(idx)) os.makedirs(saved_dir, exist_ok=True) - shutil.copyfile(img_path, - osp.join(saved_dir, osp.basename(img_path))) - shutil.copyfile(mask_path, - osp.join(saved_dir, osp.basename(mask_path))) - cv2.imwrite( - osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), - pred_mask) + shutil.copyfile(img_path, osp.join(saved_dir, osp.basename(img_path))) + shutil.copyfile(mask_path, osp.join(saved_dir, osp.basename(mask_path))) + cv2.imwrite(osp.join(saved_dir, f"pred_{osp.basename(img_path)}"), pred_mask) def parse_args(): - parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + parser = argparse.ArgumentParser(description="Distributed Training") + parser.add_argument("--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="none", help="job launcher") + parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() return args @@ -174,54 +145,54 @@ def main(): num_classes = 32 # Modify to actual number of categories. norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - transform = transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(**norm_cfg)]) + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(**norm_cfg)]) - target_transform = transforms.Lambda( - lambda x: torch.tensor(np.array(x), dtype=torch.long)) + target_transform = transforms.Lambda(lambda x: torch.tensor(np.array(x), dtype=torch.long)) train_set = CamVid( - 'data/CamVid', - img_folder='train', - mask_folder='train_labels', + "data/CamVid", + img_folder="train", + mask_folder="train_labels", transform=transform, - target_transform=target_transform) + target_transform=target_transform, + ) valid_set = CamVid( - 'data/CamVid', - img_folder='val', - mask_folder='val_labels', + "data/CamVid", + img_folder="val", + mask_folder="val_labels", transform=transform, - target_transform=target_transform) + target_transform=target_transform, + ) train_dataloader = dict( batch_size=3, dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) + sampler=dict(type="DefaultSampler", shuffle=True), + collate_fn=dict(type="default_collate"), + ) val_dataloader = dict( batch_size=3, dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + sampler=dict(type="DefaultSampler", shuffle=False), + collate_fn=dict(type="default_collate"), + ) runner = Runner( model=MMDeeplabV3(num_classes), - work_dir='./work_dir', + work_dir="./work_dir", train_dataloader=train_dataloader, - optim_wrapper=dict( - type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)), + optim_wrapper=dict(type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)), train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10), val_dataloader=val_dataloader, val_cfg=dict(), val_evaluator=dict(type=IoU), launcher=args.launcher, - custom_hooks=[SegVisHook('data/CamVid')], - default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)), + custom_hooks=[SegVisHook("data/CamVid")], + default_hooks=dict(checkpoint=dict(type="CheckpointHook", interval=1)), ) runner.train() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/test_time_augmentation.py b/examples/test_time_augmentation.py index 0a896a05a2..41dd3e499d 100644 --- a/examples/test_time_augmentation.py +++ b/examples/test_time_augmentation.py @@ -8,7 +8,6 @@ @MODELS.register_module() class ClsTTAModel(BaseTTAModel): - def merge_preds(self, data_samples_list): merged_data_samples = [] for data_samples in data_samples_list: @@ -17,29 +16,27 @@ def merge_preds(self, data_samples_list): def _merge_single_sample(self, data_samples): merged_data_sample = data_samples[0].new() - merged_score = sum(data_sample.pred_label.score - for data_sample in data_samples) / len(data_samples) + merged_score = sum(data_sample.pred_label.score for data_sample in data_samples) / len(data_samples) merged_data_sample.set_pred_score(merged_score) return merged_data_sample -if __name__ == '__main__': - cfg = get_config('mmcls::resnet/resnet50_8xb16_cifar10.py') - cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10' - cfg.model = dict(type='ClsTTAModel', module=cfg.model) +if __name__ == "__main__": + cfg = get_config("mmcls::resnet/resnet50_8xb16_cifar10.py") + cfg.work_dir = "work_dirs/resnet50_8xb16_cifar10" + cfg.model = dict(type="ClsTTAModel", module=cfg.model) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) flip_tta = dict( - type='TestTimeAug', + type="TestTimeAug", transforms=[ - [ - dict(type='RandomFlip', prob=1.), - dict(type='RandomFlip', prob=0.) - ], + [dict(type="RandomFlip", prob=1.0), dict(type="RandomFlip", prob=0.0)], [test_pipeline[-1]], - ]) + ], + ) # Replace the last transform with `TestTimeAug` cfg.test_dataloader.dataset.pipeline[-1] = flip_tta - cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \ - '/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth' + cfg.load_from = ( + "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth" + ) runner = Runner.from_cfg(cfg) runner.test() diff --git a/examples/text_classification/train.py b/examples/text_classification/train.py index 84a2841729..ff6cc1abe6 100644 --- a/examples/text_classification/train.py +++ b/examples/text_classification/train.py @@ -11,46 +11,40 @@ class MMBertForClassify(BaseModel): - def __init__(self, model): super().__init__() self.model = model def forward(self, label, input_ids, token_type_ids, attention_mask, mode): output = self.model( - input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - labels=label) - if mode == 'loss': - return {'loss': output.loss} - elif mode == 'predict': + input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=label + ) + if mode == "loss": + return {"loss": output.loss} + elif mode == "predict": return output.logits, label class Accuracy(BaseMetric): - def process(self, data_batch, data_samples): score, gt = data_samples - self.results.append({ - 'batch_size': len(gt), - 'correct': (score.argmax(dim=1) == gt).sum().cpu(), - }) + self.results.append( + { + "batch_size": len(gt), + "correct": (score.argmax(dim=1) == gt).sum().cpu(), + } + ) def compute_metrics(self, results): - total_correct = sum(item['correct'] for item in results) - total_size = sum(item['batch_size'] for item in results) + total_correct = sum(item["correct"] for item in results) + total_size = sum(item["batch_size"] for item in results) return dict(accuracy=100 * total_correct / total_size) def parse_args(): - parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + parser = argparse.ArgumentParser(description="Distributed Training") + parser.add_argument("--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="none", help="job launcher") + parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() return args @@ -62,50 +56,37 @@ def collate_fn(data): token_type_ids = [] attention_mask = [] for item in data: - labels.append(item['label']) - input_ids.append(torch.tensor(item['input_ids'])) - token_type_ids.append(torch.tensor(item['token_type_ids'])) - attention_mask.append(torch.tensor(item['attention_mask'])) + labels.append(item["label"]) + input_ids.append(torch.tensor(item["input_ids"])) + token_type_ids.append(torch.tensor(item["token_type_ids"])) + attention_mask.append(torch.tensor(item["attention_mask"])) input_ids = torch.stack(input_ids) token_type_ids = torch.stack(token_type_ids) attention_mask = torch.stack(attention_mask) label = torch.tensor(labels) - return dict( - label=label, - input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask) + return dict(label=label, input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) def main(): args = parse_args() - model = BertForSequenceClassification.from_pretrained( - 'bert-base-uncased', num_labels=2) - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - train_set = load_dataset('imdb', split='train') - test_set = load_dataset('imdb', split='test') + train_set = load_dataset("imdb", split="train") + test_set = load_dataset("imdb", split="test") train_set = train_set.map( - lambda x: tokenizer( - x['text'], truncation=True, padding=True, max_length=128), - batched=True) - test_set = test_set.map( - lambda x: tokenizer( - x['text'], truncation=True, padding=True, max_length=128), - batched=True) + lambda x: tokenizer(x["text"], truncation=True, padding=True, max_length=128), batched=True + ) + test_set = test_set.map(lambda x: tokenizer(x["text"], truncation=True, padding=True, max_length=128), batched=True) train_loader = dict( - batch_size=32, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=collate_fn) + batch_size=32, dataset=train_set, sampler=dict(type="DefaultSampler", shuffle=True), collate_fn=collate_fn + ) test_loader = dict( - batch_size=32, - dataset=test_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=collate_fn) + batch_size=32, dataset=test_set, sampler=dict(type="DefaultSampler", shuffle=False), collate_fn=collate_fn + ) runner = Runner( model=MMBertForClassify(model), train_dataloader=train_loader, @@ -113,12 +94,12 @@ def main(): optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)), train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), val_cfg=dict(), - work_dir='bert_work_dir', + work_dir="bert_work_dir", val_evaluator=dict(type=Accuracy), launcher=args.launcher, ) runner.train() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/text_translation/train.py b/examples/text_translation/train.py index 61f43bafef..233d74773b 100644 --- a/examples/text_translation/train.py +++ b/examples/text_translation/train.py @@ -8,23 +8,20 @@ from mmengine.model import BaseModel from mmengine.runner import Runner -tokenizer = AutoTokenizer.from_pretrained('t5-small') +tokenizer = AutoTokenizer.from_pretrained("t5-small") -class MMT5ForTranslation(BaseModel): +class MMT5ForTranslation(BaseModel): def __init__(self, model): super().__init__() self.model = model def forward(self, label, input_ids, attention_mask, mode): - if mode == 'loss': - output = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=label) - return {'loss': output.loss} - elif mode == 'predict': + if mode == "loss": + output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=label) + return {"loss": output.loss} + elif mode == "predict": output = self.model.generate(input_ids) return output, label @@ -39,70 +36,61 @@ def post_process(preds, labels): class Accuracy(BaseMetric): - def process(self, data_batch, data_samples): outputs, labels = data_samples decoded_preds, decoded_labels = post_process(outputs, labels) score = bleu_score(decoded_preds, decoded_labels) - prediction_lens = torch.tensor([ - torch.count_nonzero(pred != tokenizer.pad_token_id) - for pred in outputs - ], - dtype=torch.float64) + prediction_lens = torch.tensor( + [torch.count_nonzero(pred != tokenizer.pad_token_id) for pred in outputs], dtype=torch.float64 + ) gen_len = torch.mean(prediction_lens).item() - self.results.append({ - 'gen_len': gen_len, - 'bleu': score, - }) + self.results.append( + { + "gen_len": gen_len, + "bleu": score, + } + ) def compute_metrics(self, results): return dict( - gen_len=np.mean([item['gen_len'] for item in results]), - bleu_score=np.mean([item['bleu'] for item in results]), + gen_len=np.mean([item["gen_len"] for item in results]), + bleu_score=np.mean([item["bleu"] for item in results]), ) def collate_fn(data): - prefix = 'translate English to French: ' - input_sequences = [prefix + item['translation']['en'] for item in data] - target_sequences = [item['translation']['fr'] for item in data] + prefix = "translate English to French: " + input_sequences = [prefix + item["translation"]["en"] for item in data] + target_sequences = [item["translation"]["fr"] for item in data] input_dict = tokenizer( input_sequences, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", ) label = tokenizer( target_sequences, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", ).input_ids - label[label == - tokenizer.pad_token_id] = -100 # ignore contribution to loss - return dict( - label=label, - input_ids=input_dict.input_ids, - attention_mask=input_dict.attention_mask) + label[label == tokenizer.pad_token_id] = -100 # ignore contribution to loss + return dict(label=label, input_ids=input_dict.input_ids, attention_mask=input_dict.attention_mask) def main(): - model = T5ForConditionalGeneration.from_pretrained('t5-small') + model = T5ForConditionalGeneration.from_pretrained("t5-small") - books = load_dataset('opus_books', 'en-fr') - books = books['train'].train_test_split(test_size=0.2) - train_set, test_set = books['train'], books['test'] + books = load_dataset("opus_books", "en-fr") + books = books["train"].train_test_split(test_size=0.2) + train_set, test_set = books["train"], books["test"] train_loader = dict( - batch_size=16, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=collate_fn) + batch_size=16, dataset=train_set, sampler=dict(type="DefaultSampler", shuffle=True), collate_fn=collate_fn + ) test_loader = dict( - batch_size=32, - dataset=test_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=collate_fn) + batch_size=32, dataset=test_set, sampler=dict(type="DefaultSampler", shuffle=False), collate_fn=collate_fn + ) runner = Runner( model=MMT5ForTranslation(model), train_dataloader=train_loader, @@ -110,11 +98,12 @@ def main(): optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)), train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), val_cfg=dict(), - work_dir='t5_work_dir', - val_evaluator=dict(type=Accuracy)) + work_dir="t5_work_dir", + val_evaluator=dict(type=Accuracy), + ) runner.train() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmengine/__init__.py b/mmengine/__init__.py deleted file mode 100644 index a436c950e8..0000000000 --- a/mmengine/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# flake8: noqa -from .config import * -from .fileio import * -from .logging import * -from .registry import * -from .utils import * -from .version import __version__, version_info diff --git a/mmengine/_strategy/__init__.py b/mmengine/_strategy/__init__.py deleted file mode 100644 index 764abcf868..0000000000 --- a/mmengine/_strategy/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION -from .base import BaseStrategy -from .colossalai import ColossalAIStrategy -from .deepspeed import DeepSpeedStrategy -from .distributed import DDPStrategy -from .single_device import SingleDeviceStrategy - -__all__ = [ - 'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy', - 'ColossalAIStrategy' -] - -if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): - try: - from .fsdp import FSDPStrategy # noqa:F401 - __all__.append('FSDPStrategy') - except: # noqa: E722 - pass diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py deleted file mode 100644 index a713da9a70..0000000000 --- a/mmengine/_strategy/base.py +++ /dev/null @@ -1,979 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import os.path as osp -import platform -import time -from abc import ABCMeta, abstractmethod -from collections import OrderedDict -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union - -import torch -import torch.nn as nn -from torch.optim import Optimizer - -import mmengine -from mmengine.config import Config, ConfigDict -from mmengine.dist import (broadcast, get_dist_info, infer_launcher, - is_distributed) -from mmengine.logging import MMLogger -from mmengine.model.wrappers import is_model_wrapper -from mmengine.optim import (BaseOptimWrapper, OptimWrapperDict, - _ParamScheduler, build_optim_wrapper) -from mmengine.registry import MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, - set_multi_processing) - -ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, - List[_ParamScheduler]]] - - -class BaseStrategy(metaclass=ABCMeta): - """Base class for all strategies. - - In the process of supporting FSDP, DeepSpeed, and ColossalAI, the - scalability of the Runner faced challenges, which led to the redefinition - of the Runner's responsibilities. The Strategy abstraction was split out, - which is responsible for constructing, initializing, and saving/loading - the state of training components such as models, optimizers, and parameter - schedulers. - - Warning: - This is an experimental feature, and its interface is subject to - change. - - Keyword Args: - work_dir (str): The working directory to save checkpoints. The logs - will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dirs'. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as :attr:`experiment_name`. - Defaults to None. - env_kwargs (dict, optional): Environment config passed in - :meth:`setup_env`. Defaults to None. - log_kwargs (dict, optional): Logger config passed in - :meth:`build_logger`. Defaults to None. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - """ - model: nn.Module - optim_wrapper: BaseOptimWrapper - param_schedulers: ParamSchedulerType - - def __init__( - self, - *, - work_dir: str = 'work_dirs', - experiment_name: Optional[str] = None, - env_kwargs: Optional[dict] = None, - log_kwargs: Optional[dict] = None, - auto_scale_lr: Optional[dict] = None, - ): - self._work_dir = osp.abspath(work_dir) - mmengine.mkdir_or_exist(self._work_dir) - - self._env_kwargs = env_kwargs or {} - self._setup_env(**self._env_kwargs) - - if experiment_name is not None: - self._experiment_name = f'{experiment_name}_{self.timestamp}' - else: - self._experiment_name = self.timestamp - - self._log_dir = osp.join(self.work_dir, self.timestamp) - mmengine.mkdir_or_exist(self._log_dir) - - log_kwargs = log_kwargs or {} - self.logger = self.build_logger(**log_kwargs) - - self._auto_scale_lr = auto_scale_lr - - self.dispatch_kwargs: dict = {} - self._prepared = False - - @property - def work_dir(self): - return self._work_dir - - @property - def log_dir(self): - return self._log_dir - - @property - def experiment_name(self): - return self._experiment_name - - @property - def launcher(self): - return self._launcher - - @property - def distributed(self): - return self._distributed - - @property - def seed(self): - return self._seed - - @property - def rank(self): - return self._rank - - @property - def world_size(self): - return self._world_size - - @property - def timestamp(self): - return self._timestamp - - @property - def randomness(self): - return self._randomness - - @abstractmethod - def prepare( - self, - model: Union[nn.Module, dict], - *, - optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, - param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, - compile: Union[dict, bool] = False, - dispatch_kwargs: Optional[dict] = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for building a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - """ - - def _setup_env( - self, - *, - launcher: Optional[str] = None, - cudnn_benchmark: bool = False, - mp_cfg: Optional[dict] = None, - dist_cfg: Optional[dict] = None, - resource_limit: int = 4096, - randomness: dict = dict(seed=None), - ): - """Setup environment. - - This method will do the following things: - - 1. setup multi-processing - 2. setup distributed - 3. set random seed - - Keyword Args: - launcher (str, optional): Way to launcher multi-process. Supported - launchers are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' - is provided, non-distributed environment will be launched. - If launcher is None, the launcher will be inferred according - some specified environments. Defaults to None. - cudnn_benchmark (bool): Whether to enable cudnn benchmark. - Defaults to False. - mp_cfg (dict, optional): Multi-processing config. Defaults to None. - dist_cfg (dict, optional): Distributed config. Defaults to None. - resource_limit (int): Resource limit. Defaults to 4096. - randomness (dict): Some settings to make the experiment as - reproducible as possible like seed and deterministic. - Defaults to ``dict(seed=None)``. If seed is None, a random - number will be generated and it will be broadcasted to all - other processes if in distributed environment. - If ``cudnn_benchmark`` is ``True`` in but ``deterministic`` is - ``True`` in ``randomness``, the value of - ``torch.backends.cudnn.benchmark`` will be ``False`` finally. - """ - if launcher is None: - launcher = infer_launcher() - - self._launcher = launcher - if self._launcher == 'none': - self._distributed = False - else: - self._distributed = True - - if cudnn_benchmark: - torch.backends.cudnn.benchmark = True - - mp_cfg = mp_cfg if mp_cfg is not None else {} - set_multi_processing(**mp_cfg, distributed=self._distributed) - - # init distributed env first, since logger depends on the dist info. - if self._distributed and not is_distributed(): - dist_cfg = dist_cfg if dist_cfg is not None else {} - self._setup_distributed(launcher, **dist_cfg) - - self._rank, self._world_size = get_dist_info() - - timestamp = torch.tensor(time.time(), dtype=torch.float64) - # broadcast timestamp from 0 process to other processes - broadcast(timestamp) - self._timestamp = time.strftime('%Y%m%d_%H%M%S', - time.localtime(timestamp.item())) - - # https://github.com/pytorch/pytorch/issues/973 - # set resource limit - if platform.system() != 'Windows': - import resource - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - base_soft_limit = rlimit[0] - hard_limit = rlimit[1] - soft_limit = min(max(resource_limit, base_soft_limit), hard_limit) - resource.setrlimit(resource.RLIMIT_NOFILE, - (soft_limit, hard_limit)) - - self._randomness = randomness - self._set_randomness(**randomness) - - def _setup_distributed(self, *args, **kwargs): - """Setup distributed environment.""" - pass - - def _set_randomness( - self, - seed: Optional[int] = None, - diff_rank_seed: bool = False, - deterministic: bool = False, - ) -> None: - """Set random seed to guarantee reproducible results. - - Args: - seed (int, optional): A number to set random modules. - Defaults to None. - diff_rank_seed (bool): Whether or not set different seeds according - to global rank. Defaults to False. - deterministic (bool): Whether to set the deterministic option for - CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` - to True and `torch.backends.cudnn.benchmark` to False. - Defaults to False. - See https://pytorch.org/docs/stable/notes/randomness.html for - more details. - """ - from mmengine.runner import set_random_seed - self._seed = set_random_seed( - seed=seed, - deterministic=deterministic, - diff_rank_seed=diff_rank_seed) - - def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: - """Build model. - - If ``model`` is a dict, it will be used to build a ``nn.Module`` - object. Otherwise, if ``model`` is a ``nn.Module`` object it will be - returned directly. - - An example of ``model``:: - - model = dict(type='ResNet') - - Args: - model (nn.Module or dict): A ``nn.Module`` object or a dict to - build ``nn.Module`` object. If ``model`` is a ``nn.Module`` - object, just returns itself. - - Note: - The returned model must implement ``train_step``, ``test_step`` - if ``runner.train`` or ``runner.test`` will be called. If - ``runner.val`` will be called or ``val_cfg`` is configured, - model must implement `val_step`. - - Returns: - nn.Module: Model build from ``model``. - """ - if isinstance(model, nn.Module): - return model - elif isinstance(model, dict): - model = MODELS.build(model) - return model # type: ignore - else: - raise TypeError('model should be a nn.Module object or dict, ' - f'but got {model}') - - def compile_model( - self, - model: nn.Module, - compile: Union[dict, bool] = False, - ) -> nn.Module: - """Compile model. - - Args: - model (nn.Module): Model to compile. - - Returns: - nn.Module: Compiled model. - """ - if isinstance(compile, bool) and not compile: - return model - - assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( - 'PyTorch >= 2.0.0 is required to enable torch.compile') - - if isinstance(compile, bool): - compile = dict() - - target = compile.pop('target', 'forward') - func = getattr(model, target) - compiled_func = torch.compile(func, **compile) - setattr(model, target, compiled_func) - self.logger.info('Model has been "compiled". The first few iterations ' - 'will be slow, please be patient.') - - return model - - def _init_model_weights(self, model: nn.Module) -> nn.Module: - """Initialize the model weights if the model has - :meth:`init_weights`""" - if (hasattr(model, 'init_weights') and self.dispatch_kwargs.get( - 'init_weights_for_test_or_val', True)): - model.init_weights() - # sync params and buffers - for _, params in model.state_dict().items(): - broadcast(params) - - return model - - def build_optim_wrapper( - self, - optim_wrapper: Union[Optimizer, BaseOptimWrapper, dict], - model: Optional[nn.Module] = None, - ) -> BaseOptimWrapper: - """Build optimizer wrapper. - - If ``optim_wrapper`` is a config dict for only one optimizer, - the keys must contain ``optimizer``, and ``type`` is optional. - It will build a :obj:`OptimWrapper` by default. - - If ``optim_wrapper`` is a config dict for multiple optimizers, i.e., - it has multiple keys and each key is for an optimizer wrapper. The - constructor must be specified since - :obj:`DefaultOptimizerConstructor` cannot handle the building of - training with multiple optimizers. - - If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e., - each value of ``optim_wrapper`` represents an ``OptimWrapper`` - instance. ``build_optim_wrapper`` will directly build the - :obj:`OptimWrapperDict` instance from ``optim_wrapper``. - - Args: - optim_wrapper (BaseOptimWrapper or dict): An OptimWrapper object or a - dict to build OptimWrapper objects. If ``optim_wrapper`` is an - OptimWrapper, just return an ``OptimizeWrapper`` instance. - - Note: - For single optimizer training, if `optim_wrapper` is a config - dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it - must contain `optimizer` to build the corresponding optimizer. - - Examples: - >>> # build an optimizer - >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)) - >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> # is also valid. - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build optimizer without `type` - >>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - maximize: False - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build multiple optimizers - >>> optim_wrapper_cfg = dict( - ... generator=dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)), - ... discriminator=dict(type='OptimWrapper', optimizer=dict( - ... type='Adam', lr=0.001)) - ... # need to customize a multiple optimizer constructor - ... constructor='CustomMultiOptimizerConstructor', - ...) - >>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - name: generator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.1 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - name: discriminator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - 'discriminator': Adam ( - Parameter Group 0 - dampening: 0 - lr: 0.02 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - - Important: - If you need to build multiple optimizers, you should implement a - MultiOptimWrapperConstructor which gets parameters passed to - corresponding optimizers and compose the ``OptimWrapperDict``. - More details about how to customize OptimizerConstructor can be - found at `optimizer-docs`_. - - Returns: - BaseOptimWrapper: Optimizer wrapper build from ``optimizer_cfg``. - """ # noqa: E501 - if isinstance(optim_wrapper, BaseOptimWrapper): - return optim_wrapper - if isinstance(optim_wrapper, (dict, ConfigDict, Config)): - # optimizer must be defined for single optimizer training. - optimizer = optim_wrapper.get('optimizer', None) - - # If optimizer is a built `Optimizer` instance, the optimizer - # wrapper should be built by `OPTIM_WRAPPERS` registry. - if isinstance(optimizer, Optimizer): - optim_wrapper.setdefault('type', 'OptimWrapper') - return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore - - # If `optimizer` is not None or `constructor` is defined, it means, - # optimizer wrapper will be built by optimizer wrapper - # constructor. Therefore, `build_optim_wrapper` should be called. - if optimizer is not None or 'constructor' in optim_wrapper: - assert model is not None - return build_optim_wrapper(model, optim_wrapper) - else: - # if `optimizer` is not defined, it should be the case of - # training with multiple optimizers. If `constructor` is not - # defined either, each value of `optim_wrapper` must be an - # `OptimWrapper` instance since `DefaultOptimizerConstructor` - # will not handle the case of training with multiple - # optimizers. `build_optim_wrapper` will directly build the - # `OptimWrapperDict` instance from `optim_wrapper.` - optim_wrappers = OrderedDict() - for name, optim in optim_wrapper.items(): - if not isinstance(optim, BaseOptimWrapper): - raise ValueError( - 'each item mush be an optimizer object when ' - '"type" and "constructor" are not in ' - f'optimizer, but got {name}={optim}') - optim_wrappers[name] = optim - return OptimWrapperDict(**optim_wrappers) # type: ignore - else: - raise TypeError('optimizer wrapper should be an OptimWrapper ' - f'object or dict, but got {optim_wrapper}') - - def _build_param_scheduler( - self, - scheduler: Union[_ParamScheduler, Dict, List], - optim_wrapper: BaseOptimWrapper, - default_args: dict, - ) -> List[_ParamScheduler]: - """Build parameter schedulers for a single optimizer. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - optim_wrapper (BaseOptimWrapper): An optimizer wrapper object is - passed to construct ParamScheduler object. - - Returns: - list[_ParamScheduler]: List of parameter schedulers build from - ``scheduler``. - - Note: - If the train loop is built, when building parameter schedulers, - it supports setting the max epochs/iters as the default ``end`` - of schedulers, and supports converting epoch-based schedulers - to iter-based according to the ``convert_to_iter_based`` key. - """ - if not isinstance(scheduler, Sequence): - schedulers = [scheduler] - else: - schedulers = scheduler - - max_epochs = default_args.pop('max_epochs', None) - max_iters = default_args.pop('max_iters', None) - - param_schedulers = [] - for scheduler in schedulers: - if isinstance(scheduler, _ParamScheduler): - param_schedulers.append(scheduler) - elif isinstance(scheduler, dict): - _scheduler = copy.deepcopy(scheduler) - - # Set default end - if _scheduler.get('by_epoch', True): - if max_epochs is None: - raise ValueError( - 'max_epochs must be specified in default_args') - default_end = max_epochs - else: - if max_iters is None: - raise ValueError( - 'max_iters must be specified in default_args') - default_end = max_iters - _scheduler.setdefault('end', default_end) - self.logger.debug( - f'The `end` of {_scheduler["type"]} is not set. ' - 'Use the max epochs/iters of train loop as default.') - - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, **default_args))) - else: - raise TypeError( - 'scheduler should be a _ParamScheduler object or dict, ' - f'but got {scheduler}') - return param_schedulers - - def build_param_scheduler( - self, - scheduler: Union[_ParamScheduler, Dict, List], - optim_wrapper: BaseOptimWrapper, - default_args: Optional[dict] = None, - ) -> ParamSchedulerType: - """Build parameter schedulers. - - ``build_param_scheduler`` should be called after - ``build_optim_wrapper`` because the building logic will change - according to the number of optimizers built by the runner. - The cases are as below: - - - Single optimizer: When only one optimizer is built and used in the - runner, ``build_param_scheduler`` will return a list of - parameter schedulers. - - Multiple optimizers: When two or more optimizers are built and used - in runner, ``build_param_scheduler`` will return a dict containing - the same keys with multiple optimizers and each value is a list of - parameter schedulers. Note that, if you want different optimizers to - use different parameter schedulers to update optimizer's - hyper-parameters, the input parameter ``scheduler`` also needs to be - a dict and its key are consistent with multiple optimizers. - Otherwise, the same parameter schedulers will be used to update - optimizer's hyper-parameters. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - - Examples: - >>> # build one scheduler - >>> optim_cfg = dict(dict(type='SGD', lr=0.01)) - >>> runner.optim_wrapper = runner.build_optim_wrapper( - >>> optim_cfg) - >>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [] # noqa: E501 - - >>> # build multiple schedulers - >>> scheduler_cfg = [ - ... dict(type='MultiStepLR', milestones=[1, 2]), - ... dict(type='StepLR', step_size=1) - ... ] - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [, # noqa: E501 - ] - - Above examples only provide the case of one optimizer and one scheduler - or multiple schedulers. If you want to know how to set parameter - scheduler when using multiple optimizers, you can find more examples - `optimizer-docs`_. - - Returns: - list[_ParamScheduler] or dict[str, list[_ParamScheduler]]: List of - parameter schedulers or a dictionary contains list of parameter - schedulers build from ``scheduler``. - - .. _optimizer-docs: - https://mmengine.readthedocs.io/en/latest/tutorials/optim_wrapper.html - """ - if default_args is None: - default_args = {} - if 'epoch_length' in self.dispatch_kwargs: - default_args['epoch_length'] = self.dispatch_kwargs[ - 'epoch_length'] - if 'max_epochs' in self.dispatch_kwargs: - default_args['max_epochs'] = self.dispatch_kwargs['max_epochs'] - if 'max_iters' in self.dispatch_kwargs: - default_args['max_iters'] = self.dispatch_kwargs['max_iters'] - - param_schedulers: ParamSchedulerType - if not isinstance(optim_wrapper, OptimWrapperDict): - # Since `OptimWrapperDict` inherits from `OptimWrapper`, - # `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell - # whether `self.optim_wrapper` is an `OptimizerWrapper` or - # `OptimWrapperDict` instance. Therefore, here we simply check - # self.optim_wrapper is not an `OptimWrapperDict` instance and - # then assert it is an OptimWrapper instance. - assert isinstance(optim_wrapper, BaseOptimWrapper), ( - '`build_optimizer` should be called before' - '`build_param_scheduler` because the latter depends ' - 'on the former') - param_schedulers = self._build_param_scheduler( - scheduler, optim_wrapper, default_args) # type: ignore - return param_schedulers - else: - param_schedulers = dict() - for name, optimizer in optim_wrapper.items(): - if isinstance(scheduler, dict) and 'type' not in scheduler: - # scheduler is a dict and each item is a ParamScheduler - # object or a config to build ParamScheduler objects - param_schedulers[name] = self._build_param_scheduler( - scheduler[name], optimizer, default_args) - else: - param_schedulers[name] = self._build_param_scheduler( - scheduler, optimizer, default_args) - - return param_schedulers - - def _scale_lr(self) -> None: - """Automatically scaling learning rate in training according to the - ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch - size. - - It scales the learning rate linearly according to the - `paper `_. - - Note: - ``scale_lr`` must be called after building optimizer wrappers - and before building parameter schedulers. - """ - if (self._auto_scale_lr is None - or not self._auto_scale_lr.get('enable', False)): - return None - - assert 'base_batch_size' in self._auto_scale_lr, \ - 'Lack of `base_batch_size` in `auto_scale_lr`.' - - real_bs = self.world_size * self.dispatch_kwargs[ - 'train_micro_batch_size_per_gpu'] - base_bs = self._auto_scale_lr['base_batch_size'] - ratio = float(real_bs) / float(base_bs) - self.logger.info(f'LR is set based on batch size of {base_bs} ' - f'and the current batch size is {real_bs}. ' - f'Scaling the original LR by {ratio}.') - - def _is_built(schedulers): - if isinstance(schedulers, dict): - return False if 'type' in schedulers else any( - _is_built(s) for s in schedulers.values()) - if isinstance(schedulers, list): - return any(_is_built(s) for s in schedulers) - return isinstance(schedulers, _ParamScheduler) - - if hasattr(self, 'param_schedulers') and _is_built( - self.param_schedulers): - raise RuntimeError('`scale_lr` should be called before building ' - 'ParamScheduler because ParamScheduler will ' - 'store initial lr from optimizer wrappers') - - assert isinstance(self.optim_wrapper, BaseOptimWrapper), \ - '`scale_lr should be called after building OptimWrapper' - - if isinstance(self.optim_wrapper, OptimWrapperDict): - wrappers = list(self.optim_wrapper.values()) - else: - wrappers = [self.optim_wrapper] # type: ignore - - for wrapper in wrappers: - for group in wrapper.optimizer.param_groups: - group['lr'] = group['lr'] * ratio - - def build_logger( - self, - log_level: Union[int, str] = 'INFO', - log_file: Optional[str] = None, - **kwargs, - ) -> MMLogger: - """Build a global asscessable MMLogger. - - Args: - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - log_file (str, optional): Path of filename to save log. - Defaults to None. - **kwargs: Remaining parameters passed to ``MMLogger``. - - Returns: - MMLogger: A MMLogger object build from ``logger``. - """ - if log_file is None: - log_file = osp.join(self.log_dir, f'{self._timestamp}.log') - - log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs) - log_cfg.setdefault('name', self.experiment_name) - # `torch.compile` in PyTorch 2.0 could close all user defined handlers - # unexpectedly. Using file mode 'a' can help prevent abnormal - # termination of the FileHandler and ensure that the log file could - # be continuously updated during the lifespan of the runner. - log_cfg.setdefault('file_mode', 'a') - - return MMLogger.get_instance(**log_cfg) # type: ignore - - def model_state_dict(self) -> dict: - """Get model state dict.""" - from mmengine.runner import weights_to_cpu - return weights_to_cpu(self.model.state_dict()) - - def optim_state_dict(self) -> dict: - """Get optimizer state dict.""" - if isinstance(self.optim_wrapper, BaseOptimWrapper): - return self.optim_wrapper.state_dict() - else: - raise TypeError('self.optim_wrapper should be a `BaseOptimWrapper`' - f' instance, but got {self.optim_wrapper}') - - def scheduler_state_dict(self) -> Union[dict, list]: - """Get parameter scheduler state dict.""" - if isinstance(self.param_schedulers, dict): - state_dict: dict = dict() - for name, schedulers in self.param_schedulers.items(): - state_dict[name] = [] - for scheduler in schedulers: - state_dict[name].append(scheduler.state_dict()) - return state_dict - else: - state_list = [] - for scheduler in self.param_schedulers: # type: ignore - state_list.append(scheduler.state_dict()) - return state_list - - def load_model_state_dict( - self, - state_dict: dict, - *, - strict: bool = False, - revise_keys: list = [(r'^module.', '')], - ) -> None: - """Load model state from dict.""" - from mmengine.runner.checkpoint import _load_checkpoint_to_model - - if is_model_wrapper(self.model): - model = self.model.module - else: - model = self.model - - _load_checkpoint_to_model( - model, state_dict, strict=strict, revise_keys=revise_keys) - - def load_optim_state_dict(self, state_dict: dict) -> None: - """Load optimizer state from dict.""" - self.optim_wrapper.load_state_dict(state_dict) - - def load_scheduler_state_dict(self, state_dict: Union[dict, list]) -> None: - """Load scheduler state from dict.""" - if isinstance(self.param_schedulers, dict): - assert isinstance(state_dict, dict) - for name, schedulers in self.param_schedulers.items(): - for scheduler, ckpt_scheduler in zip(schedulers, - state_dict[name]): - scheduler.load_state_dict(ckpt_scheduler) - else: - for scheduler, ckpt_scheduler in zip( - self.param_schedulers, # type: ignore - state_dict): - scheduler.load_state_dict(ckpt_scheduler) - - def load_or_resume( - self, - *, - load_from: Optional[str] = None, - resume: Union[bool, str] = False, - ) -> Optional[dict]: - """Load checkpoint or resume from checkpoint. - - Args: - load_from (str, optional): The checkpoint file to load from. - Defaults to None. - resume (bool or str): Whether to resume training. Defaults to - False. If ``resume`` is True and ``load_from`` is None, - automatically to find latest checkpoint from ``work_dir``. - If not found, resuming does nothing. If ``resume`` is a string, - it will be treated as the checkpoint file to resume from. - """ - from mmengine.runner import find_latest_checkpoint - - if not resume and load_from is None: - return None - - # decide to load from checkpoint or resume from checkpoint - resume_from = None - if isinstance(resume, str): - resume_from = resume - elif resume and load_from is None: - # auto resume from the latest checkpoint - resume_from = find_latest_checkpoint(self._work_dir) - self.logger.info( - f'Auto resumed from the latest checkpoint {resume_from}.') - elif resume and load_from is not None: - # resume from the specified checkpoint - resume_from = load_from - - if resume_from is not None: - return self.resume(resume_from) - elif load_from is not None: - return self.load_checkpoint(load_from) - - return None - - @abstractmethod - def load_checkpoint( - self, - filename: str, - *, - map_location: Union[str, Callable] = 'cpu', - strict: bool = False, - revise_keys: list = [(r'^module.', '')], - callback: Optional[Callable] = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - callback (callable, callable): Callback function to modify the - checkpoint after loading the checkpoint. - Defaults to None. - """ - - @abstractmethod - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: Union[str, Callable] = 'default', - callback: Optional[Callable] = None, - ) -> dict: - """Resume training from given ``filename``. - - Four types of states will be resumed. - - - model state - - optimizer state - - scheduler state - - randomness state - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - - @abstractmethod - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: Optional[dict] = None, - callback: Optional[Callable] = None, - ) -> None: - """Save checkpoint to given ``filename``. - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - - def collect_env(self) -> Tuple[dict, dict]: - """Collect the information of the running environments.""" - system_env = collect_env() - runtime_env: OrderedDict = OrderedDict() - runtime_env.update(self._env_kwargs) - runtime_env.update(self.randomness) - runtime_env['Distributed launcher'] = self.launcher - runtime_env['Distributed training'] = self.distributed - runtime_env['GPU number'] = self.world_size - - return system_env, runtime_env - - def _prepared_components(self): - return_items = [self.model] - if hasattr(self, 'optim_wrapper'): - return_items.append(self.optim_wrapper) - - if hasattr(self, 'param_schedulers'): - return_items.append(self.param_schedulers) - - return return_items[0] if len(return_items) == 1 else return_items diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py deleted file mode 100644 index 13d9f38fc3..0000000000 --- a/mmengine/_strategy/colossalai.py +++ /dev/null @@ -1,565 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import os.path as osp -import time -from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -try: - import colossalai - import colossalai.booster.mixed_precision as colo_precision - import colossalai.booster.plugin as colo_plugin - import colossalai.nn.optimizer as colo_optimizer - from colossalai.booster import Booster - from colossalai.interface import ModelWrapper -except Exception as e: # noqa: F841 - colossalai = None - colo_precision = None - colo_plugin = None - colo_optimizer = None - Booster = None - ModelWrapper = None - -import torch -import torch.nn as nn - -import mmengine -from mmengine import mkdir_or_exist -from mmengine._strategy import BaseStrategy -from mmengine.device import get_device -from mmengine.dist import init_dist, is_main_process -from mmengine.fileio import join_path -from mmengine.model import BaseDataPreprocessor -from mmengine.optim import BaseOptimWrapper, OptimWrapper, _ParamScheduler -from mmengine.registry import STRATEGIES, Registry -from mmengine.registry.root import MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS -from mmengine.runner.checkpoint import _load_checkpoint, save_checkpoint -from mmengine.utils import get_git_hash - -# Component for colossalai `plugins` and `mixed_precisions` -PLUGINS = Registry('plugin') -MIXED_PRECISIONS = Registry('mixed_precision') - - -def register_plugins(): - _plugins = inspect.getmembers( - colo_plugin, - lambda x: inspect.isclass(x) and issubclass(x, colo_plugin.Plugin)) - - for name, plugin in _plugins: - PLUGINS.register_module(name=name, module=plugin) - - -def register_optimizers(): - _colo_optimizer = inspect.getmembers( - colo_optimizer, - lambda x: inspect.isclass(x) and issubclass(x, torch.optim.Optimizer)) - for name, optim_type in _colo_optimizer: - OPTIMIZERS.register_module(name=name, module=optim_type, force=True) - - -def register_mixed_precisions(): - _mixed_precisions = inspect.getmembers( - colo_precision, lambda x: inspect.isclass(x) and issubclass( - x, colo_precision.MixedPrecision)) - - for name, mixed_precision in _mixed_precisions: - MIXED_PRECISIONS.register_module(name=name, module=mixed_precision) - - -@OPTIM_WRAPPERS.register_module() -class ColossalAIOptimWrapper(OptimWrapper): - """OptimWrapper for ColossalAI. - - The available optimizers are: - - CPUAdam - - FusedAdam - - FusedLAMB - - FusedSGD - - HybridAdam - - Lamb - - Lars - - You can find more details in the `colossalai tutorial`_ - - Args: - optimizer (dict or torch.optim.Optimizer): The optimizer to be - wrapped. - accumulative_counts (int): The number of iterations to accumulate - gradients. The parameters will be updated per - ``accumulative_counts``. - - .. _colossalai tutorial: https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/optimizer - """ # noqa: E501 - - def __init__(self, - optimizer: torch.optim.Optimizer, - booster: Optional[Booster] = None, - accumulative_counts: int = 1): - super().__init__(optimizer, accumulative_counts=accumulative_counts) - self.booster = booster - - @contextmanager - def optim_context(self, model: nn.Module): - assert isinstance(self.booster, Booster), \ - 'Please set the booster attribute before using ' \ - '`ColossalAIOptimWrapper`.' - if self.booster.plugin.support_no_sync(): - no_sync_context = self.booster.no_sync(model, self.optimizer) - else: - yield - return - if self.should_sync(): - yield - else: - with no_sync_context: - yield - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - self._inner_count += 1 - self.optimizer.backward(loss, **kwargs) - - -@MODEL_WRAPPERS.register_module( - name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper']) -class ColossalAIModelWrapper: - - def __init__(self, model_wrapper: ModelWrapper, model: nn.Module): - self.model_wrapper = model_wrapper - self.model = model - - def __call__(self, *args, **kwargs) -> Any: - return self.model_wrapper(*args, **kwargs) - - def train_step( - self, - data: Union[dict, tuple, list], - optim_wrapper: ColossalAIOptimWrapper, - ) -> Dict[str, torch.Tensor]: - data = self.model.data_preprocessor(data, training=True) - with optim_wrapper.optim_context(self.model): - losses = self._run_forward(data, mode='loss') - parsed_loss, log_vars = self.model.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - return log_vars - - def val_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.model.data_preprocessor(data, False) - return self._run_forward(data, mode='predict') - - test_step = val_step - - def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self.model_wrapper(**data, mode=mode) - elif isinstance(data, (list, tuple)): - results = self.model_wrapper(*data, mode=mode) - else: - raise TypeError('Output of `data_preprocessor` should be ' - f'list, tuple or dict, but got {type(data)}') - return results - - def __getattr__(self, name): - if hasattr(self.model_wrapper, name): - return getattr(self.model_wrapper, name) - elif hasattr(self.model, name): - return getattr(self.model, name) - else: - raise AttributeError( - f'{self.model_wrapper} and {self.model} has no ' - f'attribute {name}') - - -@STRATEGIES.register_module() -class ColossalAIStrategy(BaseStrategy): - """ - Args: - config: (str or dict): The colossalai config file to setup distributed - environment. See more details in the `colossalai config tutorial`_. - mixed_precision (str or MixedPrecision): The mixed precision to run the - training. Defaults to None. If the argument is a string, it can be - 'fp16', 'fp16_apex', 'bf16', or 'fp8' fp16' would use PyTorch AMP - while `fp16_apex` would use Nvidia Apex. - plugin (Plugin): The plugin to run the training. The type of `plugin` - could be: - - - str: The available plugins are ``gemini`` and ``lowlevel-zero``. - - ``gemini`` means a `ZeRO`_ implementation with chunk-based - memory management. You could find more details in the - `colossalai gemini tutorial`_. ``lowlevel-zero`` means a - Zero-1 and Zero-2 implementation. Although gemini is more - memory saving, some unexpceted error could happen for - some spectial model structure. lowlevel-zero is more stable. - - - dict: **dict-type style config to build a colossalai plugin**. - - See the `booster plugin tutorial`_ for more details. - - model_wrapper (dict, optional): Dict for model wrapper. Defaults to - None. - work_dir (str): The working directory to save checkpoints. The logs - will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dirs'. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as :attr:`experiment_name`. - Defaults to None. - env_kwargs (dict, optional): Environment config passed in - :meth:`setup_env`. Defaults to None. - log_kwargs (dict, optional): Logger config passed in - :meth:`build_logger`. Defaults to None. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - - .. _colossalai config tutorial: https://colossalai.org/docs/basics/configure_parallelization - .. _ZeRO: https://arxiv.org/abs/1910.02054 - .. _colossalai gemini tutorial: https://colossalai.org/docs/features/zero_with_chunk/#geminiddp - .. _booster plugin tutorial: https://colossalai.org/docs/basics/booster_plugins - - """ # noqa: E501 - OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state. - MODEL_DIR = 'model' # directory to save model - SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs - model: ColossalAIModelWrapper # type: ignore - optim_wrapper: ColossalAIOptimWrapper # type: ignore - - def __init__( - self, - *, - config: Union[str, dict, None] = None, - mixed_precision: Union[str, dict, None] = None, - plugin: str = 'gemini', - model_wrapper: Optional[dict] = None, - **kwargs, - ): - if colossalai is None: - raise ModuleNotFoundError( - 'Please install colossalai by `pip install -U colossalai`') - register_plugins() - register_mixed_precisions() - register_optimizers() - - self.config = config or {} - super().__init__(**kwargs) - if mixed_precision is not None: - mixed_precision = self._build_mixed_precision(mixed_precision) - - if plugin is not None: - plugin = self._build_plugin(plugin) - self.booster = Booster(mixed_precision=mixed_precision, plugin=plugin) - self.model_wrapper = model_wrapper - - def prepare( - self, - model: Union[nn.Module, dict], - *, - optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, - param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, - compile: Union[dict, bool] = False, - dispatch_kwargs: Optional[dict] = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for build a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - If ``accumulative_counts`` is set in ``optim_wrapper``, you - need to provide ``max_iters`` in ``dispatch_kwargs``. - """ - if self._prepared: - return self._prepared_components() - if dispatch_kwargs is not None: - self.dispatch_kwargs.update(dispatch_kwargs) - - model = self.build_model(model) - model = self._init_model_weights(model) - - # optim_wrapper is required by booster - if optim_wrapper is not None and isinstance(optim_wrapper, dict): - optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper') - optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type']) - if optim_wrapper_type is None: - raise ValueError(f'Failed to find {optim_wrapper["type"]} in ' - '`OPTIM_WRAPPERS`.') - if 'clip_grad' in optim_wrapper: - raise ValueError('`Please configure `clip_grad` in `plugin`') - if not issubclass(optim_wrapper_type, ColossalAIOptimWrapper): - raise ValueError( - 'The type of `optim_wrapper` must be ' - '`ColossalAIOptimWrapper` (or subclass), but got ' - f'{optim_wrapper_type}') - optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - optim_wrapper.booster = self.booster # type: ignore - - if optim_wrapper is not None: - self.model, self.optim_wrapper = self._wrap( - model, optim_wrapper) # type: ignore - else: - self.model = self._wrap(model) # type: ignore - # TODO: Check whether `compile` is compatible with colossalai. - - if param_scheduler is not None: - self.param_schedulers = self.build_param_scheduler( - param_scheduler, optim_wrapper) # type: ignore - - if optim_wrapper is not None: - self._scale_lr() - accumulative_counts = getattr(self.optim_wrapper, - '_accumulative_counts', 1) - if accumulative_counts > 1: - if 'max_iters' not in self.dispatch_kwargs: - raise ValueError( - '"max_iters" must be specified because ' - '"accumulative_counts" was set as ' - f'{accumulative_counts} which is greater than 1.') - - self.optim_wrapper.initialize_count_status( # type: ignore - self.model, 0, self.dispatch_kwargs['max_iters']) - self._prepared = True - return self._prepared_components() - - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: Union[str, Callable] = 'default', - callback: Optional[Callable] = None, - ) -> dict: - """Override this method since colossalai resume optimizer from filename - directly.""" - self.logger.info(f'Resume checkpoint from {filename}') - - extra_ckpt = self.load_checkpoint( - filename, map_location=map_location, callback=callback) - - if resume_optimizer: - self.booster.load_optimizer( - self.optim_wrapper.optimizer, - join_path(filename, self.OPTIMIZER_DIR)) - - if resume_param_scheduler: - schedulers_dir = join_path(filename, self.SCHEDULER_DIR) - for i, scheduler in enumerate(self.param_schedulers): - self.booster.load_lr_scheduler( - scheduler, f'{schedulers_dir}/scheduler_{i}.pth') - - # resume random seed - resumed_seed = extra_ckpt['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - - # resume iter - self.dispatch_kwargs['cur_iter'] = extra_ckpt['meta']['iter'] - - return extra_ckpt - - def load_checkpoint( - self, - filename: str, - *, - map_location: Union[str, Callable] = 'cpu', - strict: bool = False, - revise_keys: list = [(r'^module.', '')], - callback: Optional[Callable] = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Warning: - `map_localtion` and `callback` parameters are not supported yet. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - """ - self.logger.info(f'Load checkpoint from {filename}') - self.booster.load_model(self.model.model_wrapper, - join_path(filename, self.MODEL_DIR)) - meta = _load_checkpoint(osp.join(filename, 'meta.pth')) - return meta - - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: Optional[dict] = None, - callback: Optional[Callable] = None, - ) -> None: - # The checkpoint directory will be: - # |--epoch_0.pth - # |---model/ - # |---optimizer/ - # |---scheduler/ - if extra_ckpt is None: - extra_ckpt = dict() - if 'meta' not in extra_ckpt: - extra_ckpt['meta'] = dict() - extra_ckpt['meta'].update( - seed=self.seed, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine=mmengine.__version__ + get_git_hash()) - - model_dir = join_path(filename, self.MODEL_DIR) - optimizer_dir = join_path(filename, self.OPTIMIZER_DIR) - schedulers_dir = join_path(filename, self.SCHEDULER_DIR) - mkdir_or_exist(model_dir) - mkdir_or_exist(optimizer_dir) - mkdir_or_exist(schedulers_dir) - - self.booster.save_model( - self.model.model_wrapper, checkpoint=model_dir, shard=True) - - if save_optimizer: - self.booster.save_optimizer( - self.optim_wrapper.optimizer, - checkpoint=optimizer_dir, - shard=True) - - if is_main_process() and save_param_scheduler: - for i, scheduler in enumerate(self.param_schedulers): - self.booster.save_lr_scheduler( - scheduler, f'{schedulers_dir}/scheduler_{i}.pth') - - save_checkpoint(extra_ckpt, join_path(filename, 'meta.pth')) - - def _build_plugin(self, plugin: Union[str, dict]): - if isinstance(plugin, str): - if plugin == 'gemini': - try: - plugin = colo_plugin.GeminiPlugin( - precision='bf16', placement_policy='auto') - except AssertionError: - from colossalai.zero.gemini.placement_policy import \ - PlacementPolicyFactory as colo_placement - raise ValueError('placement policy must be one of ' + - f'{list(colo_placement.policies.keys())}') - elif plugin == 'lowlevel-zero': - plugin = colo_plugin.LowLevelZeroPlugin() - else: - raise ValueError('`plugin` must be "gemini" or ' - '"lowlevel-zero"') - elif isinstance(plugin, dict): - plugin = PLUGINS.build(plugin) - else: - raise ValueError('`plugin` must be dict or str, but got a ' - f'{type(plugin)} object)') - return plugin - - def _build_mixed_precision(self, mixed_precision: Union[str, dict]): - if isinstance(mixed_precision, str): - if mixed_precision == 'fp16': - mixed_precision = colo_precision.FP16TorchMixedPrecision() - elif mixed_precision == 'fp16_apex': - mixed_precision = colo_precision.FP16ApexMixedPrecision() - elif mixed_precision == 'bf16': - mixed_precision = colo_precision.BF16MixedPrecision() - elif mixed_precision == 'fp8': - mixed_precision = colo_precision.FP8MixedPrecision() - else: - raise ValueError( - 'If `mixed_precision` is a string, it must be one of ' - '"fp16", "fp16_apex", "bf16" and "fp8", but got ' - f'{mixed_precision}') - elif isinstance(mixed_precision, dict): - mixed_precision = MIXED_PRECISIONS.build(mixed_precision) - else: - raise ValueError('mixed precision should be dict or str, but got ' - f'a {type(mixed_precision)} object') - return mixed_precision - - def _wrap( - self, - model: nn.Module, - optim_wrapper: Optional[OptimWrapper] = None, - ) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper], - ColossalAIModelWrapper]: # type: ignore - """Wrap model with :class:`ModelWrapper`.""" - if self.model_wrapper is None: - self.model_wrapper = {'type': 'ColossalAIModelWrapper'} - - # For zero series parallel, move `data_preprocessor` to current device - # is reasonable. We need to `BaseDataPreprocessor.to` manually since - # framework like colossalai and deepspeed could not handle it, leading - # to `data_preprocessor` move data to cpu. - for module in model.modules(): - if isinstance(module, BaseDataPreprocessor): - module.to(get_device()) - - if optim_wrapper is not None: - optimizer = optim_wrapper.optimizer - if not hasattr(optimizer, '_hook_for_profile'): - # PyTorch 2.0 removes the `_hook_for_profile` in - # `torch.optim.Optimizer`. We maintain this function here to - # keep compatibility. - # TODO: Remove this hardcode when ColossalAI supports - # PyTorch 2.0 - optimizer.__class__._hook_for_profile = object - - # We do not pass `scheduler` and `Dataloader` here for: - # 1. `Booster.boost` cannot accept a list of schedulers. - # 2. `Strategy` cannot not accept dataloader now. - model_wrapper, optimizer, *_ = self.booster.boost(model, optimizer) - optim_wrapper.optimizer = optimizer - default_args = {'model_wrapper': model_wrapper, 'model': model} - model_wrapper = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) - return model_wrapper, optim_wrapper # type: ignore - else: - model_wrapper, *_ = self.booster.boost(model) - default_args = {'model_wrapper': model_wrapper, 'model': model} - model_wrapper = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) - return model_wrapper - - def _setup_distributed( # type: ignore - self, - launcher: Optional[str] = None, - backend: str = 'nccl', - **kwargs, - ): - init_dist( - launcher, backend, init_backend='colossalai', config=self.config) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py deleted file mode 100644 index 3f89ff760d..0000000000 --- a/mmengine/_strategy/deepspeed.py +++ /dev/null @@ -1,581 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import os.path as osp -import time -from typing import Any, Callable, Dict, List, Optional, Union - -import torch - -from mmengine.logging import print_log - -try: - import deepspeed -except ImportError: - deepspeed = None - -import logging - -import torch.nn as nn - -import mmengine -from mmengine.dist import init_dist, is_main_process -from mmengine.optim import BaseOptimWrapper, _ParamScheduler -from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS, - STRATEGIES) -from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu -from mmengine.utils import apply_to, digit_version, get_git_hash -from .base import BaseStrategy - - -def register_deepspeed_optimizers() -> List[str]: - """Register optimizers in ``deepspeed`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - deepspeed_optimizers = [] - try: - import deepspeed # noqa: F401 - except ImportError: - pass - else: - from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam - from deepspeed.ops.lamb import FusedLamb - from deepspeed.runtime.fp16.onebit import (OnebitAdam, OnebitLamb, - ZeroOneAdam) - - OPTIMIZERS.register_module(module=DeepSpeedCPUAdam) - deepspeed_optimizers.append('DeepSpeedCPUAdam') - OPTIMIZERS.register_module(module=FusedAdam) - deepspeed_optimizers.append('FusedAdam') - OPTIMIZERS.register_module(module=FusedLamb) - deepspeed_optimizers.append('FusedLamb') - OPTIMIZERS.register_module(module=OnebitAdam) - deepspeed_optimizers.append('OnebitAdam') - OPTIMIZERS.register_module(module=OnebitLamb) - deepspeed_optimizers.append('OnebitLamb') - OPTIMIZERS.register_module(module=ZeroOneAdam) - deepspeed_optimizers.append('ZeroOneAdam') - - return deepspeed_optimizers - - -@OPTIM_WRAPPERS.register_module() -class DeepSpeedOptimWrapper(BaseOptimWrapper): - - def __init__(self, optimizer): - super().__init__(optimizer) - self._model = None - - @property - def model(self): - if self._model is None: - raise ValueError('model attribute should be set before accessing.') - return self._model - - @model.setter - def model(self, value): - self._model = value - - def update_params(self, loss) -> None: # type: ignore - """Update parameters in :attr:`optimizer`.""" - self.backward(loss) - self.step() - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """"Perform gradient back propagation.""" - self.model.backward(loss) - - def zero_grad(self, **kwargs) -> None: - raise NotImplementedError( - 'DeepSpeedOptimWrapper does not support zero_grad method ' - 'currently.') - - def step(self, **kwargs): - self.model.step() - - def state_dict(self) -> dict: - state_dict = {} - if self.base_param_settings is not None: - state_dict['base_param_settings'] = self.base_param_settings - - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - base_param_settings = state_dict.pop('base_param_settings', None) - - if base_param_settings is not None: - self.base_param_settings = base_param_settings - - -@MODEL_WRAPPERS.register_module() -class MMDeepSpeedEngineWrapper: - - def __init__( - self, - *, - model: 'deepspeed.DeepSpeedEngine', - inputs_to_half: Optional[List[Union[int, str]]] = None, - ): - self.model = model - self._inputs_to_half = inputs_to_half - - def __getattr__(self, name): - return getattr(self.model, name) - - def train_step( - self, - data: Union[dict, tuple, list], - optim_wrapper: DeepSpeedOptimWrapper, - ) -> Dict[str, torch.Tensor]: - data = self.model.module.data_preprocessor(data, training=True) - data = self._cast_inputs_half(data) - losses = self._run_forward(data, mode='loss') - parsed_loss, log_vars = self.model.module.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - - return log_vars - - def val_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.model.module.data_preprocessor(data, False) - data = self._cast_inputs_half(data) - return self._run_forward(data, mode='predict') - - def test_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the predictions of module during testing process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.model.module.data_preprocessor(data, False) - data = self._cast_inputs_half(data) - return self._run_forward(data, mode='predict') - - def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self.model(**data, mode=mode) - elif isinstance(data, (list, tuple)): - results = self.model(*data, mode=mode) - else: - raise TypeError('Output of `data_preprocessor` should be ' - f'list, tuple or dict, but got {type(data)}') - return results - - def _cast_inputs_half(self, inputs: Union[list, tuple, dict, None]): - """Cast inputs to half precision if needed. - - Args: - inputs (list or tuple or dict or None): Inputs to be casted. - - Returns: - list or tuple or dict or None: Casted inputs. - """ - if self._inputs_to_half is None: - return inputs - - dtype = next(self.model.parameters()).dtype - if isinstance(inputs, (list, tuple)): - new_inputs = [] - for i, v in enumerate(inputs): - if i in self._inputs_to_half: - new_inputs.append( - apply_to(v, lambda x: hasattr(x, 'to'), - lambda x: x.to(dtype))) - else: - new_inputs.append(v) - return inputs.__class__(new_inputs) - elif isinstance(inputs, dict): - for k, v in inputs.items(): - if k in self._inputs_to_half: - inputs[k] = apply_to(v, lambda x: hasattr(x, 'to'), - lambda x: x.to(dtype)) - return inputs - else: - raise TypeError('inputs should be list, tuple or dict, ' - f'but got {type(inputs)}') - - -@STRATEGIES.register_module() -class DeepSpeedStrategy(BaseStrategy): - """Support training models with DeepSpeed. - - Note: - The detailed usage of parameters can be found at - https://www.deepspeed.ai/docs/config-json/. - - Args: - config (str or dict, optional): If it is a string, it is a path to load - config for deepspeed. Defaults to None. - zero_optimization (dict, optional): Enabling and configuring ZeRO - memory optimizations. Defaults to None. - gradient_clipping (float, optional): Enable gradient clipping with - value. Defaults to None. - fp16 (dict, optional): Configuration for using mixed precision/FP16 - training that leverages NVIDIA's Apex package. Defaults to None. - inputs_to_half (list[int or str], optional): Which inputs are to - converted to half precision. Defaults to None. - If ``fp16`` is enabled, it also should be set. - bf16 (dict, optional): Configuration for using bfloat16 floating-point - format as an alternative to FP16. Defaults to None. - amp (dict, optional): Configuration for using automatic mixed - precision (AMP) training that leverages NVIDIA's Apex AMP package. - Defaults to None. - activation_checkpointing (dict, optional): Reduce memory usage by - clearing activations of certain layers and recomputing them - during a backward pass. - Defaults to None. - aio (dict, optional): Configuring the asynchronous I/O module for - offloading parameter and optimizer states to persistent (NVMe) - storage. This module uses Linux native asynchronous I/O (libaio). - Defaults to None. - train_micro_batch_size_per_gpu (int, optional): Batch size to be - processed by one GPU in one step (without gradient accumulation). - Defaults to None. - gradient_accumulation_steps (int, optional): Number of training steps - to accumulate gradients before averaging and applying them. - Defaults to None. - exclude_frozen_parameters (bool, optional): Exclude frozen parameters - from saved checkpoint. - """ - - def __init__( - self, - *, - # the following args are for deepspeed - config: Union[str, dict, None] = None, - zero_optimization: Optional[dict] = None, - gradient_clipping: Optional[float] = None, - fp16: Optional[dict] = None, - inputs_to_half: Optional[List[Union[int, str]]] = None, - bf16: Optional[dict] = None, - amp: Optional[dict] = None, - activation_checkpointing: Optional[dict] = None, - aio: Optional[dict] = None, - train_micro_batch_size_per_gpu: Optional[int] = None, - gradient_accumulation_steps: Optional[int] = None, - # disable the log printed by deepseed - steps_per_print: int = 10000000000000, - # the following args are for BaseStrategy - exclude_frozen_parameters: Optional[bool] = None, - **kwargs, - ): - assert deepspeed is not None, \ - 'DeepSpeed is not installed. Please check ' \ - 'https://github.com/microsoft/DeepSpeed#installation.' - - super().__init__(**kwargs) - - self.config = self._parse_config(config) - if zero_optimization is not None: - self.config['zero_optimization'] = zero_optimization - if gradient_clipping is not None: - self.config['gradient_clipping'] = gradient_clipping - if fp16 is not None: - self.config['fp16'] = fp16 - if bf16 is not None: - self.config['bf16'] = bf16 - if amp is not None: - self.config['amp'] = amp - if activation_checkpointing is not None: - self.config['activation_checkpointing'] = activation_checkpointing - if aio is not None: - self.config['aio'] = aio - if train_micro_batch_size_per_gpu is not None: - self.config['train_micro_batch_size_per_gpu'] = \ - train_micro_batch_size_per_gpu - if gradient_accumulation_steps is not None: - self.config['gradient_accumulation_steps'] = \ - gradient_accumulation_steps - else: - self.config.setdefault('gradient_accumulation_steps', 1) - self.config['steps_per_print'] = steps_per_print - self._inputs_to_half = inputs_to_half - assert (exclude_frozen_parameters is None or - digit_version(deepspeed.__version__) >= digit_version('0.13.2') - ), ('DeepSpeed >= 0.13.2 is required to enable ' - 'exclude_frozen_parameters') - self.exclude_frozen_parameters = exclude_frozen_parameters - - register_deepspeed_optimizers() - - def _parse_config(self, config): - if config is None: - config = dict() - elif isinstance(config, str): - with open(config) as f: - config = json.load(f) - return config - - def _setup_distributed( # type: ignore - self, - launcher: Optional[str] = None, - backend: str = 'nccl', - **kwargs, - ): - """Setup distributed environment. - - Args: - launcher (str, optional): Way to launch multi processes. - DeepSpeedStrategy does not support the launcher argument. - backend (str): Communication Backends. Supported backends are - 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: Other arguments for :func:`deepspeed.init_distributed`. - """ - init_dist(launcher, backend, init_backend='deepspeed', **kwargs) - - def prepare( - self, - model: Union[nn.Module, dict], - *, - optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, - param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, - compile: Union[dict, bool] = False, - dispatch_kwargs: Optional[dict] = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for build a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - """ - if self._prepared: - return self._prepared_components() - assert dispatch_kwargs is not None - self.dispatch_kwargs.update(dispatch_kwargs) - - model = self.build_model(model) - model = self._init_model_weights(model) - - if optim_wrapper is not None: - self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - self.model = self._wrap_model(model) - - self.optim_wrapper.model = self.model # type: ignore - - else: - self.model = self._wrap_model(model) - - if param_scheduler is not None: - self.param_schedulers = self.build_param_scheduler( - param_scheduler, self.optim_wrapper) - self._prepared = True - return self._prepared_components() - - def _wrap_model(self, model: nn.Module) -> nn.Module: - if hasattr(self, 'optim_wrapper'): - engine, self.optim_wrapper.optimizer, *_ = deepspeed.initialize( - model=model, - optimizer=self.optim_wrapper.optimizer, - config=self.config) - else: - engine, *_ = deepspeed.initialize(model=model, config=self.config) - - wrapper = MMDeepSpeedEngineWrapper( - model=engine, inputs_to_half=self._inputs_to_half) - return wrapper - - def load_checkpoint( - self, - filename: str, - *, - map_location: Union[str, Callable] = 'cpu', - strict: bool = False, - revise_keys: list = [(r'^module.', '')], - callback: Optional[Callable] = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Warning: - `map_localtion` and `callback` parameters are not supported yet. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - """ - self.logger.info(f'Load checkpoint from {filename}') - - dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version('0.13.2'): - _, extra_ckpt = self.model.load_checkpoint( - dirname, - tag=basename, - load_optimizer_states=False, - load_module_strict=not self.exclude_frozen_parameters) - else: - _, extra_ckpt = self.model.load_checkpoint( - dirname, tag=basename, load_optimizer_states=False) - - return extra_ckpt - - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: Union[str, Callable] = 'default', - callback: Optional[Callable] = None, - ) -> dict: - """Resume training from given ``filename``. - - Warning: - `map_location` and `callback` parameters are not supported yet. - - Args: - filename (str): Accept local filepath. - - Keyword Args: - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - """ - self.logger.info(f'Resume checkpoint from {filename}') - - dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version('0.13.2'): - _, extra_ckpt = self.model.load_checkpoint( - dirname, - tag=basename, - load_optimizer_states=resume_optimizer, - load_module_strict=not self.exclude_frozen_parameters) - else: - _, extra_ckpt = self.model.load_checkpoint( - dirname, tag=basename, load_optimizer_states=resume_optimizer) - - if resume_optimizer: - self.load_optim_state_dict(extra_ckpt.pop('optim_wrapper')) - - if resume_param_scheduler and hasattr(self, 'param_schedulers'): - param_schedulers = extra_ckpt.pop('param_schedulers') - self.load_scheduler_state_dict(param_schedulers) - - # resume random seed - resumed_seed = extra_ckpt['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - - return extra_ckpt - - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: Optional[dict] = None, - callback: Optional[Callable] = None, - ) -> None: - """Save checkpoint to given ``filename``. - - Warning: - `callback` parameter is not supported yet. - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - """ - if extra_ckpt is None: - extra_ckpt = dict() - if 'meta' not in extra_ckpt: - extra_ckpt['meta'] = dict() - extra_ckpt['meta'].update( - seed=self.seed, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine=mmengine.__version__ + get_git_hash(), - ) - - if save_param_scheduler and hasattr(self, 'param_schedulers'): - extra_ckpt['param_schedulers'] = self.scheduler_state_dict() - - if (not save_optimizer - and self.model.zero_optimization_partition_weights() - and not self.model.zero_gather_16bit_weights_on_model_save()): - print_log( - 'Configured to `save_optimizer=False`, but currently using ' - "DeepSpeed's ZeRO stage 3 with " - '`gather_16bit_weights_on_model_save=False`. In ' - 'this configuration, the model cannot be saved properly ' - 'and will be saved with the optimizer state. ' - 'To support `save_optimizer=False`, please set ' - '`gather_16bit_weights_on_model_save=True` in your ' - 'DeepSpeed config.', - logger='current', - level=logging.WARNING) - save_optimizer = True - - state_dict_kwargs = {} - if digit_version(deepspeed.__version__) >= digit_version('0.13.2'): - state_dict_kwargs[ - 'exclude_frozen_parameters'] = self.exclude_frozen_parameters - - if save_optimizer: - if hasattr(self, 'optim_wrapper'): - # The key can not be 'optimizer', otherwise error will be - # thrown when loading or resuming checkpoint. - extra_ckpt['optim_wrapper'] = self.optim_state_dict() - - dirname, basename = osp.split(filename) - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False, - **state_dict_kwargs) - else: - if self.model.zero_optimization_partition_weights(): - state_dict = self.model._zero3_consolidated_16bit_state_dict( - **state_dict_kwargs) - else: - state_dict = self.model.module_state_dict(**state_dict_kwargs) - - if is_main_process(): - ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt} - save_checkpoint(ckpt, filename) diff --git a/mmengine/_strategy/distributed.py b/mmengine/_strategy/distributed.py deleted file mode 100644 index dbe17d5aeb..0000000000 --- a/mmengine/_strategy/distributed.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -from typing import Callable, Optional - -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel - -from mmengine.device import get_device -from mmengine.dist import init_dist, is_distributed, master_only -from mmengine.model import convert_sync_batchnorm, is_model_wrapper -from mmengine.registry import MODEL_WRAPPERS, STRATEGIES -from .single_device import SingleDeviceStrategy - - -@STRATEGIES.register_module() -class DDPStrategy(SingleDeviceStrategy): - """Distribution strategy for distributed data parallel training. - - Args: - model_wrapper (dict): Dict for model wrapper. Defaults to None. - sync_bn (str): Type of sync batch norm. Defaults to None. - Options are 'torch' and 'mmcv'. - **kwargs: Other arguments for :class:`BaseStrategy`. - """ - - def __init__( - self, - *, - model_wrapper: Optional[dict] = None, - sync_bn: Optional[str] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.model_wrapper = model_wrapper - self.sync_bn = sync_bn - - def _setup_distributed( # type: ignore - self, - launcher: str = 'pytorch', - backend: str = 'nccl', - **kwargs, - ): - """Setup distributed environment. - - Args: - launcher (str): Way to launcher multi processes. Supported - launchers are 'pytorch', 'mpi' and 'slurm'. - backend (str): Communication Backends. Supported backends are - 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: Other arguments for :func:`init_dist`. - """ - if not is_distributed(): - init_dist(launcher, backend, **kwargs) - - def convert_model(self, model: nn.Module) -> nn.Module: - """Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` - (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. - - Args: - model (nn.Module): Model to be converted. - - Returns: - nn.Module: Converted model. - """ - if self.sync_bn is not None: - try: - model = convert_sync_batchnorm(model, self.sync_bn) - except ValueError as e: - self.logger.error('cfg.sync_bn should be "torch" or ' - f'"mmcv", but got {self.sync_bn}') - raise e - - return model - - def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: - """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom - distributed data-parallel module wrappers. - - Args: - model (nn.Module): Model to be wrapped. - - Returns: - nn.Module or DistributedDataParallel: nn.Module or subclass of - ``DistributedDataParallel``. - """ - if is_model_wrapper(model): - return model - - model = model.to(get_device()) - - model = self.convert_model(model) - - if self.model_wrapper is None: - # set broadcast_buffers as False to keep compatibility with - # OpenMMLab repos - self.model_wrapper = dict( - type='MMDistributedDataParallel', broadcast_buffers=False) - - default_args = dict( - type='MMDistributedDataParallel', - module=model, - device_ids=[int(os.environ['LOCAL_RANK'])]) - model = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) - return model - - @master_only - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: Optional[dict] = None, - callback: Optional[Callable] = None, - ) -> None: - super().save_checkpoint( - filename=filename, - save_optimizer=save_optimizer, - save_param_scheduler=save_param_scheduler, - extra_ckpt=extra_ckpt, - callback=callback) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py deleted file mode 100644 index 0788fafdab..0000000000 --- a/mmengine/_strategy/fsdp.py +++ /dev/null @@ -1,643 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import inspect -import os -import os.path as osp -import time -from collections import OrderedDict -from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Union - -import torch.nn as nn -from torch.distributed.fsdp import (FullStateDictConfig, - FullyShardedDataParallel, - LocalStateDictConfig, StateDictType) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullOptimStateDictConfig, LocalOptimStateDictConfig, OptimStateDictConfig, - StateDictConfig) -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler - -import mmengine -from mmengine.config import Config, ConfigDict -from mmengine.device import get_device -from mmengine.dist import get_rank, is_main_process -from mmengine.model import BaseDataPreprocessor, is_model_wrapper -from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper, - OptimWrapperDict, _ParamScheduler, - build_optim_wrapper) -from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS, - PARAM_SCHEDULERS, STRATEGIES, Registry) -from mmengine.utils import get_git_hash, mkdir_or_exist -from .distributed import DDPStrategy -from .utils import MetaTensorContext - -FSDP = FullyShardedDataParallel -FSDP_CONFIGS = Registry('fsdp configs') -FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig) -FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig) -FSDP_CONFIGS.register_module(module=FullStateDictConfig) -FSDP_CONFIGS.register_module(module=LocalStateDictConfig) - - -@STRATEGIES.register_module() -class FSDPStrategy(DDPStrategy): - """Support training model with FullyShardedDataParallel (FSDP). - - Keyword Args: - model_wrapper (dict, optional): Config dict for model wrapper. The - default configuration is: - - Examples: - >>> model_wrapper = dict( - >>> type='MMFullyShardedDataParallel', - >>> use_orig_params=True, - >>> ) - - See more configurable arguments in - :class:`MMFullyShardedDataParallel`. Defaults to None - skip_init_weights (bool, optional): Whether to skip initialization of - weights. Defaults to False. This is useful when the parameters of - the large model are loaded from a checkpoint, since skipping the - initialization of weights can save a lot of time. - state_dict_cfg (str or dict): Configuration for - how to save and load the state dict of the model, optimizer, and - scheduler. - - - "local": save and load the sharded state dict in all ranks. - - "full": save and load the full state dict in rank 0. - - `dict` object: save and load the state dict more flexibly. For - example, you can first offload the state dict to the 'cpu' and - then save it to the disk. This can help you to load the - checkpoint in a non-gpu environment: - - Examples: - >>> state_dict_cfg=dict( - >>> state_dict_type='FULL_STATE_DICT', - >>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True), - >>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True), - - See more configurable arguments for ``state_dict_cfg``, - ``state_dict_config``, and ``optim_state_dict_config``in - `FSDP official api documents`_ - kwargs (dict): Additional arguments passed to :class:`DDPStrategy`: - - - work_dir (str): The working directory to save checkpoints. - The logs will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dirs'. - - experiment_name (str, optional): Name of current experiment. If - not specified, timestamp will be used as :attr:`experiment_name`. - Defaults to None. - - env_kwargs (dict, optional): Environment config passed in - :meth:`setup_env`. Defaults to None. - - log_kwargs (dict, optional): Logger config passed in - :meth:`build_logger`. Defaults to None. - activation_checkpointing (dict, optional): Config dict for gradient - checkpoint. - - Examples: - >>> activation_checkpointing = dict(check_fn='CustomCheckFn') - >>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1)) - - - ``check_fn`` field should behave consistently with - ``auto_wrap_policy`` defined in `model_wrapper`, and other - fields will be passed to ``apply_activation_checkpointing`` - - `New in version 0.9.0.` - - .. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type - """ # noqa: E501 - - def __init__(self, - *, - model_wrapper: Optional[dict] = None, - skip_init_weights=False, - state_dict_cfg: Union[str, dict] = 'local', - activation_checkpointing: Optional[dict] = None, - **kwargs): - super().__init__(model_wrapper=model_wrapper, **kwargs) - self._init_state_dict_cfg(state_dict_cfg) - if not isinstance(skip_init_weights, bool): - raise TypeError('skip_init_weights must be a boolean, but got ' - f'{type(skip_init_weights)}') - self.skip_init_weights = skip_init_weights - self.activation_checkpointing = activation_checkpointing - - def _wrap_model(self, model: nn.Module) -> None: - """Wrap the model to :obj:``MMFullyShardedDataParallel`` or other - custom fully sharded data parallel module wrappers. - - Args: - model (nn.Module): Model to be wrapped. - - Returns: - FullyShardedDataParallel: ``MMFullyShardedDataParallel`` - or subclass of ``FullyShardedDataParallel``. - """ - try: - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ - apply_activation_checkpointing # noqa: E501 - except ImportError: - apply_activation_checkpointing = None - - for module in model.modules(): - if isinstance(module, BaseDataPreprocessor): - module.to(get_device()) - - if is_model_wrapper(model): - return - - if self.model_wrapper is None: - self.model_wrapper = dict(type='MMFullyShardedDataParallel') - - default_args = dict( - module=model, - device_id=int(os.environ['LOCAL_RANK']), - type='MMFullyShardedDataParallel') - model = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) - model.set_state_dict_type(model, self.state_dict_type, - self.state_dict_config, - self.optim_state_dict_config) - - if self.activation_checkpointing is not None: - if apply_activation_checkpointing is None: - raise RuntimeError( - 'activation_checkpointing maybe deprecated by current ' - 'PyTorch version, maybe you could switch to PyTorch 2.0 ' - 'or 2.1 to use `activation_checkpointing`.') - cfg = copy.deepcopy(self.activation_checkpointing) - with FUNCTIONS.switch_scope_and_registry(None): - check_fn = cfg.pop('check_fn') - if isinstance(check_fn, str): - check_fn = FUNCTIONS.get(check_fn) - elif isinstance(check_fn, dict): - fn_type = check_fn.pop('type') - if isinstance(fn_type, str): - fn_type = FUNCTIONS.get(fn_type) - check_fn = partial(fn_type, **cfg) - - if not callable(check_fn): - raise TypeError('`check_fn` must be a callable function') - apply_activation_checkpointing(model, check_fn=check_fn, **cfg) - return model - - def _is_full_state_dict(self): - """Whether to save and load the full state_dict in rank 0.""" - return self.state_dict_type == StateDictType.FULL_STATE_DICT - - def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: - """Build model. - - If skip_init_weights is True, the model will be built with an empty - weights. It means that :meth:`load_checkpoint` must be called to fill - the weights before training. - - Args: - model (nn.Module or dict): A ``nn.Module`` object or a dict to - build ``nn.Module`` object. If ``model`` is a ``nn.Module`` - object, just returns itself. - - Returns: - nn.Module: Model build from ``model``. - """ - if self.skip_init_weights: - if isinstance(model, dict): - # Accelerate initialization by skipping init weights - with MetaTensorContext(): - model = super().build_model(model) - model.to_empty(device='cpu') - else: - model = super().build_model(model) - - # `id_to_name` will be used to convert the `optim_state_dict` of the - # raw optimizer to the `optim_state_dict` - # returned by `FSDP.optim_state_dict` in - # `StateDictType.FULL_STATE_DICT` mode. - self.id_to_name = dict() - for name, param in model.named_parameters(): - self.id_to_name[id(param)] = name - return model - - def save_checkpoint(self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: Optional[dict] = None, - callback: Optional[Callable] = None) -> None: - """Save checkpoint to given ``filename``. - - If ``state_dict_type`` is `full`, the checkpoint will only be saved in - rank0. The structure of the saved checkpoint is the same as the one - saved by ``DDPStrategy`` - - If ``state_dict_type`` is `local`, each rank will save the sharded - state dict to a directory, which means the saved structure will look - like this: - - .. code-block:: bash - - ── epoch_0.pth - ├── rank0.pth - ├── rank1.pth - ├── ... - └── rank8.pth - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - from mmengine.runner.checkpoint import save_checkpoint - - state_dict: dict = dict() - state_dict['state_dict'] = self.model_state_dict() - - # save optimizer state dict - if save_optimizer and hasattr(self, 'optim_wrapper'): - state_dict['optimizer'] = self.optim_state_dict() - - # save param scheduler state dict - if save_param_scheduler and hasattr(self, 'param_schedulers'): - state_dict['param_schedulers'] = self.scheduler_state_dict() - - # save extra checkpoint passed by users - if extra_ckpt is None: - extra_ckpt = dict() - if 'meta' not in extra_ckpt: - extra_ckpt['meta'] = dict() - - extra_ckpt['meta'].update( - seed=self.seed, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine=mmengine.__version__ + get_git_hash(), - ) - state_dict.update(extra_ckpt) - - # users can do some modification before saving checkpoint - if callback is not None: - callback(state_dict) - - # In non-FULL_STATE_DICT model, FSDPStrategy will save checkpoint - # of different ranks in different files. - if not self._is_full_state_dict(): - rank = get_rank() - mkdir_or_exist(filename) - ckpt_name = f'rank{rank}.pth' - filename = osp.join(filename, ckpt_name) - save_checkpoint(state_dict, filename) - - if is_main_process(): - save_checkpoint(state_dict, filename) - - def model_state_dict(self) -> dict: - """Get model state dict based on the ``state_dict_type``. - - If ``state_dict_type`` is `full`, the model state dict will be the - same as the one of original unsharded model. - - If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True`` - in ``model_wrapper``. The key of the state dict will be the same as - the one of original unsharded model, but its value will be the sharded - one - - If ``state_dict_type`` is `local`, and ```use_orig_params``` is - ``False`` in ``model_wrapper``, the flatten and sharded state dict will - be returned. - - See more details in the `official api documents`_ - - .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict - """ # noqa: E501 - # We've set state_dict by `FSDP.set_state_dict_type`, therefore we - # should get model state dict by `FSDP.state_dict` - return self.model.state_dict() - - def optim_state_dict(self) -> dict: - """Get model state dict based on the ``state_dict_type``. - - If ``state_dict_type`` is ``full``, the optimizer state dict can be - loaded by the original unsharded optimizer. - - Otherwise, the optimizer state dict could only be loaded by the - optimizer with sharded parameters. - - Note: - The optimizer state dict is not the same as the one of original - optimizer even if in ``full`` mode, although they can be loaded - correctly. - - See more details in the `official api documents`_ - - .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict - """ # noqa: E501 - return FSDP.optim_state_dict(self.model, self.optim_wrapper) - - def load_checkpoint(self, filename: str, **kwargs) -> dict: - """Load checkpoint from given ``filename``. - - Note: - If ``state_dict_type`` is `local`, the filename should be a - directory contains ``rank{i}.pth``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - callback (callable, callable): Callback function to modify the - checkpoint after loading the checkpoint. - Defaults to None. - """ - if self._is_full_state_dict(): - return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) - else: - rank = get_rank() - filename = osp.join(filename, f'rank{rank}.pth') - return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) - - def load_model_state_dict( - self, - state_dict: dict, - *, - strict: bool = False, - revise_keys: list = [(r'^module.', '')], - ) -> None: # type: ignore - """Load model state from dict. - - Warning: - `revise_keys` is not supported yet. - - Args: - state_dict (dict): Model state dict returned by - :meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type`` - is ``full``. ``state_dict`` could be the result of - ``model.state_dict()`` - strict (bool): Whether to load model state dict strictly. - Defaults to False. - """ - # We should load state dict by `FSDP.load_state_dict` - self.model.load_state_dict(state_dict, strict=strict) - - def load_optim_state_dict(self, state_dict: dict) -> None: - """Load optimizer state from dict. - - Args: - state_dict (dict): The optimizer state dict. If ``state_dict_type`` - is ``full``. ``state_dict`` could be the result of - ``optimizer.state_dict()`` - """ - optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) - self.optim_wrapper.load_state_dict(optim_state_dict) - - def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: - """Make ``state_dict_type`` and ``state_dict_config`` can be configured - with string.""" - if isinstance(state_dict_cfg, str): - if state_dict_cfg == 'full': - self.state_dict_type = StateDictType.FULL_STATE_DICT - self.state_dict_config = FullStateDictConfig( - rank0_only=True, offload_to_cpu=True) - self.optim_state_dict_config = FullOptimStateDictConfig( - rank0_only=True, offload_to_cpu=True) - elif state_dict_cfg == 'local': - self.state_dict_type = StateDictType.LOCAL_STATE_DICT - self.state_dict_config = LocalStateDictConfig() - self.optim_state_dict_config = LocalOptimStateDictConfig() - else: - raise ValueError('FSDP only supports `full` and `local` ' - f'state_dict_type, but got {state_dict_cfg}') - elif isinstance(state_dict_cfg, dict): - if 'state_dict_type' not in state_dict_cfg: - self.state_dict_type = StateDictType.LOCAL_STATE_DICT - else: - state_dict_type = state_dict_cfg['state_dict_type'] - if isinstance(state_dict_type, str): - self.state_dict_type = StateDictType[ - state_dict_cfg['state_dict_type']] - else: - self.state_dict_type = state_dict_type - state_dict_config = state_dict_cfg.get('state_dict_config') - if state_dict_config is None: - self.state_dict_config = LocalStateDictConfig() - elif isinstance(state_dict_config, dict): - self.state_dict_config = FSDP_CONFIGS.build( - state_dict_cfg['state_dict_config']) - else: - self.state_dict_config = state_dict_config - - optim_state_dict_config = state_dict_cfg.get( - 'optim_state_dict_config') - if optim_state_dict_config is None: - self.optim_state_dict_config = LocalOptimStateDictConfig() - elif isinstance(optim_state_dict_config, dict): - self.optim_state_dict_config = FSDP_CONFIGS.build( - state_dict_cfg['optim_state_dict_config']) - else: - self.optim_state_dict_config = optim_state_dict_config - else: - raise TypeError('state_dict_cfg should be a `str` or a `dict`, ' - f'but got {type(state_dict_cfg)}') - - if not isinstance(self.state_dict_type, StateDictType): - raise TypeError('state_dict_type must be StateDictType, but got ' - f'{type(self.state_dict_type)}') - if not isinstance(self.state_dict_config, StateDictConfig): - raise TypeError('state_dict_config must be StateDictConfig, but ' - f'got {type(self.state_dict_config)}') - if not isinstance(self.optim_state_dict_config, OptimStateDictConfig): - raise TypeError('optim_state_dict_config must be ' - 'OptimStateDictConfig, but got ' - f'{type(self.optim_state_dict_config)}') - - def build_optim_wrapper( - self, - optim_wrapper: Union[Optimizer, OptimWrapper, dict], - model: Optional[nn.Module] = None, - ) -> BaseOptimWrapper: - """Support sharding the optimizer state dict given a built optimizer or - optim_wrapper. - - See specific usage in :meth:`BaseStrategy.build_optim_wrapper`. - """ - if isinstance(optim_wrapper, Optimizer): - optim_wrapper = OptimWrapper(optim_wrapper) - if isinstance(optim_wrapper, BaseOptimWrapper): - assert model is not None - # NOTE: The only difference is that FSDPStrategy will shard - # the the built OptimWrapper - optimizer = optim_wrapper.optimizer - param_groups = optimizer.param_groups - optim_state_dict = optimizer.state_dict() - assert not optim_state_dict['state'], ( - 'Optimizer state_dict should be empty when giving an built ' - 'optim_wrapper to FSDPStrategy') - # Align the state_dict with state_dict generated by - # FSDP.full_optim_state_dict - new_param_groups = [] - for group in param_groups: - new_group = { - key: value - for key, value in group.items() if key != 'param' - } - new_group['params'] = [ - self.id_to_name[id(param)] for param in group['params'] - ] - new_param_groups.append(new_group) - optim_state_dict['param_groups'] = new_param_groups - defaults = { - k: v - for k, v in optimizer.defaults.items() if k != 'differentiable' - } - - params_dict = {} - for k, v in model.named_parameters(): - if '_fsdp_wrapped_module' in k: - k = k.replace('_fsdp_wrapped_module.', '') - params_dict[k] = v - - params = [] - for param_group in new_param_groups: - _params = [] - for param_name in param_group['params']: - if param_name not in params_dict: - raise RuntimeError( - 'Failed to reconstruct the sharded optimizer. ' - 'You can try to set `use_orig_params=True` in ' - '`model_wrapper`') - _params.append(params_dict[param_name]) - param_group = { - k: v - for k, v in param_group.items() if k != 'param' - } - param_group['params'] = _params - params.append(param_group) - - new_optimizer = optimizer.__class__(params, **defaults) - - # Force to load the converted optim_state_dict in full mode. - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): - optim_state_dict = FSDP.optim_state_dict_to_load( - optim_state_dict, model, new_optimizer) - new_optimizer.load_state_dict(optim_state_dict) - optim_wrapper.optimizer = new_optimizer - return optim_wrapper - if isinstance(optim_wrapper, (dict, ConfigDict, Config)): - assert model is not None - # optimizer must be defined for single optimizer training. - optimizer = optim_wrapper.get('optimizer', None) - optim_wrapper.setdefault('type', 'OptimWrapper') - if optim_wrapper.get('type', - 'AmpOptimWrapper') in ('AmpOptimWrapper', - AmpOptimWrapper): - optim_wrapper.setdefault('use_fsdp', True) - - # If optimizer is a built `Optimizer` instance, the optimizer - # wrapper should be built by `OPTIM_WRAPPERS` registry. - if isinstance(optimizer, Optimizer): - return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore - - # If `optimizer` is not None or `constructor` is defined, it means, - # optimizer wrapper will be built by optimizer wrapper - # constructor. Therefore, `build_optim_wrapper` should be called. - if optimizer is not None or 'constructor' in optim_wrapper: - return build_optim_wrapper(model, optim_wrapper) - else: - # if `optimizer` is not defined, it should be the case of - # training with multiple optimizers. If `constructor` is not - # defined either, each value of `optim_wrapper` must be an - # `OptimWrapper` instance since `DefaultOptimizerConstructor` - # will not handle the case of training with multiple - # optimizers. `build_optim_wrapper` will directly build the - # `OptimWrapperDict` instance from `optim_wrapper.` - optim_wrappers = OrderedDict() - for name, optim in optim_wrapper.items(): - if not isinstance(optim, OptimWrapper): - raise ValueError( - 'each item mush be an optimizer object when ' - '"type" and "constructor" are not in ' - f'optimizer, but got {name}={optim}') - optim_wrappers[name] = optim - return OptimWrapperDict(**optim_wrappers) - else: - raise TypeError('optimizer wrapper should be an OptimWrapper ' - f'object or dict, but got {optim_wrapper}') - - def _build_param_scheduler( - self, - scheduler: Union[_ParamScheduler, Dict, List], - optim_wrapper: BaseOptimWrapper, - default_args: dict, - ) -> List[_ParamScheduler]: - """Override this method to update the scheduler with the reconstructed - sharded optimzer.""" - if not isinstance(scheduler, Sequence): - schedulers = [scheduler] - else: - schedulers = scheduler - - max_epochs = default_args.pop('max_epochs', None) - max_iters = default_args.pop('max_iters', None) - - param_schedulers = [] - for scheduler in schedulers: - # Update the built scheduler with the sharded optimizer - if isinstance(scheduler, (_ParamScheduler, LRScheduler)): - parameter_keys = inspect.signature( - scheduler.__class__).parameters.keys() - kwargs = { - k: v - for k, v in scheduler.state_dict().items() - if k in parameter_keys - } - scheduler = scheduler.__class__(optim_wrapper, **kwargs) - elif isinstance(scheduler, dict): - _scheduler = copy.deepcopy(scheduler) - - # Set default end - if _scheduler.get('by_epoch', True): - if max_epochs is None: - raise ValueError( - 'max_epochs must be specified in default_args') - default_end = max_epochs - else: - if max_iters is None: - raise ValueError( - 'max_iters must be specified in default_args') - default_end = max_iters - _scheduler.setdefault('end', default_end) - self.logger.debug( - f'The `end` of {_scheduler["type"]} is not set. ' - 'Use the max epochs/iters of train loop as default.') - - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, **default_args))) - else: - raise TypeError( - 'scheduler should be a _ParamScheduler object or dict, ' - f'but got {scheduler}') - return param_schedulers diff --git a/mmengine/_strategy/single_device.py b/mmengine/_strategy/single_device.py deleted file mode 100644 index c7d8accd5a..0000000000 --- a/mmengine/_strategy/single_device.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import time -from typing import Callable, Dict, List, Optional, Union - -import torch.nn as nn - -import mmengine -from mmengine.device import get_device -from mmengine.model import revert_sync_batchnorm -from mmengine.optim import BaseOptimWrapper, _ParamScheduler -from mmengine.registry import STRATEGIES -from mmengine.utils import get_git_hash -from .base import BaseStrategy - - -@STRATEGIES.register_module() -class SingleDeviceStrategy(BaseStrategy): - """Strategy for single device training.""" - - def prepare( - self, - model: Union[nn.Module, dict], - *, - optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, - param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, - compile: Union[dict, bool] = False, - dispatch_kwargs: Optional[dict] = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for build a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - If ``accumulative_counts`` is set in ``optim_wrapper``, you - need to provide ``max_iters`` in ``dispatch_kwargs``. - """ - if self._prepared: - return self._prepared_components() - if dispatch_kwargs is not None: - self.dispatch_kwargs.update(dispatch_kwargs) - - model = self.build_model(model) - model = self._init_model_weights(model) - model = self._wrap_model(model) - model = self.compile_model(model, compile=compile) - - self.model = model - - if optim_wrapper is not None: - self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - self._scale_lr() - - accumulative_counts = getattr(self.optim_wrapper, - '_accumulative_counts', 1) - if accumulative_counts > 1: - if 'max_iters' not in self.dispatch_kwargs: - raise ValueError( - '"max_iters" must be specified because ' - '"accumulative_counts" was set as ' - f'{accumulative_counts} which is greater than 1.') - - self.optim_wrapper.initialize_count_status( # type: ignore - self.model, 0, self.dispatch_kwargs['max_iters']) - - if param_scheduler is not None: - self.param_schedulers = self.build_param_scheduler( - param_scheduler, self.optim_wrapper) - - self._prepared = True - return self._prepared_components() - - def _wrap_model(self, model: nn.Module) -> nn.Module: - model = self.convert_model(model) - current_device = get_device() - return model.to(current_device) - - def convert_model(self, model: nn.Module) -> nn.Module: - """Convert layers of model. - - convert all ``SyncBatchNorm`` (SyncBN) and - ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers in the model to - ``BatchNormXd`` layers. - - Args: - model (nn.Module): Model to convert. - """ - self.logger.info( - 'Distributed training is not used, all SyncBatchNorm (SyncBN) ' - 'layers in the model will be automatically reverted to ' - 'BatchNormXd layers if they are used.') - model = revert_sync_batchnorm(model) - return model - - def load_checkpoint( - self, - filename: str, - *, - map_location: Union[str, Callable] = 'cpu', - strict: bool = False, - revise_keys: list = [(r'^module.', '')], - callback: Optional[Callable] = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - callback (callable, callable): Callback function to modify the - checkpoint after loading the checkpoint. - Defaults to None. - """ - from mmengine.runner.checkpoint import _load_checkpoint - - self.logger.info(f'Load checkpoint from {filename}') - - if map_location == 'default': - device = get_device() - checkpoint = _load_checkpoint(filename, map_location=device) - else: - checkpoint = _load_checkpoint(filename, map_location=map_location) - - # users can do some modification after loading checkpoint - if callback is not None: - callback(checkpoint) - - state_dict = checkpoint.pop('state_dict') - self.load_model_state_dict( - state_dict, strict=strict, revise_keys=revise_keys) - - return checkpoint - - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: Union[str, Callable] = 'default', - callback: Optional[Callable] = None, - ) -> dict: - """Resume training from given ``filename``. - - Four types of states will be resumed. - - - model state - - optimizer state - - scheduler state - - randomness state - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - self.logger.info(f'Resume checkpoint from {filename}') - - checkpoint = self.load_checkpoint( - filename, map_location=map_location, callback=callback) - - if resume_optimizer: - self.load_optim_state_dict(checkpoint.pop('optimizer')) - - if resume_param_scheduler and hasattr(self, 'param_schedulers'): - self.load_scheduler_state_dict(checkpoint.pop('param_schedulers')) - - # resume random seed - resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - - # resume iter - cur_iter = checkpoint['meta']['iter'] - - if hasattr(self, 'optim_wrapper'): - accumulative_counts = getattr(self.optim_wrapper, - '_accumulative_counts', 1) - if accumulative_counts > 1: - if 'max_iters' not in self.dispatch_kwargs: - raise ValueError( - '"max_iters" must be specified because ' - '"accumulative_counts" was set as ' - f'{accumulative_counts} which is greater than 1.') - # Initiate inner count of `optim_wrapper`. - self.optim_wrapper.initialize_count_status( # type: ignore - self.model, cur_iter, self.dispatch_kwargs['max_iters']) - - return checkpoint - - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: Optional[dict] = None, - callback: Optional[Callable] = None, - ) -> None: - """Save checkpoint to given ``filename``. - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - from mmengine.runner.checkpoint import save_checkpoint - - state_dict: dict = dict() - state_dict['state_dict'] = self.model_state_dict() - - # save optimizer state dict - if save_optimizer and hasattr(self, 'optim_wrapper'): - state_dict['optimizer'] = self.optim_state_dict() - - if save_param_scheduler and hasattr(self, 'param_schedulers'): - state_dict['param_schedulers'] = self.scheduler_state_dict() - - # save extra checkpoint passed by users - if extra_ckpt is None: - extra_ckpt = dict() - if 'meta' not in extra_ckpt: - extra_ckpt['meta'] = dict() - extra_ckpt['meta'].update( - seed=self.seed, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine=mmengine.__version__ + get_git_hash(), - ) - - state_dict.update(extra_ckpt) - - # users can do some modification before saving checkpoint - if callback is not None: - callback(state_dict) - - save_checkpoint(state_dict, filename) diff --git a/mmengine/_strategy/utils.py b/mmengine/_strategy/utils.py deleted file mode 100644 index c691bd602b..0000000000 --- a/mmengine/_strategy/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from torch._subclasses.fake_tensor import _is_tensor_constructor -from torch.utils._python_dispatch import TorchDispatchMode - - -class MetaTensorContext(TorchDispatchMode): - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if _is_tensor_constructor(func): - device_idx = [arg.name - for arg in func._schema.arguments].index('device') - if len(args) > device_idx: - args = list(args) - args[device_idx] = 'meta' - else: - kwargs['device'] = 'meta' - return func(*args, **kwargs) diff --git a/mmengine/analysis/__init__.py b/mmengine/analysis/__init__.py deleted file mode 100644 index e51090c387..0000000000 --- a/mmengine/analysis/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer, - activation_count, flop_count, - parameter_count, parameter_count_table) -from .print_helper import get_model_complexity_info - -__all__ = [ - 'FlopAnalyzer', 'ActivationAnalyzer', 'flop_count', 'activation_count', - 'parameter_count', 'parameter_count_table', 'get_model_complexity_info' -] diff --git a/mmengine/analysis/complexity_analysis.py b/mmengine/analysis/complexity_analysis.py deleted file mode 100644 index 435e5fe5d3..0000000000 --- a/mmengine/analysis/complexity_analysis.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import typing -from collections import defaultdict -from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union - -import torch.nn as nn -from rich import box -from rich.console import Console -from rich.table import Table -from torch import Tensor - -from .jit_analysis import JitModelAnalysis -from .jit_handles import (Handle, addmm_flop_jit, batchnorm_flop_jit, - bmm_flop_jit, conv_flop_jit, einsum_flop_jit, - elementwise_flop_counter, generic_activation_jit, - linear_flop_jit, matmul_flop_jit, norm_flop_counter) - -# A dictionary that maps supported operations to their flop count jit handles. -_DEFAULT_SUPPORTED_FLOP_OPS: Dict[str, Handle] = { - 'aten::addmm': addmm_flop_jit, - 'aten::bmm': bmm_flop_jit, - 'aten::_convolution': conv_flop_jit, - 'aten::einsum': einsum_flop_jit, - 'aten::matmul': matmul_flop_jit, - 'aten::mm': matmul_flop_jit, - 'aten::linear': linear_flop_jit, - # You might want to ignore BN flops due to inference-time fusion. - # Use `set_op_handle("aten::batch_norm", None) - 'aten::batch_norm': batchnorm_flop_jit, - 'aten::group_norm': norm_flop_counter(2), - 'aten::layer_norm': norm_flop_counter(2), - 'aten::instance_norm': norm_flop_counter(1), - 'aten::upsample_nearest2d': elementwise_flop_counter(0, 1), - 'aten::upsample_bilinear2d': elementwise_flop_counter(0, 4), - 'aten::adaptive_avg_pool2d': elementwise_flop_counter(1, 0), - 'aten::grid_sampler': elementwise_flop_counter(0, 4), # assume bilinear -} - -# A dictionary that maps supported operations to -# their activation count handles. -_DEFAULT_SUPPORTED_ACT_OPS: Dict[str, Handle] = { - 'aten::_convolution': generic_activation_jit('conv'), - 'aten::addmm': generic_activation_jit(), - 'aten::bmm': generic_activation_jit(), - 'aten::einsum': generic_activation_jit(), - 'aten::matmul': generic_activation_jit(), - 'aten::linear': generic_activation_jit(), -} - - -class FlopAnalyzer(JitModelAnalysis): - """Provides access to per-submodule model flop count obtained by tracing a - model with pytorch's jit tracing functionality. - - By default, comes with standard flop counters for a few common operators. - - Note: - - Flop is not a well-defined concept. We just produce our best - estimate. - - We count one fused multiply-add as one flop. - - Handles for additional operators may be added, or the default ones - overwritten, using the ``.set_op_handle(name, func)`` method. - See the method documentation for details. - Flop counts can be obtained as: - - - ``.total(module_name="")``: total flop count for the module - - ``.by_operator(module_name="")``: flop counts for the module, as a - Counter over different operator types - - ``.by_module()``: Counter of flop counts for all submodules - - ``.by_module_and_operator()``: dictionary indexed by descendant of - Counters over different operator types - - An operator is treated as within a module if it is executed inside the - module's ``__call__`` method. Note that this does not include calls to - other methods of the module or explicit calls to ``module.forward(...)``. - - Modified from - https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py - - Args: - model (nn.Module): The model to analyze. - inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. - - Examples: - >>> import torch.nn as nn - >>> import torch - >>> class TestModel(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.fc = nn.Linear(in_features=1000, out_features=10) - ... self.conv = nn.Conv2d( - ... in_channels=3, out_channels=10, kernel_size=1 - ... ) - ... self.act = nn.ReLU() - ... def forward(self, x): - ... return self.fc(self.act(self.conv(x)).flatten(1)) - >>> model = TestModel() - >>> inputs = (torch.randn((1,3,10,10)),) - >>> flops = FlopAnalyzer(model, inputs) - >>> flops.total() - 13000 - >>> flops.total("fc") - 10000 - >>> flops.by_operator() - Counter({"addmm" : 10000, "conv" : 3000}) - >>> flops.by_module() - Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) - >>> flops.by_module_and_operator() - {"" : Counter({"addmm" : 10000, "conv" : 3000}), - "fc" : Counter({"addmm" : 10000}), - "conv" : Counter({"conv" : 3000}), - "act" : Counter() - } - """ - - def __init__( - self, - model: nn.Module, - inputs: Union[Tensor, Tuple[Tensor, ...]], - ) -> None: - super().__init__(model=model, inputs=inputs) - self.set_op_handle(**_DEFAULT_SUPPORTED_FLOP_OPS) - - __init__.__doc__ = JitModelAnalysis.__init__.__doc__ - - -class ActivationAnalyzer(JitModelAnalysis): - """Provides access to per-submodule model activation count obtained by - tracing a model with pytorch's jit tracing functionality. - - By default, comes with standard activation counters for convolutional and - dot-product operators. Handles for additional operators may be added, or - the default ones overwritten, using the ``.set_op_handle(name, func)`` - method. See the method documentation for details. Activation counts can be - obtained as: - - - ``.total(module_name="")``: total activation count for a module - - ``.by_operator(module_name="")``: activation counts for the module, - as a Counter over different operator types - - ``.by_module()``: Counter of activation counts for all submodules - - ``.by_module_and_operator()``: dictionary indexed by descendant of - Counters over different operator types - - An operator is treated as within a module if it is executed inside the - module's ``__call__`` method. Note that this does not include calls to - other methods of the module or explicit calls to ``module.forward(...)``. - - Modified from - https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py - - Args: - model (nn.Module): The model to analyze. - inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. - - Examples: - >>> import torch.nn as nn - >>> import torch - >>> class TestModel(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.fc = nn.Linear(in_features=1000, out_features=10) - ... self.conv = nn.Conv2d( - ... in_channels=3, out_channels=10, kernel_size=1 - ... ) - ... self.act = nn.ReLU() - ... def forward(self, x): - ... return self.fc(self.act(self.conv(x)).flatten(1)) - >>> model = TestModel() - >>> inputs = (torch.randn((1,3,10,10)),) - >>> acts = ActivationAnalyzer(model, inputs) - >>> acts.total() - 1010 - >>> acts.total("fc") - 10 - >>> acts.by_operator() - Counter({"conv" : 1000, "addmm" : 10}) - >>> acts.by_module() - Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) - >>> acts.by_module_and_operator() - {"" : Counter({"conv" : 1000, "addmm" : 10}), - "fc" : Counter({"addmm" : 10}), - "conv" : Counter({"conv" : 1000}), - "act" : Counter() - } - """ - - def __init__( - self, - model: nn.Module, - inputs: Union[Tensor, Tuple[Tensor, ...]], - ) -> None: - super().__init__(model=model, inputs=inputs) - self.set_op_handle(**_DEFAULT_SUPPORTED_ACT_OPS) - - __init__.__doc__ = JitModelAnalysis.__init__.__doc__ - - -def flop_count( - model: nn.Module, - inputs: Tuple[Any, ...], - supported_ops: Optional[Dict[str, Handle]] = None, -) -> Tuple[DefaultDict[str, float], Counter[str]]: - """Given a model and an input to the model, compute the per-operator Gflops - of the given model. - - Adopted from - https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py - - Args: - model (nn.Module): The model to compute flop counts. - inputs (tuple): Inputs that are passed to `model` to count flops. - Inputs need to be in a tuple. - supported_ops (dict(str,Callable) or None) : provide additional - handlers for extra ops, or overwrite the existing handlers for - convolution and matmul and einsum. The key is operator name and - the value is a function that takes (inputs, outputs) of the op. - We count one Multiply-Add as one FLOP. - - Returns: - tuple[defaultdict, Counter]: A dictionary that records the number of - gflops for each operation and a Counter that records the number of - unsupported operations. - """ - if supported_ops is None: - supported_ops = {} - flop_counter = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops) - giga_flops = defaultdict(float) - for op, flop in flop_counter.by_operator().items(): - giga_flops[op] = flop / 1e9 - return giga_flops, flop_counter.unsupported_ops() - - -def activation_count( - model: nn.Module, - inputs: Tuple[Any, ...], - supported_ops: Optional[Dict[str, Handle]] = None, -) -> Tuple[DefaultDict[str, float], Counter[str]]: - """Given a model and an input to the model, compute the total number of - activations of the model. - - Adopted from - https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py - - Args: - model (nn.Module): The model to compute activation counts. - inputs (tuple): Inputs that are passed to `model` to count activations. - Inputs need to be in a tuple. - supported_ops (dict(str,Callable) or None) : provide additional - handlers for extra ops, or overwrite the existing handlers for - convolution and matmul. The key is operator name and the value - is a function that takes (inputs, outputs) of the op. - - Returns: - tuple[defaultdict, Counter]: A dictionary that records the number of - activation (mega) for each operation and a Counter that records the - number of unsupported operations. - """ - if supported_ops is None: - supported_ops = {} - act_counter = ActivationAnalyzer(model, - inputs).set_op_handle(**supported_ops) - mega_acts = defaultdict(float) - for op, act in act_counter.by_operator().items(): - mega_acts[op] = act / 1e6 - return mega_acts, act_counter.unsupported_ops() - - -def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]: - """Count parameters of a model and its submodules. - - Adopted from - https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py - - Args: - model (nn.Module): the model to count parameters. - - Returns: - dict[str, int]: the key is either a parameter name or a module name. - The value is the number of elements in the parameter, or in all - parameters of the module. The key "" corresponds to the total - number of parameters of the model. - """ - count = defaultdict(int) # type: typing.DefaultDict[str, int] - for name, param in model.named_parameters(): - size = param.numel() - name = name.split('.') - for k in range(0, len(name) + 1): - prefix = '.'.join(name[:k]) - count[prefix] += size - return count - - -def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str: - """Format the parameter count of the model (and its submodules or - parameters) - - Adopted from - https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py - - Args: - model (nn.Module): the model to count parameters. - max_depth (int): maximum depth to recursively print submodules or - parameters - - Returns: - str: the table to be printed - """ - count: typing.DefaultDict[str, int] = parameter_count(model) - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - param_shape: typing.Dict[str, typing.Tuple] = { - k: tuple(v.shape) - for k, v in model.named_parameters() - } - - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - rows: typing.List[typing.Tuple] = [] - - def format_size(x: int) -> str: - if x > 1e8: - return f'{x / 1e9:.1f}G' - if x > 1e5: - return f'{x / 1e6:.1f}M' - if x > 1e2: - return f'{x / 1e3:.1f}K' - return str(x) - - def fill(lvl: int, prefix: str) -> None: - if lvl >= max_depth: - return - for name, v in count.items(): - if name.count('.') == lvl and name.startswith(prefix): - indent = ' ' * (lvl + 1) - if name in param_shape: - rows.append( - (indent + name, indent + str(param_shape[name]))) - else: - rows.append((indent + name, indent + format_size(v))) - fill(lvl + 1, name + '.') - - rows.append(('model', format_size(count.pop('')))) - fill(0, '') - - table = Table( - title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) - table.add_column('name') - table.add_column('#elements or shape') - - for row in rows: - table.add_row(*row) - - console = Console() - with console.capture() as capture: - console.print(table, end='') - - return capture.get() diff --git a/mmengine/analysis/jit_analysis.py b/mmengine/analysis/jit_analysis.py deleted file mode 100644 index 17b294863a..0000000000 --- a/mmengine/analysis/jit_analysis.py +++ /dev/null @@ -1,684 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# Modified from -# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_analysis.py - -import logging -import typing -import warnings -from collections import Counter -from copy import copy -from dataclasses import dataclass -from numbers import Number -from typing import (Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, - TypeVar, Union) - -import numpy as np -import torch -import torch.nn as nn -from torch import Tensor -from torch.jit import TracerWarning, _get_trace_graph - -from mmengine.logging import print_log -from .jit_handles import Handle - -T = TypeVar('T', bound='JitModelAnalysis') - -# Only ignore ops that are technically truly 0 flops: -# shape-manipulation ops, integer ops, memory copy ops -_IGNORED_OPS: Set[str] = { - 'aten::Int', - 'aten::ScalarImplicit', - 'aten::__and__', - 'aten::arange', - 'aten::bitwise_not', - 'aten::cat', - 'aten::chunk', - 'aten::clamp', - 'aten::clamp_', - 'aten::constant_pad_nd', - 'aten::contiguous', - 'aten::copy_', - 'aten::detach', - 'aten::dropout', - 'aten::empty', - 'aten::eq', - 'aten::expand', - 'aten::flatten', - 'aten::floor', - 'aten::floor_divide', - 'aten::full', - 'aten::full_like', - 'aten::gather', - 'aten::ge', - 'aten::gt', - 'aten::index', - 'aten::index_put_', - 'aten::masked_fill', - 'aten::max', - 'aten::narrow', - 'aten::new_empty', - 'aten::new_full', - 'aten::new_zeros', - 'aten::nonzero', - 'aten::ones', - 'aten::permute', - 'aten::relu', - 'aten::relu_', - 'aten::remainder', - 'aten::reshape', - 'aten::roll', - 'aten::select', - 'aten::size', - 'aten::slice', - 'aten::split', - 'aten::split_with_sizes', - 'aten::squeeze', - 'aten::stack', - 'aten::t', - 'aten::to', - 'aten::transpose', - 'aten::type_as', - 'aten::unbind', - 'aten::unsqueeze', - 'aten::unsqueeze_', - 'aten::view', - 'aten::zeros', - 'aten::zeros_like', -} - - -@dataclass -class Statistics: - """For keeping track of the various model statistics recorded during - analysis.""" - - counts: Dict[str, typing.Counter[str]] - unsupported_ops: Dict[str, typing.Counter[str]] - uncalled_mods: Set[str] - - -def _named_modules_with_dup(model: nn.Module, - prefix: str = '' - ) -> Iterable[Tuple[str, nn.Module]]: - """The same as `model.named_modules()`, except that it includes duplicated - modules that have more than one name.""" - yield prefix, model - for name, module in model._modules.items(): - if module is None: - continue - submodule_prefix = prefix + ('.' if prefix else '') + name - yield from _named_modules_with_dup(module, submodule_prefix) - - -def _named_modules_without_dup( - model: nn.Module) -> Iterator[Tuple[str, nn.Module]]: - """Like .named_modules(), but the results are slightly different for some - wrapped models.""" - seen = set() - for name, mod in _named_modules_with_dup(model): - if mod not in seen: - seen.add(mod) - yield name, mod - - -def _get_scoped_trace_graph( - module: nn.Module, - inputs: Union[Tensor, Tuple[Tensor, ...]], - aliases: Dict[Union[str, nn.Module], str], -) -> torch._C.Graph: - """Traces the provided module using torch.jit._get_trace_graph, but adds - submodule scope information to each graph node. - - The resulting graph is in-lined and has all model parameters treated as - inputs. The input model has the scope name '', while its descendants - have names of the form 'child.grandchild.grandgrandchild...'. - - Args: - model (nn.Module): The module to trace - inputs (tuple): Inputs used during the trace of the model - aliases (dict[str or nn.Module, str]): maps modules and module - names to the canonical name to be used as the scope for - that module. - - Returns: - graph (torch._C.Graph): The pytorch JIT trace of the model - """ - - # torch.jit._get_trace_graph can trace torch function like `aten::linear`, - # `aten::add` etc. However, the traced node(function) cannot tell it is - # called by which module. `ScopePushHook` and `ScopePopHook` can - # help traced node get the module name information by `node.scopeName()`. - class ScopePushHook: - - def __init__(self, name: str) -> None: - self.name = name - - def __call__(self, module: nn.Module, inputs: Any) -> Any: - tracing_state = torch._C._get_tracing_state() - if tracing_state: - tracing_state.push_scope(self.name) - return inputs - - class ScopePopHook: - - def __call__(self, module: nn.Module, inputs: Any, - outputs: Any) -> Any: - tracing_state = torch._C._get_tracing_state() - if tracing_state: - tracing_state.pop_scope() - return outputs - - hook_handles: List[Any] = [] - - def register_hooks(mod: nn.Module, name: str) -> None: - prehook = mod.register_forward_pre_hook(ScopePushHook(name)) - posthook = mod.register_forward_hook(ScopePopHook()) - hook_handles.append(prehook) - hook_handles.append(posthook) - - # Unwrap DDP, but correct the scope names for the root module. - module_list = (nn.parallel.distributed.DistributedDataParallel, - nn.DataParallel) - # Since DataParallel just wraps the model, add an extra set of hooks - # to the model it wraps to account for the wrapper. Then trace it. - if isinstance(module, module_list): - root_name = aliases[module] - module = module.module - register_hooks(module, root_name) - - for name, mod in _named_modules_without_dup(module): - name = aliases[mod] - register_hooks(mod, name) - - graph, _ = _get_trace_graph(module, inputs) - - for handle in hook_handles: - handle.remove() - - return graph - - -class JitModelAnalysis: - """Provides access to per-submodule model statistics obtained by tracing a - model with pytorch's jit tracing functionality. - - Calculates a statistic on a per-operator basis using the provided set of - functions that acts on the inputs and outputs to the operator, then - aggregates this over modules in the model. Can return the aggregate - statistic for any submodule in the model. Is lazily evaluated, and will - perform the trace when a statistic is first requested. Changing the - operator handles will cause the trace to be rerun on the next request. - - Submodules may be referred to using the module's name. The input model has - name "", while its descendants have names of the form - "child.grandchild.grandgrandchild...". - - An operator is treated as within the scope of a module if calling that - module directly resulted in that operator being run. In particular, this - means that calls to other functions owned by a module or explicit - calls to module.forward(...) will not register resulting operators as - contributing statistics to that module. - - We will trace the execution of `model.forward(inputs)`. This means - inputs have to be tensors or tuple of tensors (see - https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace). - In order to trace other methods or unsupported input types, - you may need to implement a wrapper module. - - Args: - model: The model to analyze - inputs: The inputs to the model for analysis. - """ - - def __init__( - self, - model: nn.Module, - inputs: Union[Tensor, Tuple[Tensor, ...]], - ) -> None: - self._model = model - self._inputs = inputs - self._op_handles: Dict[str, Handle] = {} - # Mapping from names to submodules - self._named_modules: Dict[str, nn.Module] = dict( - _named_modules_with_dup(model)) - # Mapping from submodules and their aliases to the canonical name - # of each submodule - self._aliases: Dict[Union[nn.Module, str], - str] = self._get_aliases(model) - self._stats: Optional[Statistics] = None - - self._ignored_ops: Set[str] = copy(_IGNORED_OPS) - self.unsupported_ops_warnings(True) - self.uncalled_modules_warnings(True) - self.tracer_warnings('no_tracer_warning') - self.ancestor_mode('owner') - - def total(self, module_name: str = '') -> int: - """Returns the total aggregated statistic across all operators for the - requested module. - - Args: - module_name (str): The submodule to get data for. Defaults to - the entire model. - - Returns: - int: The aggregated statistic. - """ - stats = self._analyze() - module_name = self.canonical_module_name(module_name) - total_count = sum(stats.counts[module_name].values()) - return total_count - - def by_operator(self, module_name: str = '') -> typing.Counter[str]: - """Returns the statistics for a requested module, grouped by operator - type. - - The operator handle determines the name associated with each - operator type. - - Args: - module_name (str): The submodule to get data for. Defaults - to the entire model. - - Returns: - Counter(str): The statistics for each operator. - """ - stats = self._analyze() - module_name = self.canonical_module_name(module_name) - return stats.counts[module_name] - - def by_module_and_operator(self) -> Dict[str, typing.Counter[str]]: - """Returns the statistics for all submodules, separated out by operator - type for each submodule. - - The operator handle determines the name associated with - each operator type. - - Returns: - dict[str, Counter(str)]: The statistics for each submodule - and each operator. Grouped by submodule names, then - by operator name. - """ - stats = self._analyze() - return stats.counts - - def by_module(self) -> typing.Counter[str]: - """Returns the statistics for all submodules, aggregated over all - operators. - - Returns: - Counter(str): statistics counter grouped by submodule names - """ - stats = self._analyze() - summed_counts = Counter() # type: Counter - for mod, results in stats.counts.items(): - summed_counts[mod] = sum(results.values()) - return summed_counts - - def unsupported_ops(self, module_name: str = '') -> typing.Counter[str]: - """Lists the number of operators that were encountered but unsupported - because no operator handle is available for them. - - Does not include operators that are explicitly ignored. - - Args: - module_name (str): The submodule to list unsupported ops. - Defaults to the entire model. - - Returns: - Counter(str): The number of occurrences each unsupported operator. - """ - if self._stats is None: - raise RuntimeError('Analysis results should be computed ' - 'before calling unsupported_ops()') - module_name = self.canonical_module_name(module_name) - return self._stats.unsupported_ops[module_name] # pyre-fixme - - def uncalled_modules(self) -> Set[str]: - """Returns a set of submodules that were never called during the trace - of the graph. - - This may be because they were unused, or because they were - accessed via direct calls .forward() or with other python methods. - In the latter case, statistics will not be attributed to the submodule, - though the statistics will be included - in the parent module. - - Returns: - set[str]: The set of submodule names that were never called - during the trace of the model. - """ - stats = self._analyze() - return stats.uncalled_mods - - def set_op_handle(self, *args, - **kwargs: Optional[Handle]) -> 'JitModelAnalysis': - """Sets additional operator handles, or replaces existing ones. - - If a handle is ``None``, the op will be explicitly ignored. Otherwise, - handle should be a function that calculates the desirable statistic - from an operator. The function must take two arguments, which are the - inputs and outputs of the operator, in the form of - ``list(torch._C.Value)``. The function should return a counter object - with per-operator statistics. - - Args: - args: (str, Handle) pairs of operator names and handles. - kwargs: mapping from operator names to handles. - - Examples: - >>> handlers = {"aten::linear": my_handler} - >>> counter.set_op_handle("aten::matmul", None, - ... "aten::bmm", my_handler2).set_op_handle(**handlers) - """ - self._stats = None - if len(args) % 2 != 0: - raise TypeError( - 'set_op_handle should be called with pairs of names and' - 'handles!') - for name, handle in zip(args[::2], args[1::2]): - kwargs[name] = handle - for name, handle in kwargs.items(): - if handle is None: - self._ignored_ops.add(name) - else: - self._op_handles[name] = handle - return self - - def clear_op_handles(self) -> 'JitModelAnalysis': - """Clears all operator handles currently set.""" - self._op_handles = {} - self._ignored_ops = copy(_IGNORED_OPS) - self._stats = None - return self - - def canonical_module_name(self, name: str) -> str: - """Returns the canonical module name of the given ``name``, which might - be different from the given ``name`` if the module is shared. - - This is the name that will be used as a key when statistics are - output using .by_module() and .by_module_and_operator(). - - Args: - name (str): The name of the module to find the canonical name for. - - Returns: - str: The canonical name of the module. - """ - # Blocks access by a direct module reference - assert isinstance(name, str), 'Module name must be a string.' - if name in self._aliases: - return self._aliases[name] - else: - raise KeyError('Requested module name is not among ' - 'the descendants of the analyzed model.') - - def copy( - self, - new_model: Optional[nn.Module] = None, - new_inputs: Union[None, Tensor, Tuple[Tensor, ...]] = None, - ) -> 'JitModelAnalysis': - """Returns a copy of the :class:`JitModelAnalysis` object, keeping all - settings, but on a new model or new inputs. - - Args: - new_model (nn.Module or None): a new model for the new - JitModelAnalysis. If None, uses the original model. - Defaults to None. - new_inputs (typing.Tuple[object, ...], optional): new inputs - for the new JitModelAnalysis. If None, uses the original - inputs. Defaults to None. - - Returns: - JitModelAnalysis: the new model analysis object - """ - model = self._model if new_model is None else new_model - inputs = self._inputs if new_inputs is None else new_inputs - return (JitModelAnalysis(model=model, inputs=inputs).set_op_handle( - **self._op_handles).unsupported_ops_warnings( - self._enable_warn_unsupported_ops).uncalled_modules_warnings( - self._enable_warn_uncalled_mods).tracer_warnings( - self._warn_trace)) - - def tracer_warnings(self: T, mode: str) -> T: - """Sets which warnings to print when tracing the graph to calculate - statistics. There are three modes. Defaults to 'no_tracer_warning'. - Allowed values are: - - * 'all' : keeps all warnings raised while tracing - * 'no_tracer_warning' : suppress torch.jit.TracerWarning only - * 'none' : suppress all warnings raised while tracing - - Args: - mode (str) : warning mode in one of the above values. - """ - if mode not in ['all', 'no_tracer_warning', 'none']: - raise ValueError(f'Unrecognized tracer warning mode {mode}.') - self._warn_trace = mode - return self - - def ancestor_mode(self: T, mode: str) -> T: - """Sets how to determine the ancestor modules of an operator. Must be - one of "owner" or "caller". - - * "caller": an operator belongs to all modules that are currently - executing `forward()` at the time the operator is called. - * "owner": an operator belongs to the last module that's executing - `forward()` at the time the operator is called, plus this - module's recursive parents. If an module has multiple parents - (e.g. a shared module), only one will be picked. - - For most cases, a module only calls submodules it owns, so both - options would work identically. In certain edge cases, this option - will affect the hierarchy of results, but won't affect the total - count. - """ - if mode not in ['owner', 'caller']: - raise ValueError(f'Unrecognized ancestor mode: {mode}') - self._ancestor_mode = mode - return self - - def unsupported_ops_warnings(self: T, enabled: bool) -> T: - """Sets if warnings for unsupported operators are shown. - - Defaults to True. Counts of unsupported operators may be - obtained from :meth:`unsupported_ops` regardless of this setting. - - Args: - enabled (bool): Set to 'True' to show unsupported operator - warnings. - """ - self._enable_warn_unsupported_ops = enabled - return self - - def uncalled_modules_warnings(self: T, enabled: bool) -> T: - """Sets if warnings from uncalled submodules are shown. - - Defaults to true. A submodule is considered "uncalled" if it is never - called during tracing. This may be because it is actually unused, or - because it is accessed via calls to ``.forward()`` or other methods of - the module. The set of uncalled modules may be obtained from - :meth:`uncalled_modules` regardless of this setting. - - Args: - enabled (bool): Set to 'True' to show warnings. - """ - self._enable_warn_uncalled_mods = enabled - return self - - def _warn_unsupported_ops(self, ops: typing.Counter[str]) -> None: - if not self._enable_warn_unsupported_ops: - return - - for op, freq in ops.items(): - print_log( - 'Unsupported operator {} encountered {} time(s)'.format( - op, freq), - 'current', - logging.WARNING, - ) - - def _warn_uncalled_mods(self, uncalled_mods: Set[str]) -> None: - if not self._enable_warn_uncalled_mods: - return - uncalled_mods = {x for x in uncalled_mods if self._has_forward(x)} - if len(uncalled_mods) == 0: - return - - print_log( - 'The following submodules of the model were never ' - 'called during the trace of the graph. They may be ' - 'unused, or they were accessed by direct calls to ' - '.forward() or via other python methods. In the latter ' - 'case they will have zeros for statistics, though their ' - 'statistics will still contribute to their parent calling ' - 'module.\n' + ', '.join(sorted(uncalled_mods)), 'current', - logging.WARNING) - - def _get_aliases(self, - model: nn.Module) -> Dict[Union[str, nn.Module], str]: - aliases = {} - for name, module in _named_modules_with_dup(model): - if module not in aliases: - aliases[module] = name - aliases[name] = aliases[module] - return aliases - - def _get_all_ancestors(self, module_name: str) -> Set[str]: - """Get all ancestors of the given module, defined by ownership. - - If the given module has multiple owners, use its canonical name. - """ - parts = self.canonical_module_name(module_name).split('.') - res = {''} - for k in range(len(parts) + 1): - res.add('.'.join(parts[:k])) - return res - - def _analyze(self) -> 'Statistics': - # Don't calculate if results are already stored. - stats = self._stats - if stats is not None: - return stats - - with warnings.catch_warnings(): - if self._warn_trace == 'none': - warnings.simplefilter('ignore') - elif self._warn_trace == 'no_tracer_warning': - warnings.filterwarnings('ignore', category=TracerWarning) - graph = _get_scoped_trace_graph(self._model, self._inputs, - self._aliases) - - # Assures even modules not in the trace graph are initialized to - # zero count - counts = {} # type: Dict - unsupported_ops = {} # type: Dict - # We don't need the duplication here, but self._model.named_modules() - # gives slightly different results for some wrapped models. - for _, mod in _named_modules_with_dup(self._model): - name = self._aliases[mod] - counts[name] = Counter() - unsupported_ops[name] = Counter() - - all_seen = set() - for node in graph.nodes(): - kind = node.kind() - if kind == 'prim::PythonOp': - # for PythonOp, pyname contains the actual name in Python - # pyre-fixme[16]: `Node` has no attribute `pyname`. - kind = kind + '.' + node.pyname() - scope_names = node.scopeName().split('/') - all_seen.update(scope_names) - # The result of node.scopeName() is like: `layer1/layer1.layer` - # Therefore, if there is not shared module ancestors will have the - # same value. However, if layer1.layer is used by multiple modules. - # scopeName() will return - # `layer1/layer1.layer` - # `layer2/layer1.layer` respectively - # If mode is `caller`, the ancestors will be: - # 'layer1', 'layer2', 'layer1.layer' - # else, the ancestors will be: - # 'layer1', 'layer1.layer' - # which means only the flops will only be counted into `layer1`. - if self._ancestor_mode == 'caller': - ancestors = set(scope_names) - else: - ancestors = self._get_all_ancestors(scope_names[-1]) - all_seen.update(ancestors) - if kind not in self._op_handles: - if self._should_ignore_node(node): - continue - for name in ancestors: - unsupported_ops[name][kind] += 1 - else: - inputs, outputs = list(node.inputs()), list(node.outputs()) - op_counts = self._op_handles[kind](inputs, outputs) - if isinstance(op_counts, Number): - op_counts = Counter( - {self._simplify_op_name(kind): op_counts}) - for v in op_counts.values(): # type: ignore - if not isinstance(v, (int, float, np.float64, np.int64)): - raise ValueError( - f'Invalid type {type(v)} for the flop count! ' - 'Please use a wider type to avoid overflow.') - - # Assures an op contributes at most once to a module - for name in ancestors: - counts[name] += op_counts - - uncalled_mods = set(self._aliases.values()) - all_seen - stats = Statistics( - counts=counts, - unsupported_ops=unsupported_ops, - uncalled_mods=uncalled_mods) - self._stats = stats - self._warn_unsupported_ops(unsupported_ops['']) - self._warn_uncalled_mods(uncalled_mods) - return stats - - def _simplify_op_name(self, full_op_name: str) -> str: - """Get simplified name of the op without the preceding namespace, e.g. - aten::batch_norm -> batch_norm.""" - p = full_op_name.find('::') - if p != -1: - return full_op_name[p + 2:] - else: - return full_op_name - - def _has_forward(self, mod_name: str) -> bool: - # Whether the module has a valid forward method. - # Modules without forward are not expected to get called - # and therefore should not produce "uncalled" warnings - module = self._named_modules.get(mod_name) - if module is None: - return False - module_type = type(module) - # Containers are not meant to be called anyway (they don't have - # forward) - # NOTE: We add nn.Identity as well to silence the uncalled warning, - # but it's different from other containers: Identity has a forward - # but the forward does not contain ops, so it appears "uncalled" after - # tracing. A more proper way may be to use forward hooks (instead of - # the graph) to decide whether a module has been called. - no_forward_mods = { - nn.ModuleList, nn.ModuleDict, nn.Module, nn.Identity - } - for mod in no_forward_mods: - if module_type.forward is mod.forward: - return False - return True - - def _should_ignore_node(self, node) -> bool: - kind = node.kind() - if kind in self._ignored_ops: - return True - # Ignore all prim:: operators, with two exceptions: - # * prim::PythonOp can be a user-implemented `torch.autograd.Function` - # * prim::CallFunction an be a call to scripted module/function. - if kind.startswith('prim::PythonOp') or kind.startswith( - 'prim::CallFunction'): - return False - if kind.startswith('prim::'): - return True - return False diff --git a/mmengine/analysis/jit_handles.py b/mmengine/analysis/jit_handles.py deleted file mode 100644 index d4b9155e88..0000000000 --- a/mmengine/analysis/jit_handles.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# Modified from -# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py - -import typing -from collections import Counter, OrderedDict -from typing import Any, Callable, List, Optional, Union - -import numpy as np - -try: - from math import prod # type: ignore -except ImportError: - from numpy import prod as _prod # type: ignore - - # Patch `numpy.prod` to avoid overflow on Windows by converting its result - # from `np.int32` to `int`. - def prod(*args, **kwargs): # type: ignore - return _prod(*args, **kwargs).item() - - -Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], int]] - - -def get_shape(val: Any) -> Optional[List[int]]: - """Get the shapes from a jit value object. - - Args: - val (torch._C.Value): jit value object. - - Returns: - list(int): return a list of ints. - """ - if val.isCompleteTensor(): - return val.type().sizes() - else: - return None # type: ignore - - -""" -Below are flop/activation counters for various ops. -Every counter has the following signature: - -Args: - inputs (list(torch._C.Value)): - The inputs of the op in the form of a list of jit object. - outputs (list(torch._C.Value)): - The outputs of the op in the form of a list of jit object. - -Returns: - number: The number of flops/activations for the operation. - or Counter[str] -""" - - -def generic_activation_jit(op_name: Optional[str] = None) -> Handle: - """This method returns a handle that counts the number of activation from - the output shape for the specified operation. - - Args: - op_name (str): The name of the operation. If given, the handle will - return a counter using this name. - - Returns: - Callable: An activation handle for the given operation. - """ - - def _generic_activation_jit( - i: Any, outputs: List[Any]) -> Union[typing.Counter[str], int]: - """This is a generic jit handle that counts the number of activations - for any operation given the output shape.""" - out_shape = get_shape(outputs[0]) - ac_count = prod(out_shape) # type: ignore - if op_name is None: - return ac_count # type: ignore - else: - return Counter({op_name: ac_count}) - - return _generic_activation_jit - - -def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: - """Count flops for fully connected layers.""" - # Count flop for nn.Linear - # inputs is a list of length 3. - input_shapes = [get_shape(v) for v in inputs[1:3]] - # input_shapes[0]: [batch size, input feature dimension] - # input_shapes[1]: [batch size, output feature dimension] - assert len(input_shapes[0]) == 2, input_shapes[0] # type: ignore - assert len(input_shapes[1]) == 2, input_shapes[1] # type: ignore - batch_size, input_dim = input_shapes[0] # type: ignore - output_dim = input_shapes[1][1] # type: ignore - flops = batch_size * input_dim * output_dim - return flops - - -def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: - """Count flops for the aten::linear operator.""" - # Inputs is a list of length 3; unlike aten::addmm, it is the first - # two elements that are relevant. - input_shapes = [get_shape(v) for v in inputs[0:2]] - # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] - # input_shapes[1]: [output_feature_dim, input_feature_dim] - assert input_shapes[0][-1] == input_shapes[1][-1] # type: ignore - flops = prod(input_shapes[0]) * input_shapes[1][0] # type: ignore - return flops - - -def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: - """Count flops for the bmm operation.""" - # Inputs should be a list of length 2. - # Inputs contains the shapes of two tensor. - assert len(inputs) == 2, len(inputs) - input_shapes = [get_shape(v) for v in inputs] - n, c, t = input_shapes[0] # type: ignore - d = input_shapes[-1][-1] # type: ignore - flop = n * c * t * d - return flop - - -def conv_flop_count( - x_shape: List[int], - w_shape: List[int], - out_shape: List[int], - transposed: bool = False, -) -> Union[int, Any]: - """Count flops for convolution. Note only multiplication is counted. - Computation for addition and bias is ignored. Flops for a transposed - convolution are calculated as. - - flops = (x_shape[2:] * prod(w_shape) * batch_size). - - Args: - x_shape (list(int)): The input shape before convolution. - w_shape (list(int)): The filter shape. - out_shape (list(int)): The output shape after convolution. - transposed (bool): is the convolution transposed - - Returns: - int: the number of flops - """ - batch_size = x_shape[0] - conv_shape = (x_shape if transposed else out_shape)[2:] - flop = batch_size * prod(w_shape) * prod(conv_shape) - return flop - - -def conv_flop_jit(inputs: List[Any], - outputs: List[Any]) -> typing.Counter[str]: - """Count flops for convolution.""" - # Inputs of Convolution should be a list of length 12 or 13. - # They represent: - # 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding, - # 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn, - # 10) deterministic_cudnn and 11) user_enabled_cudnn. - # starting with #40737 it will be 12) user_enabled_tf32 - assert len(inputs) == 12 or len(inputs) == 13, len(inputs) - x, w = inputs[:2] - x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), - get_shape(outputs[0])) - transposed = inputs[6].toIValue() - - # use a custom name instead of "_convolution" - return Counter({ - 'conv': - conv_flop_count( - x_shape, # type: ignore - w_shape, # type: ignore - out_shape, # type: ignore - transposed=transposed) # type: ignore - }) - - -def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: - """Count flops for the einsum operation.""" - # Inputs of einsum should be a list of length 2+. - # Inputs[0] stores the equation used for einsum. - # Inputs[1] stores the list of input shapes. - assert len(inputs) >= 2, len(inputs) - equation = inputs[0].toIValue() - # Get rid of white space in the equation string. - equation = equation.replace(' ', '') - input_shapes_jit = inputs[1].node().inputs() - input_shapes = [get_shape(v) for v in input_shapes_jit] - - # Re-map equation so that same equation with different alphabet - # representations will look the same. - letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() - mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} - equation = equation.translate(mapping) - - if equation == 'abc,abd->acd': - n, c, t = input_shapes[0] # type: ignore - p = input_shapes[-1][-1] # type: ignore - flop = n * c * t * p - return flop - - elif equation == 'abc,adc->adb': - n, t, g = input_shapes[0] # type: ignore - c = input_shapes[-1][1] # type: ignore - flop = n * t * g * c - return flop - else: - np_arrs = [np.zeros(s) for s in input_shapes] - optim = np.einsum_path(equation, *np_arrs, optimize='optimal')[1] - for line in optim.split('\n'): - if 'optimized flop' in line.lower(): - # divided by 2 because we count MAC - # (multiply-add counted as one flop) - flop = float(np.floor(float(line.split(':')[-1]) / 2)) - return flop - raise NotImplementedError('Unsupported einsum operation.') - - -def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: - """Count flops for matmul.""" - # input_shapes is a list of length 2. - input_shapes: list = [get_shape(v) for v in inputs] - input1, input2 = input_shapes - if len(input1) == 1: - input1 = [1, input1[0]] - if len(input2) == 1: - input2 = [input2[0], 1] - - assert input1[-1] == input2[-2], input_shapes - flop = prod(input1) * input2[-1] - return flop - - -def norm_flop_counter(affine_arg_index: int) -> Handle: - """ - Args: - affine_arg_index: index of the affine argument in inputs - """ - - def norm_flop_jit(inputs: List[Any], - outputs: List[Any]) -> Union[int, Any]: - """Count flops for norm layers.""" - # Inputs[0] contains the shape of the input. - input_shape = get_shape(inputs[0]) - has_affine = get_shape(inputs[affine_arg_index]) is not None - assert 2 <= len(input_shape) <= 5, input_shape # type: ignore - # 5 is just a rough estimate - flop = prod(input_shape) * (5 if has_affine else 4) # type: ignore - return flop - - return norm_flop_jit - - -def batchnorm_flop_jit(inputs: List[Any], - outputs: List[Any]) -> Union[int, Any]: - training = inputs[5].toIValue() - assert isinstance(training, - bool), 'Signature of aten::batch_norm has changed!' - if training: - return norm_flop_counter(1)(inputs, outputs) # pyre-ignore - has_affine = get_shape(inputs[1]) is not None - input_shape = prod(get_shape(inputs[0])) # type: ignore - return input_shape * (2 if has_affine else 1) - - -def elementwise_flop_counter(input_scale: float = 1, - output_scale: float = 0) -> Handle: - """Count flops by. - - input_tensor.numel() * input_scale + - output_tensor.numel() * output_scale - - Args: - input_scale: scale of the input tensor (first argument) - output_scale: scale of the output tensor (first element in outputs) - """ - - def elementwise_flop(inputs: List[Any], - outputs: List[Any]) -> Union[int, Any]: - ret = 0 - if input_scale != 0: - shape = get_shape(inputs[0]) - ret += input_scale * prod(shape) # type: ignore - if output_scale != 0: - shape = get_shape(outputs[0]) - ret += output_scale * prod(shape) # type: ignore - return ret - - return elementwise_flop diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py deleted file mode 100644 index 3b87d42373..0000000000 --- a/mmengine/analysis/print_helper.py +++ /dev/null @@ -1,784 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -# Modified from -# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py - -from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import torch -from rich import box -from rich.console import Console -from rich.table import Table -from torch import nn - -from mmengine.utils import is_tuple_of -from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer, - parameter_count) - - -def _format_size(x: int, sig_figs: int = 3, hide_zero: bool = False) -> str: - """Formats an integer for printing in a table or model representation. - - Expresses the number in terms of 'kilo', 'mega', etc., using - 'K', 'M', etc. as a suffix. - - Args: - x (int): The integer to format. - sig_figs (int): The number of significant figures to keep. - Defaults to 3. - hide_zero (bool): If True, x=0 is replaced with an empty string - instead of '0'. Defaults to False. - - Returns: - str: The formatted string. - """ - if hide_zero and x == 0: - return '' - - def fmt(x: float) -> str: - # use fixed point to avoid scientific notation - return f'{{:.{sig_figs}f}}'.format(x).rstrip('0').rstrip('.') - - if abs(x) > 1e14: - return fmt(x / 1e15) + 'P' - if abs(x) > 1e11: - return fmt(x / 1e12) + 'T' - if abs(x) > 1e8: - return fmt(x / 1e9) + 'G' - if abs(x) > 1e5: - return fmt(x / 1e6) + 'M' - if abs(x) > 1e2: - return fmt(x / 1e3) + 'K' - return str(x) - - -def _pretty_statistics(statistics: Dict[str, Dict[str, int]], - sig_figs: int = 3, - hide_zero: bool = False) -> Dict[str, Dict[str, str]]: - """Converts numeric statistics to strings with kilo/mega/giga/etc. labels. - - Args: - statistics (dict[str, dict[str, int]]) : the statistics to - format. Organized as a dictionary over modules, which are - each a dictionary over statistic types. - sig_figs (int): the number of significant figures for each stat. - Defaults to 3. - hide_zero (bool): if True, statistics that are zero will be - written as an empty string. Defaults to False. - - Returns: - dict[str, dict[str, str]]: the input statistics as pretty strings - """ - out_stats = {} - for mod, stats in statistics.items(): - out_stats[mod] = { - s: _format_size(val, sig_figs, hide_zero) - for s, val in stats.items() - } - return out_stats - - -def _group_by_module( - statistics: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: - """Converts statistics organized first by statistic type and then by module - to statistics organized first by module and then by statistic type. - - Args: - statistics (dict[str, dict[str, any]]): the statistics to convert - - Returns: - dict[str, dict[str, any]]: the reorganized statistics - """ - out_stats = defaultdict(dict) # type: Dict[str, Dict[str, Any]] - for stat_name, stat in statistics.items(): - for mod, val in stat.items(): - out_stats[mod][stat_name] = val - return dict(out_stats) - - -def _indicate_uncalled_modules( - statistics: Dict[str, Dict[str, str]], - stat_name: str, - uncalled_modules: Set[str], - uncalled_indicator: str = 'N/A', -) -> Dict[str, Dict[str, str]]: - """If a module is in the set of uncalled modules, replace its statistics - with the specified indicator, instead of using the existing string. - - Assumes the statistic is already formatting in string form. - - Args: - statistics (dict[str, dict[str, str]]): the statistics to - format. Organized as a dictionary over modules, which are - each a dictionary over statistic types. Expects statistics - have already been converted to strings. - stat_name (str): the name of the statistic being modified - uncalled_modules set(str): a set of names of uncalled modules. - indicator (str): the string that will be used to indicate - unused modules. Defaults to 'N/A'. - - Returns: - dict[str, dict[str, str]]: the modified statistics - """ - - stats_out = {mod: stats.copy() for mod, stats in statistics.items()} - for mod in uncalled_modules: - if mod not in stats_out: - stats_out[mod] = {} - stats_out[mod][stat_name] = uncalled_indicator - return stats_out - - -def _remove_zero_statistics( - statistics: Dict[str, Dict[str, int]], - force_keep: Optional[Set[str]] = None, - require_trivial_children: bool = False, -) -> Dict[str, Dict[str, int]]: - """Any module that has zero for all available statistics is removed from - the set of statistics. - - This can help declutter the reporting of statistics - if many submodules have zero statistics. Assumes the statistics have - a model hierarchy starting with a root that has name ''. - - Args: - statistics (dict[str, dict[str, int]]): the statistics to - remove zeros from. Organized as a dictionary over modules, - which are each a dictionary over statistic types. - force_keep (set[str] or None): a set of modules to always keep, even - if they are all zero. - require_trivial_children (bool): If True, a statistic will only - be deleted if all its children are also deleted. Defaults to - False. - - Returns: - dict[str, dict[str, int]]: the input statistics dictionary, - with submodules removed if they have zero for all statistics. - """ - out_stats: Dict[str, Dict[str, int]] = {} - _force_keep: Set[str] = force_keep if force_keep else set() | {''} - - def keep_stat(name: str) -> None: - prefix = name + ('.' if name else '') - trivial_children = True - for mod in statistics: - # 'if mod' excludes root = '', which is never a child - if mod and mod.count('.') == prefix.count('.') and mod.startswith( - prefix): - keep_stat(mod) - trivial_children &= mod not in out_stats - - if ((not all(val == 0 for val in statistics[name].values())) - or (name in _force_keep) - or (require_trivial_children and not trivial_children)): - out_stats[name] = statistics[name].copy() - - keep_stat('') - return out_stats - - -def _fill_missing_statistics( - model: nn.Module, - statistics: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: - """If, for a given submodule name in the model, a statistic is missing from - statistics, fills it in with zero. - - This visually uniformizes the reporting of statistics. - - Args: - model (nn.Module): the model whose submodule names will be - used to fill in statistics - statistics (dict[str, dict[str, int]]) : the statistics to - fill in missing values for. Organized as a dictionary - over statistics, which are each a dictionary over submodules' - names. The statistics are assumed to be formatted already - to the desired string format for printing. - - Returns: - dict[str, dict[str, int]]: the input statistics with missing - values filled with zero. - """ - out_stats = {name: stat.copy() for name, stat in statistics.items()} - for mod_name, _ in model.named_modules(): - for stat in out_stats.values(): - if mod_name not in stat: - stat[mod_name] = 0 - return out_stats - - -def _model_stats_str(model: nn.Module, - statistics: Dict[str, Dict[str, str]]) -> str: - """This produces a representation of the model much like 'str(model)' - would, except the provided statistics are written out as additional - information for each submodule. - - Args: - model (nn.Module): the model to form a representation of. - statistics (dict[str, dict[str, str]]): the statistics to - include in the model representations. Organized as a dictionary - over module names, which are each a dictionary over statistics. - The statistics are assumed to be formatted already to the - desired string format for printing. - - Returns: - str: the string representation of the model with the statistics - inserted. - """ - - # Copied from nn.Module._addindent - def _addindent(s_: str, numSpaces: int) -> str: - s = s_.split('\n') - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * ' ') + line for line in s] - s = '\n'.join(s) # type: ignore - s = first + '\n' + s # type: ignore - return s # type: ignore - - def print_statistics(name: str) -> str: - if name not in statistics: - return '' - printed_stats = [f'{k}: {v}' for k, v in statistics[name].items()] - return ', '.join(printed_stats) - - # This comes directly from nn.Module.__repr__ with small changes - # to include the statistics. - def repr_with_statistics(module: nn.Module, name: str) -> str: - # We treat the extra repr like the sub-module, one item per line - extra_lines = [] - extra_repr = module.extra_repr() - printed_stats = print_statistics(name) - # empty string will be split into list [''] - if extra_repr: - extra_lines.extend(extra_repr.split('\n')) - if printed_stats: - extra_lines.extend(printed_stats.split('\n')) - child_lines = [] - for key, submod in module._modules.items(): - submod_name = name + ('.' if name else '') + key - # pyre-fixme[6]: Expected `Module` for 1st param but got - # `Optional[nn.modules.module.Module]`. - submod_str = repr_with_statistics(submod, submod_name) - submod_str = _addindent(submod_str, 2) - child_lines.append('(' + key + '): ' + submod_str) - lines = extra_lines + child_lines - - main_str = module._get_name() + '(' - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += '\n ' + '\n '.join(lines) + '\n' - - main_str += ')' - return main_str - - return repr_with_statistics(model, '') - - -def _get_input_sizes(iterable: Iterable[Any]) -> List[Any]: # type: ignore - """Gets the sizes of all torch tensors in an iterable. - - If an element of the iterable is a non-torch tensor iterable, it recurses - into that iterable to continue calculating sizes. Any non-iterable is given - a size of None. The output consists of nested lists with the same nesting - structure as the input iterables. - """ - out_list = [] - for i in iterable: - if isinstance(i, torch.Tensor): - out_list.append(list(i.size())) - elif isinstance(i, Iterable): - sublist_sizes = _get_input_sizes(i) - if all(j is None for j in sublist_sizes): - out_list.append(None) # type: ignore - else: - out_list.append(sublist_sizes) - else: - out_list.append(None) # type: ignore - return out_list - - -def _get_single_child(name: str, - statistics: Dict[str, Dict[str, str]]) -> Optional[str]: - """If the given module has only a single child in statistics, return it. - - Otherwise, return None. - """ - prefix = name + ('.' if name else '') - child = None - for mod in statistics: - # 'if mod' excludes root = '', which is never a child - if mod and mod.count('.') == prefix.count('.') and mod.startswith( - prefix): - if child is None: - child = mod - else: - return None # We found a second child, so return None - return child - - -def _try_combine(stats1: Dict[str, str], - stats2: Dict[str, str]) -> Optional[Dict[str, str]]: - """Try combine two statistics dict to display in one row. - - If they conflict, returns None. - """ - ret = {} - if set(stats1.keys()) != set(stats2.keys()): - return None - for k, v1 in stats1.items(): - v2 = stats2[k] - if v1 != v2 and len(v1) and len(v2): - return None - ret[k] = v1 if len(v1) else v2 - return ret - - -def _fastforward( - name: str, - statistics: Dict[str, Dict[str, str]]) -> Tuple[str, Dict[str, str]]: - """If the given module has only a single child and matches statistics with - that child, merge statistics and their names into one row. - - Then repeat until the condition isn't met. - - Returns: - tuple[str, dict]: the new name and the combined statistics of this row - """ - single_child = _get_single_child(name, statistics) - if single_child is None: - return name, statistics[name] - combined = _try_combine(statistics[name], statistics[single_child]) - if combined is None: - return name, statistics[name] - statistics[single_child] = combined - return _fastforward(single_child, statistics) - - -def _stats_table_format( - statistics: Dict[str, Dict[str, str]], - max_depth: int = 3, - stat_columns: Optional[List[str]] = None, -) -> str: - """Formats the statistics obtained from a model in a nice table. - - Args: - statistics (dict[str, dict[str, str]]): The statistics to print. - Organized as a dictionary over modules, then as a dictionary - over statistics in the model. The statistics are assumed to - already be formatted for printing. - max_depth (int): The maximum submodule depth to recurse to. - Defaults to 3. - stat_columns (list[str]): Specify the order of the columns to print. - If None, columns are found automatically from the provided - statistics. Defaults to None. - - Return: - str: The formatted table. - """ - if stat_columns is None: - stat_columns = set() # type: ignore - for stats in statistics.values(): - stat_columns.update(stats.keys()) # type: ignore - stat_columns = list(stat_columns) # type: ignore - - headers = ['module'] + stat_columns - rows: List[List[str]] = [] - - def build_row(name: str, stats: Dict[str, str], - indent_lvl: int) -> List[str]: - indent = ' ' * indent_lvl - row = [indent + name] - for stat_name in stat_columns: # type: ignore - row_str = (indent + stats[stat_name]) if stat_name in stats else '' - row.append(row_str) - return row - - def fill(indent_lvl: int, prefix: str) -> None: - if indent_lvl > max_depth: - return - for mod_name in statistics: - # 'if mod' excludes root = '', which is never a child - if (mod_name and mod_name.count('.') == prefix.count('.') - and mod_name.startswith(prefix)): - mod_name, curr_stats = _fastforward(mod_name, statistics) - if root_prefix and mod_name.startswith(root_prefix): - # Skip the root_prefix shared by all submodules as it - # carries 0 information - pretty_mod_name = mod_name[len(root_prefix):] - else: - pretty_mod_name = mod_name - row = build_row(pretty_mod_name, curr_stats, indent_lvl) - rows.append(row) - fill(indent_lvl + 1, mod_name + '.') - - root_name, curr_stats = _fastforward('', statistics) - row = build_row(root_name or 'model', curr_stats, indent_lvl=0) - rows.append(row) - root_prefix = root_name + ('.' if root_name else '') - fill(indent_lvl=1, prefix=root_prefix) - - table = Table(box=box.ASCII2) - for header in headers: - table.add_column(header) - - for row in rows: - table.add_row(*row) - - console = Console() - with console.capture() as capture: - console.print(table, end='') - - return capture.get() - - -def complexity_stats_str( - flops: FlopAnalyzer, - activations: Optional[ActivationAnalyzer] = None) -> str: - """Calculates the parameters and flops of the model with the given inputs - and returns a string representation of the model that includes the - parameters and flops of every submodule. The string is structured to be - similar that given by str(model), though it is not guaranteed to be - identical in form if the default string representation of a module has been - overridden. If a module has zero parameters and flops, statistics will not - be reported for succinctness. The trace can only register the scope of a - module if it is called directly, which means flops (and activations) - arising from explicit calls to .forward() or to other python functions of - the module will not be attributed to that module. Modules that are never - called will have 'N/A' listed for their flops; this means they are either - unused or their statistics are missing for this reason. Any such flops are - still counted towards the parent. - - Examples: - >>> import torch - >>> import torch.nn as nn - >>> class InnerNet(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.fc1 = nn.Linear(10,10) - ... self.fc2 = nn.Linear(10,10) - ... def forward(self, x): - ... return self.fc1(self.fc2(x)) - >>> class TestNet(nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.fc1 = nn.Linear(10,10) - ... self.fc2 = nn.Linear(10,10) - ... self.inner = InnerNet() - ... def forward(self, x): - ... return self.fc1(self.fc2(self.inner(x))) - >>> inputs = torch.randn((1,10)) - >>> print(complexity_stats_str(FlopAnalyzer(model, inputs))) - TestNet( - #params: 0.44K, #flops: 0.4K - (fc1): Linear( - in_features=10, out_features=10, bias=True - #params: 0.11K, #flops: 100 - ) - (fc2): Linear( - in_features=10, out_features=10, bias=True - #params: 0.11K, #flops: 100 - ) - (inner): InnerNet( - #params: 0.22K, #flops: 0.2K - (fc1): Linear( - in_features=10, out_features=10, bias=True - #params: 0.11K, #flops: 100 - ) - (fc2): Linear( - in_features=10, out_features=10, bias=True - #params: 0.11K, #flops: 100 - ) - ) - ) - - Args: - flops (FlopAnalyzer): the flop counting object - activations (ActivationAnalyzer or None): If given, the activations of - each layer will also be calculated and included in the - representation. Defaults to None. - - Returns: - str: a string representation of the model with the number of - parameters and flops included. - """ - # cast to dict since pyre doesn't like the implicit defaultdict->dict - model = flops._model - params = dict(parameter_count(model)) - - flops.unsupported_ops_warnings(False) - flops.uncalled_modules_warnings(False) - flops.tracer_warnings('none') - stats = {'#params': params, '#flops': flops.by_module()} - - if activations is not None: - activations.unsupported_ops_warnings(False) - activations.uncalled_modules_warnings(False) - activations.tracer_warnings('none') - stats['#acts'] = activations.by_module() - - all_uncalled = flops.uncalled_modules() | ( - activations.uncalled_modules() if activations is not None else set()) - stats = _fill_missing_statistics(model, stats) - stats = _group_by_module(stats) - stats = _remove_zero_statistics(stats, force_keep=all_uncalled) - stats = _pretty_statistics(stats, sig_figs=2) # type: ignore - stats = _indicate_uncalled_modules( # type: ignore - stats, # type: ignore - '#flops', # type: ignore - flops.uncalled_modules()) # type: ignore - if activations is not None: - stats = _indicate_uncalled_modules( # type: ignore - stats, # type: ignore - '#acts', # type: ignore - activations.uncalled_modules()) # type: ignore - - model_string = '' - if all_uncalled: - model_string += ( - 'N/A indicates a possibly missing statistic due to how ' - 'the module was called. Missing values are still included ' - "in the parent's total.\n") - model_string += _model_stats_str(model, stats) # type: ignore - return model_string - - -def complexity_stats_table( - flops: FlopAnalyzer, - max_depth: int = 3, - activations: Optional[ActivationAnalyzer] = None, - show_param_shapes: bool = True, -) -> str: - """ - Format the per-module parameters and flops of a model in a table. - It looks like this: - :: - | model | #parameters or shape| #flops | - |:---------------------------------|:--------------------|:----------| - | model | 34.6M | 65.7G | - | s1 | 15.4K | 4.32G | - | s1.pathway0_stem | 9.54K | 1.23G | - | s1.pathway0_stem.conv | 9.41K | 1.23G | - | s1.pathway0_stem.bn | 0.128K | | - | s1.pathway1_stem | 5.9K | 3.08G | - | s1.pathway1_stem.conv | 5.88K | 3.08G | - | s1.pathway1_stem.bn | 16 | | - | s1_fuse | 0.928K | 29.4M | - | s1_fuse.conv_f2s | 0.896K | 29.4M | - | s1_fuse.conv_f2s.weight | (16, 8, 7, 1, 1) | | - | s1_fuse.bn | 32 | | - | s1_fuse.bn.weight | (16,) | | - | s1_fuse.bn.bias | (16,) | | - | s2 | 0.226M | 7.73G | - | s2.pathway0_res0 | 80.1K | 2.58G | - | s2.pathway0_res0.branch1 | 20.5K | 0.671G | - | s2.pathway0_res0.branch1_bn | 0.512K | | - | s2.pathway0_res0.branch2 | 59.1K | 1.91G | - | s2.pathway0_res1.branch2 | 70.4K | 2.28G | - | s2.pathway0_res1.branch2.a | 16.4K | 0.537G | - | s2.pathway0_res1.branch2.a_bn | 0.128K | | - | s2.pathway0_res1.branch2.b | 36.9K | 1.21G | - | s2.pathway0_res1.branch2.b_bn | 0.128K | | - | s2.pathway0_res1.branch2.c | 16.4K | 0.537G | - | s2.pathway0_res1.branch2.c_bn | 0.512K | | - | s2.pathway0_res2.branch2 | 70.4K | 2.28G | - | s2.pathway0_res2.branch2.a | 16.4K | 0.537G | - | s2.pathway0_res2.branch2.a_bn | 0.128K | | - | s2.pathway0_res2.branch2.b | 36.9K | 1.21G | - | s2.pathway0_res2.branch2.b_bn | 0.128K | | - | s2.pathway0_res2.branch2.c | 16.4K | 0.537G | - | s2.pathway0_res2.branch2.c_bn | 0.512K | | - | ............................. | ...... | ...... | - - Args: - flops (FlopAnalyzer): the flop counting object - max_depth (int): The max depth of submodules to include in the - table. Defaults to 3. - activations (ActivationAnalyzer or None): If given, include - activation counts as an additional column in the table. - Defaults to None. - show_param_shapes (bool): If true, shapes for parameters will be - included in the table. Defaults to True. - - Returns: - str: The formatted table. - - Examples: - >>> print(complexity_stats_table(FlopAnalyzer(model, inputs))) - """ - params_header = '#parameters' + (' or shape' if show_param_shapes else '') - flops_header, acts_header = '#flops', '#activations' - - model = flops._model - # cast to dict since pyre doesn't like the implicit defaultdict->dict - params = dict(parameter_count(model)) - - flops.unsupported_ops_warnings(False) - flops.uncalled_modules_warnings(False) - flops.tracer_warnings('none') - - stats = {params_header: params, flops_header: flops.by_module()} - stat_columns = [params_header, flops_header] - - if activations is not None: - activations.unsupported_ops_warnings(False) - activations.uncalled_modules_warnings(False) - activations.tracer_warnings('none') - stats[acts_header] = activations.by_module() - stat_columns += [acts_header] - - stats = _group_by_module(stats) - stats = _remove_zero_statistics( - stats, # type: ignore - require_trivial_children=True) # type: ignore - stats = _pretty_statistics(stats, hide_zero=False) # type: ignore - stats = _indicate_uncalled_modules( # type: ignore - stats, # type: ignore - flops_header, # type: ignore - flops.uncalled_modules() & stats.keys(), # type: ignore - uncalled_indicator='', # type: ignore - ) - if activations: - stats = _indicate_uncalled_modules( # type: ignore - stats, # type: ignore - acts_header, # type: ignore - activations.uncalled_modules() & stats.keys(), # type: ignore - uncalled_indicator='', # type: ignore - ) - - # Swap in shapes for parameters or delete shapes from dict - param_shapes: Dict[str, Tuple[int, ...]] = { - k: tuple(v.shape) - for k, v in model.named_parameters() - } - to_delete = [] - for mod in stats: - if mod in param_shapes: - if show_param_shapes: - stats[mod][params_header] = str( # type: ignore - param_shapes[mod]) # type: ignore - else: - to_delete.append(mod) - for mod in to_delete: - del stats[mod] - - return _stats_table_format( - statistics=stats, # type: ignore - max_depth=max_depth, - stat_columns=stat_columns, - ) - - -def get_model_complexity_info( - model: nn.Module, - input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], - None] = None, - inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...], - None] = None, - show_table: bool = True, - show_arch: bool = True, -): - """Interface to get the complexity of a model. - - The parameter `inputs` are fed to the forward method of model. - If `inputs` is not specified, the `input_shape` is required and - it will be used to construct the dummy input fed to model. - If the forward of model requires two or more inputs, the `inputs` - should be a tuple of tensor or the `input_shape` should be a tuple - of tuple which each element will be constructed into a dumpy input. - - Examples: - >>> # the forward of model accepts only one input - >>> input_shape = (3, 224, 224) - >>> get_model_complexity_info(model, input_shape=input_shape) - >>> # the forward of model accepts two or more inputs - >>> input_shape = ((3, 224, 224), (3, 10)) - >>> get_model_complexity_info(model, input_shape=input_shape) - - Args: - model (nn.Module): The model to analyze. - input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...]], None]): - The input shape of the model. - If "inputs" is not specified, the "input_shape" should be set. - Defaults to None. - inputs (torch.Tensor, tuple[torch.Tensor, ...] or Tuple[Any, ...],\ - optional]): - The input tensor(s) of the model. If not given the input tensor - will be generated automatically with the given input_shape. - Defaults to None. - show_table (bool): Whether to show the complexity table. - Defaults to True. - show_arch (bool): Whether to show the complexity arch. - Defaults to True. - - Returns: - dict: The complexity information of the model. - """ - if input_shape is None and inputs is None: - raise ValueError('One of "input_shape" and "inputs" should be set.') - elif input_shape is not None and inputs is not None: - raise ValueError('"input_shape" and "inputs" cannot be both set.') - - if inputs is None: - device = next(model.parameters()).device - if is_tuple_of(input_shape, int): # tuple of int, construct one tensor - inputs = (torch.randn(1, *input_shape).to(device), ) - elif is_tuple_of(input_shape, tuple) and all([ - is_tuple_of(one_input_shape, int) - for one_input_shape in input_shape # type: ignore - ]): # tuple of tuple of int, construct multiple tensors - inputs = tuple([ - torch.randn(1, *one_input_shape).to(device) - for one_input_shape in input_shape # type: ignore - ]) - else: - raise ValueError( - '"input_shape" should be either a `tuple of int` (to construct' - 'one input tensor) or a `tuple of tuple of int` (to construct' - 'multiple input tensors).') - - flop_handler = FlopAnalyzer(model, inputs) - activation_handler = ActivationAnalyzer(model, inputs) - - flops = flop_handler.total() - activations = activation_handler.total() - params = parameter_count(model)[''] - - flops_str = _format_size(flops) - activations_str = _format_size(activations) - params_str = _format_size(params) - - if show_table: - complexity_table = complexity_stats_table( - flops=flop_handler, - activations=activation_handler, - show_param_shapes=True, - ) - complexity_table = '\n' + complexity_table - else: - complexity_table = '' - - if show_arch: - complexity_arch = complexity_stats_str( - flops=flop_handler, - activations=activation_handler, - ) - complexity_arch = '\n' + complexity_arch - else: - complexity_arch = '' - - return { - 'flops': flops, - 'flops_str': flops_str, - 'activations': activations, - 'activations_str': activations_str, - 'params': params, - 'params_str': params_str, - 'out_table': complexity_table, - 'out_arch': complexity_arch - } diff --git a/mmengine/config/__init__.py b/mmengine/config/__init__.py deleted file mode 100644 index 9a1bc47db4..0000000000 --- a/mmengine/config/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .config import Config, ConfigDict, DictAction, read_base - -__all__ = ['Config', 'ConfigDict', 'DictAction', 'read_base'] diff --git a/mmengine/config/config.py b/mmengine/config/config.py deleted file mode 100644 index 801243c82d..0000000000 --- a/mmengine/config/config.py +++ /dev/null @@ -1,1858 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import ast -import copy -import difflib -import os -import os.path as osp -import platform -import shutil -import sys -import tempfile -import types -import uuid -import warnings -from argparse import Action, ArgumentParser, Namespace -from collections import OrderedDict, abc -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Union - -from addict import Dict -from rich.console import Console -from rich.text import Text - -from mmengine.fileio import dump, load -from mmengine.logging import print_log -from mmengine.utils import (check_file_exist, digit_version, - get_installed_path, import_modules_from_strings, - is_installed) -from .lazy import LazyAttr, LazyObject -from .utils import (ConfigParsingError, ImportTransformer, RemoveAssignFromAST, - _gather_abs_import_lazyobj, _get_external_cfg_base_path, - _get_external_cfg_path, _get_package_and_cfg_path, - _is_builtin_module) - -BASE_KEY = '_base_' -DELETE_KEY = '_delete_' -DEPRECATION_KEY = '_deprecation_' -RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'env_variables'] - -if platform.system() == 'Windows': - import regex as re -else: - import re # type: ignore - - -def _lazy2string(cfg_dict, dict_type=None): - if isinstance(cfg_dict, dict): - dict_type = dict_type or type(cfg_dict) - return dict_type( - {k: _lazy2string(v, dict_type) - for k, v in dict.items(cfg_dict)}) - elif isinstance(cfg_dict, (tuple, list)): - return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) - elif isinstance(cfg_dict, (LazyAttr, LazyObject)): - return f'{cfg_dict.module}.{str(cfg_dict)}' - else: - return cfg_dict - - -class ConfigDict(Dict): - """A dictionary for config which has the same interface as python's built- - in dictionary and can be used as a normal dictionary. - - The Config class would transform the nested fields (dictionary-like fields) - in config file into ``ConfigDict``. - - If the class attribute ``lazy`` is ``False``, users will get the - object built by ``LazyObject`` or ``LazyAttr``, otherwise users will get - the ``LazyObject`` or ``LazyAttr`` itself. - - The ``lazy`` should be set to ``True`` to avoid building the imported - object during configuration parsing, and it should be set to False outside - the Config to ensure that users do not experience the ``LazyObject``. - """ - lazy = False - - def __init__(__self, *args, **kwargs): - object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) - object.__setattr__(__self, '__key', kwargs.pop('__key', None)) - object.__setattr__(__self, '__frozen', False) - for arg in args: - if not arg: - continue - # Since ConfigDict.items will convert LazyObject to real object - # automatically, we need to call super().items() to make sure - # the LazyObject will not be converted. - if isinstance(arg, ConfigDict): - for key, val in dict.items(arg): - __self[key] = __self._hook(val) - elif isinstance(arg, dict): - for key, val in arg.items(): - __self[key] = __self._hook(val) - elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): - __self[arg[0]] = __self._hook(arg[1]) - else: - for key, val in iter(arg): - __self[key] = __self._hook(val) - - for key, val in dict.items(kwargs): - __self[key] = __self._hook(val) - - def __missing__(self, name): - raise KeyError(name) - - def __getattr__(self, name): - try: - value = super().__getattr__(name) - if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy: - value = value.build() - except KeyError: - raise AttributeError(f"'{self.__class__.__name__}' object has no " - f"attribute '{name}'") - except Exception as e: - raise e - else: - return value - - @classmethod - def _hook(cls, item): - # avoid to convert user defined dict to ConfigDict. - if type(item) in (dict, OrderedDict): - return cls(item) - elif isinstance(item, (list, tuple)): - return type(item)(cls._hook(elem) for elem in item) - return item - - def __setattr__(self, name, value): - value = self._hook(value) - return super().__setattr__(name, value) - - def __setitem__(self, name, value): - value = self._hook(value) - return super().__setitem__(name, value) - - def __getitem__(self, key): - return self.build_lazy(super().__getitem__(key)) - - def __deepcopy__(self, memo): - other = self.__class__() - memo[id(self)] = other - for key, value in super().items(): - other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) - return other - - def __copy__(self): - other = self.__class__() - for key, value in super().items(): - other[key] = value - return other - - copy = __copy__ - - def __iter__(self): - # Implement `__iter__` to overwrite the unpacking operator `**cfg_dict` - # to get the built lazy object - return iter(self.keys()) - - def get(self, key: str, default: Optional[Any] = None) -> Any: - """Get the value of the key. If class attribute ``lazy`` is True, the - LazyObject will be built and returned. - - Args: - key (str): The key. - default (any, optional): The default value. Defaults to None. - - Returns: - Any: The value of the key. - """ - return self.build_lazy(super().get(key, default)) - - def pop(self, key, default=None): - """Pop the value of the key. If class attribute ``lazy`` is True, the - LazyObject will be built and returned. - - Args: - key (str): The key. - default (any, optional): The default value. Defaults to None. - - Returns: - Any: The value of the key. - """ - return self.build_lazy(super().pop(key, default)) - - def update(self, *args, **kwargs) -> None: - """Override this method to make sure the LazyObject will not be built - during updating.""" - other = {} - if args: - if len(args) > 1: - raise TypeError('update only accept one positional argument') - # Avoid to used self.items to build LazyObject - for key, value in dict.items(args[0]): - other[key] = value - - for key, value in dict(kwargs).items(): - other[key] = value - for k, v in other.items(): - if ((k not in self) or (not isinstance(self[k], dict)) - or (not isinstance(v, dict))): - self[k] = self._hook(v) - else: - self[k].update(v) - - def build_lazy(self, value: Any) -> Any: - """If class attribute ``lazy`` is False, the LazyObject will be built - and returned. - - Args: - value (Any): The value to be built. - - Returns: - Any: The built value. - """ - if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy: - value = value.build() - return value - - def values(self): - """Yield the values of the dictionary. - - If class attribute ``lazy`` is False, the value of ``LazyObject`` or - ``LazyAttr`` will be built and returned. - """ - values = [] - for value in super().values(): - values.append(self.build_lazy(value)) - return values - - def items(self): - """Yield the keys and values of the dictionary. - - If class attribute ``lazy`` is False, the value of ``LazyObject`` or - ``LazyAttr`` will be built and returned. - """ - items = [] - for key, value in super().items(): - items.append((key, self.build_lazy(value))) - return items - - def merge(self, other: dict): - """Merge another dictionary into current dictionary. - - Args: - other (dict): Another dictionary. - """ - default = object() - - def _merge_a_into_b(a, b): - if isinstance(a, dict): - if not isinstance(b, dict): - a.pop(DELETE_KEY, None) - return a - if a.pop(DELETE_KEY, False): - b.clear() - all_keys = list(b.keys()) + list(a.keys()) - return { - key: - _merge_a_into_b(a.get(key, default), b.get(key, default)) - for key in all_keys if key != DELETE_KEY - } - else: - return a if a is not default else b - - merged = _merge_a_into_b(copy.deepcopy(other), copy.deepcopy(self)) - self.clear() - for key, value in merged.items(): - self[key] = value - - def __reduce_ex__(self, proto): - # Override __reduce_ex__ to avoid `self.items` will be - # called by CPython interpreter during pickling. See more details in - # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 - if digit_version(platform.python_version()) < digit_version('3.8'): - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None) - else: - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None, None) - - def __eq__(self, other): - if isinstance(other, ConfigDict): - return other.to_dict() == self.to_dict() - elif isinstance(other, dict): - return {k: v for k, v in self.items()} == other - else: - return False - - def _to_lazy_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and keep - the ``LazyObject`` or ``LazyAttr`` object not built.""" - - def _to_dict(data): - if isinstance(data, ConfigDict): - return { - key: _to_dict(value) - for key, value in Dict.items(data) - } - elif isinstance(data, dict): - return {key: _to_dict(value) for key, value in data.items()} - elif isinstance(data, (list, tuple)): - return type(data)(_to_dict(item) for item in data) - else: - return data - - return _to_dict(self) - - def to_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and - convert the ``LazyObject`` or ``LazyAttr`` to string.""" - return _lazy2string(self, dict_type=dict) - - -def add_args(parser: ArgumentParser, - cfg: dict, - prefix: str = '') -> ArgumentParser: - """Add config fields into argument parser. - - Args: - parser (ArgumentParser): Argument parser. - cfg (dict): Config dictionary. - prefix (str, optional): Prefix of parser argument. - Defaults to ''. - - Returns: - ArgumentParser: Argument parser containing config fields. - """ - for k, v in cfg.items(): - if isinstance(v, str): - parser.add_argument('--' + prefix + k) - elif isinstance(v, bool): - parser.add_argument('--' + prefix + k, action='store_true') - elif isinstance(v, int): - parser.add_argument('--' + prefix + k, type=int) - elif isinstance(v, float): - parser.add_argument('--' + prefix + k, type=float) - elif isinstance(v, dict): - add_args(parser, v, prefix + k + '.') - elif isinstance(v, abc.Iterable): - parser.add_argument( - '--' + prefix + k, type=type(next(iter(v))), nargs='+') - else: - print_log( - f'cannot parse key {prefix + k} of type {type(v)}', - logger='current') - return parser - - -class Config: - """A facility for config and config files. - - It supports common file formats as configs: python/json/yaml. - ``Config.fromfile`` can parse a dictionary from a config file, then - build a ``Config`` instance with the dictionary. - The interface is the same as a dict object and also allows access config - values as attributes. - - Args: - cfg_dict (dict, optional): A config dictionary. Defaults to None. - cfg_text (str, optional): Text of config. Defaults to None. - filename (str or Path, optional): Name of config file. - Defaults to None. - format_python_code (bool): Whether to format Python code by yapf. - Defaults to True. - - Here is a simple example: - - Examples: - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> cfg.a - 1 - >>> cfg.b - {'b1': [0, 1]} - >>> cfg.b.b1 - [0, 1] - >>> cfg = Config.fromfile('tests/data/config/a.py') - >>> cfg.filename - "/home/username/projects/mmengine/tests/data/config/a.py" - >>> cfg.item4 - 'test' - >>> cfg - "Config [path: /home/username/projects/mmengine/tests/data/config/a.py] - :" - "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" - - You can find more advance usage in the `config tutorial`_. - - .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html - """ # noqa: E501 - - def __init__( - self, - cfg_dict: Optional[dict] = None, - cfg_text: Optional[str] = None, - filename: Optional[Union[str, Path]] = None, - env_variables: Optional[dict] = None, - format_python_code: bool = True, - ): - filename = str(filename) if isinstance(filename, Path) else filename - if cfg_dict is None: - cfg_dict = dict() - elif not isinstance(cfg_dict, dict): - raise TypeError('cfg_dict must be a dict, but ' - f'got {type(cfg_dict)}') - for key in cfg_dict: - if key in RESERVED_KEYS: - raise KeyError(f'{key} is reserved for config file') - - if not isinstance(cfg_dict, ConfigDict): - cfg_dict = ConfigDict(cfg_dict) - super().__setattr__('_cfg_dict', cfg_dict) - super().__setattr__('_filename', filename) - super().__setattr__('_format_python_code', format_python_code) - if not hasattr(self, '_imported_names'): - super().__setattr__('_imported_names', set()) - - if cfg_text: - text = cfg_text - elif filename: - with open(filename, encoding='utf-8') as f: - text = f.read() - else: - text = '' - super().__setattr__('_text', text) - if env_variables is None: - env_variables = dict() - super().__setattr__('_env_variables', env_variables) - - @staticmethod - def fromfile(filename: Union[str, Path], - use_predefined_variables: bool = True, - import_custom_modules: bool = True, - use_environment_variables: bool = True, - lazy_import: Optional[bool] = None, - format_python_code: bool = True) -> 'Config': - """Build a Config instance from config file. - - Args: - filename (str or Path): Name of config file. - use_predefined_variables (bool, optional): Whether to use - predefined variables. Defaults to True. - import_custom_modules (bool, optional): Whether to support - importing custom modules in config. Defaults to None. - use_environment_variables (bool, optional): Whether to use - environment variables. Defaults to True. - lazy_import (bool): Whether to load config in `lazy_import` mode. - If it is `None`, it will be deduced by the content of the - config file. Defaults to None. - format_python_code (bool): Whether to format Python code by yapf. - Defaults to True. - - Returns: - Config: Config instance built from config file. - """ - filename = str(filename) if isinstance(filename, Path) else filename - if lazy_import is False or \ - lazy_import is None and not Config._is_lazy_import(filename): - cfg_dict, cfg_text, env_variables = Config._file2dict( - filename, use_predefined_variables, use_environment_variables, - lazy_import) - if import_custom_modules and cfg_dict.get('custom_imports', None): - try: - import_modules_from_strings(**cfg_dict['custom_imports']) - except ImportError as e: - err_msg = ( - 'Failed to import custom modules from ' - f"{cfg_dict['custom_imports']}, the current sys.path " - 'is: ') - for p in sys.path: - err_msg += f'\n {p}' - err_msg += ( - '\nYou should set `PYTHONPATH` to make `sys.path` ' - 'include the directory which contains your custom ' - 'module') - raise ImportError(err_msg) from e - return Config( - cfg_dict, - cfg_text=cfg_text, - filename=filename, - env_variables=env_variables, - ) - else: - # Enable lazy import when parsing the config. - # Using try-except to make sure ``ConfigDict.lazy`` will be reset - # to False. See more details about lazy in the docstring of - # ConfigDict - ConfigDict.lazy = True - try: - cfg_dict, imported_names = Config._parse_lazy_import(filename) - except Exception as e: - raise e - finally: - # disable lazy import to get the real type. See more details - # about lazy in the docstring of ConfigDict - ConfigDict.lazy = False - - cfg = Config( - cfg_dict, - filename=filename, - format_python_code=format_python_code) - object.__setattr__(cfg, '_imported_names', imported_names) - return cfg - - @staticmethod - def fromstring(cfg_str: str, file_format: str) -> 'Config': - """Build a Config instance from config text. - - Args: - cfg_str (str): Config text. - file_format (str): Config file format corresponding to the - config str. Only py/yml/yaml/json type are supported now! - - Returns: - Config: Config object generated from ``cfg_str``. - """ - if file_format not in ['.py', '.json', '.yaml', '.yml']: - raise OSError('Only py/yml/yaml/json type are supported now!') - if file_format != '.py' and 'dict(' in cfg_str: - # check if users specify a wrong suffix for python - warnings.warn( - 'Please check "file_format", the file format may be .py') - - # A temporary file can not be opened a second time on Windows. - # See https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile for more details. # noqa - # `temp_file` is opened first in `tempfile.NamedTemporaryFile` and - # second in `Config.from_file`. - # In addition, a named temporary file will be removed after closed. - # As a workaround we set `delete=False` and close the temporary file - # before opening again. - - with tempfile.NamedTemporaryFile( - 'w', encoding='utf-8', suffix=file_format, - delete=False) as temp_file: - temp_file.write(cfg_str) - - cfg = Config.fromfile(temp_file.name) - os.remove(temp_file.name) # manually delete the temporary file - return cfg - - @staticmethod - def _get_base_modules(nodes: list) -> list: - """Get base module name from parsed code. - - Args: - nodes (list): Parsed code of the config file. - - Returns: - list: Name of base modules. - """ - - def _get_base_module_from_with(with_nodes: list) -> list: - """Get base module name from if statement in python file. - - Args: - with_nodes (list): List of if statement. - - Returns: - list: Name of base modules. - """ - base_modules = [] - for node in with_nodes: - assert isinstance(node, ast.ImportFrom), ( - 'Illegal syntax in config file! Only ' - '`from ... import ...` could be implemented` in ' - 'with read_base()`') - assert node.module is not None, ( - 'Illegal syntax in config file! Syntax like ' - '`from . import xxx` is not allowed in `with read_base()`') - base_modules.append(node.level * '.' + node.module) - return base_modules - - for idx, node in enumerate(nodes): - if (isinstance(node, ast.Assign) - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id == BASE_KEY): - raise ConfigParsingError( - 'The configuration file type in the inheritance chain ' - 'must match the current configuration file type, either ' - '"lazy_import" or non-"lazy_import". You got this error ' - f'since you use the syntax like `_base_ = "{node.targets[0].id}"` ' # noqa: E501 - 'in your config. You should use `with read_base(): ... to` ' # noqa: E501 - 'mark the inherited config file. See more information ' - 'in https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html' # noqa: E501 - ) - - if not isinstance(node, ast.With): - continue - - expr = node.items[0].context_expr - if (not isinstance(expr, ast.Call) - or not expr.func.id == 'read_base' or # type: ignore - len(node.items) > 1): - raise ConfigParsingError( - 'Only `read_base` context manager can be used in the ' - 'config') - - # The original code: - # ``` - # with read_base(): - # from .._base_.default_runtime import * - # ``` - # The processed code: - # ``` - # from .._base_.default_runtime import * - # ``` - # As you can see, the if statement is removed and the - # from ... import statement will be unindent - for nested_idx, nested_node in enumerate(node.body): - nodes.insert(idx + nested_idx + 1, nested_node) - nodes.pop(idx) - return _get_base_module_from_with(node.body) - return [] - - @staticmethod - def _validate_py_syntax(filename: str): - """Validate syntax of python config. - - Args: - filename (str): Filename of python config file. - """ - with open(filename, encoding='utf-8') as f: - content = f.read() - try: - ast.parse(content) - except SyntaxError as e: - raise SyntaxError('There are syntax errors in config ' - f'file {filename}: {e}') - - @staticmethod - def _substitute_predefined_vars(filename: str, temp_config_name: str): - """Substitute predefined variables in config with actual values. - - Sometimes we want some variables in the config to be related to the - current path or file name, etc. - - Here is an example of a typical usage scenario. When training a model, - we define a working directory in the config that save the models and - logs. For different configs, we expect to define different working - directories. A common way for users is to use the config file name - directly as part of the working directory name, e.g. for the config - ``config_setting1.py``, the working directory is - ``. /work_dir/config_setting1``. - - This can be easily achieved using predefined variables, which can be - written in the config `config_setting1.py` as follows - - .. code-block:: python - - work_dir = '. /work_dir/{{ fileBasenameNoExtension }}' - - - Here `{{ fileBasenameNoExtension }}` indicates the file name of the - config (without the extension), and when the config class reads the - config file, it will automatically parse this double-bracketed string - to the corresponding actual value. - - .. code-block:: python - - cfg = Config.fromfile('. /config_setting1.py') - cfg.work_dir # ". /work_dir/config_setting1" - - - For details, Please refer to docs/zh_cn/advanced_tutorials/config.md . - - Args: - filename (str): Filename of config. - temp_config_name (str): Temporary filename to save substituted - config. - """ - file_dirname = osp.dirname(filename) - file_basename = osp.basename(filename) - file_basename_no_extension = osp.splitext(file_basename)[0] - file_extname = osp.splitext(filename)[1] - support_templates = dict( - fileDirname=file_dirname, - fileBasename=file_basename, - fileBasenameNoExtension=file_basename_no_extension, - fileExtname=file_extname) - with open(filename, encoding='utf-8') as f: - config_file = f.read() - for key, value in support_templates.items(): - regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' - value = value.replace('\\', '/') - config_file = re.sub(regexp, value, config_file) - with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: - tmp_config_file.write(config_file) - - @staticmethod - def _substitute_env_variables(filename: str, temp_config_name: str): - """Substitute environment variables in config with actual values. - - Sometimes, we want to change some items in the config with environment - variables. For examples, we expect to change dataset root by setting - ``DATASET_ROOT=/dataset/root/path`` in the command line. This can be - easily achieved by writing lines in the config as follows - - .. code-block:: python - - data_root = '{{$DATASET_ROOT:/default/dataset}}/images' - - - Here, ``{{$DATASET_ROOT:/default/dataset}}`` indicates using the - environment variable ``DATASET_ROOT`` to replace the part between - ``{{}}``. If the ``DATASET_ROOT`` is not set, the default value - ``/default/dataset`` will be used. - - Environment variables not only can replace items in the string, they - can also substitute other types of data in config. In this situation, - we can write the config as below - - .. code-block:: python - - model = dict( - bbox_head = dict(num_classes={{'$NUM_CLASSES:80'}})) - - - For details, Please refer to docs/zh_cn/tutorials/config.md . - - Args: - filename (str): Filename of config. - temp_config_name (str): Temporary filename to save substituted - config. - """ - with open(filename, encoding='utf-8') as f: - config_file = f.read() - regexp = r'\{\{[\'\"]?\s*\$(\w+)\s*\:\s*(\S*?)\s*[\'\"]?\}\}' - keys = re.findall(regexp, config_file) - env_variables = dict() - for var_name, value in keys: - regexp = r'\{\{[\'\"]?\s*\$' + var_name + r'\s*\:\s*' \ - + value + r'\s*[\'\"]?\}\}' - if var_name in os.environ: - value = os.environ[var_name] - env_variables[var_name] = value - print_log( - f'Using env variable `{var_name}` with value of ' - f'{value} to replace item in config.', - logger='current') - if not value: - raise KeyError(f'`{var_name}` cannot be found in `os.environ`.' - f' Please set `{var_name}` in environment or ' - 'give a default value.') - config_file = re.sub(regexp, value, config_file) - - with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: - tmp_config_file.write(config_file) - return env_variables - - @staticmethod - def _pre_substitute_base_vars(filename: str, - temp_config_name: str) -> dict: - """Preceding step for substituting variables in base config with actual - value. - - Args: - filename (str): Filename of config. - temp_config_name (str): Temporary filename to save substituted - config. - - Returns: - dict: A dictionary contains variables in base config. - """ - with open(filename, encoding='utf-8') as f: - config_file = f.read() - base_var_dict = {} - regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' - base_vars = set(re.findall(regexp, config_file)) - for base_var in base_vars: - randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' - base_var_dict[randstr] = base_var - regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' - config_file = re.sub(regexp, f'"{randstr}"', config_file) - with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: - tmp_config_file.write(config_file) - return base_var_dict - - @staticmethod - def _substitute_base_vars(cfg: Any, base_var_dict: dict, - base_cfg: dict) -> Any: - """Substitute base variables from strings to their actual values. - - Args: - Any : Config dictionary. - base_var_dict (dict): A dictionary contains variables in base - config. - base_cfg (dict): Base config dictionary. - - Returns: - Any : A dictionary with origin base variables - substituted with actual values. - """ - cfg = copy.deepcopy(cfg) - - if isinstance(cfg, dict): - for k, v in cfg.items(): - if isinstance(v, str) and v in base_var_dict: - new_v = base_cfg - for new_k in base_var_dict[v].split('.'): - new_v = new_v[new_k] - cfg[k] = new_v - elif isinstance(v, (list, tuple, dict)): - cfg[k] = Config._substitute_base_vars( - v, base_var_dict, base_cfg) - elif isinstance(cfg, tuple): - cfg = tuple( - Config._substitute_base_vars(c, base_var_dict, base_cfg) - for c in cfg) - elif isinstance(cfg, list): - cfg = [ - Config._substitute_base_vars(c, base_var_dict, base_cfg) - for c in cfg - ] - elif isinstance(cfg, str) and cfg in base_var_dict: - new_v = base_cfg - for new_k in base_var_dict[cfg].split('.'): - new_v = new_v[new_k] - cfg = new_v - - return cfg - - @staticmethod - def _file2dict( - filename: str, - use_predefined_variables: bool = True, - use_environment_variables: bool = True, - lazy_import: Optional[bool] = None) -> Tuple[dict, str, dict]: - """Transform file to variables dictionary. - - Args: - filename (str): Name of config file. - use_predefined_variables (bool, optional): Whether to use - predefined variables. Defaults to True. - use_environment_variables (bool, optional): Whether to use - environment variables. Defaults to True. - lazy_import (bool): Whether to load config in `lazy_import` mode. - If it is `None`, it will be deduced by the content of the - config file. Defaults to None. - - Returns: - Tuple[dict, str]: Variables dictionary and text of Config. - """ - if lazy_import is None and Config._is_lazy_import(filename): - raise RuntimeError( - 'The configuration file type in the inheritance chain ' - 'must match the current configuration file type, either ' - '"lazy_import" or non-"lazy_import". You got this error ' - 'since you use the syntax like `with read_base(): ...` ' - f'or import non-builtin module in {filename}. See more ' - 'information in https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html' # noqa: E501 - ) - - filename = osp.abspath(osp.expanduser(filename)) - check_file_exist(filename) - fileExtname = osp.splitext(filename)[1] - if fileExtname not in ['.py', '.json', '.yaml', '.yml']: - raise OSError('Only py/yml/yaml/json type are supported now!') - try: - with tempfile.TemporaryDirectory() as temp_config_dir: - temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix=fileExtname, delete=False) - if platform.system() == 'Windows': - temp_config_file.close() - - # Substitute predefined variables - if use_predefined_variables: - Config._substitute_predefined_vars(filename, - temp_config_file.name) - else: - shutil.copyfile(filename, temp_config_file.name) - # Substitute environment variables - env_variables = dict() - if use_environment_variables: - env_variables = Config._substitute_env_variables( - temp_config_file.name, temp_config_file.name) - # Substitute base variables from placeholders to strings - base_var_dict = Config._pre_substitute_base_vars( - temp_config_file.name, temp_config_file.name) - - # Handle base files - base_cfg_dict = ConfigDict() - cfg_text_list = list() - for base_cfg_path in Config._get_base_files( - temp_config_file.name): - base_cfg_path, scope = Config._get_cfg_path( - base_cfg_path, filename) - _cfg_dict, _cfg_text, _env_variables = Config._file2dict( - filename=base_cfg_path, - use_predefined_variables=use_predefined_variables, - use_environment_variables=use_environment_variables, - lazy_import=lazy_import, - ) - cfg_text_list.append(_cfg_text) - env_variables.update(_env_variables) - duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() - if len(duplicate_keys) > 0: - raise KeyError( - 'Duplicate key is not allowed among bases. ' - f'Duplicate keys: {duplicate_keys}') - - # _dict_to_config_dict will do the following things: - # 1. Recursively converts ``dict`` to :obj:`ConfigDict`. - # 2. Set `_scope_` for the outer dict variable for the base - # config. - # 3. Set `scope` attribute for each base variable. - # Different from `_scope_`, `scope` is not a key of base - # dict, `scope` attribute will be parsed to key `_scope_` - # by function `_parse_scope` only if the base variable is - # accessed by the current config. - _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) - base_cfg_dict.update(_cfg_dict) - - if filename.endswith('.py'): - with open(temp_config_file.name, encoding='utf-8') as f: - parsed_codes = ast.parse(f.read()) - parsed_codes = RemoveAssignFromAST(BASE_KEY).visit( - parsed_codes) - codeobj = compile(parsed_codes, filename, mode='exec') - # Support load global variable in nested function of the - # config. - global_locals_var = {BASE_KEY: base_cfg_dict} - ori_keys = set(global_locals_var.keys()) - eval(codeobj, global_locals_var, global_locals_var) - cfg_dict = { - key: value - for key, value in global_locals_var.items() - if (key not in ori_keys and not key.startswith('__')) - } - elif filename.endswith(('.yml', '.yaml', '.json')): - cfg_dict = load(temp_config_file.name) - # close temp file - for key, value in list(cfg_dict.items()): - if isinstance(value, - (types.FunctionType, types.ModuleType)): - cfg_dict.pop(key) - temp_config_file.close() - - # If the current config accesses a base variable of base - # configs, The ``scope`` attribute of corresponding variable - # will be converted to the `_scope_`. - Config._parse_scope(cfg_dict) - except Exception as e: - if osp.exists(temp_config_dir): - shutil.rmtree(temp_config_dir) - raise e - - # check deprecation information - if DEPRECATION_KEY in cfg_dict: - deprecation_info = cfg_dict.pop(DEPRECATION_KEY) - warning_msg = f'The config file {filename} will be deprecated ' \ - 'in the future.' - if 'expected' in deprecation_info: - warning_msg += f' Please use {deprecation_info["expected"]} ' \ - 'instead.' - if 'reference' in deprecation_info: - warning_msg += ' More information can be found at ' \ - f'{deprecation_info["reference"]}' - warnings.warn(warning_msg, DeprecationWarning) - - cfg_text = filename + '\n' - with open(filename, encoding='utf-8') as f: - # Setting encoding explicitly to resolve coding issue on windows - cfg_text += f.read() - - # Substitute base variables from strings to their actual values - cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, - base_cfg_dict) - cfg_dict.pop(BASE_KEY, None) - - cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) - cfg_dict = { - k: v - for k, v in cfg_dict.items() if not k.startswith('__') - } - - # merge cfg_text - cfg_text_list.append(cfg_text) - cfg_text = '\n'.join(cfg_text_list) - - return cfg_dict, cfg_text, env_variables - - @staticmethod - def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]: - """Transform file to variables dictionary. - - Args: - filename (str): Name of config file. - - Returns: - Tuple[dict, dict]: ``cfg_dict`` and ``imported_names``. - - - cfg_dict (dict): Variables dictionary of parsed config. - - imported_names (set): Used to mark the names of - imported object. - """ - # In lazy import mode, users can use the Python syntax `import` to - # implement inheritance between configuration files, which is easier - # for users to understand the hierarchical relationships between - # different configuration files. - - # Besides, users can also using `import` syntax to import corresponding - # module which will be filled in the `type` field. It means users - # can directly navigate to the source of the module in the - # configuration file by clicking the `type` field. - - # To avoid really importing the third party package like `torch` - # during import `type` object, we use `_parse_lazy_import` to parse the - # configuration file, which will not actually trigger the import - # process, but simply parse the imported `type`s as LazyObject objects. - - # The overall pipeline of _parse_lazy_import is: - # 1. Parse the base module from the config file. - # || - # \/ - # base_module = ['mmdet.configs.default_runtime'] - # || - # \/ - # 2. recursively parse the base module and gather imported objects to - # a dict. - # || - # \/ - # The base_dict will be: - # { - # 'mmdet.configs.default_runtime': {...} - # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...} - # ... - # }, each item in base_dict is a dict of `LazyObject` - # 3. parse the current config file filling the imported variable - # with the base_dict. - # - # 4. During the parsing process, all imported variable will be - # recorded in the `imported_names` set. These variables can be - # accessed, but will not be dumped by default. - - with open(filename, encoding='utf-8') as f: - global_dict = {'LazyObject': LazyObject, '__file__': filename} - base_dict = {} - - parsed_codes = ast.parse(f.read()) - # get the names of base modules, and remove the - # `with read_base():'` statement - base_modules = Config._get_base_modules(parsed_codes.body) - base_imported_names = set() - for base_module in base_modules: - # If base_module means a relative import, assuming the level is - # 2, which means the module is imported like - # "from ..a.b import c". we must ensure that c is an - # object `defined` in module b, and module b should not be a - # package including `__init__` file but a single python file. - level = len(re.match(r'\.*', base_module).group()) - if level > 0: - # Relative import - base_dir = osp.dirname(filename) - module_path = osp.join( - base_dir, *(['..'] * (level - 1)), - f'{base_module[level:].replace(".", "/")}.py') - else: - # Absolute import - module_list = base_module.split('.') - if len(module_list) == 1: - raise ConfigParsingError( - 'The imported configuration file should not be ' - f'an independent package {module_list[0]}. Here ' - 'is an example: ' - '`with read_base(): from mmdet.configs.retinanet_r50_fpn_1x_coco import *`' # noqa: E501 - ) - else: - package = module_list[0] - root_path = get_installed_path(package) - module_path = f'{osp.join(root_path, *module_list[1:])}.py' # noqa: E501 - if not osp.isfile(module_path): - raise ConfigParsingError( - f'{module_path} not found! It means that incorrect ' - 'module is defined in ' - f'`with read_base(): = from {base_module} import ...`, please ' # noqa: E501 - 'make sure the base config module is valid ' - 'and is consistent with the prior import ' - 'logic') - _base_cfg_dict, _base_imported_names = Config._parse_lazy_import( # noqa: E501 - module_path) - base_imported_names |= _base_imported_names - # The base_dict will be: - # { - # 'mmdet.configs.default_runtime': {...} - # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...} - # ... - # } - base_dict[base_module] = _base_cfg_dict - - # `base_dict` contains all the imported modules from `base_cfg`. - # In order to collect the specific imported module from `base_cfg` - # before parse the current file, we using AST Transform to - # transverse the imported module from base_cfg and merge then into - # the global dict. After the ast transformation, most of import - # syntax will be removed (except for the builtin import) and - # replaced with the `LazyObject` - transform = ImportTransformer( - global_dict=global_dict, - base_dict=base_dict, - filename=filename) - modified_code = transform.visit(parsed_codes) - modified_code, abs_imported = _gather_abs_import_lazyobj( - modified_code, filename=filename) - imported_names = transform.imported_obj | abs_imported - imported_names |= base_imported_names - modified_code = ast.fix_missing_locations(modified_code) - exec( - compile(modified_code, filename, mode='exec'), global_dict, - global_dict) - - ret: dict = {} - for key, value in global_dict.items(): - if key.startswith('__') or key in ['LazyObject']: - continue - ret[key] = value - # convert dict to ConfigDict - cfg_dict = Config._dict_to_config_dict_lazy(ret) - - return cfg_dict, imported_names - - @staticmethod - def _dict_to_config_dict_lazy(cfg: dict): - """Recursively converts ``dict`` to :obj:`ConfigDict`. The only - difference between ``_dict_to_config_dict_lazy`` and - ``_dict_to_config_dict_lazy`` is that the former one does not consider - the scope, and will not trigger the building of ``LazyObject``. - - Args: - cfg (dict): Config dict. - - Returns: - ConfigDict: Converted dict. - """ - # Only the outer dict with key `type` should have the key `_scope_`. - if isinstance(cfg, dict): - cfg_dict = ConfigDict() - for key, value in cfg.items(): - cfg_dict[key] = Config._dict_to_config_dict_lazy(value) - return cfg_dict - if isinstance(cfg, (tuple, list)): - return type(cfg)( - Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) - return cfg - - @staticmethod - def _dict_to_config_dict(cfg: dict, - scope: Optional[str] = None, - has_scope=True): - """Recursively converts ``dict`` to :obj:`ConfigDict`. - - Args: - cfg (dict): Config dict. - scope (str, optional): Scope of instance. - has_scope (bool): Whether to add `_scope_` key to config dict. - - Returns: - ConfigDict: Converted dict. - """ - # Only the outer dict with key `type` should have the key `_scope_`. - if isinstance(cfg, dict): - if has_scope and 'type' in cfg: - has_scope = False - if scope is not None and cfg.get('_scope_', None) is None: - cfg._scope_ = scope # type: ignore - cfg = ConfigDict(cfg) - dict.__setattr__(cfg, 'scope', scope) - for key, value in cfg.items(): - cfg[key] = Config._dict_to_config_dict( - value, scope=scope, has_scope=has_scope) - elif isinstance(cfg, tuple): - cfg = tuple( - Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) - for _cfg in cfg) - elif isinstance(cfg, list): - cfg = [ - Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) - for _cfg in cfg - ] - return cfg - - @staticmethod - def _parse_scope(cfg: dict) -> None: - """Adds ``_scope_`` to :obj:`ConfigDict` instance, which means a base - variable. - - If the config dict already has the scope, scope will not be - overwritten. - - Args: - cfg (dict): Config needs to be parsed with scope. - """ - if isinstance(cfg, ConfigDict): - cfg._scope_ = cfg.scope - elif isinstance(cfg, (tuple, list)): - [Config._parse_scope(value) for value in cfg] - else: - return - - @staticmethod - def _get_base_files(filename: str) -> list: - """Get the base config file. - - Args: - filename (str): The config file. - - Raises: - TypeError: Name of config file. - - Returns: - list: A list of base config. - """ - file_format = osp.splitext(filename)[1] - if file_format == '.py': - Config._validate_py_syntax(filename) - with open(filename, encoding='utf-8') as f: - parsed_codes = ast.parse(f.read()).body - - def is_base_line(c): - return (isinstance(c, ast.Assign) - and isinstance(c.targets[0], ast.Name) - and c.targets[0].id == BASE_KEY) - - base_code = next((c for c in parsed_codes if is_base_line(c)), - None) - if base_code is not None: - base_code = ast.Expression( # type: ignore - body=base_code.value) # type: ignore - base_files = eval(compile(base_code, '', - mode='eval')) # type: ignore - else: - base_files = [] - elif file_format in ('.yml', '.yaml', '.json'): - import mmengine - cfg_dict = mmengine.load(filename) - base_files = cfg_dict.get(BASE_KEY, []) - else: - raise ConfigParsingError( - 'The config type should be py, json, yaml or ' - f'yml, but got {file_format}') - base_files = base_files if isinstance(base_files, - list) else [base_files] - return base_files - - @staticmethod - def _get_cfg_path(cfg_path: str, - filename: str) -> Tuple[str, Optional[str]]: - """Get the config path from the current or external package. - - Args: - cfg_path (str): Relative path of config. - filename (str): The config file being parsed. - - Returns: - Tuple[str, str or None]: Path and scope of config. If the config - is not an external config, the scope will be `None`. - """ - if '::' in cfg_path: - # `cfg_path` startswith '::' means an external config path. - # Get package name and relative config path. - scope = cfg_path.partition('::')[0] - package, cfg_path = _get_package_and_cfg_path(cfg_path) - - if not is_installed(package): - raise ModuleNotFoundError( - f'{package} is not installed, please install {package} ' - f'manually') - - # Get installed package path. - package_path = get_installed_path(package) - try: - # Get config path from meta file. - cfg_path = _get_external_cfg_path(package_path, cfg_path) - except ValueError: - # Since base config does not have a metafile, it should be - # concatenated with package path and relative config path. - cfg_path = _get_external_cfg_base_path(package_path, cfg_path) - except FileNotFoundError as e: - raise e - return cfg_path, scope - else: - # Get local config path. - cfg_dir = osp.dirname(filename) - cfg_path = osp.join(cfg_dir, cfg_path) - return cfg_path, None - - @staticmethod - def _merge_a_into_b(a: dict, - b: dict, - allow_list_keys: bool = False) -> dict: - """Merge dict ``a`` into dict ``b`` (non-inplace). - - Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid - in-place modifications. - - Args: - a (dict): The source dict to be merged into ``b``. - b (dict): The origin dict to be fetch keys from ``a``. - allow_list_keys (bool): If True, int string keys (e.g. '0', '1') - are allowed in source ``a`` and will replace the element of the - corresponding index in b if b is a list. Defaults to False. - - Returns: - dict: The modified dict of ``b`` using ``a``. - - Examples: - # Normally merge a into b. - >>> Config._merge_a_into_b( - ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) - {'obj': {'a': 2}} - - # Delete b first and merge a into b. - >>> Config._merge_a_into_b( - ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) - {'obj': {'a': 2}} - - # b is a list - >>> Config._merge_a_into_b( - ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) - [{'a': 2}, {'b': 2}] - """ - b = b.copy() - for k, v in a.items(): - if allow_list_keys and k.isdigit() and isinstance(b, list): - k = int(k) - if len(b) <= k: - raise KeyError(f'Index {k} exceeds the length of list {b}') - b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) - elif isinstance(v, dict): - if k in b and not v.pop(DELETE_KEY, False): - allowed_types: Union[Tuple, type] = ( - dict, list) if allow_list_keys else dict - if not isinstance(b[k], allowed_types): - raise TypeError( - f'{k}={v} in child config cannot inherit from ' - f'base because {k} is a dict in the child config ' - f'but is of type {type(b[k])} in base config. ' - f'You may set `{DELETE_KEY}=True` to ignore the ' - f'base config.') - b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) - else: - b[k] = ConfigDict(v) - else: - b[k] = v - return b - - @staticmethod - def auto_argparser(description=None): - """Generate argparser from config file automatically (experimental)""" - partial_parser = ArgumentParser(description=description) - partial_parser.add_argument('config', help='config file path') - cfg_file = partial_parser.parse_known_args()[0].config - cfg = Config.fromfile(cfg_file) - parser = ArgumentParser(description=description) - parser.add_argument('config', help='config file path') - add_args(parser, cfg) - return parser, cfg - - @property - def filename(self) -> str: - """Get file name of config.""" - return self._filename - - @property - def text(self) -> str: - """Get config text.""" - return self._text - - @property - def env_variables(self) -> dict: - """Get used environment variables.""" - return self._env_variables - - @property - def pretty_text(self) -> str: - """Get formatted python config text.""" - import yapf - from yapf.yapflib.yapf_api import FormatCode - - indent = 4 - - def _indent(s_, num_spaces): - s = s_.split('\n') - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s - return s - - def _format_basic_types(k, v, use_mapping=False): - if isinstance(v, str): - v_str = repr(v) - else: - v_str = str(v) - - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) - - return attr_str - - def _format_list_tuple(k, v, use_mapping=False): - if isinstance(v, list): - left = '[' - right = ']' - else: - left = '(' - right = ')' - - v_str = f'{left}\n' - # check if all items in the list are dict - for item in v: - if isinstance(item, dict): - v_str += f'dict({_indent(_format_dict(item), indent)}),\n' - elif isinstance(item, tuple): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, list): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, str): - v_str += f'{_indent(repr(item), indent)},\n' - else: - v_str += str(item) + ',\n' - if k is None: - return _indent(v_str, indent) + right - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + right - return attr_str - - def _contain_invalid_identifier(dict_str): - contain_invalid_identifier = False - for key_name in dict_str: - contain_invalid_identifier |= \ - (not str(key_name).isidentifier()) - return contain_invalid_identifier - - def _format_dict(input_dict, outest_level=False): - r = '' - s = [] - - use_mapping = _contain_invalid_identifier(input_dict) - if use_mapping: - r += '{' - for idx, (k, v) in enumerate( - sorted(input_dict.items(), key=lambda x: str(x[0]))): - is_last = idx >= len(input_dict) - 1 - end = '' if outest_level or is_last else ',' - if isinstance(v, dict): - v_str = '\n' + _format_dict(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: dict({v_str}' - else: - attr_str = f'{str(k)}=dict({v_str}' - attr_str = _indent(attr_str, indent) + ')' + end - elif isinstance(v, (list, tuple)): - attr_str = _format_list_tuple(k, v, use_mapping) + end - else: - attr_str = _format_basic_types(k, v, use_mapping) + end - - s.append(attr_str) - r += '\n'.join(s) - if use_mapping: - r += '}' - return r - - cfg_dict = self.to_dict() - text = _format_dict(cfg_dict, outest_level=True) - if self._format_python_code: - # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) - try: - if digit_version(yapf.__version__) >= digit_version('0.40.2'): - text, _ = FormatCode(text, style_config=yapf_style) - else: - text, _ = FormatCode( - text, style_config=yapf_style, verify=True) - except: # noqa: E722 - raise SyntaxError('Failed to format the config file, please ' - f'check the syntax of: \n{text}') - return text - - def __repr__(self): - return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' - - def __len__(self): - return len(self._cfg_dict) - - def __getattr__(self, name: str) -> Any: - return getattr(self._cfg_dict, name) - - def __getitem__(self, name): - return self._cfg_dict.__getitem__(name) - - def __setattr__(self, name, value): - if isinstance(value, dict): - value = ConfigDict(value) - self._cfg_dict.__setattr__(name, value) - - def __setitem__(self, name, value): - if isinstance(value, dict): - value = ConfigDict(value) - self._cfg_dict.__setitem__(name, value) - - def __iter__(self): - return iter(self._cfg_dict) - - def __getstate__( - self - ) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]: - state = (self._cfg_dict, self._filename, self._text, - self._env_variables, self._format_python_code, - self._imported_names) - return state - - def __deepcopy__(self, memo): - cls = self.__class__ - other = cls.__new__(cls) - memo[id(self)] = other - - for key, value in self.__dict__.items(): - super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) - - return other - - def __copy__(self): - cls = self.__class__ - other = cls.__new__(cls) - other.__dict__.update(self.__dict__) - super(Config, other).__setattr__('_cfg_dict', self._cfg_dict.copy()) - - return other - - copy = __copy__ - - def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], - dict, bool, set]): - super().__setattr__('_cfg_dict', state[0]) - super().__setattr__('_filename', state[1]) - super().__setattr__('_text', state[2]) - super().__setattr__('_env_variables', state[3]) - super().__setattr__('_format_python_code', state[4]) - super().__setattr__('_imported_names', state[5]) - - def dump(self, file: Optional[Union[str, Path]] = None): - """Dump config to file or return config text. - - Args: - file (str or Path, optional): If not specified, then the object - is dumped to a str, otherwise to a file specified by the filename. - Defaults to None. - - Returns: - str or None: Config text. - """ - file = str(file) if isinstance(file, Path) else file - cfg_dict = self.to_dict() - if file is None: - if self.filename is None or self.filename.endswith('.py'): - return self.pretty_text - else: - file_format = self.filename.split('.')[-1] - return dump(cfg_dict, file_format=file_format) - elif file.endswith('.py'): - with open(file, 'w', encoding='utf-8') as f: - f.write(self.pretty_text) - else: - file_format = file.split('.')[-1] - return dump(cfg_dict, file=file, file_format=file_format) - - def merge_from_dict(self, - options: dict, - allow_list_keys: bool = True) -> None: - """Merge list into cfg_dict. - - Merge the dict parsed by MultipleKVAction into this cfg. - - Args: - options (dict): dict of configs to merge from. - allow_list_keys (bool): If True, int string keys (e.g. '0', '1') - are allowed in ``options`` and will replace the element of the - corresponding index in the config if the config is a list. - Defaults to True. - - Examples: - >>> from mmengine import Config - >>> # Merge dictionary element - >>> options = {'model.backbone.depth': 50, 'model.backbone.with_cp': True} - >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) - >>> cfg.merge_from_dict(options) - >>> cfg._cfg_dict - {'model': {'backbone': {'type': 'ResNet', 'depth': 50, 'with_cp': True}}} - >>> # Merge list element - >>> cfg = Config( - >>> dict(pipeline=[dict(type='LoadImage'), - >>> dict(type='LoadAnnotations')])) - >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) - >>> cfg.merge_from_dict(options, allow_list_keys=True) - >>> cfg._cfg_dict - {'pipeline': [{'type': 'SelfLoadImage'}, {'type': 'LoadAnnotations'}]} - """ # noqa: E501 - option_cfg_dict: dict = {} - for full_key, v in options.items(): - d = option_cfg_dict - key_list = full_key.split('.') - for subkey in key_list[:-1]: - d.setdefault(subkey, ConfigDict()) - d = d[subkey] - subkey = key_list[-1] - d[subkey] = v - - cfg_dict = super().__getattribute__('_cfg_dict') - super().__setattr__( - '_cfg_dict', - Config._merge_a_into_b( - option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) - - @staticmethod - def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str: - if isinstance(cfg1, str): - cfg1 = Config.fromfile(cfg1) - - if isinstance(cfg2, str): - cfg2 = Config.fromfile(cfg2) - - res = difflib.unified_diff( - cfg1.pretty_text.split('\n'), cfg2.pretty_text.split('\n')) - - # Convert into rich format for better visualization - console = Console() - text = Text() - for line in res: - if line.startswith('+'): - color = 'bright_green' - elif line.startswith('-'): - color = 'bright_red' - else: - color = 'bright_white' - _text = Text(line + '\n') - _text.stylize(color) - text.append(_text) - - with console.capture() as capture: - console.print(text) - - return capture.get() - - @staticmethod - def _is_lazy_import(filename: str) -> bool: - if not filename.endswith('.py'): - return False - with open(filename, encoding='utf-8') as f: - codes_str = f.read() - parsed_codes = ast.parse(codes_str) - for node in ast.walk(parsed_codes): - if (isinstance(node, ast.Assign) - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id == BASE_KEY): - return False - - if isinstance(node, ast.With): - expr = node.items[0].context_expr - if (not isinstance(expr, ast.Call) - or not expr.func.id == 'read_base'): # type: ignore - raise ConfigParsingError( - 'Only `read_base` context manager can be used in the ' - 'config') - return True - if isinstance(node, ast.ImportFrom): - # relative import -> lazy_import - if node.level != 0: - return True - # Skip checking when using `mmengine.config` in cfg file - if (node.module == 'mmengine' and len(node.names) == 1 - and node.names[0].name == 'Config'): - continue - if not isinstance(node.module, str): - continue - # non-builtin module -> lazy_import - if not _is_builtin_module(node.module): - return True - if isinstance(node, ast.Import): - for alias_node in node.names: - if not _is_builtin_module(alias_node.name): - return True - return False - - def _to_lazy_dict(self, keep_imported: bool = False) -> dict: - """Convert config object to dictionary with lazy object, and filter the - imported object.""" - res = self._cfg_dict._to_lazy_dict() - if hasattr(self, '_imported_names') and not keep_imported: - res = { - key: value - for key, value in res.items() - if key not in self._imported_names - } - return res - - def to_dict(self, keep_imported: bool = False): - """Convert all data in the config to a builtin ``dict``. - - Args: - keep_imported (bool): Whether to keep the imported field. - Defaults to False - - If you import third-party objects in the config file, all imported - objects will be converted to a string like ``torch.optim.SGD`` - """ - cfg_dict = self._cfg_dict.to_dict() - if hasattr(self, '_imported_names') and not keep_imported: - cfg_dict = { - key: value - for key, value in cfg_dict.items() - if key not in self._imported_names - } - return cfg_dict - - -class DictAction(Action): - """Argparse action to split an argument into KEY=VALUE form on the first = - and append to a dictionary. - - List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3', - or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested - brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' - """ - - @staticmethod - def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]: - """Parse int/float/bool value in the string.""" - try: - return int(val) - except ValueError: - pass - try: - return float(val) - except ValueError: - pass - if val.lower() in ['true', 'false']: - return True if val.lower() == 'true' else False - if val == 'None': - return None - return val - - @staticmethod - def _parse_iterable(val: str) -> Union[list, tuple, Any]: - """Parse iterable values in the string. - - All elements inside '()' or '[]' are treated as iterable values. - - Args: - val (str): Value string. - - Returns: - list | tuple | Any: The expanded list or tuple from the string, - or single value if no iterable values are found. - - Examples: - >>> DictAction._parse_iterable('1,2,3') - [1, 2, 3] - >>> DictAction._parse_iterable('[a, b, c]') - ['a', 'b', 'c'] - >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') - [(1, 2, 3), ['a', 'b'], 'c'] - """ - - def find_next_comma(string): - """Find the position of next comma in the string. - - If no ',' is found in the string, return the string length. All - chars inside '()' and '[]' are treated as one element and thus ',' - inside these brackets are ignored. - """ - assert (string.count('(') == string.count(')')) and ( - string.count('[') == string.count(']')), \ - f'Imbalanced brackets exist in {string}' - end = len(string) - for idx, char in enumerate(string): - pre = string[:idx] - # The string before this ',' is balanced - if ((char == ',') and (pre.count('(') == pre.count(')')) - and (pre.count('[') == pre.count(']'))): - end = idx - break - return end - - # Strip ' and " characters and replace whitespace. - val = val.strip('\'\"').replace(' ', '') - is_tuple = False - if val.startswith('(') and val.endswith(')'): - is_tuple = True - val = val[1:-1] - elif val.startswith('[') and val.endswith(']'): - val = val[1:-1] - elif ',' not in val: - # val is a single value - return DictAction._parse_int_float_bool(val) - - values = [] - while len(val) > 0: - comma_idx = find_next_comma(val) - element = DictAction._parse_iterable(val[:comma_idx]) - values.append(element) - val = val[comma_idx + 1:] - - if is_tuple: - return tuple(values) - - return values - - def __call__(self, - parser: ArgumentParser, - namespace: Namespace, - values: Union[str, Sequence[Any], None], - option_string: str = None): # type: ignore - """Parse Variables in string and add them into argparser. - - Args: - parser (ArgumentParser): Argument parser. - namespace (Namespace): Argument namespace. - values (Union[str, Sequence[Any], None]): Argument string. - option_string (list[str], optional): Option string. - Defaults to None. - """ - # Copied behavior from `argparse._ExtendAction`. - options = copy.copy(getattr(namespace, self.dest, None) or {}) - if values is not None: - for kv in values: - key, val = kv.split('=', maxsplit=1) - options[key] = self._parse_iterable(val) - setattr(namespace, self.dest, options) - - -@contextmanager -def read_base(): - """Context manager to mark the base config. - - The pure Python-style configuration file allows you to use the import - syntax. However, it is important to note that you need to import the base - configuration file within the context of ``read_base``, and import other - dependencies outside of it. - - You can see more usage of Python-style configuration in the `tutorial`_ - - .. _tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta - """ # noqa: E501 - yield diff --git a/mmengine/config/lazy.py b/mmengine/config/lazy.py deleted file mode 100644 index e83cce7c89..0000000000 --- a/mmengine/config/lazy.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import importlib -from typing import Any, Optional, Union - -from mmengine.utils import is_seq_of - - -class LazyObject: - """LazyObject is used to lazily initialize the imported module during - parsing the configuration file. - - During parsing process, the syntax like: - - Examples: - >>> import torch.nn as nn - >>> from mmdet.models import RetinaNet - >>> import mmcls.models - >>> import mmcls.datasets - >>> import mmcls - - Will be parsed as: - - Examples: - >>> # import torch.nn as nn - >>> nn = lazyObject('torch.nn') - >>> # from mmdet.models import RetinaNet - >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') - >>> # import mmcls.models; import mmcls.datasets; import mmcls - >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) - - ``LazyObject`` records all module information and will be further - referenced by the configuration file. - - Args: - module (str or list or tuple): The module name to be imported. - imported (str, optional): The imported module name. Defaults to None. - location (str, optional): The filename and line number of the imported - module statement happened. - """ - - def __init__(self, - module: Union[str, list, tuple], - imported: Optional[str] = None, - location: Optional[str] = None): - if not isinstance(module, str) and not is_seq_of(module, str): - raise TypeError('module should be `str`, `list`, or `tuple`' - f'but got {type(module)}, this might be ' - 'a bug of MMEngine, please report it to ' - 'https://github.com/open-mmlab/mmengine/issues') - self._module: Union[str, list, tuple] = module - - if not isinstance(imported, str) and imported is not None: - raise TypeError('imported should be `str` or None, but got ' - f'{type(imported)}, this might be ' - 'a bug of MMEngine, please report it to ' - 'https://github.com/open-mmlab/mmengine/issues') - self._imported = imported - self.location = location - - def build(self) -> Any: - """Return imported object. - - Returns: - Any: Imported object - """ - if isinstance(self._module, str): - try: - module = importlib.import_module(self._module) - except Exception as e: - raise type(e)(f'Failed to import {self._module} ' - f'in {self.location} for {e}') - - if self._imported is not None: - if hasattr(module, self._imported): - module = getattr(module, self._imported) - else: - raise ImportError( - f'Failed to import {self._imported} ' - f'from {self._module} in {self.location}') - - return module - else: - # import xxx.xxx - # import xxx.yyy - # import xxx.zzz - # return imported xxx - try: - for module in self._module: - importlib.import_module(module) # type: ignore - module_name = self._module[0].split('.')[0] - return importlib.import_module(module_name) - except Exception as e: - raise type(e)(f'Failed to import {self.module} ' - f'in {self.location} for {e}') - - @property - def module(self): - if isinstance(self._module, str): - return self._module - return self._module[0].split('.')[0] - - def __call__(self, *args, **kwargs): - raise RuntimeError() - - def __deepcopy__(self, memo): - return LazyObject(self._module, self._imported, self.location) - - def __getattr__(self, name): - # Cannot locate the line number of the getting attribute. - # Therefore only record the filename. - if self.location is not None: - location = self.location.split(', line')[0] - else: - location = self.location - return LazyAttr(name, self, location) - - def __str__(self) -> str: - if self._imported is not None: - return self._imported - return self.module - - __repr__ = __str__ - - # `pickle.dump` will try to get the `__getstate__` and `__setstate__` - # methods of the dumped object. If these two methods are not defined, - # LazyObject will return a `__getstate__` LazyObject` or `__setstate__` - # LazyObject. - def __getstate__(self): - return self.__dict__ - - def __setstate__(self, state): - self.__dict__ = state - - -class LazyAttr: - """The attribute of the LazyObject. - - When parsing the configuration file, the imported syntax will be - parsed as the assignment ``LazyObject``. During the subsequent parsing - process, users may reference the attributes of the LazyObject. - To ensure that these attributes also contain information needed to - reconstruct the attribute itself, LazyAttr was introduced. - - Examples: - >>> models = LazyObject(['mmdet.models']) - >>> model = dict(type=models.RetinaNet) - >>> print(type(model['type'])) # - >>> print(model['type'].build()) # - """ # noqa: E501 - - def __init__(self, - name: str, - source: Union['LazyObject', 'LazyAttr'], - location=None): - self.name = name - self.source: Union[LazyAttr, LazyObject] = source - - if isinstance(self.source, LazyObject): - if isinstance(self.source._module, str): - if self.source._imported is None: - # source code: - # from xxx.yyy import zzz - # equivalent code: - # zzz = LazyObject('xxx.yyy', 'zzz') - # The source code of get attribute: - # eee = zzz.eee - # Then, `eee._module` should be "xxx.yyy.zzz" - self._module = self.source._module - else: - # source code: - # import xxx.yyy as zzz - # equivalent code: - # zzz = LazyObject('xxx.yyy') - # The source code of get attribute: - # eee = zzz.eee - # Then, `eee._module` should be "xxx.yyy" - self._module = f'{self.source._module}.{self.source}' - else: - # The source code of LazyObject should be - # 1. import xxx.yyy - # 2. import xxx.zzz - # Equivalent to - # xxx = LazyObject(['xxx.yyy', 'xxx.zzz']) - - # The source code of LazyAttr should be - # eee = xxx.eee - # Then, eee._module = xxx - self._module = str(self.source) - elif isinstance(self.source, LazyAttr): - # 1. import xxx - # 2. zzz = xxx.yyy.zzz - - # Equivalent to: - # xxx = LazyObject('xxx') - # zzz = xxx.yyy.zzz - # zzz._module = xxx.yyy._module + zzz.name - self._module = f'{self.source._module}.{self.source.name}' - self.location = location - - @property - def module(self): - return self._module - - def __call__(self, *args, **kwargs: Any) -> Any: - raise RuntimeError() - - def __getattr__(self, name: str) -> 'LazyAttr': - return LazyAttr(name, self) - - def __deepcopy__(self, memo): - return LazyAttr(self.name, self.source) - - def build(self) -> Any: - """Return the attribute of the imported object. - - Returns: - Any: attribute of the imported object. - """ - obj = self.source.build() - try: - return getattr(obj, self.name) - except AttributeError: - raise ImportError(f'Failed to import {self.module}.{self.name} in ' - f'{self.location}') - except ImportError as e: - raise e - - def __str__(self) -> str: - return self.name - - __repr__ = __str__ - - # `pickle.dump` will try to get the `__getstate__` and `__setstate__` - # methods of the dumped object. If these two methods are not defined, - # LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__` - # LazyAttr. - def __getstate__(self): - return self.__dict__ - - def __setstate__(self, state): - self.__dict__ = state diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py deleted file mode 100644 index 81b58fb49a..0000000000 --- a/mmengine/config/utils.py +++ /dev/null @@ -1,469 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import ast -import os.path as osp -import re -import sys -import warnings -from collections import defaultdict -from importlib.util import find_spec -from typing import List, Optional, Tuple, Union - -from mmengine.fileio import load -from mmengine.utils import check_file_exist - -PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable)) -SYSTEM_PYTHON_PREFIX = '/usr/lib/python' - -MODULE2PACKAGE = { - 'mmcls': 'mmcls', - 'mmdet': 'mmdet', - 'mmdet3d': 'mmdet3d', - 'mmseg': 'mmsegmentation', - 'mmaction': 'mmaction2', - 'mmtrack': 'mmtrack', - 'mmpose': 'mmpose', - 'mmedit': 'mmedit', - 'mmocr': 'mmocr', - 'mmgen': 'mmgen', - 'mmfewshot': 'mmfewshot', - 'mmrazor': 'mmrazor', - 'mmflow': 'mmflow', - 'mmhuman3d': 'mmhuman3d', - 'mmrotate': 'mmrotate', - 'mmselfsup': 'mmselfsup', - 'mmyolo': 'mmyolo', - 'mmpretrain': 'mmpretrain', - 'mmagic': 'mmagic', -} - -# PKG2PROJECT is not a proper name to represent the mapping between module name -# (module import from) and package name (used by pip install). Therefore, -# PKG2PROJECT will be deprecated and this alias will only be kept until -# MMEngine v1.0.0 -PKG2PROJECT = MODULE2PACKAGE - - -class ConfigParsingError(RuntimeError): - """Raise error when failed to parse pure Python style config files.""" - - -def _get_cfg_metainfo(package_path: str, cfg_path: str) -> dict: - """Get target meta information from all 'metafile.yml' defined in `mode- - index.yml` of external package. - - Args: - package_path (str): Path of external package. - cfg_path (str): Name of experiment config. - - Returns: - dict: Meta information of target experiment. - """ - meta_index_path = osp.join(package_path, '.mim', 'model-index.yml') - meta_index = load(meta_index_path) - cfg_dict = dict() - for meta_path in meta_index['Import']: - meta_path = osp.join(package_path, '.mim', meta_path) - cfg_meta = load(meta_path) - for model_cfg in cfg_meta['Models']: - if 'Config' not in model_cfg: - warnings.warn(f'There is not `Config` define in {model_cfg}') - continue - cfg_name = model_cfg['Config'].partition('/')[-1] - # Some config could have multiple weights, we only pick the - # first one. - if cfg_name in cfg_dict: - continue - cfg_dict[cfg_name] = model_cfg - if cfg_path not in cfg_dict: - raise ValueError(f'Expected configs: {cfg_dict.keys()}, but got ' - f'{cfg_path}') - return cfg_dict[cfg_path] - - -def _get_external_cfg_path(package_path: str, cfg_file: str) -> str: - """Get config path of external package. - - Args: - package_path (str): Path of external package. - cfg_file (str): Name of experiment config. - - Returns: - str: Absolute config path from external package. - """ - cfg_file = cfg_file.split('.')[0] - model_cfg = _get_cfg_metainfo(package_path, cfg_file) - cfg_path = osp.join(package_path, model_cfg['Config']) - check_file_exist(cfg_path) - return cfg_path - - -def _get_external_cfg_base_path(package_path: str, cfg_name: str) -> str: - """Get base config path of external package. - - Args: - package_path (str): Path of external package. - cfg_name (str): External relative config path with 'package::'. - - Returns: - str: Absolute config path from external package. - """ - cfg_path = osp.join(package_path, '.mim', 'configs', cfg_name) - check_file_exist(cfg_path) - return cfg_path - - -def _get_package_and_cfg_path(cfg_path: str) -> Tuple[str, str]: - """Get package name and relative config path. - - Args: - cfg_path (str): External relative config path with 'package::'. - - Returns: - Tuple[str, str]: Package name and config path. - """ - if re.match(r'\w*::\w*/\w*', cfg_path) is None: - raise ValueError( - '`_get_package_and_cfg_path` is used for get external package, ' - 'please specify the package name and relative config path, just ' - 'like `mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py`') - package_cfg = cfg_path.split('::') - if len(package_cfg) > 2: - raise ValueError('`::` should only be used to separate package and ' - 'config name, but found multiple `::` in ' - f'{cfg_path}') - package, cfg_path = package_cfg - assert package in MODULE2PACKAGE, ( - f'mmengine does not support to load {package} config.') - package = MODULE2PACKAGE[package] - return package, cfg_path - - -class RemoveAssignFromAST(ast.NodeTransformer): - """Remove Assign node if the target's name match the key. - - Args: - key (str): The target name of the Assign node. - """ - - def __init__(self, key): - self.key = key - - def visit_Assign(self, node): - if (isinstance(node.targets[0], ast.Name) - and node.targets[0].id == self.key): - return None - else: - return node - - -def _is_builtin_module(module_name: str) -> bool: - """Check if a module is a built-in module. - - Arg: - module_name: name of module. - """ - if module_name.startswith('.'): - return False - if module_name.startswith('mmengine.config'): - return True - if module_name in sys.builtin_module_names: - return True - spec = find_spec(module_name.split('.')[0]) - # Module not found - if spec is None: - return False - origin_path = getattr(spec, 'origin', None) - if origin_path is None: - return False - origin_path = osp.abspath(origin_path) - if ('site-package' in origin_path or 'dist-package' in origin_path - or not origin_path.startswith( - (PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))): - return False - else: - return True - - -class ImportTransformer(ast.NodeTransformer): - """Convert the import syntax to the assignment of - :class:`mmengine.config.LazyObject` and preload the base variable before - parsing the configuration file. - - Since you are already looking at this part of the code, I believe you must - be interested in the mechanism of the ``lazy_import`` feature of - :class:`Config`. In this docstring, we will dive deeper into its - principles. - - Most of OpenMMLab users maybe bothered with that: - - * In most of popular IDEs, they cannot navigate to the source code in - configuration file - * In most of popular IDEs, they cannot jump to the base file in current - configuration file, which is much painful when the inheritance - relationship is complex. - - In order to solve this problem, we introduce the ``lazy_import`` mode. - - A very intuitive idea for solving this problem is to import the module - corresponding to the "type" field using the ``import`` syntax. Similarly, - we can also ``import`` base file. - - However, this approach has a significant drawback. It requires triggering - the import logic to parse the configuration file, which can be - time-consuming. Additionally, it implies downloading numerous dependencies - solely for the purpose of parsing the configuration file. - However, it's possible that only a portion of the config will actually be - used. For instance, the package used in the ``train_pipeline`` may not - be necessary for an evaluation task. Forcing users to download these - unused packages is not a desirable solution. - - To avoid this problem, we introduce :class:`mmengine.config.LazyObject` and - :class:`mmengine.config.LazyAttr`. Before we proceed with further - explanations, you may refer to the documentation of these two modules to - gain an understanding of their functionalities. - - Actually, one of the functions of ``ImportTransformer`` is to hack the - ``import`` syntax. It will replace the import syntax - (exclude import the base files) with the assignment of ``LazyObject``. - - As for the import syntax of the base file, we cannot lazy import it since - we're eager to merge the fields of current file and base files. Therefore, - another function of the ``ImportTransformer`` is to collaborate with - ``Config._parse_lazy_import`` to parse the base files. - - Args: - global_dict (dict): The global dict of the current configuration file. - If we divide ordinary Python syntax into two parts, namely the - import section and the non-import section (assuming a simple case - with imports at the beginning and the rest of the code following), - the variables generated by the import statements are stored in - global variables for subsequent code use. In this context, - the ``global_dict`` represents the global variables required when - executing the non-import code. ``global_dict`` will be filled - during visiting the parsed code. - base_dict (dict): All variables defined in base files. - - Examples: - >>> from mmengine.config import read_base - >>> - >>> - >>> with read_base(): - >>> from .._base_.default_runtime import * - >>> from .._base_.datasets.coco_detection import dataset - - In this case, the base_dict will be: - - Examples: - >>> base_dict = { - >>> '.._base_.default_runtime': ... - >>> '.._base_.datasets.coco_detection': dataset} - - and `global_dict` will be updated like this: - - Examples: - >>> global_dict.update(base_dict['.._base_.default_runtime']) # `import *` means update all data - >>> global_dict.update(dataset=base_dict['.._base_.datasets.coco_detection']['dataset']) # only update `dataset` - """ # noqa: E501 - - def __init__(self, - global_dict: dict, - base_dict: Optional[dict] = None, - filename: Optional[str] = None): - self.base_dict = base_dict if base_dict is not None else {} - self.global_dict = global_dict - # In Windows, the filename could be like this: - # "C:\\Users\\runneradmin\\AppData\\Local\\" - # Although it has been an raw string, ast.parse will firstly escape - # it as the executed code: - # "C:\Users\runneradmin\AppData\Local\\\" - # As you see, the `\U` will be treated as a part of - # the escape sequence during code parsing, leading to an - # parsing error - # Here we use `encode('unicode_escape').decode()` for double escaping - if isinstance(filename, str): - filename = filename.encode('unicode_escape').decode() - self.filename = filename - self.imported_obj: set = set() - super().__init__() - - def visit_ImportFrom( - self, node: ast.ImportFrom - ) -> Optional[Union[List[ast.Assign], ast.ImportFrom]]: - """Hack the ``from ... import ...`` syntax and update the global_dict. - - Examples: - >>> from mmdet.models import RetinaNet - - Will be parsed as: - - Examples: - >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') - - ``global_dict`` will also be updated by ``base_dict`` as the - class docstring says. - - Args: - node (ast.AST): The node of the current import statement. - - Returns: - Optional[List[ast.Assign]]: There three cases: - - * If the node is a statement of importing base files. - None will be returned. - * If the node is a statement of importing a builtin module, - node will be directly returned - * Otherwise, it will return the assignment statements of - ``LazyObject``. - """ - # Built-in modules will not be parsed as LazyObject - module = f'{node.level*"."}{node.module}' - if _is_builtin_module(module): - # Make sure builtin module will be added into `self.imported_obj` - for alias in node.names: - if alias.asname is not None: - self.imported_obj.add(alias.asname) - elif alias.name == '*': - raise ConfigParsingError( - 'Cannot import * from non-base config') - else: - self.imported_obj.add(alias.name) - return node - - if module in self.base_dict: - for alias_node in node.names: - if alias_node.name == '*': - self.global_dict.update(self.base_dict[module]) - return None - if alias_node.asname is not None: - base_key = alias_node.asname - else: - base_key = alias_node.name - self.global_dict[base_key] = self.base_dict[module][ - alias_node.name] - return None - - nodes: List[ast.Assign] = [] - for alias_node in node.names: - # `ast.alias` has lineno attr after Python 3.10, - if hasattr(alias_node, 'lineno'): - lineno = alias_node.lineno - else: - lineno = node.lineno - if alias_node.name == '*': - # TODO: If users import * from a non-config module, it should - # fallback to import the real module and raise a warning to - # remind users the real module will be imported which will slow - # down the parsing speed. - raise ConfigParsingError( - 'Illegal syntax in config! `from xxx import *` is not ' - 'allowed to appear outside the `if base:` statement') - elif alias_node.asname is not None: - # case1: - # from mmengine.dataset import BaseDataset as Dataset -> - # Dataset = LazyObject('mmengine.dataset', 'BaseDataset') - code = f'{alias_node.asname} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501 - self.imported_obj.add(alias_node.asname) - else: - # case2: - # from mmengine.model import BaseModel - # BaseModel = LazyObject('mmengine.model', 'BaseModel') - code = f'{alias_node.name} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501 - self.imported_obj.add(alias_node.name) - try: - nodes.append(ast.parse(code).body[0]) # type: ignore - except Exception as e: - raise ConfigParsingError( - f'Cannot import {alias_node} from {module}' - '1. Cannot import * from 3rd party lib in the config ' - 'file\n' - '2. Please check if the module is a base config which ' - 'should be added to `_base_`\n') from e - return nodes - - def visit_Import(self, node) -> Union[ast.Assign, ast.Import]: - """Work with ``_gather_abs_import_lazyobj`` to hack the ``import ...`` - syntax. - - Examples: - >>> import mmcls.models - >>> import mmcls.datasets - >>> import mmcls - - Will be parsed as: - - Examples: - >>> # import mmcls.models; import mmcls.datasets; import mmcls - >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) - - Args: - node (ast.AST): The node of the current import statement. - - Returns: - ast.Assign: If the import statement is ``import ... as ...``, - ast.Assign will be returned, otherwise node will be directly - returned. - """ - # For absolute import like: `import mmdet.configs as configs`. - # It will be parsed as: - # configs = LazyObject('mmdet.configs') - # For absolute import like: - # `import mmdet.configs` - # `import mmdet.configs.default_runtime` - # This will be parsed as - # mmdet = LazyObject(['mmdet.configs.default_runtime', 'mmdet.configs]) - # However, visit_Import cannot gather other import information, so - # `_gather_abs_import_LazyObject` will gather all import information - # from the same module and construct the LazyObject. - alias_list = node.names - assert len(alias_list) == 1, ( - 'Illegal syntax in config! import multiple modules in one line is ' - 'not supported') - # TODO Support multiline import - alias = alias_list[0] - if alias.asname is not None: - self.imported_obj.add(alias.asname) - if _is_builtin_module(alias.name.split('.')[0]): - return node - return ast.parse( # type: ignore - f'{alias.asname} = LazyObject(' - f'"{alias.name}",' - f'location="{self.filename}, line {node.lineno}")').body[0] - return node - - -def _gather_abs_import_lazyobj(tree: ast.Module, - filename: Optional[str] = None): - """Experimental implementation of gathering absolute import information.""" - if isinstance(filename, str): - filename = filename.encode('unicode_escape').decode() - imported = defaultdict(list) - abs_imported = set() - new_body: List[ast.stmt] = [] - # module2node is used to get lineno when Python < 3.10 - module2node: dict = dict() - for node in tree.body: - if isinstance(node, ast.Import): - for alias in node.names: - # Skip converting built-in module to LazyObject - if _is_builtin_module(alias.name): - new_body.append(node) - continue - module = alias.name.split('.')[0] - module2node.setdefault(module, node) - imported[module].append(alias) - continue - new_body.append(node) - - for key, value in imported.items(): - names = [_value.name for _value in value] - if hasattr(value[0], 'lineno'): - lineno = value[0].lineno - else: - lineno = module2node[key].lineno - lazy_module_assign = ast.parse( - f'{key} = LazyObject({names}, location="{filename}, line {lineno}")' # noqa: E501 - ) # noqa: E501 - abs_imported.add(key) - new_body.insert(0, lazy_module_assign.body[0]) - tree.body = new_body - return tree, abs_imported diff --git a/mmengine/dataset/__init__.py b/mmengine/dataset/__init__.py deleted file mode 100644 index c58ef983f4..0000000000 --- a/mmengine/dataset/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_dataset import BaseDataset, Compose, force_full_init -from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset -from .sampler import DefaultSampler, InfiniteSampler -from .utils import (COLLATE_FUNCTIONS, default_collate, pseudo_collate, - worker_init_fn) - -__all__ = [ - 'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset', - 'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler', - 'worker_init_fn', 'pseudo_collate', 'COLLATE_FUNCTIONS', 'default_collate' -] diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py deleted file mode 100644 index 4622f146a5..0000000000 --- a/mmengine/dataset/base_dataset.py +++ /dev/null @@ -1,826 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import functools -import gc -import logging -import pickle -from collections.abc import Mapping -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union - -import numpy as np -from torch.utils.data import Dataset - -from mmengine.config import Config -from mmengine.fileio import join_path, list_from_file, load -from mmengine.logging import print_log -from mmengine.registry import TRANSFORMS -from mmengine.utils import is_abs - - -class Compose: - """Compose multiple transforms sequentially. - - Args: - transforms (Sequence[dict, callable], optional): Sequence of transform - object or config dict to be composed. - """ - - def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]): - self.transforms: List[Callable] = [] - - if transforms is None: - transforms = [] - - for transform in transforms: - # `Compose` can be built with config dict with type and - # corresponding arguments. - if isinstance(transform, dict): - transform = TRANSFORMS.build(transform) - if not callable(transform): - raise TypeError(f'transform should be a callable object, ' - f'but got {type(transform)}') - self.transforms.append(transform) - elif callable(transform): - self.transforms.append(transform) - else: - raise TypeError( - f'transform must be a callable object or dict, ' - f'but got {type(transform)}') - - def __call__(self, data: dict) -> Optional[dict]: - """Call function to apply transforms sequentially. - - Args: - data (dict): A result dict contains the data to transform. - - Returns: - dict: Transformed data. - """ - for t in self.transforms: - data = t(data) - # The transform will return None when it failed to load images or - # cannot find suitable augmentation parameters to augment the data. - # Here we simply return None if the transform returns None and the - # dataset will handle it by randomly selecting another data sample. - if data is None: - return None - return data - - def __repr__(self): - """Print ``self.transforms`` in sequence. - - Returns: - str: Formatted string. - """ - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += f' {t}' - format_string += '\n)' - return format_string - - -def force_full_init(old_func: Callable) -> Any: - """Those methods decorated by ``force_full_init`` will be forced to call - ``full_init`` if the instance has not been fully initiated. - - Args: - old_func (Callable): Decorated function, make sure the first arg is an - instance with ``full_init`` method. - - Returns: - Any: Depends on old_func. - """ - - @functools.wraps(old_func) - def wrapper(obj: object, *args, **kwargs): - # The instance must have `full_init` method. - if not hasattr(obj, 'full_init'): - raise AttributeError(f'{type(obj)} does not have full_init ' - 'method.') - # If instance does not have `_fully_initialized` attribute or - # `_fully_initialized` is False, call `full_init` and set - # `_fully_initialized` to True - if not getattr(obj, '_fully_initialized', False): - print_log( - f'Attribute `_fully_initialized` is not defined in ' - f'{type(obj)} or `type(obj)._fully_initialized is ' - 'False, `full_init` will be called and ' - f'{type(obj)}._fully_initialized will be set to True', - logger='current', - level=logging.WARNING) - obj.full_init() # type: ignore - obj._fully_initialized = True # type: ignore - - return old_func(obj, *args, **kwargs) - - return wrapper - - -class BaseDataset(Dataset): - r"""BaseDataset for open source projects in OpenMMLab. - - The annotation format is shown as follows. - - .. code-block:: none - - { - "metainfo": - { - "dataset_type": "test_dataset", - "task_name": "test_task" - }, - "data_list": - [ - { - "img_path": "test_img.jpg", - "height": 604, - "width": 640, - "instances": - [ - { - "bbox": [0, 0, 10, 20], - "bbox_label": 1, - "mask": [[0,0],[0,10],[10,20],[20,0]], - "extra_anns": [1,2,3] - }, - { - "bbox": [10, 10, 110, 120], - "bbox_label": 2, - "mask": [[10,10],[10,110],[110,120],[120,10]], - "extra_anns": [4,5,6] - } - ] - }, - ] - } - - Args: - ann_file (str, optional): Annotation file path. Defaults to ''. - metainfo (Mapping or Config, optional): Meta information for - dataset, such as class information. Defaults to None. - data_root (str, optional): The root directory for ``data_prefix`` and - ``ann_file``. Defaults to ''. - data_prefix (dict): Prefix for training data. Defaults to - dict(img_path=''). - filter_cfg (dict, optional): Config for filter data. Defaults to None. - indices (int or Sequence[int], optional): Support using first few - data in annotation file to facilitate training/testing on a smaller - serialize_data (bool, optional): Whether to hold memory using - serialized objects, when enabled, data loader workers can use - shared RAM from master process instead of making a copy. Defaults - to True. - pipeline (list, optional): Processing pipeline. Defaults to []. - test_mode (bool, optional): ``test_mode=True`` means in test phase. - Defaults to False. - lazy_init (bool, optional): Whether to load annotation during - instantiation. In some cases, such as visualization, only the meta - information of the dataset is needed, which is not necessary to - load annotation file. ``Basedataset`` can skip load annotations to - save time by set ``lazy_init=True``. Defaults to False. - max_refetch (int, optional): If ``Basedataset.prepare_data`` get a - None img. The maximum extra number of cycles to get a valid - image. Defaults to 1000. - - Note: - BaseDataset collects meta information from ``annotation file`` (the - lowest priority), ``BaseDataset.METAINFO``(medium) and ``metainfo - parameter`` (highest) passed to constructors. The lower priority meta - information will be overwritten by higher one. - - Note: - Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc. - should not inherit from ``BaseDataset`` since ``get_subset`` and - ``get_subset_`` could produce ambiguous meaning sub-dataset which - conflicts with original dataset. - - Examples: - >>> # Assume the annotation file is given above. - >>> class CustomDataset(BaseDataset): - >>> METAINFO: dict = dict(task_name='custom_task', - >>> dataset_type='custom_type') - >>> metainfo=dict(task_name='custom_task_name') - >>> custom_dataset = CustomDataset( - >>> 'path/to/ann_file', - >>> metainfo=metainfo) - >>> # meta information of annotation file will be overwritten by - >>> # `CustomDataset.METAINFO`. The merged meta information will - >>> # further be overwritten by argument `metainfo`. - >>> custom_dataset.metainfo - {'task_name': custom_task_name, dataset_type: custom_type} - """ - - METAINFO: dict = dict() - _fully_initialized: bool = False - - def __init__(self, - ann_file: Optional[str] = '', - metainfo: Union[Mapping, Config, None] = None, - data_root: Optional[str] = '', - data_prefix: dict = dict(img_path=''), - filter_cfg: Optional[dict] = None, - indices: Optional[Union[int, Sequence[int]]] = None, - serialize_data: bool = True, - pipeline: List[Union[dict, Callable]] = [], - test_mode: bool = False, - lazy_init: bool = False, - max_refetch: int = 1000): - self.ann_file = ann_file - self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) - self.data_root = data_root - self.data_prefix = copy.copy(data_prefix) - self.filter_cfg = copy.deepcopy(filter_cfg) - self._indices = indices - self.serialize_data = serialize_data - self.test_mode = test_mode - self.max_refetch = max_refetch - self.data_list: List[dict] = [] - self.data_bytes: np.ndarray - - # Join paths. - self._join_prefix() - - # Build pipeline. - self.pipeline = Compose(pipeline) - # Full initialize the dataset. - if not lazy_init: - self.full_init() - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index and automatically call ``full_init`` if the - dataset has not been fully initialized. - - Args: - idx (int): The index of data. - - Returns: - dict: The idx-th annotation of the dataset. - """ - if self.serialize_data: - start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() - end_addr = self.data_address[idx].item() - bytes = memoryview( - self.data_bytes[start_addr:end_addr]) # type: ignore - data_info = pickle.loads(bytes) # type: ignore - else: - data_info = copy.deepcopy(self.data_list[idx]) - # Some codebase needs `sample_idx` of data information. Here we convert - # the idx to a positive number and save it in data information. - if idx >= 0: - data_info['sample_idx'] = idx - else: - data_info['sample_idx'] = len(self) + idx - - return data_info - - def full_init(self): - """Load annotation file and set ``BaseDataset._fully_initialized`` to - True. - - If ``lazy_init=False``, ``full_init`` will be called during the - instantiation and ``self._fully_initialized`` will be set to True. If - ``obj._fully_initialized=False``, the class method decorated by - ``force_full_init`` will call ``full_init`` automatically. - - Several steps to initialize annotation: - - - load_data_list: Load annotations from annotation file. - - filter data information: Filter annotations according to - filter_cfg. - - slice_data: Slice dataset according to ``self._indices`` - - serialize_data: Serialize ``self.data_list`` if - ``self.serialize_data`` is True. - """ - if self._fully_initialized: - return - # load data information - self.data_list = self.load_data_list() - # filter illegal data, such as data that has no annotations. - self.data_list = self.filter_data() - # Get subset data according to indices. - if self._indices is not None: - self.data_list = self._get_unserialized_subset(self._indices) - - # serialize data_list - if self.serialize_data: - self.data_bytes, self.data_address = self._serialize_data() - - self._fully_initialized = True - - @property - def metainfo(self) -> dict: - """Get meta information of dataset. - - Returns: - dict: meta information collected from ``BaseDataset.METAINFO``, - annotation file and metainfo argument during instantiation. - """ - return copy.deepcopy(self._metainfo) - - def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: - """Parse raw annotation to target format. - - This method should return dict or list of dict. Each dict or list - contains the data information of a training sample. If the protocol of - the sample annotations is changed, this function can be overridden to - update the parsing logic while keeping compatibility. - - Args: - raw_data_info (dict): Raw data information load from ``ann_file`` - - Returns: - list or list[dict]: Parsed annotation. - """ - for prefix_key, prefix in self.data_prefix.items(): - assert prefix_key in raw_data_info, ( - f'raw_data_info: {raw_data_info} dose not contain prefix key' - f'{prefix_key}, please check your data_prefix.') - raw_data_info[prefix_key] = join_path(prefix, - raw_data_info[prefix_key]) - return raw_data_info - - def filter_data(self) -> List[dict]: - """Filter annotations according to filter_cfg. Defaults return all - ``data_list``. - - If some ``data_list`` could be filtered according to specific logic, - the subclass should override this method. - - Returns: - list[int]: Filtered results. - """ - return self.data_list - - def get_cat_ids(self, idx: int) -> List[int]: - """Get category ids by index. Dataset wrapped by ClassBalancedDataset - must implement this method. - - The ``ClassBalancedDataset`` requires a subclass which implements this - method. - - Args: - idx (int): The index of data. - - Returns: - list[int]: All categories in the image of specified index. - """ - raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` ' - 'method') - - def __getitem__(self, idx: int) -> dict: - """Get the idx-th image and data information of dataset after - ``self.pipeline``, and ``full_init`` will be called if the dataset has - not been fully initialized. - - During training phase, if ``self.pipeline`` get ``None``, - ``self._rand_another`` will be called until a valid image is fetched or - the maximum limit of refetech is reached. - - Args: - idx (int): The index of self.data_list. - - Returns: - dict: The idx-th image and data information of dataset after - ``self.pipeline``. - """ - # Performing full initialization by calling `__getitem__` will consume - # extra memory. If a dataset is not fully initialized by setting - # `lazy_init=True` and then fed into the dataloader. Different workers - # will simultaneously read and parse the annotation. It will cost more - # time and memory, although this may work. Therefore, it is recommended - # to manually call `full_init` before dataset fed into dataloader to - # ensure all workers use shared RAM from master process. - if not self._fully_initialized: - print_log( - 'Please call `full_init()` method manually to accelerate ' - 'the speed.', - logger='current', - level=logging.WARNING) - self.full_init() - - if self.test_mode: - data = self.prepare_data(idx) - if data is None: - raise Exception('Test time pipline should not get `None` ' - 'data_sample') - return data - - for _ in range(self.max_refetch + 1): - data = self.prepare_data(idx) - # Broken images or random augmentations may cause the returned data - # to be None - if data is None: - idx = self._rand_another() - continue - return data - - raise Exception(f'Cannot find valid image after {self.max_refetch}! ' - 'Please check your image path and pipeline') - - def load_data_list(self) -> List[dict]: - """Load annotations from an annotation file named as ``self.ann_file`` - - If the annotation file does not follow `OpenMMLab 2.0 format dataset - `_ . - The subclass must override this method for load annotations. The meta - information of annotation file will be overwritten :attr:`METAINFO` - and ``metainfo`` argument of constructor. - - Returns: - list[dict]: A list of annotation. - """ # noqa: E501 - # `self.ann_file` denotes the absolute annotation file path if - # `self.root=None` or relative path if `self.root=/path/to/data/`. - annotations = load(self.ann_file) - if not isinstance(annotations, dict): - raise TypeError(f'The annotations loaded from annotation file ' - f'should be a dict, but got {type(annotations)}!') - if 'data_list' not in annotations or 'metainfo' not in annotations: - raise ValueError('Annotation must have data_list and metainfo ' - 'keys') - metainfo = annotations['metainfo'] - raw_data_list = annotations['data_list'] - - # Meta information load from annotation file will not influence the - # existed meta information load from `BaseDataset.METAINFO` and - # `metainfo` arguments defined in constructor. - for k, v in metainfo.items(): - self._metainfo.setdefault(k, v) - - # load and parse data_infos. - data_list = [] - for raw_data_info in raw_data_list: - # parse raw data information to target format - data_info = self.parse_data_info(raw_data_info) - if isinstance(data_info, dict): - # For image tasks, `data_info` should information if single - # image, such as dict(img_path='xxx', width=360, ...) - data_list.append(data_info) - elif isinstance(data_info, list): - # For video tasks, `data_info` could contain image - # information of multiple frames, such as - # [dict(video_path='xxx', timestamps=...), - # dict(video_path='xxx', timestamps=...)] - for item in data_info: - if not isinstance(item, dict): - raise TypeError('data_info must be list of dict, but ' - f'got {type(item)}') - data_list.extend(data_info) - else: - raise TypeError('data_info should be a dict or list of dict, ' - f'but got {type(data_info)}') - - return data_list - - @classmethod - def _load_metainfo(cls, - metainfo: Union[Mapping, Config, None] = None) -> dict: - """Collect meta information from the dictionary of meta. - - Args: - metainfo (Mapping or Config, optional): Meta information dict. - If ``metainfo`` contains existed filename, it will be - parsed by ``list_from_file``. - - Returns: - dict: Parsed meta information. - """ - # avoid `cls.METAINFO` being overwritten by `metainfo` - cls_metainfo = copy.deepcopy(cls.METAINFO) - if metainfo is None: - return cls_metainfo - if not isinstance(metainfo, (Mapping, Config)): - raise TypeError('metainfo should be a Mapping or Config, ' - f'but got {type(metainfo)}') - - for k, v in metainfo.items(): - if isinstance(v, str): - # If type of value is string, and can be loaded from - # corresponding backend. it means the file name of meta file. - try: - cls_metainfo[k] = list_from_file(v) - except (TypeError, FileNotFoundError): - print_log( - f'{v} is not a meta file, simply parsed as meta ' - 'information', - logger='current', - level=logging.WARNING) - cls_metainfo[k] = v - else: - cls_metainfo[k] = v - return cls_metainfo - - def _join_prefix(self): - """Join ``self.data_root`` with ``self.data_prefix`` and - ``self.ann_file``. - - Examples: - >>> # self.data_prefix contains relative paths - >>> self.data_root = 'a/b/c' - >>> self.data_prefix = dict(img='d/e/') - >>> self.ann_file = 'f' - >>> self._join_prefix() - >>> self.data_prefix - dict(img='a/b/c/d/e') - >>> self.ann_file - 'a/b/c/f' - >>> # self.data_prefix contains absolute paths - >>> self.data_root = 'a/b/c' - >>> self.data_prefix = dict(img='/d/e/') - >>> self.ann_file = 'f' - >>> self._join_prefix() - >>> self.data_prefix - dict(img='/d/e') - >>> self.ann_file - 'a/b/c/f' - """ - # Automatically join annotation file path with `self.root` if - # `self.ann_file` is not an absolute path. - if self.ann_file and not is_abs(self.ann_file) and self.data_root: - self.ann_file = join_path(self.data_root, self.ann_file) - # Automatically join data directory with `self.root` if path value in - # `self.data_prefix` is not an absolute path. - for data_key, prefix in self.data_prefix.items(): - if not isinstance(prefix, str): - raise TypeError('prefix should be a string, but got ' - f'{type(prefix)}') - if not is_abs(prefix) and self.data_root: - self.data_prefix[data_key] = join_path(self.data_root, prefix) - else: - self.data_prefix[data_key] = prefix - - @force_full_init - def get_subset_(self, indices: Union[Sequence[int], int]) -> None: - """The in-place version of ``get_subset`` to convert dataset to a - subset of original dataset. - - This method will convert the original dataset to a subset of dataset. - If type of indices is int, ``get_subset_`` will return a subdataset - which contains the first or last few data information according to - indices is positive or negative. If type of indices is a sequence of - int, the subdataset will extract the data information according to - the index given in indices. - - Examples: - >>> dataset = BaseDataset('path/to/ann_file') - >>> len(dataset) - 100 - >>> dataset.get_subset_(90) - >>> len(dataset) - 90 - >>> # if type of indices is sequence, extract the corresponding - >>> # index data information - >>> dataset.get_subset_([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - >>> len(dataset) - 10 - >>> dataset.get_subset_(-3) - >>> len(dataset) # Get the latest few data information. - 3 - - Args: - indices (int or Sequence[int]): If type of indices is int, indices - represents the first or last few data of dataset according to - indices is positive or negative. If type of indices is - Sequence, indices represents the target data information - index of dataset. - """ - # Get subset of data from serialized data or data information sequence - # according to `self.serialize_data`. - if self.serialize_data: - self.data_bytes, self.data_address = \ - self._get_serialized_subset(indices) - else: - self.data_list = self._get_unserialized_subset(indices) - - @force_full_init - def get_subset(self, indices: Union[Sequence[int], int]) -> 'BaseDataset': - """Return a subset of dataset. - - This method will return a subset of original dataset. If type of - indices is int, ``get_subset_`` will return a subdataset which - contains the first or last few data information according to - indices is positive or negative. If type of indices is a sequence of - int, the subdataset will extract the information according to the index - given in indices. - - Examples: - >>> dataset = BaseDataset('path/to/ann_file') - >>> len(dataset) - 100 - >>> subdataset = dataset.get_subset(90) - >>> len(sub_dataset) - 90 - >>> # if type of indices is list, extract the corresponding - >>> # index data information - >>> subdataset = dataset.get_subset([0, 1, 2, 3, 4, 5, 6, 7, - >>> 8, 9]) - >>> len(sub_dataset) - 10 - >>> subdataset = dataset.get_subset(-3) - >>> len(subdataset) # Get the latest few data information. - 3 - - Args: - indices (int or Sequence[int]): If type of indices is int, indices - represents the first or last few data of dataset according to - indices is positive or negative. If type of indices is - Sequence, indices represents the target data information - index of dataset. - - Returns: - BaseDataset: A subset of dataset. - """ - # Get subset of data from serialized data or data information list - # according to `self.serialize_data`. Since `_get_serialized_subset` - # will recalculate the subset data information, - # `_copy_without_annotation` will copy all attributes except data - # information. - sub_dataset = self._copy_without_annotation() - # Get subset of dataset with serialize and unserialized data. - if self.serialize_data: - data_bytes, data_address = \ - self._get_serialized_subset(indices) - sub_dataset.data_bytes = data_bytes.copy() - sub_dataset.data_address = data_address.copy() - else: - data_list = self._get_unserialized_subset(indices) - sub_dataset.data_list = copy.deepcopy(data_list) - return sub_dataset - - def _get_serialized_subset(self, indices: Union[Sequence[int], int]) \ - -> Tuple[np.ndarray, np.ndarray]: - """Get subset of serialized data information list. - - Args: - indices (int or Sequence[int]): If type of indices is int, - indices represents the first or last few data of serialized - data information list. If type of indices is Sequence, indices - represents the target data information index which consist of - subset data information. - - Returns: - Tuple[np.ndarray, np.ndarray]: subset of serialized data - information. - """ - sub_data_bytes: Union[List, np.ndarray] - sub_data_address: Union[List, np.ndarray] - if isinstance(indices, int): - if indices >= 0: - assert indices < len(self.data_address), \ - f'{indices} is out of dataset length({len(self)}' - # Return the first few data information. - end_addr = self.data_address[indices - 1].item() \ - if indices > 0 else 0 - # Slicing operation of `np.ndarray` does not trigger a memory - # copy. - sub_data_bytes = self.data_bytes[:end_addr] - # Since the buffer size of first few data information is not - # changed, - sub_data_address = self.data_address[:indices] - else: - assert -indices <= len(self.data_address), \ - f'{indices} is out of dataset length({len(self)}' - # Return the last few data information. - ignored_bytes_size = self.data_address[indices - 1] - start_addr = self.data_address[indices - 1].item() - sub_data_bytes = self.data_bytes[start_addr:] - sub_data_address = self.data_address[indices:] - sub_data_address = sub_data_address - ignored_bytes_size - elif isinstance(indices, Sequence): - sub_data_bytes = [] - sub_data_address = [] - for idx in indices: - assert len(self) > idx >= -len(self) - start_addr = 0 if idx == 0 else \ - self.data_address[idx - 1].item() - end_addr = self.data_address[idx].item() - # Get data information by address. - sub_data_bytes.append(self.data_bytes[start_addr:end_addr]) - # Get data information size. - sub_data_address.append(end_addr - start_addr) - # Handle indices is an empty list. - if sub_data_bytes: - sub_data_bytes = np.concatenate(sub_data_bytes) - sub_data_address = np.cumsum(sub_data_address) - else: - sub_data_bytes = np.array([]) - sub_data_address = np.array([]) - else: - raise TypeError('indices should be a int or sequence of int, ' - f'but got {type(indices)}') - return sub_data_bytes, sub_data_address # type: ignore - - def _get_unserialized_subset(self, indices: Union[Sequence[int], - int]) -> list: - """Get subset of data information list. - - Args: - indices (int or Sequence[int]): If type of indices is int, - indices represents the first or last few data of data - information. If type of indices is Sequence, indices represents - the target data information index which consist of subset data - information. - - Returns: - Tuple[np.ndarray, np.ndarray]: subset of data information. - """ - if isinstance(indices, int): - if indices >= 0: - # Return the first few data information. - sub_data_list = self.data_list[:indices] - else: - # Return the last few data information. - sub_data_list = self.data_list[indices:] - elif isinstance(indices, Sequence): - # Return the data information according to given indices. - sub_data_list = [] - for idx in indices: - sub_data_list.append(self.data_list[idx]) - else: - raise TypeError('indices should be a int or sequence of int, ' - f'but got {type(indices)}') - return sub_data_list - - def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: - """Serialize ``self.data_list`` to save memory when launching multiple - workers in data loading. This function will be called in ``full_init``. - - Hold memory using serialized objects, and data loader workers can use - shared RAM from master process instead of making a copy. - - Returns: - Tuple[np.ndarray, np.ndarray]: Serialized result and corresponding - address. - """ - - def _serialize(data): - buffer = pickle.dumps(data, protocol=4) - return np.frombuffer(buffer, dtype=np.uint8) - - # Serialize data information list avoid making multiple copies of - # `self.data_list` when iterate `import torch.utils.data.dataloader` - # with multiple workers. - data_list = [_serialize(x) for x in self.data_list] - address_list = np.asarray([len(x) for x in data_list], dtype=np.int64) - data_address: np.ndarray = np.cumsum(address_list) - # TODO Check if np.concatenate is necessary - data_bytes = np.concatenate(data_list) - # Empty cache for preventing making multiple copies of - # `self.data_info` when loading data multi-processes. - self.data_list.clear() - gc.collect() - return data_bytes, data_address - - def _rand_another(self) -> int: - """Get random index. - - Returns: - int: Random index from 0 to ``len(self)-1`` - """ - return np.random.randint(0, len(self)) - - def prepare_data(self, idx) -> Any: - """Get data processed by ``self.pipeline``. - - Args: - idx (int): The index of ``data_info``. - - Returns: - Any: Depends on ``self.pipeline``. - """ - data_info = self.get_data_info(idx) - return self.pipeline(data_info) - - @force_full_init - def __len__(self) -> int: - """Get the length of filtered dataset and automatically call - ``full_init`` if the dataset has not been fully init. - - Returns: - int: The length of filtered dataset. - """ - if self.serialize_data: - return len(self.data_address) - else: - return len(self.data_list) - - def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': - """Deepcopy for all attributes other than ``data_list``, - ``data_address`` and ``data_bytes``. - - Args: - memo: Memory dict which used to reconstruct complex object - correctly. - """ - cls = self.__class__ - other = cls.__new__(cls) - memo[id(self)] = other - - for key, value in self.__dict__.items(): - if key in ['data_list', 'data_address', 'data_bytes']: - continue - super(BaseDataset, other).__setattr__(key, - copy.deepcopy(value, memo)) - - return other diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py deleted file mode 100644 index e63860bee0..0000000000 --- a/mmengine/dataset/dataset_wrapper.py +++ /dev/null @@ -1,529 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import bisect -import copy -import logging -import math -from collections import defaultdict -from typing import List, Sequence, Tuple, Union - -import numpy as np -from torch.utils.data.dataset import ConcatDataset as _ConcatDataset - -from mmengine.logging import print_log -from mmengine.registry import DATASETS -from .base_dataset import BaseDataset, force_full_init - - -@DATASETS.register_module() -class ConcatDataset(_ConcatDataset): - """A wrapper of concatenated dataset. - - Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. - - Note: - ``ConcatDataset`` should not inherit from ``BaseDataset`` since - ``get_subset`` and ``get_subset_`` could produce ambiguous meaning - sub-dataset which conflicts with original dataset. If you want to use - a sub-dataset of ``ConcatDataset``, you should set ``indices`` - arguments for wrapped dataset which inherit from ``BaseDataset``. - - Args: - datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets - which will be concatenated. - lazy_init (bool, optional): Whether to load annotation during - instantiation. Defaults to False. - ignore_keys (List[str] or str): Ignore the keys that can be - unequal in `dataset.metainfo`. Defaults to None. - `New in version 0.3.0.` - """ - - def __init__(self, - datasets: Sequence[Union[BaseDataset, dict]], - lazy_init: bool = False, - ignore_keys: Union[str, List[str], None] = None): - self.datasets: List[BaseDataset] = [] - for i, dataset in enumerate(datasets): - if isinstance(dataset, dict): - self.datasets.append(DATASETS.build(dataset)) - elif isinstance(dataset, BaseDataset): - self.datasets.append(dataset) - else: - raise TypeError( - 'elements in datasets sequence should be config or ' - f'`BaseDataset` instance, but got {type(dataset)}') - if ignore_keys is None: - self.ignore_keys = [] - elif isinstance(ignore_keys, str): - self.ignore_keys = [ignore_keys] - elif isinstance(ignore_keys, list): - self.ignore_keys = ignore_keys - else: - raise TypeError('ignore_keys should be a list or str, ' - f'but got {type(ignore_keys)}') - - meta_keys: set = set() - for dataset in self.datasets: - meta_keys |= dataset.metainfo.keys() - # Only use metainfo of first dataset. - self._metainfo = self.datasets[0].metainfo - for i, dataset in enumerate(self.datasets, 1): - for key in meta_keys: - if key in self.ignore_keys: - continue - if key not in dataset.metainfo: - raise ValueError( - f'{key} does not in the meta information of ' - f'the {i}-th dataset') - first_type = type(self._metainfo[key]) - cur_type = type(dataset.metainfo[key]) - if first_type is not cur_type: # type: ignore - raise TypeError( - f'The type {cur_type} of {key} in the {i}-th dataset ' - 'should be the same with the first dataset ' - f'{first_type}') - if (isinstance(self._metainfo[key], np.ndarray) - and not np.array_equal(self._metainfo[key], - dataset.metainfo[key]) - or (not isinstance(self._metainfo[key], np.ndarray) - and self._metainfo[key] != dataset.metainfo[key])): - raise ValueError( - f'The meta information of the {i}-th dataset does not ' - 'match meta information of the first dataset') - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the first dataset in ``self.datasets``. - - Returns: - dict: Meta information of first dataset. - """ - # Prevent `self._metainfo` from being modified by outside. - return copy.deepcopy(self._metainfo) - - def full_init(self): - """Loop to ``full_init`` each dataset.""" - if self._fully_initialized: - return - for d in self.datasets: - d.full_init() - # Get the cumulative sizes of `self.datasets`. For example, the length - # of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9] - super().__init__(self.datasets) - self._fully_initialized = True - - @force_full_init - def _get_ori_dataset_idx(self, idx: int) -> Tuple[int, int]: - """Convert global idx to local index. - - Args: - idx (int): Global index of ``RepeatDataset``. - - Returns: - Tuple[int, int]: The index of ``self.datasets`` and the local - index of data. - """ - if idx < 0: - if -idx > len(self): - raise ValueError( - f'absolute value of index({idx}) should not exceed dataset' - f'length({len(self)}).') - idx = len(self) + idx - # Get `dataset_idx` to tell idx belongs to which dataset. - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - # Get the inner index of single dataset. - if dataset_idx == 0: - sample_idx = idx - else: - sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] - - return dataset_idx, sample_idx - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index. - - Args: - idx (int): Global index of ``ConcatDataset``. - - Returns: - dict: The idx-th annotation of the datasets. - """ - dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) - return self.datasets[dataset_idx].get_data_info(sample_idx) - - @force_full_init - def __len__(self): - return super().__len__() - - def __getitem__(self, idx): - if not self._fully_initialized: - print_log( - 'Please call `full_init` method manually to ' - 'accelerate the speed.', - logger='current', - level=logging.WARNING) - self.full_init() - dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) - return self.datasets[dataset_idx][sample_idx] - - def get_subset_(self, indices: Union[List[int], int]) -> None: - """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - '`ConcatDataset` dose not support `get_subset` and ' - '`get_subset_` interfaces because this will lead to ambiguous ' - 'implementation of some methods. If you want to use `get_subset` ' - 'or `get_subset_` interfaces, please use them in the wrapped ' - 'dataset first and then use `ConcatDataset`.') - - def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': - """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - '`ConcatDataset` dose not support `get_subset` and ' - '`get_subset_` interfaces because this will lead to ambiguous ' - 'implementation of some methods. If you want to use `get_subset` ' - 'or `get_subset_` interfaces, please use them in the wrapped ' - 'dataset first and then use `ConcatDataset`.') - - -@DATASETS.register_module() -class RepeatDataset: - """A wrapper of repeated dataset. - - The length of repeated dataset will be `times` larger than the original - dataset. This is useful when the data loading time is long but the dataset - is small. Using RepeatDataset can reduce the data loading time between - epochs. - - Note: - ``RepeatDataset`` should not inherit from ``BaseDataset`` since - ``get_subset`` and ``get_subset_`` could produce ambiguous meaning - sub-dataset which conflicts with original dataset. If you want to use - a sub-dataset of ``RepeatDataset``, you should set ``indices`` - arguments for wrapped dataset which inherit from ``BaseDataset``. - - Args: - dataset (BaseDataset or dict): The dataset to be repeated. - times (int): Repeat times. - lazy_init (bool): Whether to load annotation during - instantiation. Defaults to False. - """ - - def __init__(self, - dataset: Union[BaseDataset, dict], - times: int, - lazy_init: bool = False): - self.dataset: BaseDataset - if isinstance(dataset, dict): - self.dataset = DATASETS.build(dataset) - elif isinstance(dataset, BaseDataset): - self.dataset = dataset - else: - raise TypeError( - 'elements in datasets sequence should be config or ' - f'`BaseDataset` instance, but got {type(dataset)}') - self.times = times - self._metainfo = self.dataset.metainfo - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the repeated dataset. - - Returns: - dict: The meta information of repeated dataset. - """ - return copy.deepcopy(self._metainfo) - - def full_init(self): - """Loop to ``full_init`` each dataset.""" - if self._fully_initialized: - return - - self.dataset.full_init() - self._ori_len = len(self.dataset) - self._fully_initialized = True - - @force_full_init - def _get_ori_dataset_idx(self, idx: int) -> int: - """Convert global index to local index. - - Args: - idx: Global index of ``RepeatDataset``. - - Returns: - idx (int): Local index of data. - """ - return idx % self._ori_len - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index. - - Args: - idx (int): Global index of ``ConcatDataset``. - - Returns: - dict: The idx-th annotation of the datasets. - """ - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset.get_data_info(sample_idx) - - def __getitem__(self, idx): - if not self._fully_initialized: - print_log( - 'Please call `full_init` method manually to accelerate the ' - 'speed.', - logger='current', - level=logging.WARNING) - self.full_init() - - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset[sample_idx] - - @force_full_init - def __len__(self): - return self.times * self._ori_len - - def get_subset_(self, indices: Union[List[int], int]) -> None: - """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - '`RepeatDataset` dose not support `get_subset` and ' - '`get_subset_` interfaces because this will lead to ambiguous ' - 'implementation of some methods. If you want to use `get_subset` ' - 'or `get_subset_` interfaces, please use them in the wrapped ' - 'dataset first and then use `RepeatDataset`.') - - def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': - """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - '`RepeatDataset` dose not support `get_subset` and ' - '`get_subset_` interfaces because this will lead to ambiguous ' - 'implementation of some methods. If you want to use `get_subset` ' - 'or `get_subset_` interfaces, please use them in the wrapped ' - 'dataset first and then use `RepeatDataset`.') - - -@DATASETS.register_module() -class ClassBalancedDataset: - """A wrapper of class balanced dataset. - - Suitable for training on class imbalanced datasets like LVIS. Following - the sampling strategy in the `paper `_, - in each epoch, an image may appear multiple times based on its - "repeat factor". - The repeat factor for an image is a function of the frequency the rarest - category labeled in that image. The "frequency of category c" in [0, 1] - is defined by the fraction of images in the training set (without repeats) - in which category c appears. - The dataset needs to instantiate :meth:`get_cat_ids` to support - ClassBalancedDataset. - - The repeat factor is computed as followed. - - 1. For each category c, compute the fraction # of images - that contain it: :math:`f(c)` - 2. For each category c, compute the category-level repeat factor: - :math:`r(c) = max(1, sqrt(t/f(c)))` - 3. For each image I, compute the image-level repeat factor: - :math:`r(I) = max_{c in I} r(c)` - - Note: - ``ClassBalancedDataset`` should not inherit from ``BaseDataset`` - since ``get_subset`` and ``get_subset_`` could produce ambiguous - meaning sub-dataset which conflicts with original dataset. If you - want to use a sub-dataset of ``ClassBalancedDataset``, you should set - ``indices`` arguments for wrapped dataset which inherit from - ``BaseDataset``. - - Args: - dataset (BaseDataset or dict): The dataset to be repeated. - oversample_thr (float): frequency threshold below which data is - repeated. For categories with ``f_c >= oversample_thr``, there is - no oversampling. For categories with ``f_c < oversample_thr``, the - degree of oversampling following the square-root inverse frequency - heuristic above. - lazy_init (bool, optional): whether to load annotation during - instantiation. Defaults to False - """ - - def __init__(self, - dataset: Union[BaseDataset, dict], - oversample_thr: float, - lazy_init: bool = False): - if isinstance(dataset, dict): - self.dataset = DATASETS.build(dataset) - elif isinstance(dataset, BaseDataset): - self.dataset = dataset - else: - raise TypeError( - 'elements in datasets sequence should be config or ' - f'`BaseDataset` instance, but got {type(dataset)}') - self.oversample_thr = oversample_thr - self._metainfo = self.dataset.metainfo - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the repeated dataset. - - Returns: - dict: The meta information of repeated dataset. - """ - return copy.deepcopy(self._metainfo) - - def full_init(self): - """Loop to ``full_init`` each dataset.""" - if self._fully_initialized: - return - - self.dataset.full_init() - # Get repeat factors for each image. - repeat_factors = self._get_repeat_factors(self.dataset, - self.oversample_thr) - # Repeat dataset's indices according to repeat_factors. For example, - # if `repeat_factors = [1, 2, 3]`, and the `len(dataset) == 3`, - # the repeated indices will be [1, 2, 2, 3, 3, 3]. - repeat_indices = [] - for dataset_index, repeat_factor in enumerate(repeat_factors): - repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) - self.repeat_indices = repeat_indices - - self._fully_initialized = True - - def _get_repeat_factors(self, dataset: BaseDataset, - repeat_thr: float) -> List[float]: - """Get repeat factor for each images in the dataset. - - Args: - dataset (BaseDataset): The dataset. - repeat_thr (float): The threshold of frequency. If an image - contains the categories whose frequency below the threshold, - it would be repeated. - - Returns: - List[float]: The repeat factors for each images in the dataset. - """ - # 1. For each category c, compute the fraction # of images - # that contain it: f(c) - category_freq: defaultdict = defaultdict(float) - num_images = len(dataset) - for idx in range(num_images): - cat_ids = set(self.dataset.get_cat_ids(idx)) - for cat_id in cat_ids: - category_freq[cat_id] += 1 - for k, v in category_freq.items(): - assert v > 0, f'caterogy {k} does not contain any images' - category_freq[k] = v / num_images - - # 2. For each category c, compute the category-level repeat factor: - # r(c) = max(1, sqrt(t/f(c))) - category_repeat = { - cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) - for cat_id, cat_freq in category_freq.items() - } - - # 3. For each image I and its labels L(I), compute the image-level - # repeat factor: - # r(I) = max_{c in L(I)} r(c) - repeat_factors = [] - for idx in range(num_images): - # the length of `repeat_factors` need equal to the length of - # dataset. Hence, if the `cat_ids` is empty, - # the repeat_factor should be 1. - repeat_factor: float = 1. - cat_ids = set(self.dataset.get_cat_ids(idx)) - if len(cat_ids) != 0: - repeat_factor = max( - {category_repeat[cat_id] - for cat_id in cat_ids}) - repeat_factors.append(repeat_factor) - - return repeat_factors - - @force_full_init - def _get_ori_dataset_idx(self, idx: int) -> int: - """Convert global index to local index. - - Args: - idx (int): Global index of ``RepeatDataset``. - - Returns: - int: Local index of data. - """ - return self.repeat_indices[idx] - - @force_full_init - def get_cat_ids(self, idx: int) -> List[int]: - """Get category ids of class balanced dataset by index. - - Args: - idx (int): Index of data. - - Returns: - List[int]: All categories in the image of specified index. - """ - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset.get_cat_ids(sample_idx) - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index. - - Args: - idx (int): Global index of ``ConcatDataset``. - - Returns: - dict: The idx-th annotation of the dataset. - """ - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset.get_data_info(sample_idx) - - def __getitem__(self, idx): - if not self._fully_initialized: - print_log( - 'Please call `full_init` method manually to accelerate ' - 'the speed.', - logger='current', - level=logging.WARNING) - self.full_init() - - ori_index = self._get_ori_dataset_idx(idx) - return self.dataset[ori_index] - - @force_full_init - def __len__(self): - return len(self.repeat_indices) - - def get_subset_(self, indices: Union[List[int], int]) -> None: - """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning - of sub-dataset.""" - raise NotImplementedError( - '`ClassBalancedDataset` dose not support `get_subset` and ' - '`get_subset_` interfaces because this will lead to ambiguous ' - 'implementation of some methods. If you want to use `get_subset` ' - 'or `get_subset_` interfaces, please use them in the wrapped ' - 'dataset first and then use `ClassBalancedDataset`.') - - def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': - """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning - of sub-dataset.""" - raise NotImplementedError( - '`ClassBalancedDataset` dose not support `get_subset` and ' - '`get_subset_` interfaces because this will lead to ambiguous ' - 'implementation of some methods. If you want to use `get_subset` ' - 'or `get_subset_` interfaces, please use them in the wrapped ' - 'dataset first and then use `ClassBalancedDataset`.') diff --git a/mmengine/dataset/sampler.py b/mmengine/dataset/sampler.py deleted file mode 100644 index 95e8e2da6b..0000000000 --- a/mmengine/dataset/sampler.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import itertools -import math -from typing import Iterator, Optional, Sized - -import torch -from torch.utils.data import Sampler - -from mmengine.dist import get_dist_info, sync_random_seed -from mmengine.registry import DATA_SAMPLERS - - -@DATA_SAMPLERS.register_module() -class DefaultSampler(Sampler): - """The default data sampler for both distributed and non-distributed - environment. - - It has several differences from the PyTorch ``DistributedSampler`` as - below: - - 1. This sampler supports non-distributed environment. - - 2. The round up behaviors are a little different. - - - If ``round_up=True``, this sampler will add extra samples to make the - number of samples is evenly divisible by the world size. And - this behavior is the same as the ``DistributedSampler`` with - ``drop_last=False``. - - If ``round_up=False``, this sampler won't remove or add any samples - while the ``DistributedSampler`` with ``drop_last=True`` will remove - tail samples. - - Args: - dataset (Sized): The dataset. - shuffle (bool): Whether shuffle the dataset or not. Defaults to True. - seed (int, optional): Random seed used to shuffle the sampler if - :attr:`shuffle=True`. This number should be identical across all - processes in the distributed group. Defaults to None. - round_up (bool): Whether to add extra samples to make the number of - samples evenly divisible by the world size. Defaults to True. - """ - - def __init__(self, - dataset: Sized, - shuffle: bool = True, - seed: Optional[int] = None, - round_up: bool = True) -> None: - rank, world_size = get_dist_info() - self.rank = rank - self.world_size = world_size - - self.dataset = dataset - self.shuffle = shuffle - if seed is None: - seed = sync_random_seed() - self.seed = seed - self.epoch = 0 - self.round_up = round_up - - if self.round_up: - self.num_samples = math.ceil(len(self.dataset) / world_size) - self.total_size = self.num_samples * self.world_size - else: - self.num_samples = math.ceil( - (len(self.dataset) - rank) / world_size) - self.total_size = len(self.dataset) - - def __iter__(self) -> Iterator[int]: - """Iterate the indices.""" - # deterministically shuffle based on epoch and seed - if self.shuffle: - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = torch.arange(len(self.dataset)).tolist() - - # add extra samples to make it evenly divisible - if self.round_up: - indices = ( - indices * - int(self.total_size / len(indices) + 1))[:self.total_size] - - # subsample - indices = indices[self.rank:self.total_size:self.world_size] - - return iter(indices) - - def __len__(self) -> int: - """The number of samples in this rank.""" - return self.num_samples - - def set_epoch(self, epoch: int) -> None: - """Sets the epoch for this sampler. - - When :attr:`shuffle=True`, this ensures all replicas use a different - random ordering for each epoch. Otherwise, the next iteration of this - sampler will yield the same ordering. - - Args: - epoch (int): Epoch number. - """ - self.epoch = epoch - - -@DATA_SAMPLERS.register_module() -class InfiniteSampler(Sampler): - """It's designed for iteration-based runner and yields a mini-batch indices - each time. - - The implementation logic is referred to - https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/distributed_sampler.py - - Args: - dataset (Sized): The dataset. - shuffle (bool): Whether shuffle the dataset or not. Defaults to True. - seed (int, optional): Random seed. If None, set a random seed. - Defaults to None. - """ # noqa: W605 - - def __init__(self, - dataset: Sized, - shuffle: bool = True, - seed: Optional[int] = None) -> None: - rank, world_size = get_dist_info() - self.rank = rank - self.world_size = world_size - - self.dataset = dataset - self.world_size = world_size - self.rank = rank - self.shuffle = shuffle - if seed is None: - seed = sync_random_seed() - self.seed = seed - self.size = len(dataset) - self.indices = self._indices_of_rank() - - def _infinite_indices(self) -> Iterator[int]: - """Infinitely yield a sequence of indices.""" - g = torch.Generator() - g.manual_seed(self.seed) - while True: - if self.shuffle: - yield from torch.randperm(self.size, generator=g).tolist() - - else: - yield from torch.arange(self.size).tolist() - - def _indices_of_rank(self) -> Iterator[int]: - """Slice the infinite indices by rank.""" - yield from itertools.islice(self._infinite_indices(), self.rank, None, - self.world_size) - - def __iter__(self) -> Iterator[int]: - """Iterate the indices.""" - yield from self.indices - - def __len__(self) -> int: - """Length of base dataset.""" - return self.size - - def set_epoch(self, epoch: int) -> None: - """Not supported in iteration-based runner.""" - pass diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py deleted file mode 100644 index 2c9cf96497..0000000000 --- a/mmengine/dataset/utils.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import random -import warnings -from typing import Any, Mapping, Sequence - -import numpy as np -import torch -from torch.utils.data._utils.collate import \ - default_collate as torch_default_collate - -from mmengine.registry import FUNCTIONS -from mmengine.structures import BaseDataElement - -# FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep -# the compatibility. -COLLATE_FUNCTIONS = FUNCTIONS - - -def worker_init_fn(worker_id: int, - num_workers: int, - rank: int, - seed: int, - disable_subprocess_warning: bool = False) -> None: - """This function will be called on each worker subprocess after seeding and - before data loading. - - Args: - worker_id (int): Worker id in [0, num_workers - 1]. - num_workers (int): How many subprocesses to use for data loading. - rank (int): Rank of process in distributed environment. If in - non-distributed environment, it is a constant number `0`. - seed (int): Random seed. - """ - # The seed of each worker equals to - # num_worker * rank + worker_id + user_seed - worker_seed = num_workers * rank + worker_id + seed - np.random.seed(worker_seed) - random.seed(worker_seed) - torch.manual_seed(worker_seed) - if disable_subprocess_warning and worker_id != 0: - warnings.simplefilter('ignore') - - -@FUNCTIONS.register_module() -def pseudo_collate(data_batch: Sequence) -> Any: - """Convert list of data sampled from dataset into a batch of data, of which - type consistent with the type of each data_itement in ``data_batch``. - - The default behavior of dataloader is to merge a list of samples to form - a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` - will not stack tensors to batch tensors, and convert int, float, ndarray to - tensors. - - This code is referenced from: - `Pytorch default_collate `_. - - Args: - data_batch (Sequence): Batch of data from dataloader. - - Returns: - Any: Transversed Data in the same format as the data_itement of - ``data_batch``. - """ # noqa: E501 - data_item = data_batch[0] - data_item_type = type(data_item) - if isinstance(data_item, (str, bytes)): - return data_batch - elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): - # named tuple - return data_item_type(*(pseudo_collate(samples) - for samples in zip(*data_batch))) - elif isinstance(data_item, Sequence): - # check to make sure that the data_itements in batch have - # consistent size - it = iter(data_batch) - data_item_size = len(next(it)) - if not all(len(data_item) == data_item_size for data_item in it): - raise RuntimeError( - 'each data_itement in list of batch should be of equal size') - transposed = list(zip(*data_batch)) - - if isinstance(data_item, tuple): - return [pseudo_collate(samples) - for samples in transposed] # Compat with Pytorch. - else: - try: - return data_item_type( - [pseudo_collate(samples) for samples in transposed]) - except TypeError: - # The sequence type may not support `__init__(iterable)` - # (e.g., `range`). - return [pseudo_collate(samples) for samples in transposed] - elif isinstance(data_item, Mapping): - return data_item_type({ - key: pseudo_collate([d[key] for d in data_batch]) - for key in data_item - }) - else: - return data_batch - - -@FUNCTIONS.register_module() -def default_collate(data_batch: Sequence) -> Any: - """Convert list of data sampled from dataset into a batch of data, of which - type consistent with the type of each data_itement in ``data_batch``. - - Different from :func:`pseudo_collate`, ``default_collate`` will stack - tensor contained in ``data_batch`` into a batched tensor with the - first dimension batch size, and then move input tensor to the target - device. - - Different from ``default_collate`` in pytorch, ``default_collate`` will - not process ``BaseDataElement``. - - This code is referenced from: - `Pytorch default_collate `_. - - Note: - ``default_collate`` only accept input tensor with the same shape. - - Args: - data_batch (Sequence): Data sampled from dataset. - - Returns: - Any: Data in the same format as the data_itement of ``data_batch``, of which - tensors have been stacked, and ndarray, int, float have been - converted to tensors. - """ # noqa: E501 - data_item = data_batch[0] - data_item_type = type(data_item) - - if isinstance(data_item, (BaseDataElement, str, bytes)): - return data_batch - elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): - # named_tuple - return data_item_type(*(default_collate(samples) - for samples in zip(*data_batch))) - elif isinstance(data_item, Sequence): - # check to make sure that the data_itements in batch have - # consistent size - it = iter(data_batch) - data_item_size = len(next(it)) - if not all(len(data_item) == data_item_size for data_item in it): - raise RuntimeError( - 'each data_itement in list of batch should be of equal size') - transposed = list(zip(*data_batch)) - - if isinstance(data_item, tuple): - return [default_collate(samples) - for samples in transposed] # Compat with Pytorch. - else: - try: - return data_item_type( - [default_collate(samples) for samples in transposed]) - except TypeError: - # The sequence type may not support `__init__(iterable)` - # (e.g., `range`). - return [default_collate(samples) for samples in transposed] - elif isinstance(data_item, Mapping): - return data_item_type({ - key: default_collate([d[key] for d in data_batch]) - for key in data_item - }) - else: - return torch_default_collate(data_batch) diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py deleted file mode 100644 index 88937d5592..0000000000 --- a/mmengine/device/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory, - is_cuda_available, is_dipu_available, is_mlu_available, - is_mps_available, is_musa_available, is_npu_available, - is_npu_support_full_precision) - -__all__ = [ - 'get_max_cuda_memory', 'get_device', 'is_cuda_available', - 'is_mlu_available', 'is_mps_available', 'is_npu_available', - 'is_dipu_available', 'get_max_musa_memory', 'is_musa_available', - 'is_npu_support_full_precision' -] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py deleted file mode 100644 index 8fe6e0c156..0000000000 --- a/mmengine/device/utils.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -from typing import Optional - -import torch - -try: - import torch_npu # noqa: F401 - import torch_npu.npu.utils as npu_utils - - # Enable operator support for dynamic shape and - # binary operator support on the NPU. - npu_jit_compile = bool(os.getenv('NPUJITCompile', False)) - torch.npu.set_compile_mode(jit_compile=npu_jit_compile) - IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available() -except Exception: - IS_NPU_AVAILABLE = False - -try: - import torch_mlu # noqa: F401 - IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available() -except Exception: - IS_MLU_AVAILABLE = False - -try: - import torch_dipu # noqa: F401 - IS_DIPU_AVAILABLE = True -except Exception: - IS_DIPU_AVAILABLE = False - -try: - import torch_musa # noqa: F401 - IS_MUSA_AVAILABLE = True -except Exception: - IS_MUSA_AVAILABLE = False - - -def get_max_cuda_memory(device: Optional[torch.device] = None) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for - a given device. By default, this returns the peak allocated memory since - the beginning of this program. - - Args: - device (torch.device, optional): selected device. Returns - statistic for the current device, given by - :func:`~torch.cuda.current_device`, if ``device`` is None. - Defaults to None. - - Returns: - int: The maximum GPU memory occupied by tensors in megabytes - for a given device. - """ - mem = torch.cuda.max_memory_allocated(device=device) - mem_mb = torch.tensor([int(mem) // (1024 * 1024)], - dtype=torch.int, - device=device) - torch.cuda.reset_peak_memory_stats() - return int(mem_mb.item()) - - -def is_cuda_available() -> bool: - """Returns True if cuda devices exist.""" - return torch.cuda.is_available() - - -def is_npu_available() -> bool: - """Returns True if Ascend PyTorch and npu devices exist.""" - return IS_NPU_AVAILABLE - - -def is_mlu_available() -> bool: - """Returns True if Cambricon PyTorch and mlu devices exist.""" - return IS_MLU_AVAILABLE - - -def is_mps_available() -> bool: - """Return True if mps devices exist. - - It's specialized for mac m1 chips and require torch version 1.12 or higher. - """ - return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() - - -def is_dipu_available() -> bool: - return IS_DIPU_AVAILABLE - - -def get_max_musa_memory(device: Optional[torch.device] = None) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for - a given device. By default, this returns the peak allocated memory since - the beginning of this program. - - Args: - device (torch.device, optional): selected device. Returns - statistic for the current device, given by - :func:`~torch.musa.current_device`, if ``device`` is None. - Defaults to None. - - Returns: - int: The maximum GPU memory occupied by tensors in megabytes - for a given device. - """ - mem = torch.musa.max_memory_allocated(device=device) - mem_mb = torch.tensor([int(mem) // (1024 * 1024)], - dtype=torch.int, - device=device) - # TODO:haowen.han@mthreads.com: This function is not supported by musa yet. - # torch.musa.reset_peak_memory_stats() - return int(mem_mb.item()) - - -def is_musa_available() -> bool: - return IS_MUSA_AVAILABLE - - -def is_npu_support_full_precision() -> bool: - """Returns True if npu devices support full precision training.""" - version_of_support_full_precision = 220 - return IS_NPU_AVAILABLE and npu_utils.get_soc_version( - ) >= version_of_support_full_precision - - -DEVICE = 'cpu' -if is_npu_available(): - DEVICE = 'npu' -elif is_cuda_available(): - DEVICE = 'cuda' -elif is_mlu_available(): - DEVICE = 'mlu' -elif is_mps_available(): - DEVICE = 'mps' -elif is_dipu_available(): - DEVICE = 'dipu' -elif is_musa_available(): - DEVICE = 'musa' - - -def get_device() -> str: - """Returns the currently existing device type. - - Returns: - str: cuda | npu | mlu | mps | musa | cpu. - """ - return DEVICE diff --git a/mmengine/dist/__init__.py b/mmengine/dist/__init__.py deleted file mode 100644 index c70e181d5d..0000000000 --- a/mmengine/dist/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict, - collect_results, gather, broadcast, gather_object, - sync_random_seed, broadcast_object_list, - collect_results_cpu, collect_results_gpu, all_reduce_params) -from .utils import (get_dist_info, init_dist, init_local_group, get_backend, - get_world_size, get_rank, get_local_size, get_local_rank, - is_main_process, master_only, barrier, get_local_group, - is_distributed, get_default_group, get_data_device, - get_comm_device, cast_data_device, infer_launcher) - -__all__ = [ - 'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict', - 'collect_results', 'collect_results_cpu', 'collect_results_gpu', 'gather', - 'broadcast', 'gather_object', 'sync_random_seed', 'broadcast_object_list', - 'get_dist_info', 'init_dist', 'init_local_group', 'get_backend', - 'get_world_size', 'get_rank', 'get_local_size', 'get_local_group', - 'get_local_rank', 'is_main_process', 'master_only', 'barrier', - 'is_distributed', 'get_default_group', 'all_reduce_params', - 'get_data_device', 'get_comm_device', 'cast_data_device', 'infer_launcher' -] diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py deleted file mode 100644 index f70cc3ef46..0000000000 --- a/mmengine/dist/dist.py +++ /dev/null @@ -1,1184 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -import pickle -import shutil -import tempfile -from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import Tensor -from torch import distributed as torch_dist -from torch._utils import (_flatten_dense_tensors, _take_tensors, - _unflatten_dense_tensors) -from torch.distributed import ProcessGroup -from itertools import zip_longest, chain -import mmengine -from .utils import (get_world_size, get_rank, get_backend, get_dist_info, - get_default_group, barrier, get_data_device, - get_comm_device, cast_data_device) -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION -from mmengine.device import is_npu_available - - -def _get_reduce_op(name: str) -> torch_dist.ReduceOp: - op_mappings = { - 'sum': torch_dist.ReduceOp.SUM, - 'product': torch_dist.ReduceOp.PRODUCT, - 'min': torch_dist.ReduceOp.MIN, - 'max': torch_dist.ReduceOp.MAX, - 'band': torch_dist.ReduceOp.BAND, - 'bor': torch_dist.ReduceOp.BOR, - 'bxor': torch_dist.ReduceOp.BXOR, - } - - if name.lower() not in op_mappings: - raise ValueError( - f'reduce op should be one of {op_mappings.keys()}, bug got {name}') - - return op_mappings[name.lower()] - - -def all_reduce(data: Tensor, - op: str = 'sum', - group: Optional[ProcessGroup] = None) -> None: - """Reduces the tensor data across all machines in such a way that all get - the final result. - - After the call ``data`` is going to be bitwise identical in all - processes. - - Note: - Calling ``all_reduce`` in non-distributed environment does nothing. - - Args: - data (Tensor): Input and output of the collective. The function - operates in-place. - op (str): Operation to reduce data. Defaults to 'sum'. Optional values - are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and - 'bxor'. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = torch.arange(2, dtype=torch.int64) - >>> dist.all_reduce(data) - >>> data - tensor([0, 1]) - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank - >>> data - tensor([1, 2]) # Rank 0 - tensor([3, 4]) # Rank 1 - >>> dist.all_reduce(data, op=dist.ReduceOp.SUM) - >>> data - tensor([4, 6]) # Rank 0 - tensor([4, 6]) # Rank 1 - """ - world_size = get_world_size(group) - if world_size > 1: - if group is None: - group = get_default_group() - - input_device = get_data_device(data) - backend_device = get_comm_device(group) - data_on_device = cast_data_device(data, backend_device) - - # pytorch does not support 'mean' operation so we fall back to support - # it with 'sum' operation. - if op.lower() == 'mean': - torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group) - - # use true_divide to handle torch1.6.0 throws an RuntimeError when - # the type of `data_on_device` is int64 - data_on_device = torch.true_divide(data_on_device, world_size) - else: - torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group) - - cast_data_device(data_on_device, input_device, out=data) - - -def all_gather(data: Tensor, - group: Optional[ProcessGroup] = None) -> List[Tensor]: - """Gather data from the whole group in a list. - - Note: - Calling ``all_gather`` in non-distributed environment does nothing - and just returns a list containing :attr:`data` itself. - - Note: - Unlike PyTorch ``torch.distributed.all_gather``, :meth:`all_gather` in - MMEngine does not pass in an empty list ``gather_list`` and returns - the ``gather_list`` directly, which is more convenient. The difference - between their interfaces is as below: - - - MMEngine: all_gather(data, group) -> gather_list - - PyTorch: all_gather(gather_list, data, group) -> None - - Args: - data (Tensor): Tensor to be gathered. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - list[Tensor]: Return a list containing data from the whole group if - in distributed environment, otherwise a list only containing - :attr:`data` itself. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = torch.arange(2, dtype=torch.int64) - >>> data - tensor([0, 1]) - >>> output = dist.all_gather(data) - >>> output - [tensor([0, 1])] - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank - >>> data - tensor([1, 2]) # Rank 0 - tensor([3, 4]) # Rank 1 - >>> output = dist.all_gather(data) - >>> output - [tensor([1, 2]), tensor([3, 4])] # Rank 0 - [tensor([1, 2]), tensor([3, 4])] # Rank 1 - """ - world_size = get_world_size(group) - if world_size == 1: - return [data] - - if group is None: - group = get_default_group() - - input_device = get_data_device(data) - backend_device = get_comm_device(group) - data_on_device = cast_data_device(data, backend_device) - - gather_list = [ - torch.empty_like(data, device=backend_device) - for _ in range(world_size) - ] - - torch_dist.all_gather(gather_list, data_on_device, group) - - return cast_data_device(gather_list, input_device) # type: ignore - - -def gather(data: Tensor, - dst: int = 0, - group: Optional[ProcessGroup] = None) -> List[Optional[Tensor]]: - """Gather data from the whole group to ``dst`` process. - - Note: - Calling ``gather`` in non-distributed environment dose nothing - and just returns a list containing :attr:`data` itself. - - Note: - ``NCCL`` backend does not support ``gather``. - - Note: - Unlike PyTorch ``torch.distributed.gather``, :meth:`gather` in - MMEngine does not pass in an empty list ``gather_list`` and returns - the ``gather_list`` directly, which is more convenient. The difference - between their interfaces is as below: - - - MMEngine: gather(data, dst, group) -> gather_list - - PyTorch: gather(data, gather_list, dst, group) -> None - - Args: - data (Tensor): Tensor to be gathered. CUDA tensor is not supported. - dst (int): Destination rank. Defaults to 0. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - list[Tensor]: ``dst`` process will get a list of tensor gathering from - the whole group. Other process will get a empty list. If in - non-distributed environment, just return a list containing - :attr:`data` itself. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = torch.arange(2, dtype=torch.int64) - >>> data - tensor([0, 1]) - >>> output = dist.gather(data) - >>> output - [tensor([0, 1])] - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank - >>> data - tensor([1, 2]) # Rank 0 - tensor([3, 4]) # Rank 1 - >>> output = dist.gather(data) - >>> output - [tensor([1, 2]), tensor([3, 4])] # Rank 0 - [] # Rank 1 - """ - world_size = get_world_size(group) - if world_size == 1: - return [data] - - if group is None: - group = get_default_group() - - input_device = get_data_device(data) - backend_device = get_comm_device(group) - - if get_rank(group) == dst: - gather_list = [ - torch.empty_like(data, device=backend_device) - for _ in range(world_size) - ] - else: - gather_list = [] - - torch_dist.gather(data, gather_list, dst, group) - - if get_rank(group) == dst: - return cast_data_device(gather_list, input_device) # type: ignore - else: - return gather_list - - -def broadcast(data: Tensor, - src: int = 0, - group: Optional[ProcessGroup] = None) -> None: - """Broadcast the data from ``src`` process to the whole group. - - ``data`` must have the same number of elements in all processes - participating in the collective. - - Note: - Calling ``broadcast`` in non-distributed environment does nothing. - - Args: - data (Tensor): Data to be sent if ``src`` is the rank of current - process, and data to be used to save received data otherwise. - src (int): Source rank. Defaults to 0. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = torch.arange(2, dtype=torch.int64) - >>> data - tensor([0, 1]) - >>> dist.broadcast(data) - >>> data - tensor([0, 1]) - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank - >>> data - tensor([1, 2]) # Rank 0 - tensor([3, 4]) # Rank 1 - >>> dist.broadcast(data) - >>> data - tensor([1, 2]) # Rank 0 - tensor([1, 2]) # Rank 1 - """ - if get_world_size(group) > 1: - if group is None: - group = get_default_group() - - input_device = get_data_device(data) - backend_device = get_comm_device(group) - data_on_device = cast_data_device(data, backend_device) - # broadcast requires tensor is contiguous - data_on_device = data_on_device.contiguous() # type: ignore - torch_dist.broadcast(data_on_device, src, group) - - if get_rank(group) != src: - cast_data_device(data_on_device, input_device, data) - - -def sync_random_seed(group: Optional[ProcessGroup] = None) -> int: - """Synchronize a random seed to all processes. - - In distributed sampling, different ranks should sample non-overlapped - data in the dataset. Therefore, this function is used to make sure that - each rank shuffles the data indices in the same order based - on the same seed. Then different ranks could use different indices - to select non-overlapped data from the same data list. - - Args: - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - int: Random seed. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> seed = dist.sync_random_seed() - >>> seed # which a random number - 587791752 - - >>> distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> seed = dist.sync_random_seed() - >>> seed - 587791752 # Rank 0 - 587791752 # Rank 1 - """ - seed = np.random.randint(2**31) - if get_world_size(group) == 1: - return seed - - if group is None: - group = get_default_group() - - backend_device = get_comm_device(group) - - if get_rank(group) == 0: - random_num = torch.tensor(seed, dtype=torch.int32).to(backend_device) - else: - random_num = torch.tensor(0, dtype=torch.int32).to(backend_device) - - torch_dist.broadcast(random_num, src=0, group=group) - - return random_num.item() - - -def _object_to_tensor(obj: Any) -> Tuple[Tensor, Tensor]: - """Serialize picklable python object to tensor.""" - byte_storage = torch.ByteStorage.from_buffer(pickle.dumps(obj)) - # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor - # and specifying dtype. Otherwise, it will cause 100X slowdown. - # See: https://github.com/pytorch/pytorch/issues/65696 - byte_tensor = torch.ByteTensor(byte_storage) - local_size = torch.LongTensor([byte_tensor.numel()]) - return byte_tensor, local_size - - -def _tensor_to_object(tensor: Tensor, tensor_size: int) -> Any: - """Deserialize tensor to picklable python object.""" - buf = tensor.cpu().numpy().tobytes()[:tensor_size] - return pickle.loads(buf) - - -def _broadcast_object_list(object_list: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None) -> None: - """Broadcast picklable objects in ``object_list`` to the whole group. - - Similar to :func:`broadcast`, but Python objects can be passed in. Note - that all objects in ``object_list`` must be picklable in order to be - broadcasted. - """ - if torch_dist.distributed_c10d._rank_not_in_group(group): - return - - my_rank = get_rank() - # Serialize object_list elements to tensors on src rank. - if my_rank == src: - tensor_list, size_list = zip( - *[_object_to_tensor(obj) for obj in object_list]) - object_sizes_tensor = torch.cat(size_list) - else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) - - # Current device selection. - # To preserve backwards compatibility, ``device`` is ``None`` by default. - # in which case we run current logic of device selection, i.e. - # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In - # the case it is not ``None`` we move the size and object tensors to be - # broadcasted to this device. - group_backend = get_backend(group) - is_nccl_backend = group_backend == torch_dist.Backend.NCCL - current_device = torch.device('cpu') - is_hccl_backend = group_backend == 'hccl' - is_cncl_backend = group_backend == 'cncl' - is_mccl_backend = group_backend == 'mccl' - if is_hccl_backend: - current_device = torch.device('npu', torch.npu.current_device()) - object_sizes_tensor = object_sizes_tensor.to(current_device) - elif is_cncl_backend: - current_device = torch.device('mlu', torch.mlu.current_device()) - object_sizes_tensor = object_sizes_tensor.to(current_device) - elif is_mccl_backend: - current_device = torch.device('musa', torch.musa.current_device()) - object_sizes_tensor = object_sizes_tensor.to(current_device) - elif is_nccl_backend: - # See note about using torch.cuda.current_device() here in - # docstring. We cannot simply use my_rank since rank == device is - # not necessarily true. - current_device = torch.device('cuda', torch.cuda.current_device()) - object_sizes_tensor = object_sizes_tensor.to(current_device) - - # Broadcast object sizes - torch_dist.broadcast(object_sizes_tensor, src=src, group=group) - - # Concatenate and broadcast serialized object tensors - if my_rank == src: - object_tensor = torch.cat(tensor_list) - else: - object_tensor = torch.empty( - torch.sum(object_sizes_tensor).int().item(), - dtype=torch.uint8, - ) - - if is_nccl_backend or is_hccl_backend or is_cncl_backend: - object_tensor = object_tensor.to(current_device) - torch_dist.broadcast(object_tensor, src=src, group=group) - # Deserialize objects using their stored sizes. - offset = 0 - if my_rank != src: - for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] - obj_view = obj_view.type(torch.uint8) - if obj_view.device != torch.device('cpu'): - obj_view = obj_view.cpu() - offset += obj_size - object_list[i] = _tensor_to_object(obj_view, obj_size) - - -def broadcast_object_list(data: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None) -> None: - """Broadcasts picklable objects in ``object_list`` to the whole group. - Similar to :func:`broadcast`, but Python objects can be passed in. Note - that all objects in ``object_list`` must be picklable in order to be - broadcasted. - - Note: - Calling ``broadcast_object_list`` in non-distributed environment does - nothing. - - Args: - data (List[Any]): List of input objects to broadcast. - Each object must be picklable. Only objects on the ``src`` rank - will be broadcast, but each rank must provide lists of equal sizes. - src (int): Source rank from which to broadcast ``object_list``. - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - device (``torch.device``, optional): If not None, the objects are - serialized and converted to tensors which are moved to the - ``device`` before broadcasting. Default is ``None``. - - Note: - For NCCL-based process groups, internal tensor representations of - objects must be moved to the GPU device before communication starts. - In this case, the used device is given by - ``torch.cuda.current_device()`` and it is the user's responsibility to - ensure that this is correctly set so that each rank has an individual - GPU, via ``torch.cuda.set_device()``. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = ['foo', 12, {1: 2}] - >>> dist.broadcast_object_list(data) - >>> data - ['foo', 12, {1: 2}] - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> if dist.get_rank() == 0: - >>> # Assumes world_size of 3. - >>> data = ["foo", 12, {1: 2}] # any picklable object - >>> else: - >>> data = [None, None, None] - >>> dist.broadcast_object_list(data) - >>> data - ["foo", 12, {1: 2}] # Rank 0 - ["foo", 12, {1: 2}] # Rank 1 - """ - assert isinstance(data, list) - - if get_world_size(group) > 1: - if group is None: - group = get_default_group() - - if digit_version(TORCH_VERSION) >= digit_version( - '1.8.0') and not is_npu_available(): - torch_dist.broadcast_object_list(data, src, group) - else: - _broadcast_object_list(data, src, group) - - -def all_reduce_dict(data: Dict[str, Tensor], - op: str = 'sum', - group: Optional[ProcessGroup] = None) -> None: - """Reduces the dict across all machines in such a way that all get the - final result. - - The code is modified from https://github.com/Megvii- - BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. - - Args: - data (dict[str, Tensor]): Data to be reduced. - op (str): Operation to reduce data. Defaults to 'sum'. Optional values - are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and - 'bxor'. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = { - 'key1': torch.arange(2, dtype=torch.int64), - 'key2': torch.arange(3, dtype=torch.int64) - } - >>> dist.all_reduce_dict(data) - >>> data - {'key1': tensor([0, 1]), 'key2': tensor([0, 1, 2])} - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> data = { - 'key1': torch.arange(2, dtype=torch.int64), - 'key2': torch.arange(3, dtype=torch.int64) - } - >>> dist.all_reduce_dict(data) - >>> data - {'key1': tensor([0, 2]), 'key2': tensor([0, 2, 4])} # Rank 0 - {'key1': tensor([0, 2]), 'key2': tensor([0, 2, 4])} # Rank 1 - """ - assert isinstance(data, dict) - - world_size = get_world_size(group) - if world_size > 1: - - if group is None: - group = get_default_group() - - # ensure keys are consistent across processes - keys = sorted(data.keys()) - tensor_shapes = [data[k].shape for k in keys] - tensor_sizes = [data[k].numel() for k in keys] - - if digit_version(TORCH_VERSION) == digit_version('1.5.0'): - # `torch.cat` in torch1.5 can not concatenate different types so - # we fallback to convert them all to float type. - flatten_tensor = torch.cat( - [data[k].flatten().float() for k in keys]) - else: - flatten_tensor = torch.cat([data[k].flatten() for k in keys]) - - all_reduce(flatten_tensor, op=op, group=group) - - split_tensors = [ - x.reshape(shape) for x, shape in zip( - torch.split(flatten_tensor, tensor_sizes), tensor_shapes) - ] - - for k, v in zip(keys, split_tensors): - data[k] = v - - -def _all_gather_object(object_list: List[Any], - obj: Any, - group: Optional[ProcessGroup] = None) -> None: - """Gather picklable objects from the whole group into a list. - - Similar to :func:`all_gather`, but Python objects can be passed in. - Note that the object must be picklable in order to be gathered. - - Args: - object_list (list[Any]): Output list. It should be correctly sized as - the size of the group for this collective and will contain the - output. - object (Any): Pickable Python object to be broadcast from current - process. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - None. If the calling rank is part of this group, the output of the - collective will be populated into the input ``object_list``. If the - calling rank is not part of the group, the passed in ``object_list`` - will be unmodified. - """ - if torch_dist.distributed_c10d._rank_not_in_group(group): - return - - input_tensor, local_size = _object_to_tensor(obj) - group_backend = get_backend(group) - current_device = torch.device('cpu') - is_nccl_backend = group_backend == torch_dist.Backend.NCCL - is_mccl_backend = group_backend == 'mccl' - if is_nccl_backend: - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device('cuda', torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - elif is_mccl_backend: - # See note about using torch.musa.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device('musa', torch.musa.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - # Gather all local sizes. This is so that we can find the max size, and - # index until the correct size when deserializing the tensors. - group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - # Allgather tensor sizes - torch_dist.all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) - # Resize tensor to max size across all ranks. - input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=current_device) - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i:max_object_size * (i + 1)] - for i in range(group_size) - ] - torch_dist.all_gather(output_tensors, input_tensor, group=group) - # Deserialize outputs back to object. - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - if tensor.device != torch.device('cpu'): - tensor = tensor.cpu() - tensor_size = object_size_list[i] - object_list[i] = _tensor_to_object(tensor, tensor_size) - - -def all_gather_object(data: Any, - group: Optional[ProcessGroup] = None) -> List[Any]: - """Gather picklable objects from the whole group into a list. Similar to - :func:`all_gather`, but Python objects can be passed in. Note that the - object must be picklable in order to be gathered. - - Note: - Calling ``all_gather_object`` in non-distributed environment does - nothing and just returns a list containing :attr:`data` itself. - - Note: - Unlike PyTorch ``torch.distributed.all_gather_object``, - :meth:`all_gather_object` in MMEngine does not pass in an empty list - ``gather_list`` and returns the ``gather_list`` directly, which is - more convenient. The difference between their interfaces is as below: - - - MMEngine: all_gather_object(data, group) -> gather_list - - PyTorch: all_gather_object(gather_list, data, group) -> None - - Args: - data (Any): Pickable Python object to be broadcast from current - process. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - list[Tensor]: Return a list containing data from the whole group if - in distributed environment, otherwise a list only containing - :attr:`data` itself. - - Note: - For NCCL-based process groups, internal tensor representations - of objects must be moved to the GPU device before communication starts. - In this case, the used device is given by - ``torch.cuda.current_device()`` and it is the user's responsibility to - ensure that this is correctly set so that each rank has an individual - GPU, via ``torch.cuda.set_device()``. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = ['foo', 12, {1: 2}] # any picklable object - >>> gather_objects = dist.all_gather_object(data[dist.get_rank()]) - >>> output - ['foo'] - - >>> # distributed environment - >>> # We have 3 process groups, 3 ranks. - >>> output = dist.all_gather_object(data[dist.get_rank()]) - >>> output - ['foo', 12, {1: 2}] # Rank 0 - ['foo', 12, {1: 2}] # Rank 1 - ['foo', 12, {1: 2}] # Rank 2 - """ - world_size = get_world_size(group) - if world_size == 1: - return [data] - - if group is None: - group = get_default_group() - - gather_list = [None] * world_size - - if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): - torch_dist.all_gather_object(gather_list, data, group) - else: - _all_gather_object(gather_list, data, group) - - return gather_list - - -def _validate_output_list_for_rank(my_rank: int, dst: int, - gather_list: Optional[list]) -> None: - """Validate whether ``gather_list`` is None in non-dst ranks.""" - if dst == my_rank: - if not gather_list: - raise ValueError( - 'Argument ``gather_list`` must be specified on destination ' - 'rank.') - elif gather_list: - raise ValueError('Argument ``gather_list`` must NOT be specified ' - 'on non-destination ranks.') - - -def _gather_object(obj: Any, - object_gather_list=None, - dst: int = 0, - group: Optional[ProcessGroup] = None) -> None: - """Gathers picklable objects from the whole group in a single process. - - Similar to :func:`gather`, but Python objects can be passed in. Note that - the object must be picklable in order to be gathered. - - Args: - obj (Any): Input object. Must be picklable. - object_gather_list (list[Any], optional): Output list. On the ``dst`` - rank, it should be correctly sized as the size of the group for - this collective and will contain the output. Must be ``None`` on - non-dst ranks. Defaults to None. - dst (int): Destination rank. Defaults to 0. - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - """ - if torch_dist.distributed_c10d._rank_not_in_group(group): - return - - # Ensure object_gather_list is specified appopriately. - my_rank = get_rank() - _validate_output_list_for_rank(my_rank, dst, object_gather_list) - input_tensor, local_size = _object_to_tensor(obj) - group_backend = get_backend(group) - current_device = torch.device('cpu') - is_nccl_backend = group_backend == torch_dist.Backend.NCCL - is_mccl_backend = group_backend == 'mccl' - if is_nccl_backend: - current_device = torch.device('cuda', torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - elif is_mccl_backend: - current_device = torch.device('musa', torch.musa.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - # Gather all local sizes. This is so that we can find the max size, and - # index until the correct size when deserializing the tensors. - group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - # Allgather tensor sizes. An all-gather is needed here despite this being a - # gather, since each rank needs to broadcast a tensor of the same (maximal) - # size. - torch_dist.all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) - # Resize tensor to max size across all ranks. - input_tensor.resize_(max_object_size) - # Avoid populating output tensors if the result won't be gathered on this - # rank. - if my_rank == dst: - coalesced_output_tensor = torch.empty( - max_object_size * group_size, - dtype=torch.uint8, - device=current_device) - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i:max_object_size * - (i + 1)] for i in range(group_size) - ] - # All ranks call gather with equal-sized tensors. - torch_dist.gather( - input_tensor, - gather_list=output_tensors if my_rank == dst else None, - dst=dst, - group=group, - ) - if my_rank != dst: - return - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - tensor_size = object_size_list[i] - object_gather_list[i] = _tensor_to_object(tensor, tensor_size) - - -def gather_object(data: Any, - dst: int = 0, - group: Optional[ProcessGroup] = None) -> Optional[List[Any]]: - """Gathers picklable objects from the whole group in a single process. - Similar to :func:`gather`, but Python objects can be passed in. Note that - the object must be picklable in order to be gathered. - - Note: - ``NCCL backend`` does not support ``gather_object``. - - Note: - Unlike PyTorch ``torch.distributed.gather_object``, - :meth:`gather_object` in MMEngine does not pass in an empty list - ``gather_list`` and returns the ``gather_list`` directly, which is - more convenient. The difference between their interfaces is as below: - - - MMEngine: gather_object(data, dst, group) -> gather_list - - PyTorch: gather_object(data, gather_list, data, group) -> None - - Args: - data (Any): Input object. Must be picklable. - dst (int): Destination rank. Defaults to 0. - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - list[Any]. On the ``dst`` rank, return ``gather_list`` which contains - the output of the collective. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = ['foo', 12, {1: 2}] # any picklable object - >>> gather_objects = dist.gather_object(data[dist.get_rank()]) - >>> output - ['foo'] - - >>> # distributed environment - >>> # We have 3 process groups, 3 ranks. - >>> dist.gather_object(gather_objects[dist.get_rank()], dst=0) - >>> output - ['foo', 12, {1: 2}] # Rank 0 - None # Rank 1 - None # Rank 2 - """ - world_size = get_world_size(group) - if world_size == 1: - return [data] - - if group is None: - group = get_default_group() - - gather_list = [None] * world_size if get_rank(group) == dst else None - - if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): - torch_dist.gather_object(data, gather_list, dst, group) - else: - _gather_object(data, gather_list, dst, group) - - return gather_list - - -def collect_results(results: list, - size: int, - device: str = 'cpu', - tmpdir: Optional[str] = None) -> Optional[list]: - """Collected results in distributed environments. - - Args: - results (list[object]): Result list containing result parts to be - collected. Each item of ``result_part`` should be a picklable - object. - size (int): Size of the results, commonly equal to length of - the results. - device (str): Device name. Optional values are 'cpu', 'gpu' or 'npu'. - tmpdir (str | None): Temporal directory for collected results to - store. If set to None, it will create a temporal directory for it. - ``tmpdir`` should be None when device is 'gpu' or 'npu'. - Defaults to None. - - Returns: - list or None: The collected results. - - Examples: - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> import mmengine.dist as dist - >>> if dist.get_rank() == 0: - data = ['foo', {1: 2}] - else: - data = [24, {'a': 'b'}] - >>> size = 4 - >>> output = dist.collect_results(data, size, device='cpu') - >>> output - ['foo', 24, {1: 2}, {'a': 'b'}] # rank 0 - None # rank 1 - """ - if device not in ['gpu', 'cpu', 'npu']: - raise NotImplementedError( - f"device must be 'cpu' , 'gpu' or 'npu', but got {device}") - - if device == 'gpu' or device == 'npu': - assert tmpdir is None, f'tmpdir should be None when device is {device}' - return _collect_results_device(results, size) - else: - return collect_results_cpu(results, size, tmpdir) - - -def collect_results_cpu(result_part: list, - size: int, - tmpdir: Optional[str] = None) -> Optional[list]: - """Collect results under cpu mode. - - On cpu mode, this function will save the results on different gpus to - ``tmpdir`` and collect them by the rank 0 worker. - - Args: - result_part (list): Result list containing result parts - to be collected. Each item of ``result_part`` should be a picklable - object. - size (int): Size of the results, commonly equal to length of - the results. - tmpdir (str | None): Temporal directory for collected results to - store. If set to None, it will create a random temporal directory - for it. Defaults to None. - - Returns: - list or None: The collected results. - - Examples: - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> import mmengine.dist as dist - >>> if dist.get_rank() == 0: - data = ['foo', {1: 2}] - else: - data = [24, {'a': 'b'}] - >>> size = 4 - >>> output = dist.collect_results_cpu(data, size) - >>> output - ['foo', 24, {1: 2}, {'a': 'b'}] # rank 0 - None # rank 1 - """ - rank, world_size = get_dist_info() - if world_size == 1: - return result_part[:size] - - # create a tmp dir if it is not specified - if tmpdir is None: - MAX_LEN = 512 - # 32 is whitespace - dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8) - if rank == 0: - mmengine.mkdir_or_exist('.dist_test') - tmpdir = tempfile.mkdtemp(dir='.dist_test') - tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8) - dir_tensor[:len(tmpdir)] = tmpdir - broadcast(dir_tensor, 0) - tmpdir = dir_tensor.numpy().tobytes().decode().rstrip() - else: - mmengine.mkdir_or_exist(tmpdir) - - # dump the part result to the dir - with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore - pickle.dump(result_part, f, protocol=2) - - barrier() - - # collect all parts - if rank != 0: - return None - else: - # load results of all parts from tmp dir - part_list = [] - for i in range(world_size): - path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore - if not osp.exists(path): - raise FileNotFoundError( - f'{tmpdir} is not an shared directory for ' - f'rank {i}, please make sure {tmpdir} is a shared ' - 'directory for all ranks!') - with open(path, 'rb') as f: - part_list.append(pickle.load(f)) - # sort the results - ordered_results = [] - zipped_results = zip_longest(*part_list) - ordered_results = [ - i for i in chain.from_iterable(zipped_results) if i is not None - ] - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - # remove tmp dir - shutil.rmtree(tmpdir) # type: ignore - return ordered_results - - -def _collect_results_device(result_part: list, size: int) -> Optional[list]: - """Collect results under gpu or npu mode.""" - rank, world_size = get_dist_info() - if world_size == 1: - return result_part[:size] - - # gather all result part. Note that NCCL does not support gather so use - # all_gather_object instead. - part_list = all_gather_object(result_part) - - if rank == 0: - # sort the results - ordered_results = [] - zipped_results = zip_longest(*part_list) - ordered_results = [ - i for i in chain.from_iterable(zipped_results) if i is not None - ] - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - return ordered_results - else: - return None - - -def collect_results_gpu(result_part: list, size: int) -> Optional[list]: - """Collect results under gpu mode. - - On gpu mode, this function will encode results to gpu tensors and use gpu - communication for results collection. - - Args: - result_part (list[object]): Result list containing result parts - to be collected. Each item of ``result_part`` should be a picklable - object. - size (int): Size of the results, commonly equal to length of - the results. - - Returns: - list or None: The collected results. - - Examples: - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> import mmengine.dist as dist - >>> if dist.get_rank() == 0: - data = ['foo', {1: 2}] - else: - data = [24, {'a': 'b'}] - >>> size = 4 - >>> output = dist.collect_results_gpu(data, size) - >>> output - ['foo', 24, {1: 2}, {'a': 'b'}] # rank 0 - None # rank 1 - """ - return _collect_results_device(result_part, size) - - -def _all_reduce_coalesced(tensors: List[torch.Tensor], - bucket_size_mb: int = -1, - op: str = 'sum', - group: Optional[ProcessGroup] = None) -> None: - """All-reduce a sequence of tensors as a whole. - - Args: - tensors (List[torch.Tensor]): A sequence of tensors to be - all-reduced. - bucket_size_mb (int): The limit of each chunk in megabytes - for grouping tensors into chunks. Defaults to -1. - op (str): Operation to reduce data. Defaults to 'sum'. Optional values - are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and - 'bxor'. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - """ - if bucket_size_mb > 0: - bucket_size_bytes = bucket_size_mb * 1024 * 1024 - buckets = _take_tensors(tensors, bucket_size_bytes) - else: - buckets = OrderedDict() - for tensor in tensors: - tp = tensor.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(tensor) - buckets = buckets.values() - - for bucket in buckets: - flat_tensors = _flatten_dense_tensors(bucket) - all_reduce(flat_tensors, op=op, group=group) - for tensor, synced in zip( - bucket, _unflatten_dense_tensors(flat_tensors, bucket)): - tensor.copy_(synced) - - -def all_reduce_params(params: Union[List, Generator[torch.Tensor, None, None]], - coalesce: bool = True, - bucket_size_mb: int = -1, - op: str = 'sum', - group: Optional[ProcessGroup] = None) -> None: - """All-reduce parameters. - - Args: - params (List or Generator[torch.Tensor, None, None]): List of - parameters or buffers of a model. - coalesce (bool, optional): Whether to reduce parameters as a whole. - Defaults to True. - bucket_size_mb (int, optional): Size of bucket, the unit is MB. - Defaults to -1. - op (str): Operation to reduce data. Defaults to 'sum'. Optional values - are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and - 'bxor'. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Examples: - >>> import torch - >>> import mmengine.dist as dist - - >>> # non-distributed environment - >>> data = [torch.arange(2), torch.arange(3)] - >>> dist.all_reduce_params(data) - >>> data - [tensor([0, 1]), tensor([0, 1, 2])] - - >>> # distributed environment - >>> # We have 2 process groups, 2 ranks. - >>> if dist.get_rank() == 0: - ... data = [torch.tensor([1, 2]), torch.tensor([3, 4])] - ... else: - ... data = [torch.tensor([2, 3]), torch.tensor([4, 5])] - - >>> dist.all_reduce_params(data) - >>> data - [torch.tensor([3, 5]), torch.tensor([7, 9])] - """ - world_size = get_world_size(group) - if world_size == 1: - return - params_data = [param.data for param in params] - if coalesce: - _all_reduce_coalesced(params_data, bucket_size_mb, op=op, group=group) - else: - for tensor in params_data: - all_reduce(tensor, op=op, group=group) diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py deleted file mode 100644 index 5d32cec36b..0000000000 --- a/mmengine/dist/utils.py +++ /dev/null @@ -1,623 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import datetime -import functools -import os -import subprocess -from typing import Callable, Optional, Tuple, Union - -import numpy as np -import torch -import torch.multiprocessing as mp -from torch import Tensor -from torch import distributed as torch_dist -from torch.distributed import ProcessGroup -from mmengine.device import (is_mlu_available, is_npu_available, - is_musa_available) - -from collections.abc import Iterable, Mapping - -_LOCAL_PROCESS_GROUP = None - - -def is_distributed() -> bool: - """Return True if distributed environment has been initialized.""" - return torch_dist.is_available() and torch_dist.is_initialized() - - -def get_local_group() -> Optional[ProcessGroup]: - """Return local process group.""" - if not is_distributed(): - return None - - if _LOCAL_PROCESS_GROUP is None: - raise RuntimeError('Local process group is not created, please use ' - '`init_local_group` to setup local process group.') - - return _LOCAL_PROCESS_GROUP - - -def get_default_group() -> Optional[ProcessGroup]: - """Return default process group.""" - - return torch_dist.distributed_c10d._get_default_group() - - -def infer_launcher(): - if 'WORLD_SIZE' in os.environ: - return 'pytorch' - elif 'SLURM_NTASKS' in os.environ: - return 'slurm' - elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: - return 'mpi' - else: - return 'none' - - -def init_dist(launcher, - backend='nccl', - init_backend='torch', - **kwargs) -> None: - """Initialize distributed environment. - - Args: - launcher (str): Way to launcher multi processes. Supported launchers - are 'pytorch', 'mpi' and 'slurm'. - backend (str): Communication Backends. Supported backends are 'nccl', - 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: keyword arguments are passed to ``init_process_group``. - """ - timeout = kwargs.get('timeout', None) - if timeout is not None: - # If a timeout (in seconds) is specified, it must be converted - # to a timedelta object before forwarding the call to - # the respective backend, because they expect a timedelta object. - try: - kwargs['timeout'] = datetime.timedelta(seconds=timeout) - except TypeError as exception: - raise TypeError( - f'Timeout for distributed training must be provided as ' - f"timeout in seconds, but we've received the type " - f'{type(timeout)}. Please specify the timeout like this: ' - f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception - if mp.get_start_method(allow_none=True) is None: - mp.set_start_method('spawn') - if launcher == 'pytorch': - _init_dist_pytorch(backend, init_backend=init_backend, **kwargs) - elif launcher == 'mpi': - _init_dist_mpi(backend, **kwargs) - elif launcher == 'slurm': - _init_dist_slurm(backend, init_backend=init_backend, **kwargs) - else: - raise ValueError(f'Invalid launcher type: {launcher}') - - -def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: - """Initialize distributed environment with PyTorch launcher. - - Args: - backend (str): Backend of torch.distributed. Supported backends are - 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: keyword arguments are passed to ``init_process_group``. - """ - rank = int(os.environ['RANK']) - # LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1 - local_rank = int(os.environ['LOCAL_RANK']) - if is_mlu_available(): - import torch_mlu # noqa: F401 - torch.mlu.set_device(local_rank) - torch_dist.init_process_group( - backend='cncl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) - elif is_npu_available(): - import torch_npu # noqa: F401 - torch.npu.set_device(local_rank) - torch_dist.init_process_group( - backend='hccl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) - elif is_musa_available(): - import torch_musa # noqa: F401 - torch.musa.set_device(rank) - torch_dist.init_process_group( - backend='mccl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) - else: - torch.cuda.set_device(local_rank) - - if init_backend == 'torch': - torch_dist.init_process_group(backend=backend, **kwargs) - elif init_backend == 'deepspeed': - import deepspeed - deepspeed.init_distributed(dist_backend=backend, **kwargs) - elif init_backend == 'colossalai': - import colossalai - colossalai.launch_from_torch(backend=backend, **kwargs) - else: - raise ValueError( - 'supported "init_backend" is "torch" or "deepspeed", ' - f'but got {init_backend}') - - -def _init_dist_mpi(backend, **kwargs) -> None: - """Initialize distributed environment with MPI launcher. - - Args: - backend (str): Backend of torch.distributed. Supported backends are - 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: keyword arguments are passed to ``init_process_group``. - """ - if backend == 'smddp': - try: - import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - 'Please use an Amazon SageMaker DLC to access smdistributed: ' - 'https://github.com/aws/deep-learning-containers/blob/master' - '/available_images.md#sagemaker-framework-containers' - '-sm-support-only') from e - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - torch.cuda.set_device(local_rank) - if 'MASTER_PORT' not in os.environ: - # 29500 is torch.distributed default port - os.environ['MASTER_PORT'] = '29500' - if 'MASTER_ADDR' not in os.environ: - raise KeyError('The environment variable MASTER_ADDR is not set') - os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] - os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] - torch_dist.init_process_group(backend=backend, **kwargs) - - -def _init_dist_slurm(backend, - port=None, - init_backend='torch', - **kwargs) -> None: - """Initialize slurm distributed training environment. - - If argument ``port`` is not specified, then the master port will be system - environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system - environment variable, then a default port ``29500`` will be used. - - Args: - backend (str): Backend of torch.distributed. - port (int, optional): Master port. Defaults to None. - """ - proc_id = int(os.environ['SLURM_PROCID']) - ntasks = int(os.environ['SLURM_NTASKS']) - node_list = os.environ['SLURM_NODELIST'] - # Not sure when this environment variable could be None, so use a fallback - local_rank_env = os.environ.get('SLURM_LOCALID', None) - if local_rank_env is not None: - local_rank = int(local_rank_env) - else: - num_gpus = torch.cuda.device_count() - local_rank = proc_id % num_gpus - addr = subprocess.getoutput( - f'scontrol show hostname {node_list} | head -n1') - # specify master port - if port is not None: - os.environ['MASTER_PORT'] = str(port) - elif 'MASTER_PORT' in os.environ: - pass # use MASTER_PORT in the environment variable - else: - # 29500 is torch.distributed default port - os.environ['MASTER_PORT'] = '29500' - # use MASTER_ADDR in the environment variable if it already exists - if 'MASTER_ADDR' not in os.environ: - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(ntasks) - os.environ['LOCAL_RANK'] = str(local_rank) - os.environ['RANK'] = str(proc_id) - - if is_mlu_available(): - import torch_mlu # noqa: F401 - torch.mlu.set_device(local_rank) - torch_dist.init_process_group(backend='cncl', **kwargs) - else: - torch.cuda.set_device(local_rank) - - if init_backend == 'torch': - torch_dist.init_process_group(backend=backend, **kwargs) - elif init_backend == 'deepspeed': - import deepspeed - deepspeed.init_distributed(dist_backend=backend, **kwargs) - elif init_backend == 'colossalai': - import colossalai - colossalai.launch_from_slurm( - backend=backend, - host=os.environ['MASTER_ADDR'], - port=os.environ['MASTER_PORT'], - **kwargs, - ) - else: - raise ValueError( - 'supported "init_backend" is "torch" or "deepspeed", ' - f'but got {init_backend}') - - -def init_local_group(node_rank: int, num_gpus_per_node: int): - """Setup the local process group. - - Setup a process group which only includes processes that on the same - machine as the current process. - - The code is modified from - https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py - - Args: - node_rank (int): Rank of machines used for training. - num_gpus_per_node (int): Number of gpus used for training in a single - machine. - """ # noqa: W501 - global _LOCAL_PROCESS_GROUP - assert _LOCAL_PROCESS_GROUP is None - - ranks = list( - range(node_rank * num_gpus_per_node, - (node_rank + 1) * num_gpus_per_node)) - _LOCAL_PROCESS_GROUP = torch_dist.new_group(ranks) - - -def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]: - """Return the backend of the given process group. - - Note: - Calling ``get_backend`` in non-distributed environment will return - None. - - Args: - group (ProcessGroup, optional): The process group to work on. The - default is the general main process group. If another specific - group is specified, the calling process must be part of - :attr:`group`. Defaults to None. - - Returns: - str or None: Return the backend of the given process group as a lower - case string if in distributed environment, otherwise None. - """ - if is_distributed(): - # handle low versions of torch like 1.5.0 which does not support - # passing in None for group argument - if group is None: - group = get_default_group() - return torch_dist.get_backend(group) - else: - return None - - -def get_world_size(group: Optional[ProcessGroup] = None) -> int: - """Return the number of the given process group. - - Note: - Calling ``get_world_size`` in non-distributed environment will return - 1. - - Args: - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - int: Return the number of processes of the given process group if in - distributed environment, otherwise 1. - """ - if is_distributed(): - # handle low versions of torch like 1.5.0 which does not support - # passing in None for group argument - if group is None: - group = get_default_group() - return torch_dist.get_world_size(group) - else: - return 1 - - -def get_rank(group: Optional[ProcessGroup] = None) -> int: - """Return the rank of the given process group. - - Rank is a unique identifier assigned to each process within a distributed - process group. They are always consecutive integers ranging from 0 to - ``world_size``. - - Note: - Calling ``get_rank`` in non-distributed environment will return 0. - - Args: - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - int: Return the rank of the process group if in distributed - environment, otherwise 0. - """ - - if is_distributed(): - # handle low versions of torch like 1.5.0 which does not support - # passing in None for group argument - if group is None: - group = get_default_group() - return torch_dist.get_rank(group) - else: - return 0 - - -def get_local_size() -> int: - """Return the number of the current node. - - Returns: - int: Return the number of processes in the current node if in - distributed environment, otherwise 1. - """ - if not is_distributed(): - return 1 - - if _LOCAL_PROCESS_GROUP is None: - raise RuntimeError('Local process group is not created, please use ' - '`init_local_group` to setup local process group.') - - return torch_dist.get_world_size(_LOCAL_PROCESS_GROUP) - - -def get_local_rank() -> int: - """Return the rank of current process in the current node. - - Returns: - int: Return the rank of current process in the current node if in - distributed environment, otherwise 0 - """ - if not is_distributed(): - return 0 - - if _LOCAL_PROCESS_GROUP is None: - raise RuntimeError('Local process group is not created, please use ' - '`init_local_group` to setup local process group.') - - return torch_dist.get_rank(_LOCAL_PROCESS_GROUP) - - -def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]: - """Get distributed information of the given process group. - - Note: - Calling ``get_dist_info`` in non-distributed environment will return - (0, 1). - - Args: - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - tuple[int, int]: Return a tuple containing the ``rank`` and - ``world_size``. - """ - world_size = get_world_size(group) - rank = get_rank(group) - return rank, world_size - - -def is_main_process(group: Optional[ProcessGroup] = None) -> bool: - """Whether the current rank of the given process group is equal to 0. - - Args: - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - - Returns: - bool: Return True if the current rank of the given process group is - equal to 0, otherwise False. - """ - return get_rank(group) == 0 - - -def master_only(func: Callable) -> Callable: - """Decorate those methods which should be executed in master process. - - Args: - func (callable): Function to be decorated. - - Returns: - callable: Return decorated function. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - if is_main_process(): - return func(*args, **kwargs) - - return wrapper - - -def barrier(group: Optional[ProcessGroup] = None) -> None: - """Synchronize all processes from the given process group. - - This collective blocks processes until the whole group enters this - function. - - Note: - Calling ``barrier`` in non-distributed environment will do nothing. - - Args: - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Defaults to None. - """ - if is_distributed(): - # handle low versions of torch like 1.5.0 which does not support - # passing in None for group argument - if group is None: - group = get_default_group() - torch_dist.barrier(group) - - -def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device: - """Return the device of ``data``. - - If ``data`` is a sequence of Tensor, all items in ``data`` should have a - same device type. - - If ``data`` is a dict whose values are Tensor, all values should have a - same device type. - - Args: - data (Tensor or Sequence or dict): Inputs to be inferred the device. - - Returns: - torch.device: The device of ``data``. - - Examples: - >>> import torch - >>> from mmengine.dist import cast_data_device - >>> # data is a Tensor - >>> data = torch.tensor([0, 1]) - >>> get_data_device(data) - device(type='cpu') - >>> # data is a list of Tensor - >>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])] - >>> get_data_device(data) - device(type='cpu') - >>> # data is a dict - >>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])} - >>> get_data_device(data) - device(type='cpu') - """ - if isinstance(data, Tensor): - return data.device - elif isinstance(data, Mapping): - pre = None - for v in data.values(): - cur = get_data_device(v) - if pre is None: - pre = cur - else: - if cur != pre: - raise ValueError( - 'device type in data should be consistent, but got ' - f'{cur} and {pre}') - if pre is None: - raise ValueError('data should not be empty.') - return pre - elif isinstance(data, Iterable) and not isinstance(data, str): - pre = None - for item in data: - cur = get_data_device(item) - if pre is None: - pre = cur - else: - if cur != pre: - raise ValueError( - 'device type in data should be consistent, but got ' - f'{cur} and {pre}') - if pre is None: - raise ValueError('data should not be empty.') - return pre - else: - raise TypeError('data should be a Tensor, sequence of tensor or dict, ' - f'but got {data}') - - -def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: - """Return the device for communication among groups. - - Args: - group (ProcessGroup, optional): The process group to work on. - - Returns: - torch.device: The device of backend. - """ - backend = get_backend(group) - if backend == 'hccl': - import torch_npu # noqa: F401 - return torch.device('npu', torch.npu.current_device()) - elif backend == torch_dist.Backend.NCCL: - return torch.device('cuda', torch.cuda.current_device()) - elif backend == 'cncl': - import torch_mlu # noqa: F401 - return torch.device('mlu', torch.mlu.current_device()) - elif backend == 'smddp': - return torch.device('cuda', torch.cuda.current_device()) - elif backend == 'mccl': - import torch_musa - return torch.device('musa', torch_musa.current_device()) - else: - # GLOO and MPI backends use cpu device by default - return torch.device('cpu') - - -def cast_data_device( - data: Union[Tensor, Mapping, Iterable], - device: torch.device, - out: Optional[Union[Tensor, Mapping, Iterable]] = None -) -> Union[Tensor, Mapping, Iterable]: - """Recursively convert Tensor in ``data`` to ``device``. - - If ``data`` has already on the ``device``, it will not be casted again. - - Args: - data (Tensor or list or dict): Inputs to be casted. - device (torch.device): Destination device type. - out (Tensor or list or dict, optional): If ``out`` is specified, its - value will be equal to ``data``. Defaults to None. - - Returns: - Tensor or list or dict: ``data`` was casted to ``device``. - """ - if out is not None: - if type(data) is not type(out): - raise TypeError( - 'out should be the same type with data, but got data is ' - f'{type(data)} and out is {type(data)}') - - if isinstance(out, set): - raise TypeError('out should not be a set') - - if isinstance(data, Tensor): - if get_data_device(data) == device: - data_on_device = data - else: - data_on_device = data.to(device) - - if out is not None: - # modify the value of out inplace - out.copy_(data_on_device) # type: ignore - - return data_on_device - elif isinstance(data, Mapping): - data_on_device = {} - if out is not None: - data_len = len(data) - out_len = len(out) # type: ignore - if data_len != out_len: - raise ValueError('length of data and out should be same, ' - f'but got {data_len} and {out_len}') - - for k, v in data.items(): - data_on_device[k] = cast_data_device(v, device, - out[k]) # type: ignore - else: - for k, v in data.items(): - data_on_device[k] = cast_data_device(v, device) - - if len(data_on_device) == 0: - raise ValueError('data should not be empty') - - # To ensure the type of output as same as input, we use `type(data)` - # to wrap the output - return type(data)(data_on_device) # type: ignore - elif isinstance(data, Iterable) and not isinstance( - data, str) and not isinstance(data, np.ndarray): - data_on_device = [] - if out is not None: - for v1, v2 in zip(data, out): - data_on_device.append(cast_data_device(v1, device, v2)) - else: - for v in data: - data_on_device.append(cast_data_device(v, device)) - - if len(data_on_device) == 0: - raise ValueError('data should not be empty') - - return type(data)(data_on_device) # type: ignore - else: - raise TypeError('data should be a Tensor, list of tensor or dict, ' - f'but got {data}') diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py deleted file mode 100644 index e6bc78425e..0000000000 --- a/mmengine/evaluator/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .evaluator import Evaluator -from .metric import BaseMetric, DumpResults -from .utils import get_metric_value - -__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults'] diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py deleted file mode 100644 index 930ce93028..0000000000 --- a/mmengine/evaluator/evaluator.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterator, List, Optional, Sequence, Union - -from mmengine.dataset import pseudo_collate -from mmengine.registry import EVALUATOR, METRICS -from mmengine.structures import BaseDataElement -from .metric import BaseMetric - - -@EVALUATOR.register_module() -class Evaluator: - """Wrapper class to compose multiple :class:`BaseMetric` instances. - - Args: - metrics (dict or BaseMetric or Sequence): The config of metrics. - """ - - def __init__(self, metrics: Union[dict, BaseMetric, Sequence]): - self._dataset_meta: Optional[dict] = None - if not isinstance(metrics, Sequence): - metrics = [metrics] - self.metrics: List[BaseMetric] = [] - for metric in metrics: - if isinstance(metric, dict): - self.metrics.append(METRICS.build(metric)) - else: - self.metrics.append(metric) - - @property - def dataset_meta(self) -> Optional[dict]: - """Optional[dict]: Meta info of the dataset.""" - return self._dataset_meta - - @dataset_meta.setter - def dataset_meta(self, dataset_meta: dict) -> None: - """Set the dataset meta info to the evaluator and it's metrics.""" - self._dataset_meta = dataset_meta - for metric in self.metrics: - metric.dataset_meta = dataset_meta - - def process(self, - data_samples: Sequence[BaseDataElement], - data_batch: Optional[Any] = None): - """Convert ``BaseDataSample`` to dict and invoke process method of each - metric. - - Args: - data_samples (Sequence[BaseDataElement]): predictions of the model, - and the ground truth of the validation set. - data_batch (Any, optional): A batch of data from the dataloader. - """ - _data_samples = [] - for data_sample in data_samples: - if isinstance(data_sample, BaseDataElement): - _data_samples.append(data_sample.to_dict()) - else: - _data_samples.append(data_sample) - - for metric in self.metrics: - metric.process(data_batch, _data_samples) - - def evaluate(self, size: int) -> dict: - """Invoke ``evaluate`` method of each metric and collect the metrics - dictionary. - - Args: - size (int): Length of the entire validation dataset. When batch - size > 1, the dataloader may pad some data samples to make - sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data based on - this size. - - Returns: - dict: Evaluation results of all metrics. The keys are the names - of the metrics, and the values are corresponding results. - """ - metrics = {} - for metric in self.metrics: - _results = metric.evaluate(size) - - # Check metric name conflicts - for name in _results.keys(): - if name in metrics: - raise ValueError( - 'There are multiple evaluation results with the same ' - f'metric name {name}. Please make sure all metrics ' - 'have different prefixes.') - - metrics.update(_results) - return metrics - - def offline_evaluate(self, - data_samples: Sequence, - data: Optional[Sequence] = None, - chunk_size: int = 1): - """Offline evaluate the dumped predictions on the given data . - - Args: - data_samples (Sequence): All predictions and ground truth of the - model and the validation set. - data (Sequence, optional): All data of the validation set. - chunk_size (int): The number of data samples and predictions to be - processed in a batch. - """ - - # support chunking iterable objects - def get_chunks(seq: Iterator, chunk_size=1): - stop = False - while not stop: - chunk = [] - for _ in range(chunk_size): - try: - chunk.append(next(seq)) - except StopIteration: - stop = True - break - if chunk: - yield chunk - - if data is not None: - assert len(data_samples) == len(data), ( - 'data_samples and data should have the same length, but got ' - f'data_samples length: {len(data_samples)} ' - f'data length: {len(data)}') - data = get_chunks(iter(data), chunk_size) - - size = 0 - for output_chunk in get_chunks(iter(data_samples), chunk_size): - if data is not None: - data_chunk = pseudo_collate(next(data)) # type: ignore - else: - data_chunk = None - size += len(output_chunk) - self.process(output_chunk, data_chunk) - return self.evaluate(size) diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py deleted file mode 100644 index 1292ce61ec..0000000000 --- a/mmengine/evaluator/metric.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Sequence, Union - -from torch import Tensor - -from mmengine.dist import (broadcast_object_list, collect_results, - is_main_process) -from mmengine.fileio import dump -from mmengine.logging import print_log -from mmengine.registry import METRICS -from mmengine.structures import BaseDataElement - - -class BaseMetric(metaclass=ABCMeta): - """Base class for a metric. - - The metric first processes each batch of data_samples and predictions, - and appends the processed results to the results list. Then it - collects all results together from all ranks if distributed training - is used. Finally, it computes the metrics of the entire dataset. - - A subclass of class:`BaseMetric` should assign a meaningful value to the - class attribute `default_prefix`. See the argument `prefix` for details. - - Args: - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Default: None - collect_dir: (str, optional): Synchronize directory for collecting data - from different ranks. This argument should only be configured when - ``collect_device`` is 'cpu'. Defaults to None. - `New in version 0.7.3.` - """ - - default_prefix: Optional[str] = None - - def __init__(self, - collect_device: str = 'cpu', - prefix: Optional[str] = None, - collect_dir: Optional[str] = None) -> None: - if collect_dir is not None and collect_device != 'cpu': - raise ValueError('`collec_dir` could only be configured when ' - "`collect_device='cpu'`") - - self._dataset_meta: Union[None, dict] = None - self.collect_device = collect_device - self.results: List[Any] = [] - self.prefix = prefix or self.default_prefix - self.collect_dir = collect_dir - - if self.prefix is None: - print_log( - 'The prefix is not set in metric class ' - f'{self.__class__.__name__}.', - logger='current', - level=logging.WARNING) - - @property - def dataset_meta(self) -> Optional[dict]: - """Optional[dict]: Meta info of the dataset.""" - return self._dataset_meta - - @dataset_meta.setter - def dataset_meta(self, dataset_meta: dict) -> None: - """Set the dataset meta info to the metric.""" - self._dataset_meta = dataset_meta - - @abstractmethod - def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: - """Process one batch of data samples and predictions. The processed - results should be stored in ``self.results``, which will be used to - compute the metrics when all batches have been processed. - - Args: - data_batch (Any): A batch of data from the dataloader. - data_samples (Sequence[dict]): A batch of outputs from - the model. - """ - - @abstractmethod - def compute_metrics(self, results: list) -> dict: - """Compute the metrics from processed results. - - Args: - results (list): The processed results of each batch. - - Returns: - dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - def evaluate(self, size: int) -> dict: - """Evaluate the model performance of the whole dataset after processing - all batches. - - Args: - size (int): Length of the entire validation dataset. When batch - size > 1, the dataloader may pad some data samples to make - sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data based on - this size. - - Returns: - dict: Evaluation metrics dict on the val dataset. The keys are the - names of the metrics, and the values are corresponding results. - """ - if len(self.results) == 0: - print_log( - f'{self.__class__.__name__} got empty `self.results`. Please ' - 'ensure that the processed results are properly added into ' - '`self.results` in `process` method.', - logger='current', - level=logging.WARNING) - - if self.collect_device == 'cpu': - results = collect_results( - self.results, - size, - self.collect_device, - tmpdir=self.collect_dir) - else: - results = collect_results(self.results, size, self.collect_device) - - if is_main_process(): - # cast all tensors in results list to cpu - results = _to_cpu(results) - _metrics = self.compute_metrics(results) # type: ignore - # Add prefix to metric names - if self.prefix: - _metrics = { - '/'.join((self.prefix, k)): v - for k, v in _metrics.items() - } - metrics = [_metrics] - else: - metrics = [None] # type: ignore - - broadcast_object_list(metrics) - - # reset the results list - self.results.clear() - return metrics[0] - - -@METRICS.register_module() -class DumpResults(BaseMetric): - """Dump model predictions to a pickle file for offline evaluation. - - Args: - out_file_path (str): Path of the dumped file. Must end with '.pkl' - or '.pickle'. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - collect_dir: (str, optional): Synchronize directory for collecting data - from different ranks. This argument should only be configured when - ``collect_device`` is 'cpu'. Defaults to None. - `New in version 0.7.3.` - """ - - def __init__(self, - out_file_path: str, - collect_device: str = 'cpu', - collect_dir: Optional[str] = None) -> None: - super().__init__( - collect_device=collect_device, collect_dir=collect_dir) - if not out_file_path.endswith(('.pkl', '.pickle')): - raise ValueError('The output file must be a pkl file.') - self.out_file_path = out_file_path - - def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: - """Transfer tensors in predictions to CPU.""" - self.results.extend(_to_cpu(predictions)) - - def compute_metrics(self, results: list) -> dict: - """Dump the prediction results to a pickle file.""" - dump(results, self.out_file_path) - print_log( - f'Results has been saved to {self.out_file_path}.', - logger='current') - return {} - - -def _to_cpu(data: Any) -> Any: - """Transfer all tensors and BaseDataElement to cpu.""" - if isinstance(data, (Tensor, BaseDataElement)): - return data.to('cpu') - elif isinstance(data, list): - return [_to_cpu(d) for d in data] - elif isinstance(data, tuple): - return tuple(_to_cpu(d) for d in data) - elif isinstance(data, dict): - return {k: _to_cpu(v) for k, v in data.items()} - else: - return data diff --git a/mmengine/evaluator/utils.py b/mmengine/evaluator/utils.py deleted file mode 100644 index 6981c881b9..0000000000 --- a/mmengine/evaluator/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict - - -def get_metric_value(indicator: str, metrics: Dict) -> Any: - """Get the metric value specified by an indicator, which can be either a - metric name or a full name with evaluator prefix. - - Args: - indicator (str): The metric indicator, which can be the metric name - (e.g. 'AP') or the full name with prefix (e.g. 'COCO/AP') - metrics (dict): The evaluation results output by the evaluator - - Returns: - Any: The specified metric value - """ - - if '/' in indicator: - # The indicator is a full name - if indicator in metrics: - return metrics[indicator] - else: - raise ValueError( - f'The indicator "{indicator}" can not match any metric in ' - f'{list(metrics.keys())}') - else: - # The indicator is metric name without prefix - matched = [k for k in metrics.keys() if k.split('/')[-1] == indicator] - - if not matched: - raise ValueError( - f'The indicator {indicator} can not match any metric in ' - f'{list(metrics.keys())}') - elif len(matched) > 1: - raise ValueError(f'The indicator "{indicator}" matches multiple ' - f'metrics {matched}') - else: - return metrics[matched[0]] diff --git a/mmengine/fileio/__init__.py b/mmengine/fileio/__init__.py deleted file mode 100644 index 81adcd4c02..0000000000 --- a/mmengine/fileio/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, - LocalBackend, MemcachedBackend, PetrelBackend, - register_backend) -from .file_client import FileClient, HardDiskBackend -from .handlers import (BaseFileHandler, JsonHandler, PickleHandler, - YamlHandler, register_handler) -from .io import (copy_if_symlink_fails, copyfile, copyfile_from_local, - copyfile_to_local, copytree, copytree_from_local, - copytree_to_local, dump, exists, generate_presigned_url, get, - get_file_backend, get_local_path, get_text, isdir, isfile, - join_path, list_dir_or_file, load, put, put_text, remove, - rmtree) -from .parse import dict_from_file, list_from_file - -__all__ = [ - 'BaseStorageBackend', 'FileClient', 'PetrelBackend', 'MemcachedBackend', - 'LmdbBackend', 'HardDiskBackend', 'LocalBackend', 'HTTPBackend', - 'copy_if_symlink_fails', 'copyfile', 'copyfile_from_local', - 'copyfile_to_local', 'copytree', 'copytree_from_local', - 'copytree_to_local', 'exists', 'generate_presigned_url', 'get', - 'get_file_backend', 'get_local_path', 'get_text', 'isdir', 'isfile', - 'join_path', 'list_dir_or_file', 'put', 'put_text', 'remove', 'rmtree', - 'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler', - 'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file', - 'register_backend' -] diff --git a/mmengine/fileio/backends/__init__.py b/mmengine/fileio/backends/__init__.py deleted file mode 100644 index fa0008977f..0000000000 --- a/mmengine/fileio/backends/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseStorageBackend -from .http_backend import HTTPBackend -from .lmdb_backend import LmdbBackend -from .local_backend import LocalBackend -from .memcached_backend import MemcachedBackend -from .petrel_backend import PetrelBackend -from .registry_utils import backends, prefix_to_backends, register_backend - -__all__ = [ - 'BaseStorageBackend', 'LocalBackend', 'HTTPBackend', 'LmdbBackend', - 'MemcachedBackend', 'PetrelBackend', 'register_backend', 'backends', - 'prefix_to_backends' -] diff --git a/mmengine/fileio/backends/base.py b/mmengine/fileio/backends/base.py deleted file mode 100644 index 9331edf598..0000000000 --- a/mmengine/fileio/backends/base.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from abc import ABCMeta, abstractmethod - -from mmengine.logging import print_log - - -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: :meth:`get()` and - :meth:`get_text()`. - - - :meth:`get()` reads the file as a byte stream. - - :meth:`get_text()` reads the file as texts. - """ - - # a flag to indicate whether the backend can create a symlink for a file - # This attribute will be deprecated in future. - _allow_symlink = False - - @property - def allow_symlink(self): - print_log( - 'allow_symlink will be deprecated in future', - logger='current', - level=logging.WARNING) - return self._allow_symlink - - @property - def name(self): - return self.__class__.__name__ - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass diff --git a/mmengine/fileio/backends/http_backend.py b/mmengine/fileio/backends/http_backend.py deleted file mode 100644 index b3e65bbdbb..0000000000 --- a/mmengine/fileio/backends/http_backend.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import tempfile -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Union -from urllib.request import urlopen - -from .base import BaseStorageBackend - - -class HTTPBackend(BaseStorageBackend): - """HTTP and HTTPS storage bachend.""" - - def get(self, filepath: str) -> bytes: - """Read bytes from a given ``filepath``. - - Args: - filepath (str): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get('http://path/of/file') - b'hello world' - """ - return urlopen(filepath).read() - - def get_text(self, filepath, encoding='utf-8') -> str: - """Read text from a given ``filepath``. - - Args: - filepath (str): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get_text('http://path/of/file') - 'hello world' - """ - return urlopen(filepath).read().decode(encoding) - - @contextmanager - def get_local_path( - self, filepath: str) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = HTTPBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with backend.get_local_path('http://path/of/file') as path: - ... # do something here - """ - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py deleted file mode 100644 index eb47923e56..0000000000 --- a/mmengine/fileio/backends/lmdb_backend.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from pathlib import Path -from typing import Union - -from .base import BaseStorageBackend - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_path (str): Lmdb database path. - readonly (bool): Lmdb environment parameter. If True, disallow any - write operations. Defaults to True. - lock (bool): Lmdb environment parameter. If False, when concurrent - access occurs, do not lock the database. Defaults to False. - readahead (bool): Lmdb environment parameter. If False, disable the OS - filesystem readahead mechanism, which may improve random read - performance when a database is larger than RAM. Defaults to False. - **kwargs: Keyword arguments passed to `lmdb.open`. - - Attributes: - db_path (str): Lmdb database path. - """ - - def __init__(self, - db_path, - readonly=True, - lock=False, - readahead=False, - **kwargs): - try: - import lmdb # noqa: F401 - except ImportError: - raise ImportError( - 'Please run "pip install lmdb" to enable LmdbBackend.') - - self.db_path = str(db_path) - self.readonly = readonly - self.lock = lock - self.readahead = readahead - self.kwargs = kwargs - self._client = None - - def get(self, filepath: Union[str, Path]) -> bytes: - """Get values according to the filepath. - - Args: - filepath (str or Path): Here, filepath is the lmdb key. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = LmdbBackend('path/to/lmdb') - >>> backend.get('key') - b'hello world' - """ - if self._client is None: - self._client = self._get_client() - - filepath = str(filepath) - with self._client.begin(write=False) as txn: - value_buf = txn.get(filepath.encode('ascii')) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError - - def _get_client(self): - import lmdb - - return lmdb.open( - self.db_path, - readonly=self.readonly, - lock=self.lock, - readahead=self.readahead, - **self.kwargs) - - def __del__(self): - if self._client is not None: - self._client.close() diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py deleted file mode 100644 index c7d5f04621..0000000000 --- a/mmengine/fileio/backends/local_backend.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -import shutil -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Iterator, Optional, Tuple, Union - -import mmengine -from .base import BaseStorageBackend - - -class LocalBackend(BaseStorageBackend): - """Raw local storage backend.""" - - _allow_symlink = True - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - with open(filepath, 'rb') as f: - value = f.read() - return value - - def get_text(self, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - with open(filepath, encoding=encoding) as f: - text = f.read() - return text - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put(b'hello world', filepath) - """ - mmengine.mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, 'wb') as f: - f.write(obj) - - def put_text(self, - obj: str, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put_text('hello world', filepath) - """ - mmengine.mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, 'w', encoding=encoding) as f: - f.write(obj) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.exists(filepath) - True - """ - return osp.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/dir' - >>> backend.isdir(filepath) - True - """ - return osp.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.isfile(filepath) - True - """ - return osp.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - - Examples: - >>> backend = LocalBackend() - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> backend.join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - # TODO, if filepath or filepaths are Path, should return Path - return osp.join(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, Path], - ) -> Generator[Union[str, Path], None, None]: - """Only for unified API and do nothing. - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> backend = LocalBackend() - >>> with backend.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - """ - yield filepath - - def copyfile( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> backend.copyfile(src, dst) - '/path1/of/dir/file' - """ - return shutil.copy(src, dst) - - def copytree( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - TODO: Whether to support dirs_exist_ok parameter. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree(src, dst) - '/path/of/dir2' - """ - return shutil.copytree(src, dst) - - def copyfile_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a local file src to dst and return the destination file. Same - as :meth:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_from_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_from_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. Same as - :meth:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def copyfile_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy the file src to local dst and return the destination file. Same - as :meth:`copyfile`. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_to_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_to_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.remove(filepath) - """ - if not self.exists(filepath): - raise FileNotFoundError(f'filepath {filepath} does not exist') - - if self.isdir(filepath): - raise IsADirectoryError('filepath should be a file') - - os.remove(filepath) - - def rmtree(self, dir_path: Union[str, Path]) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> dir_path = '/path/of/dir' - >>> backend.rmtree(dir_path) - """ - shutil.rmtree(dir_path) - - def copy_if_symlink_fails( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directly copy src - to dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - - Returns: - bool: Return True if successfully create a symbolic link pointing - to src. Otherwise, return False. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> backend.copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> backend.copy_if_symlink_fails(src, dst) - True - """ - try: - os.symlink(src, dst) - return True - except Exception: - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = LocalBackend() - >>> dir_path = '/path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ # noqa: E501 - if list_dir and suffix is not None: - raise TypeError('`suffix` should be None when `list_dir` is True') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('`suffix` must be a string or tuple of strings') - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - if (suffix is None - or rel_path.endswith(suffix)) and list_file: - yield rel_path - elif osp.isdir(entry.path): - if list_dir: - rel_dir = osp.relpath(entry.path, root) - yield rel_dir - if recursive: - yield from _list_dir_or_file(entry.path, list_dir, - list_file, suffix, - recursive) - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) diff --git a/mmengine/fileio/backends/memcached_backend.py b/mmengine/fileio/backends/memcached_backend.py deleted file mode 100644 index 2458e7c6ec..0000000000 --- a/mmengine/fileio/backends/memcached_backend.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from pathlib import Path -from typing import Union - -from .base import BaseStorageBackend - - -class MemcachedBackend(BaseStorageBackend): - """Memcached storage backend. - - Attributes: - server_list_cfg (str): Config file for memcached server list. - client_cfg (str): Config file for memcached client. - sys_path (str, optional): Additional path to be appended to `sys.path`. - Defaults to None. - """ - - def __init__(self, server_list_cfg, client_cfg, sys_path=None): - if sys_path is not None: - import sys - sys.path.append(sys_path) - try: - import mc - except ImportError: - raise ImportError( - 'Please install memcached to enable MemcachedBackend.') - - self.server_list_cfg = server_list_cfg - self.client_cfg = client_cfg - self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, - self.client_cfg) - # mc.pyvector servers as a point which points to a memory cache - self._mc_buffer = mc.pyvector() - - def get(self, filepath: Union[str, Path]): - """Get values according to the filepath. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> server_list_cfg = '/path/of/server_list.conf' - >>> client_cfg = '/path/of/mc.conf' - >>> backend = MemcachedBackend(server_list_cfg, client_cfg) - >>> backend.get('/path/of/file') - b'hello world' - """ - filepath = str(filepath) - import mc - self._client.Get(filepath, self._mc_buffer) - value_buf = mc.ConvertBuffer(self._mc_buffer) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py deleted file mode 100644 index 3994372f66..0000000000 --- a/mmengine/fileio/backends/petrel_backend.py +++ /dev/null @@ -1,771 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -import re -import tempfile -from contextlib import contextmanager -from pathlib import Path -from shutil import SameFileError -from typing import Generator, Iterator, Optional, Tuple, Union - -import mmengine -from mmengine.utils import has_method -from .base import BaseStorageBackend - - -class PetrelBackend(BaseStorageBackend): - """Petrel storage backend (for internal usage). - - PetrelBackend supports reading and writing data to multiple clusters. - If the file path contains the cluster name, PetrelBackend will read data - from specified cluster or write data to it. Otherwise, PetrelBackend will - access the default cluster. - - Args: - path_mapping (dict, optional): Path mapping dict from local path to - Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in - ``filepath`` will be replaced by ``dst``. Defaults to None. - enable_mc (bool, optional): Whether to enable memcached support. - Defaults to True. - conf_path (str, optional): Config path of Petrel client. Default: None. - `New in version 0.3.3`. - - Examples: - >>> backend = PetrelBackend() - >>> filepath1 = 'petrel://path/of/file' - >>> filepath2 = 'cluster-name:petrel://path/of/file' - >>> backend.get(filepath1) # get data from default cluster - >>> client.get(filepath2) # get data from 'cluster-name' cluster - """ - - def __init__(self, - path_mapping: Optional[dict] = None, - enable_mc: bool = True, - conf_path: Optional[str] = None): - try: - from petrel_client import client - except ImportError: - raise ImportError('Please install petrel_client to enable ' - 'PetrelBackend.') - - self._client = client.Client(conf_path=conf_path, enable_mc=enable_mc) - assert isinstance(path_mapping, dict) or path_mapping is None - self.path_mapping = path_mapping - - def _map_path(self, filepath: Union[str, Path]) -> str: - """Map ``filepath`` to a string path whose prefix will be replaced by - :attr:`self.path_mapping`. - - Args: - filepath (str or Path): Path to be mapped. - """ - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v, 1) - return filepath - - def _format_path(self, filepath: str) -> str: - """Convert a ``filepath`` to standard format of petrel oss. - - If the ``filepath`` is concatenated by ``os.path.join``, in a Windows - environment, the ``filepath`` will be the format of - 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the - above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. - - Args: - filepath (str): Path to be formatted. - """ - return re.sub(r'\\+', '/', filepath) - - def _replace_prefix(self, filepath: Union[str, Path]) -> str: - filepath = str(filepath) - return filepath.replace('petrel://', 's3://') - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Return bytes read from filepath. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - value = self._client.Get(filepath) - return value - - def get_text( - self, - filepath: Union[str, Path], - encoding: str = 'utf-8', - ) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - return str(self.get(filepath), encoding=encoding) - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write bytes to a given ``filepath``. - - Args: - obj (bytes): Data to be saved. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.put(b'hello world', filepath) - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.put(filepath, obj) - - def put_text( - self, - obj: str, - filepath: Union[str, Path], - encoding: str = 'utf-8', - ) -> None: - """Write text to a given ``filepath``. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to encode the ``obj``. - Defaults to 'utf-8'. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.put_text('hello world', filepath) - """ - self.put(bytes(obj, encoding=encoding), filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.exists(filepath) - True - """ - if not (has_method(self._client, 'contains') - and has_method(self._client, 'isdir')): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `contains` and `isdir` methods, please use a higher' - 'version or dev branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.contains(filepath) or self._client.isdir(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/dir' - >>> backend.isdir(filepath) - True - """ - if not has_method(self._client, 'isdir'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `isdir` method, please use a higher version or dev' - ' branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.isfile(filepath) - True - """ - if not has_method(self._client, 'contains'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `contains` method, please use a higher version or ' - 'dev branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.contains(filepath) - - def join_path( - self, - filepath: Union[str, Path], - *filepaths: Union[str, Path], - ) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result after concatenation. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.join_path(filepath, 'another/path') - 'petrel://path/of/file/another/path' - >>> backend.join_path(filepath, '/another/path') - 'petrel://path/of/file/another/path' - """ - filepath = self._format_path(self._map_path(filepath)) - if filepath.endswith('/'): - filepath = filepath[:-1] - formatted_paths = [filepath] - for path in filepaths: - formatted_path = self._format_path(self._map_path(path)) - formatted_paths.append(formatted_path.lstrip('/')) - - return '/'.join(formatted_paths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, Path], - ) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str or Path): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = PetrelBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> filepath = 'petrel://path/of/file' - >>> with backend.get_local_path(filepath) as path: - ... # do something here - """ - assert self.isfile(filepath) - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - - def copyfile( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = PetrelBackend() - >>> # dst is a file - >>> src = 'petrel://path/of/file' - >>> dst = 'petrel://path/of/file1' - >>> backend.copyfile(src, dst) - 'petrel://path/of/file1' - - >>> # dst is a directory - >>> dst = 'petrel://path/of/dir' - >>> backend.copyfile(src, dst) - 'petrel://path/of/dir/file' - """ - src = self._format_path(self._map_path(src)) - dst = self._format_path(self._map_path(dst)) - if self.isdir(dst): - dst = self.join_path(dst, src.split('/')[-1]) - - if src == dst: - raise SameFileError('src and dst should not be same') - - self.put(self.get(src), dst) - return dst - - def copytree( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'petrel://path/of/dir' - >>> dst = 'petrel://path/of/dir1' - >>> backend.copytree(src, dst) - 'petrel://path/of/dir1' - """ - src = self._format_path(self._map_path(src)) - dst = self._format_path(self._map_path(dst)) - - if self.exists(dst): - raise FileExistsError('dst should not exist') - - for path in self.list_dir_or_file(src, list_dir=False, recursive=True): - src_path = self.join_path(src, path) - dst_path = self.join_path(dst, path) - self.put(self.get(src_path), dst_path) - - return dst - - def copyfile_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Upload a local file src to dst and return the destination file. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = PetrelBackend() - >>> # dst is a file - >>> src = 'path/of/your/file' - >>> dst = 'petrel://path/of/file1' - >>> backend.copyfile_from_local(src, dst) - 'petrel://path/of/file1' - - >>> # dst is a directory - >>> dst = 'petrel://path/of/dir' - >>> backend.copyfile_from_local(src, dst) - 'petrel://path/of/dir/file' - """ - dst = self._format_path(self._map_path(dst)) - if self.isdir(dst): - dst = self.join_path(dst, osp.basename(src)) - - with open(src, 'rb') as f: - self.put(f.read(), dst) - - return dst - - def copytree_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'path/of/your/dir' - >>> dst = 'petrel://path/of/dir1' - >>> backend.copytree_from_local(src, dst) - 'petrel://path/of/dir1' - """ - dst = self._format_path(self._map_path(dst)) - if self.exists(dst): - raise FileExistsError('dst should not exist') - - src = str(src) - - for cur_dir, _, files in os.walk(src): - for f in files: - src_path = osp.join(cur_dir, f) - dst_path = self.join_path(dst, src_path.replace(src, '')) - self.copyfile_from_local(src_path, dst_path) - - return dst - - def copyfile_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> Union[str, Path]: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = PetrelBackend() - >>> # dst is a file - >>> src = 'petrel://path/of/file' - >>> dst = 'path/of/your/file' - >>> backend.copyfile_to_local(src, dst) - 'path/of/your/file' - - >>> # dst is a directory - >>> dst = 'path/of/your/dir' - >>> backend.copyfile_to_local(src, dst) - 'path/of/your/dir/file' - """ - if osp.isdir(dst): - basename = osp.basename(src) - if isinstance(dst, str): - dst = osp.join(dst, basename) - else: - assert isinstance(dst, Path) - dst = dst / basename - - with open(dst, 'wb') as f: - f.write(self.get(src)) - - return dst - - def copytree_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'petrel://path/of/dir' - >>> dst = 'path/of/your/dir' - >>> backend.copytree_to_local(src, dst) - 'path/of/your/dir' - """ - for path in self.list_dir_or_file(src, list_dir=False, recursive=True): - dst_path = osp.join(dst, path) - mmengine.mkdir_or_exist(osp.dirname(dst_path)) - with open(dst_path, 'wb') as f: - f.write(self.get(self.join_path(src, path))) - - return dst - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.remove(filepath) - """ - if not has_method(self._client, 'delete'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `delete` method, please use a higher version or dev ' - 'branch instead.') - - if not self.exists(filepath): - raise FileNotFoundError(f'filepath {filepath} does not exist') - - if self.isdir(filepath): - raise IsADirectoryError('filepath should be a file') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.delete(filepath) - - def rmtree(self, dir_path: Union[str, Path]) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> backend = PetrelBackend() - >>> dir_path = 'petrel://path/of/dir' - >>> backend.rmtree(dir_path) - """ - for path in self.list_dir_or_file( - dir_path, list_dir=False, recursive=True): - filepath = self.join_path(dir_path, path) - self.remove(filepath) - - def copy_if_symlink_fails( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> bool: - """Create a symbolic link pointing to src named dst. - - Directly copy src to dst because PetrelBacekend does not support create - a symbolic link. - - Args: - src (str or Path): A file or directory to be copied. - dst (str or Path): Copy a file or directory to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - bool: Return False because PetrelBackend does not support create - a symbolic link. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'petrel://path/of/file' - >>> dst = 'petrel://path/of/your/file' - >>> backend.copy_if_symlink_fails(src, dst) - False - >>> src = 'petrel://path/of/dir' - >>> dst = 'petrel://path/of/your/dir' - >>> backend.copy_if_symlink_fails(src, dst) - False - """ - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - Petrel has no concept of directories but it simulates the directory - hierarchy in the filesystem through public prefixes. In addition, - if the returned path ends with '/', it means the path is a public - prefix which is a logical directory. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - In addition, the returned path of directory will not contains the - suffix '/' which is consistent with other backends. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = PetrelBackend() - >>> dir_path = 'petrel://path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ # noqa: E501 - if not has_method(self._client, 'list'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `list` method, please use a higher version or dev' - ' branch instead.') - - dir_path = self._map_path(dir_path) - dir_path = self._format_path(dir_path) - dir_path = self._replace_prefix(dir_path) - if list_dir and suffix is not None: - raise TypeError( - '`list_dir` should be False when `suffix` is not None') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('`suffix` must be a string or tuple of strings') - - # Petrel's simulated directory hierarchy assumes that directory paths - # should end with `/` - if not dir_path.endswith('/'): - dir_path += '/' - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive): - for path in self._client.list(dir_path): - # the `self.isdir` is not used here to determine whether path - # is a directory, because `self.isdir` relies on - # `self._client.list` - if path.endswith('/'): # a directory path - next_dir_path = self.join_path(dir_path, path) - if list_dir: - # get the relative path and exclude the last - # character '/' - rel_dir = next_dir_path[len(root):-1] - yield rel_dir - if recursive: - yield from _list_dir_or_file(next_dir_path, list_dir, - list_file, suffix, - recursive) - else: # a file path - absolute_path = self.join_path(dir_path, path) - rel_path = absolute_path[len(root):] - if (suffix is None - or rel_path.endswith(suffix)) and list_file: - yield rel_path - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) - - def generate_presigned_url(self, - url: str, - client_method: str = 'get_object', - expires_in: int = 3600) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on Petrel backend. - - Note: - Now only work on Petrel backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Default: 'get_object'. - expires_in (int): expires, in seconds. Default: 3600. - - Returns: - str: Generated presigned url. - """ - return self._client.generate_presigned_url(url, client_method, - expires_in) diff --git a/mmengine/fileio/backends/registry_utils.py b/mmengine/fileio/backends/registry_utils.py deleted file mode 100644 index 4578a4ca76..0000000000 --- a/mmengine/fileio/backends/registry_utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -from typing import Optional, Type, Union - -from .base import BaseStorageBackend -from .http_backend import HTTPBackend -from .lmdb_backend import LmdbBackend -from .local_backend import LocalBackend -from .memcached_backend import MemcachedBackend -from .petrel_backend import PetrelBackend - -backends: dict = {} -prefix_to_backends: dict = {} - - -def _register_backend(name: str, - backend: Type[BaseStorageBackend], - force: bool = False, - prefixes: Union[str, list, tuple, None] = None): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (BaseStorageBackend): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - """ - global backends, prefix_to_backends - - if not isinstance(name, str): - raise TypeError('the backend name should be a string, ' - f'but got {type(name)}') - - if not inspect.isclass(backend): - raise TypeError(f'backend should be a class, but got {type(backend)}') - if not issubclass(backend, BaseStorageBackend): - raise TypeError( - f'backend {backend} is not a subclass of BaseStorageBackend') - - if name in backends and not force: - raise ValueError(f'{name} is already registered as a storage backend, ' - 'add "force=True" if you want to override it') - backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - - for prefix in prefixes: - if prefix in prefix_to_backends and not force: - raise ValueError( - f'{prefix} is already registered as a storage backend,' - ' add "force=True" if you want to override it') - - prefix_to_backends[prefix] = backend - - -def register_backend(name: str, - backend: Optional[Type[BaseStorageBackend]] = None, - force: bool = False, - prefixes: Union[str, list, tuple, None] = None): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - - This method can be used as a normal method or a decorator. - - Examples: - - >>> class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - >>> register_backend('new', NewBackend) - - >>> @register_backend('new') - ... class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - """ - if backend is not None: - _register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - _register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - -register_backend('local', LocalBackend, prefixes='') -register_backend('memcached', MemcachedBackend) -register_backend('lmdb', LmdbBackend) -# To avoid breaking backward Compatibility, 's3' is also used as a -# prefix for PetrelBackend -register_backend('petrel', PetrelBackend, prefixes=['petrel', 's3']) -register_backend('http', HTTPBackend, prefixes=['http', 'https']) diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py deleted file mode 100644 index 61551d3d1d..0000000000 --- a/mmengine/fileio/file_client.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Generator, Iterator, Optional, Tuple, Union - -from mmengine.logging import print_log -from mmengine.utils import is_filepath -from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, - LocalBackend, MemcachedBackend, PetrelBackend) - - -class HardDiskBackend(LocalBackend): - """Raw hard disks storage backend.""" - - def __init__(self) -> None: - print_log( - '"HardDiskBackend" is the alias of "LocalBackend" ' - 'and the former will be deprecated in future.', - logger='current', - level=logging.WARNING) - - @property - def name(self): - return self.__class__.__name__ - - -class FileClient: - """A general file client to access files in different backends. - - The client loads a file or text in a specified backend from its path - and returns it as a binary or text file. There are two ways to choose a - backend, the name of backend and the prefix of path. Although both of them - can be used to choose a storage backend, ``backend`` has a higher priority - that is if they are all set, the storage backend will be chosen by the - backend argument. If they are all `None`, the disk backend will be chosen. - Note that It can also register other backend accessor with a given name, - prefixes, and backend class. In addition, We use the singleton pattern to - avoid repeated object creation. If the arguments are the same, the same - object will be returned. - - Warning: - `FileClient` will be deprecated in future. Please use io functions - in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io - - Args: - backend (str, optional): The storage backend type. Options are "disk", - "memcached", "lmdb", "http" and "petrel". Defaults to None. - prefix (str, optional): The prefix of the registered storage backend. - Options are "s3", "http", "https". Defaults to None. - - Examples: - >>> # only set backend - >>> file_client = FileClient(backend='petrel') - >>> # only set prefix - >>> file_client = FileClient(prefix='s3') - >>> # set both backend and prefix but use backend to choose client - >>> file_client = FileClient(backend='petrel', prefix='s3') - >>> # if the arguments are the same, the same object is returned - >>> file_client1 = FileClient(backend='petrel') - >>> file_client1 is file_client - True - - Attributes: - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - 'disk': HardDiskBackend, - 'memcached': MemcachedBackend, - 'lmdb': LmdbBackend, - 'petrel': PetrelBackend, - 'http': HTTPBackend, - } - - _prefix_to_backends: dict = { - 's3': PetrelBackend, - 'petrel': PetrelBackend, - 'http': HTTPBackend, - 'https': HTTPBackend, - } - - _instances: dict = {} - - client: Any - - def __new__(cls, backend=None, prefix=None, **kwargs): - print_log( - '"FileClient" will be deprecated in future. Please use io ' - 'functions in ' - 'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io', # noqa: E501 - logger='current', - level=logging.WARNING) - if backend is None and prefix is None: - backend = 'disk' - if backend is not None and backend not in cls._backends: - raise ValueError( - f'Backend {backend} is not supported. Currently supported ones' - f' are {list(cls._backends.keys())}') - if prefix is not None and prefix not in cls._prefix_to_backends: - raise ValueError( - f'prefix {prefix} is not supported. Currently supported ones ' - f'are {list(cls._prefix_to_backends.keys())}') - - # concatenate the arguments to a unique key for determining whether - # objects with the same arguments were created - arg_key = f'{backend}:{prefix}' - for key, value in kwargs.items(): - arg_key += f':{key}:{value}' - - # if a backend was overridden, it will create a new object - if arg_key in cls._instances: - _instance = cls._instances[arg_key] - else: - # create a new object and put it to _instance - _instance = super().__new__(cls) - if backend is not None: - _instance.client = cls._backends[backend](**kwargs) - else: - _instance.client = cls._prefix_to_backends[prefix](**kwargs) - - cls._instances[arg_key] = _instance - - return _instance - - @property - def name(self): - return self.client.name - - @property - def allow_symlink(self): - return self.client.allow_symlink - - @staticmethod - def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: - """Parse the prefix of a uri. - - Args: - uri (str | Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> FileClient.parse_uri_prefix('s3://path/of/your/file') - 's3' - - Returns: - str | None: Return the prefix of uri if the uri contains '://' else - ``None``. - """ - assert is_filepath(uri) - uri = str(uri) - if '://' not in uri: - return None - else: - prefix, _ = uri.split('://') - # In the case of PetrelBackend, the prefix may contains the cluster - # name like clusterName:s3 - if ':' in prefix: - _, prefix = prefix.split(':') - return prefix - - @classmethod - def infer_client(cls, - file_client_args: Optional[dict] = None, - uri: Optional[Union[str, Path]] = None) -> 'FileClient': - """Infer a suitable file client based on the URI and arguments. - - Args: - file_client_args (dict, optional): Arguments to instantiate a - FileClient. Defaults to None. - uri (str | Path, optional): Uri to be parsed that contains the file - prefix. Defaults to None. - - Examples: - >>> uri = 's3://path/of/your/file' - >>> file_client = FileClient.infer_client(uri=uri) - >>> file_client_args = {'backend': 'petrel'} - >>> file_client = FileClient.infer_client(file_client_args) - - Returns: - FileClient: Instantiated FileClient object. - """ - assert file_client_args is not None or uri is not None - if file_client_args is None: - file_prefix = cls.parse_uri_prefix(uri) # type: ignore - return cls(prefix=file_prefix) - else: - return cls(**file_client_args) - - @classmethod - def _register_backend(cls, name, backend, force=False, prefixes=None): - if not isinstance(name, str): - raise TypeError('the backend name should be a string, ' - f'but got {type(name)}') - if not inspect.isclass(backend): - raise TypeError( - f'backend should be a class but got {type(backend)}') - if not issubclass(backend, BaseStorageBackend): - raise TypeError( - f'backend {backend} is not a subclass of BaseStorageBackend') - if not force and name in cls._backends: - raise KeyError( - f'{name} is already registered as a storage backend, ' - 'add "force=True" if you want to override it') - - if name in cls._backends and force: - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, cls._backends[name]): - cls._instances.pop(arg_key) - cls._backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - for prefix in prefixes: - if prefix not in cls._prefix_to_backends: - cls._prefix_to_backends[prefix] = backend - elif (prefix in cls._prefix_to_backends) and force: - overridden_backend = cls._prefix_to_backends[prefix] - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, overridden_backend): - cls._instances.pop(arg_key) - else: - raise KeyError( - f'{prefix} is already registered as a storage backend,' - ' add "force=True" if you want to override it') - - @classmethod - def register_backend(cls, name, backend=None, force=False, prefixes=None): - """Register a backend to FileClient. - - This method can be used as a normal class method or a decorator. - - .. code-block:: python - - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - FileClient.register_backend('new', NewBackend) - - or - - .. code-block:: python - - @FileClient.register_backend('new') - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool, optional): Whether to override the backend if the name - has already been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefixes - of the registered storage backend. Defaults to None. - `New in version 1.3.15.` - """ - if backend is not None: - cls._register_backend( - name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - cls._register_backend( - name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: - """Read data from a given ``filepath`` with 'rb' mode. - - Note: - There are two types of return values for ``get``, one is ``bytes`` - and the other is ``memoryview``. The advantage of using memoryview - is that you can avoid copying, and if you want to convert it to - ``bytes``, you can use ``.tobytes()``. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes | memoryview: Expected bytes object or a memory view of the - bytes object. - """ - return self.client.get(filepath) - - def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - return self.client.get_text(filepath, encoding) - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of ``filepath`` - does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - """ - self.client.put(obj, filepath) - - def put_text(self, obj: str, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - `filepath`. Defaults to 'utf-8'. - """ - self.client.put_text(obj, filepath) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - """ - self.client.remove(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - return self.client.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - return self.client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - return self.client.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - """ - return self.client.join_path(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, - Path]) -> Generator[Union[str, Path], None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself. - - .. warning:: - ``get_local_path`` is an experimental interface that may change in - the future. - - Args: - filepath (str or Path): Path to be read data. - - Examples: - >>> file_client = FileClient(prefix='s3') - >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - - Yields: - Iterable[str]: Only yield one path. - """ - with self.client.get_local_path(str(filepath)) as local_path: - yield local_path - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, - suffix, recursive) diff --git a/mmengine/fileio/handlers/__init__.py b/mmengine/fileio/handlers/__init__.py deleted file mode 100644 index 391a60c36b..0000000000 --- a/mmengine/fileio/handlers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseFileHandler -from .json_handler import JsonHandler -from .pickle_handler import PickleHandler -from .registry_utils import file_handlers, register_handler -from .yaml_handler import YamlHandler - -__all__ = [ - 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', - 'register_handler', 'file_handlers' -] diff --git a/mmengine/fileio/handlers/base.py b/mmengine/fileio/handlers/base.py deleted file mode 100644 index 288878bc57..0000000000 --- a/mmengine/fileio/handlers/base.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - - -class BaseFileHandler(metaclass=ABCMeta): - # `str_like` is a flag to indicate whether the type of file object is - # str-like object or bytes-like object. Pickle only processes bytes-like - # objects but json only processes str-like object. If it is str-like - # object, `StringIO` will be used to process the buffer. - str_like = True - - @abstractmethod - def load_from_fileobj(self, file, **kwargs): - pass - - @abstractmethod - def dump_to_fileobj(self, obj, file, **kwargs): - pass - - @abstractmethod - def dump_to_str(self, obj, **kwargs): - pass - - def load_from_path(self, filepath, mode='r', **kwargs): - with open(filepath, mode) as f: - return self.load_from_fileobj(f, **kwargs) - - def dump_to_path(self, obj, filepath, mode='w', **kwargs): - with open(filepath, mode) as f: - self.dump_to_fileobj(obj, f, **kwargs) diff --git a/mmengine/fileio/handlers/json_handler.py b/mmengine/fileio/handlers/json_handler.py deleted file mode 100644 index 18d4f15f74..0000000000 --- a/mmengine/fileio/handlers/json_handler.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json - -import numpy as np - -from .base import BaseFileHandler - - -def set_default(obj): - """Set default json values for non-serializable values. - - It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. - It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, - etc.) into plain numbers of plain python built-in types. - """ - if isinstance(obj, (set, range)): - return list(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - raise TypeError(f'{type(obj)} is unsupported for json dump') - - -class JsonHandler(BaseFileHandler): - - def load_from_fileobj(self, file): - return json.load(file) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault('default', set_default) - json.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault('default', set_default) - return json.dumps(obj, **kwargs) diff --git a/mmengine/fileio/handlers/pickle_handler.py b/mmengine/fileio/handlers/pickle_handler.py deleted file mode 100644 index 073856fd25..0000000000 --- a/mmengine/fileio/handlers/pickle_handler.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pickle - -from .base import BaseFileHandler - - -class PickleHandler(BaseFileHandler): - - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return pickle.load(file, **kwargs) - - def load_from_path(self, filepath, **kwargs): - return super().load_from_path(filepath, mode='rb', **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault('protocol', 2) - return pickle.dumps(obj, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault('protocol', 2) - pickle.dump(obj, file, **kwargs) - - def dump_to_path(self, obj, filepath, **kwargs): - super().dump_to_path(obj, filepath, mode='wb', **kwargs) diff --git a/mmengine/fileio/handlers/registry_utils.py b/mmengine/fileio/handlers/registry_utils.py deleted file mode 100644 index 106fc881f2..0000000000 --- a/mmengine/fileio/handlers/registry_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils import is_list_of -from .base import BaseFileHandler -from .json_handler import JsonHandler -from .pickle_handler import PickleHandler -from .yaml_handler import YamlHandler - -file_handlers = { - 'json': JsonHandler(), - 'yaml': YamlHandler(), - 'yml': YamlHandler(), - 'pickle': PickleHandler(), - 'pkl': PickleHandler(), -} - - -def _register_handler(handler, file_formats): - """Register a handler for some file extensions. - - Args: - handler (:obj:`BaseFileHandler`): Handler to be registered. - file_formats (str or list[str]): File formats to be handled by this - handler. - """ - if not isinstance(handler, BaseFileHandler): - raise TypeError( - f'handler must be a child of BaseFileHandler, not {type(handler)}') - if isinstance(file_formats, str): - file_formats = [file_formats] - if not is_list_of(file_formats, str): - raise TypeError('file_formats must be a str or a list of str') - for ext in file_formats: - file_handlers[ext] = handler - - -def register_handler(file_formats, **kwargs): - - def wrap(cls): - _register_handler(cls(**kwargs), file_formats) - return cls - - return wrap diff --git a/mmengine/fileio/handlers/yaml_handler.py b/mmengine/fileio/handlers/yaml_handler.py deleted file mode 100644 index 22c2607ae4..0000000000 --- a/mmengine/fileio/handlers/yaml_handler.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import yaml - -try: - from yaml import CDumper as Dumper # type: ignore - from yaml import CLoader as Loader # type: ignore -except ImportError: - from yaml import Loader, Dumper # type: ignore - -from .base import BaseFileHandler # isort:skip - - -class YamlHandler(BaseFileHandler): - - def load_from_fileobj(self, file, **kwargs): - kwargs.setdefault('Loader', Loader) - return yaml.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault('Dumper', Dumper) - yaml.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault('Dumper', Dumper) - return yaml.dump(obj, **kwargs) diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py deleted file mode 100644 index fdeb4dc6df..0000000000 --- a/mmengine/fileio/io.py +++ /dev/null @@ -1,940 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""This module provides unified file I/O related functions, which support -operating I/O with different file backends based on the specified filepath or -backend_args. - -MMEngine currently supports five file backends: - -- LocalBackend -- PetrelBackend -- HTTPBackend -- LmdbBackend -- MemcacheBackend - -Note that this module provide a union of all of the above file backends so -NotImplementedError will be raised if the interface in the file backend is not -implemented. - -There are two ways to call a method of a file backend: - -- Initialize a file backend with ``get_file_backend`` and call its methods. -- Directory call unified I/O functions, which will call ``get_file_backend`` - first and then call the corresponding backend method. - -Examples: - >>> # Initialize a file backend and call its methods - >>> import mmengine.fileio as fileio - >>> backend = fileio.get_file_backend(backend_args={'backend': 'petrel'}) - >>> backend.get('s3://path/of/your/file') - - >>> # Directory call unified I/O functions - >>> fileio.get('s3://path/of/your/file') -""" -import json -import warnings -from contextlib import contextmanager -from io import BytesIO, StringIO -from pathlib import Path -from typing import Generator, Iterator, Optional, Tuple, Union - -from mmengine.utils import is_filepath, is_str -from .backends import backends, prefix_to_backends -from .file_client import FileClient -# file_handlers and register_handler had been moved to -# mmengine/fileio/handlers/registry_utis. Import them -# in this file to keep backward compatibility. -from .handlers import file_handlers, register_handler # noqa: F401 - -backend_instances: dict = {} - - -def _parse_uri_prefix(uri: Union[str, Path]) -> str: - """Parse the prefix of uri. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> _parse_uri_prefix('/home/path/of/your/file') - '' - >>> _parse_uri_prefix('s3://path/of/your/file') - 's3' - >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') - 's3' - - Returns: - str: Return the prefix of uri if the uri contains '://'. Otherwise, - return ''. - """ - assert is_filepath(uri) - uri = str(uri) - # if uri does not contains '://', the uri will be handled by - # LocalBackend by default - if '://' not in uri: - return '' - else: - prefix, _ = uri.split('://') - # In the case of PetrelBackend, the prefix may contain the cluster - # name like clusterName:s3://path/of/your/file - if ':' in prefix: - _, prefix = prefix.split(':') - return prefix - - -def _get_file_backend(prefix: str, backend_args: dict): - """Return a file backend based on the prefix or backend_args. - - Args: - prefix (str): Prefix of uri. - backend_args (dict): Arguments to instantiate the corresponding - backend. - """ - # backend name has a higher priority - if 'backend' in backend_args: - # backend_args should not be modified - backend_args_bak = backend_args.copy() - backend_name = backend_args_bak.pop('backend') - backend = backends[backend_name](**backend_args_bak) - else: - backend = prefix_to_backends[prefix](**backend_args) - return backend - - -def get_file_backend( - uri: Union[str, Path, None] = None, - *, - backend_args: Optional[dict] = None, - enable_singleton: bool = False, -): - """Return a file backend based on the prefix of uri or backend_args. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - enable_singleton (bool): Whether to enable the singleton pattern. - If it is True, the backend created will be reused if the - signature is same with the previous one. Defaults to False. - - Returns: - BaseStorageBackend: Instantiated Backend object. - - Examples: - >>> # get file backend based on the prefix of uri - >>> uri = 's3://path/of/your/file' - >>> backend = get_file_backend(uri) - >>> # get file backend based on the backend_args - >>> backend = get_file_backend(backend_args={'backend': 'petrel'}) - >>> # backend name has a higher priority if 'backend' in backend_args - >>> backend = get_file_backend(uri, backend_args={'backend': 'petrel'}) - """ - global backend_instances - - if backend_args is None: - backend_args = {} - - if uri is None and 'backend' not in backend_args: - raise ValueError( - 'uri should not be None when "backend" does not exist in ' - 'backend_args') - - if uri is not None: - prefix = _parse_uri_prefix(uri) - else: - prefix = '' - - if enable_singleton: - # TODO: whether to pass sort_key to json.dumps - unique_key = f'{prefix}:{json.dumps(backend_args)}' - if unique_key in backend_instances: - return backend_instances[unique_key] - - backend = _get_file_backend(prefix, backend_args) - backend_instances[unique_key] = backend - return backend - else: - backend = _get_file_backend(prefix, backend_args) - return backend - - -def get( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> filepath = '/path/of/file' - >>> get(filepath) - b'hello world' - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - return backend.get(filepath) - - -def get_text( - filepath: Union[str, Path], - encoding='utf-8', - backend_args: Optional[dict] = None, -) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> filepath = '/path/of/file' - >>> get_text(filepath) - 'hello world' - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - return backend.get_text(filepath, encoding) - - -def put( - obj: bytes, - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> filepath = '/path/of/file' - >>> put(b'hello world', filepath) - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - backend.put(obj, filepath) - - -def put_text( - obj: str, - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> filepath = '/path/of/file' - >>> put_text('hello world', filepath) - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - backend.put_text(obj, filepath) - - -def exists( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> exists(filepath) - True - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - return backend.exists(filepath) - - -def isdir( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/dir' - >>> isdir(filepath) - True - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - return backend.isdir(filepath) - - -def isfile( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> isfile(filepath) - True - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - return backend.isfile(filepath) - - -def join_path( - filepath: Union[str, Path], - *filepaths: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - *filepaths (str or Path): Other paths to be concatenated. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The result of concatenation. - - Examples: - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - return backend.join_path(filepath, *filepaths) - - -@contextmanager -def get_local_path( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Generator[Union[str, Path], None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself and it will - not be released (removed). - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: Only yield one path. - - Examples: - >>> with get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - with backend.get_local_path(str(filepath)) as local_path: - yield local_path - - -def copyfile( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError will - be raised. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> copyfile(src, dst) - '/path1/of/dir/file' - """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) - return backend.copyfile(src, dst) - - -def copytree( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will be - raised. - - Examples: - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> copytree(src, dst) - '/path/of/dir2' - """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) - return backend.copytree(src, dst) - - -def copyfile_from_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - """Copy a local file src to dst and return the destination file. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = 's3://openmmlab/mmengine/file1' - >>> # src will be copied to 's3://openmmlab/mmengine/file1' - >>> copyfile_from_local(src, dst) - s3://openmmlab/mmengine/file1 - - >>> # dst is a directory - >>> dst = 's3://openmmlab/mmengine' - >>> # src will be copied to 's3://openmmlab/mmengine/file'' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/file' - """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) - return backend.copyfile_from_local(src, dst) - - -def copytree_from_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = '/path/of/dir' - >>> dst = 's3://openmmlab/mmengine/dir' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/dir' - """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) - return backend.copytree_from_local(src, dst) - - -def copyfile_to_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = 's3://openmmlab/mmengine/file' - >>> dst = '/path/of/file' - >>> # src will be copied to '/path/of/file' - >>> copyfile_to_local(src, dst) - '/path/of/file' - - >>> # dst is a directory - >>> dst = '/path/of/dir' - >>> # src will be copied to '/path/of/dir/file' - >>> copyfile_to_local(src, dst) - '/path/of/dir/file' - """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) - return backend.copyfile_to_local(src, dst) - - -def copytree_to_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = 's3://openmmlab/mmengine/dir' - >>> dst = '/path/of/dir' - >>> copytree_to_local(src, dst) - '/path/of/dir' - """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) - return backend.copytree_to_local(src, dst) - - -def remove( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, -) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> filepath = '/path/of/file' - >>> remove(filepath) - """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) - backend.remove(filepath) - - -def rmtree( - dir_path: Union[str, Path], - backend_args: Optional[dict] = None, -) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> dir_path = '/path/of/dir' - >>> rmtree(dir_path) - """ - backend = get_file_backend( - dir_path, backend_args=backend_args, enable_singleton=True) - backend.rmtree(dir_path) - - -def copy_if_symlink_fails( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, -) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directory copy src to - dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return True if successfully create a symbolic link pointing to - src. Otherwise, return False. - - Examples: - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> copy_if_symlink_fails(src, dst) - True - """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) - return backend.copy_if_symlink_fails(src, dst) - - -def list_dir_or_file( - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - backend_args: Optional[dict] = None, -) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> dir_path = '/path/of/dir' - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # list those files and directories in current directory - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ - backend = get_file_backend( - dir_path, backend_args=backend_args, enable_singleton=True) - yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) - - -def generate_presigned_url( - url: str, - client_method: str = 'get_object', - expires_in: int = 3600, - backend_args: Optional[dict] = None, -) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on Petrel backend. - - Note: - Now only work on Petrel backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Defaults to 'get_object'. - expires_in (int): expires, in seconds. Defaults to 3600. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: Generated presigned url. - """ - backend = get_file_backend( - url, backend_args=backend_args, enable_singleton=True) - return backend.generate_presigned_url(url, client_method, expires_in) - - -def load(file, - file_format=None, - file_client_args=None, - backend_args=None, - **kwargs): - """Load data from json/yaml/pickle files. - - This method provides a unified api for loading data from serialized files. - - ``load`` supports loading data from serialized files those can be storaged - in different backends. - - Args: - file (str or :obj:`Path` or file-like object): Filename or a file-like - object. - file_format (str, optional): If not specified, the file format will be - inferred from the file extension, otherwise use the specified one. - Currently supported formats include "json", "yaml/yml" and - "pickle/pkl". - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> load('/path/of/your/file') # file is storaged in disk - >>> load('https://path/of/your/file') # file is storaged in Internet - >>> load('s3://path/of/your/file') # file is storaged in petrel - - Returns: - The content from the file. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None and is_str(file): - file_format = file.split('.')[-1] - if file_format not in file_handlers: - raise TypeError(f'Unsupported format: {file_format}') - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', DeprecationWarning) - if backend_args is not None: - raise ValueError( - '"file_client_args and "backend_args" cannot be set at the ' - 'same time.') - - handler = file_handlers[file_format] - if is_str(file): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend(file, backend_args=backend_args) - - if handler.str_like: - with StringIO(file_backend.get_text(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - with BytesIO(file_backend.get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - elif hasattr(file, 'read'): - obj = handler.load_from_fileobj(file, **kwargs) - else: - raise TypeError('"file" must be a filepath str or a file-object') - return obj - - -def dump(obj, - file=None, - file_format=None, - file_client_args=None, - backend_args=None, - **kwargs): - """Dump data to json/yaml/pickle strings or files. - - This method provides a unified api for dumping data as strings or to files, - and also supports custom arguments for each file format. - - ``dump`` supports dumping data as strings or to files which is saved to - different backends. - - Args: - obj (any): The python object to be dumped. - file (str or :obj:`Path` or file-like object, optional): If not - specified, then the object is dumped to a str, otherwise to a file - specified by the filename or file-like object. - file_format (str, optional): Same as :func:`load`. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> dump('hello world', '/path/of/your/file') # disk - >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel - - Returns: - bool: True for success, False otherwise. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None: - if is_str(file): - file_format = file.split('.')[-1] - elif file is None: - raise ValueError( - 'file_format must be specified since file is None') - if file_format not in file_handlers: - raise TypeError(f'Unsupported format: {file_format}') - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', DeprecationWarning) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set at the ' - 'same time.') - - handler = file_handlers[file_format] - if file is None: - return handler.dump_to_str(obj, **kwargs) - elif is_str(file): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend(file, backend_args=backend_args) - - if handler.str_like: - with StringIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put_text(f.getvalue(), file) - else: - with BytesIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put(f.getvalue(), file) - elif hasattr(file, 'write'): - handler.dump_to_fileobj(obj, file, **kwargs) - else: - raise TypeError('"file" must be a filename str or a file-object') diff --git a/mmengine/fileio/parse.py b/mmengine/fileio/parse.py deleted file mode 100644 index 781d899a04..0000000000 --- a/mmengine/fileio/parse.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from io import StringIO - -from .file_client import FileClient -from .io import get_text - - -def list_from_file(filename, - prefix='', - offset=0, - max_num=0, - encoding='utf-8', - file_client_args=None, - backend_args=None): - """Load a text file and parse the content as a list of strings. - - ``list_from_file`` supports loading a text file which can be storaged in - different backends and parsing the content as a list for strings. - - Args: - filename (str): Filename. - prefix (str): The prefix to be inserted to the beginning of each item. - offset (int): The offset of lines. - max_num (int): The maximum number of lines to be read, - zeros and negatives mean no limitation. - encoding (str): Encoding used to open the file. Defaults to utf-8. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> list_from_file('/path/of/your/file') # disk - ['hello', 'world'] - >>> list_from_file('s3://path/of/your/file') # ceph or petrel - ['hello', 'world'] - - Returns: - list[str]: A list of strings. - """ - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', DeprecationWarning) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set at the ' - 'same time.') - cnt = 0 - item_list = [] - - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, filename) - text = file_client.get_text(filename, encoding) - else: - text = get_text(filename, encoding, backend_args=backend_args) - - with StringIO(text) as f: - for _ in range(offset): - f.readline() - for line in f: - if 0 < max_num <= cnt: - break - item_list.append(prefix + line.rstrip('\n\r')) - cnt += 1 - return item_list - - -def dict_from_file(filename, - key_type=str, - encoding='utf-8', - file_client_args=None, - backend_args=None): - """Load a text file and parse the content as a dict. - - Each line of the text file will be two or more columns split by - whitespaces or tabs. The first column will be parsed as dict keys, and - the following columns will be parsed as dict values. - - ``dict_from_file`` supports loading a text file which can be storaged in - different backends and parsing the content as a dict. - - Args: - filename(str): Filename. - key_type(type): Type of the dict keys. str is user by default and - type conversion will be performed if specified. - encoding (str): Encoding used to open the file. Defaults to utf-8. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> dict_from_file('/path/of/your/file') # disk - {'key1': 'value1', 'key2': 'value2'} - >>> dict_from_file('s3://path/of/your/file') # ceph or petrel - {'key1': 'value1', 'key2': 'value2'} - - Returns: - dict: The parsed contents. - """ - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', DeprecationWarning) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set at the ' - 'same time.') - - mapping = {} - - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, filename) - text = file_client.get_text(filename, encoding) - else: - text = get_text(filename, encoding, backend_args=backend_args) - - with StringIO(text) as f: - for line in f: - items = line.rstrip('\n').split() - assert len(items) >= 2 - key = key_type(items[0]) - val = items[1:] if len(items) > 2 else items[1] - mapping[key] = val - return mapping diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py deleted file mode 100644 index 746be6b02a..0000000000 --- a/mmengine/hooks/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .checkpoint_hook import CheckpointHook -from .early_stopping_hook import EarlyStoppingHook -from .ema_hook import EMAHook -from .empty_cache_hook import EmptyCacheHook -from .hook import Hook -from .iter_timer_hook import IterTimerHook -from .logger_hook import LoggerHook -from .naive_visualization_hook import NaiveVisualizationHook -from .param_scheduler_hook import ParamSchedulerHook -from .profiler_hook import NPUProfilerHook, ProfilerHook -from .runtime_info_hook import RuntimeInfoHook -from .sampler_seed_hook import DistSamplerSeedHook -from .sync_buffer_hook import SyncBuffersHook -from .test_time_aug_hook import PrepareTTAHook - -__all__ = [ - 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', - 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook', - 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook', - 'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook' -] diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py deleted file mode 100644 index 92a4867bb9..0000000000 --- a/mmengine/hooks/checkpoint_hook.py +++ /dev/null @@ -1,665 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import hashlib -import logging -import os.path as osp -import pickle -from collections import deque -from math import inf -from pathlib import Path -from typing import Callable, Dict, List, Optional, Sequence, Union - -from mmengine.dist import is_main_process, master_only -from mmengine.fileio import FileClient, get_file_backend -from mmengine.logging import print_log -from mmengine.registry import HOOKS -from mmengine.utils import is_list_of, is_seq_of -from .hook import Hook - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -@HOOKS.register_module() -class CheckpointHook(Hook): - """Save checkpoints periodically. - - Args: - interval (int): The saving period. If ``by_epoch=True``, interval - indicates epochs, otherwise it indicates iterations. - Defaults to -1, which means "never". - by_epoch (bool): Saving checkpoints by epoch or by iteration. - Defaults to True. - save_optimizer (bool): Whether to save optimizer state_dict in the - checkpoint. It is usually used for resuming experiments. - Defaults to True. - save_param_scheduler (bool): Whether to save param_scheduler state_dict - in the checkpoint. It is usually used for resuming experiments. - Defaults to True. - out_dir (str, Path, Optional): The root directory to save checkpoints. - If not specified, ``runner.work_dir`` will be used by default. If - specified, the ``out_dir`` will be the concatenation of ``out_dir`` - and the last level directory of ``runner.work_dir``. For example, - if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is - ``./work_dir/cur_exp``, then the ckpt will be saved in - ``./tmp/cur_exp``. Defaults to None. - max_keep_ckpts (int): The maximum checkpoints to keep. - In some cases we want only the latest few checkpoints and would - like to delete old ones to save the disk space. - Defaults to -1, which means unlimited. - save_last (bool): Whether to force the last checkpoint to be - saved regardless of interval. Defaults to True. - save_best (str, List[str], optional): If a metric is specified, it - would measure the best checkpoint during evaluation. If a list of - metrics is passed, it would measure a group of best checkpoints - corresponding to the passed metrics. The information about best - checkpoint(s) would be saved in ``runner.message_hub`` to keep - best score value and best checkpoint path, which will be also - loaded when resuming checkpoint. Options are the evaluation metrics - on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox - detection and instance segmentation. ``AR@100`` for proposal - recall. If ``save_best`` is ``auto``, the first key of the returned - ``OrderedDict`` result will be used. Defaults to None. - rule (str, List[str], optional): Comparison rule for best score. If - set to None, it will infer a reasonable rule. Keys such as 'acc', - 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' - will be inferred by 'less' rule. If ``save_best`` is a list of - metrics and ``rule`` is a str, all metrics in ``save_best`` will - share the comparison rule. If ``save_best`` and ``rule`` are both - lists, their length must be the same, and metrics in ``save_best`` - will use the corresponding comparison rule in ``rule``. Options - are 'greater', 'less', None and list which contains 'greater' and - 'less'. Defaults to None. - greater_keys (List[str], optional): Metric keys that will be - inferred by 'greater' comparison rule. If ``None``, - _default_greater_keys will be used. Defaults to None. - less_keys (List[str], optional): Metric keys that will be - inferred by 'less' comparison rule. If ``None``, _default_less_keys - will be used. Defaults to None. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - filename_tmpl (str, optional): String template to indicate checkpoint - name. If specified, must contain one and only one "{}", which will - be replaced with ``epoch + 1`` if ``by_epoch=True`` else - ``iteration + 1``. - Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" - accordingly. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - `New in version 0.2.0.` - published_keys (str, List[str], optional): If ``save_last`` is ``True`` - or ``save_best`` is not ``None``, it will automatically - publish model with keys in the list after training. - Defaults to None. - `New in version 0.7.1.` - save_begin (int): Control the epoch number or iteration number - at which checkpoint saving begins. Defaults to 0, which means - saving at the beginning. - `New in version 0.8.3.` - - Examples: - >>> # Save best based on single metric - >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', - >>> rule='less') - >>> # Save best based on multi metrics with the same comparison rule - >>> CheckpointHook(interval=2, by_epoch=True, - >>> save_best=['acc', 'mIoU'], rule='greater') - >>> # Save best based on multi metrics with different comparison rule - >>> CheckpointHook(interval=2, by_epoch=True, - >>> save_best=['FID', 'IS'], rule=['less', 'greater']) - >>> # Save best based on single metric and publish model after training - >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', - >>> rule='less', published_keys=['meta', 'state_dict']) - """ - out_dir: str - - priority = 'VERY_LOW' - - # logic to save best checkpoints - # Since the key for determining greater or less is related to the - # downstream tasks, downstream repositories may need to overwrite - # the following inner variables accordingly. - - rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} - init_value_map = {'greater': -inf, 'less': inf} - _default_greater_keys = [ - 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', - 'mAcc', 'aAcc' - ] - _default_less_keys = ['loss'] - - def __init__(self, - interval: int = -1, - by_epoch: bool = True, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - out_dir: Optional[Union[str, Path]] = None, - max_keep_ckpts: int = -1, - save_last: bool = True, - save_best: Union[str, List[str], None] = None, - rule: Union[str, List[str], None] = None, - greater_keys: Optional[Sequence[str]] = None, - less_keys: Optional[Sequence[str]] = None, - file_client_args: Optional[dict] = None, - filename_tmpl: Optional[str] = None, - backend_args: Optional[dict] = None, - published_keys: Union[str, List[str], None] = None, - save_begin: int = 0, - **kwargs) -> None: - self.interval = interval - self.by_epoch = by_epoch - self.save_optimizer = save_optimizer - self.save_param_scheduler = save_param_scheduler - self.out_dir = out_dir # type: ignore - self.max_keep_ckpts = max_keep_ckpts - self.save_last = save_last - self.args = kwargs - - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', - logger='current', - level=logging.WARNING) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set ' - 'at the same time.') - - self.file_client_args = file_client_args - self.backend_args = backend_args - - if filename_tmpl is None: - if self.by_epoch: - self.filename_tmpl = 'epoch_{}.pth' - else: - self.filename_tmpl = 'iter_{}.pth' - else: - self.filename_tmpl = filename_tmpl - - # save best logic - assert (isinstance(save_best, str) or is_list_of(save_best, str) - or (save_best is None)), ( - '"save_best" should be a str or list of str or None, ' - f'but got {type(save_best)}') - - if isinstance(save_best, list): - if 'auto' in save_best: - assert len(save_best) == 1, ( - 'Only support one "auto" in "save_best" list.') - assert len(save_best) == len( - set(save_best)), ('Find duplicate element in "save_best".') - else: - # convert str to list[str] - if save_best is not None: - save_best = [save_best] # type: ignore # noqa: F401 - self.save_best = save_best - - # rule logic - assert (isinstance(rule, str) or is_list_of(rule, str) - or (rule is None)), ( - '"rule" should be a str or list of str or None, ' - f'but got {type(rule)}') - if isinstance(rule, list): - # check the length of rule list - assert len(rule) in [ - 1, - len(self.save_best) # type: ignore - ], ('Number of "rule" must be 1 or the same as number of ' - f'"save_best", but got {len(rule)}.') - else: - # convert str/None to list - rule = [rule] # type: ignore # noqa: F401 - - if greater_keys is None: - self.greater_keys = self._default_greater_keys - else: - if not isinstance(greater_keys, (list, tuple)): - greater_keys = (greater_keys, ) # type: ignore - assert is_seq_of(greater_keys, str) - self.greater_keys = greater_keys # type: ignore - - if less_keys is None: - self.less_keys = self._default_less_keys - else: - if not isinstance(less_keys, (list, tuple)): - less_keys = (less_keys, ) # type: ignore - assert is_seq_of(less_keys, str) - self.less_keys = less_keys # type: ignore - - if self.save_best is not None: - self.is_better_than: Dict[str, Callable] = dict() - self._init_rule(rule, self.save_best) - if len(self.key_indicators) == 1: - self.best_ckpt_path: Optional[str] = None - else: - self.best_ckpt_path_dict: Dict = dict() - - # published keys - if not (isinstance(published_keys, str) - or is_seq_of(published_keys, str) or published_keys is None): - raise TypeError( - '"published_keys" should be a str or a sequence of str or ' - f'None, but got {type(published_keys)}') - - if isinstance(published_keys, str): - published_keys = [published_keys] - elif isinstance(published_keys, (list, tuple)): - assert len(published_keys) == len(set(published_keys)), ( - 'Find duplicate elements in "published_keys".') - self.published_keys = published_keys - - self.last_ckpt = None - if save_begin < 0: - raise ValueError( - 'save_begin should not be less than 0, but got {save_begin}') - self.save_begin = save_begin - - def before_train(self, runner) -> None: - """Finish all operations, related to checkpoint. - - This function will get the appropriate file client, and the directory - to save these checkpoints of the model. - - Args: - runner (Runner): The runner of the training process. - """ - if self.out_dir is None: - self.out_dir = runner.work_dir - - # If self.file_client_args is None, self.file_client will not - # used in CheckpointHook. To avoid breaking backward compatibility, - # it will not be removed util the release of MMEngine1.0 - self.file_client = FileClient.infer_client(self.file_client_args, - self.out_dir) - - if self.file_client_args is None: - self.file_backend = get_file_backend( - self.out_dir, backend_args=self.backend_args) - else: - self.file_backend = self.file_client - - # if `self.out_dir` is not equal to `runner.work_dir`, it means that - # `self.out_dir` is set so the final `self.out_dir` is the - # concatenation of `self.out_dir` and the last level directory of - # `runner.work_dir` - if self.out_dir != runner.work_dir: - basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_backend.join_path( - self.out_dir, basename) # type: ignore # noqa: E501 - - runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') - - if self.save_best is not None: - if len(self.key_indicators) == 1: - if 'best_ckpt' not in runner.message_hub.runtime_info: - self.best_ckpt_path = None - else: - self.best_ckpt_path = runner.message_hub.get_info( - 'best_ckpt') - else: - for key_indicator in self.key_indicators: - best_ckpt_name = f'best_ckpt_{key_indicator}' - if best_ckpt_name not in runner.message_hub.runtime_info: - self.best_ckpt_path_dict[key_indicator] = None - else: - self.best_ckpt_path_dict[ - key_indicator] = runner.message_hub.get_info( - best_ckpt_name) - - if self.max_keep_ckpts > 0: - keep_ckpt_ids = [] - if 'keep_ckpt_ids' in runner.message_hub.runtime_info: - keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') - - while len(keep_ckpt_ids) > self.max_keep_ckpts: - step = keep_ckpt_ids.pop(0) - if is_main_process(): - path = self.file_backend.join_path( - self.out_dir, self.filename_tmpl.format(step)) - if self.file_backend.isfile(path): - self.file_backend.remove(path) - elif self.file_backend.isdir(path): - # checkpoints saved by deepspeed are directories - self.file_backend.rmtree(path) - - self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, - self.max_keep_ckpts) - - def after_train_epoch(self, runner) -> None: - """Save the checkpoint and synchronize buffers after each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - if not self.by_epoch: - return - - # save checkpoint for following cases: - # 1. every ``self.interval`` epochs which start at ``self.save_begin`` - # 2. reach the last epoch of training - if self.every_n_epochs(runner, self.interval, self.save_begin) or ( - self.save_last and self.is_last_train_epoch(runner)): - runner.logger.info( - f'Saving checkpoint at {runner.epoch + 1} epochs') - self._save_checkpoint(runner) - - def after_val_epoch(self, runner, metrics): - """Save the checkpoint and synchronize buffers after each evaluation - epoch. - - Args: - runner (Runner): The runner of the training process. - metrics (dict): Evaluation results of all metrics - """ - if len(metrics) == 0: - runner.logger.warning( - 'Since `metrics` is an empty dict, the behavior to save ' - 'the best checkpoint will be skipped in this evaluation.') - return - - self._save_best_checkpoint(runner, metrics) - - def after_train(self, runner) -> None: - """Publish the checkpoint after training. - - Args: - runner (Runner): The runner of the training process. - """ - if self.published_keys is None: - return - - if self.save_last and self.last_ckpt is not None: - self._publish_model(runner, self.last_ckpt) - - if getattr(self, 'best_ckpt_path', None) is not None: - self._publish_model(runner, str(self.best_ckpt_path)) - if getattr(self, 'best_ckpt_path_dict', None) is not None: - for best_ckpt in self.best_ckpt_path_dict.values(): - self._publish_model(runner, best_ckpt) - - @master_only - def _publish_model(self, runner, ckpt_path: str) -> None: - """Remove unnecessary keys from ckpt_path and save the new checkpoint. - - Args: - runner (Runner): The runner of the training process. - ckpt_path (str): The checkpoint path that ought to be published. - """ - from mmengine.runner import save_checkpoint - from mmengine.runner.checkpoint import _load_checkpoint - checkpoint = _load_checkpoint(ckpt_path) - assert self.published_keys is not None - removed_keys = [] - for key in list(checkpoint.keys()): - if key not in self.published_keys: - removed_keys.append(key) - checkpoint.pop(key) - if removed_keys: - print_log( - f'Key {removed_keys} will be removed because they are not ' - 'found in published_keys. If you want to keep them, ' - f'please set `{removed_keys}` in published_keys', - logger='current') - checkpoint_data = pickle.dumps(checkpoint) - sha = hashlib.sha256(checkpoint_data).hexdigest() - final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth' - save_checkpoint(checkpoint, final_path) - print_log( - f'The checkpoint ({ckpt_path}) is published to ' - f'{final_path}.', - logger='current') - - def _save_checkpoint_with_step(self, runner, step, meta): - # remove other checkpoints before save checkpoint to make the - # self.keep_ckpt_ids are saved as expected - if self.max_keep_ckpts > 0: - # _save_checkpoint and _save_best_checkpoint may call this - # _save_checkpoint_with_step in one epoch - if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: - pass - else: - if len(self.keep_ckpt_ids) == self.max_keep_ckpts: - _step = self.keep_ckpt_ids.popleft() - if is_main_process(): - ckpt_path = self.file_backend.join_path( - self.out_dir, self.filename_tmpl.format(_step)) - - if self.file_backend.isfile(ckpt_path): - self.file_backend.remove(ckpt_path) - elif self.file_backend.isdir(ckpt_path): - # checkpoints saved by deepspeed are directories - self.file_backend.rmtree(ckpt_path) - - self.keep_ckpt_ids.append(step) - runner.message_hub.update_info('keep_ckpt_ids', - list(self.keep_ckpt_ids)) - - ckpt_filename = self.filename_tmpl.format(step) - self.last_ckpt = self.file_backend.join_path(self.out_dir, - ckpt_filename) - runner.message_hub.update_info('last_ckpt', self.last_ckpt) - - runner.save_checkpoint( - self.out_dir, - ckpt_filename, - self.file_client_args, - save_optimizer=self.save_optimizer, - save_param_scheduler=self.save_param_scheduler, - meta=meta, - by_epoch=self.by_epoch, - backend_args=self.backend_args, - **self.args) - - # Model parallel-like training should involve pulling sharded states - # from all ranks, but skip the following procedure. - if not is_main_process(): - return - - save_file = osp.join(runner.work_dir, 'last_checkpoint') - with open(save_file, 'w') as f: - f.write(self.last_ckpt) # type: ignore - - def _save_checkpoint(self, runner) -> None: - """Save the current checkpoint and delete outdated checkpoint. - - Args: - runner (Runner): The runner of the training process. - """ - if self.by_epoch: - step = runner.epoch + 1 - meta = dict(epoch=step, iter=runner.iter) - else: - step = runner.iter + 1 - meta = dict(epoch=runner.epoch, iter=step) - - self._save_checkpoint_with_step(runner, step, meta=meta) - - def _save_best_checkpoint(self, runner, metrics) -> None: - """Save the current checkpoint and delete outdated checkpoint. - - Args: - runner (Runner): The runner of the training process. - metrics (dict): Evaluation results of all metrics. - """ - if not self.save_best: - return - - if self.by_epoch: - ckpt_filename = self.filename_tmpl.format(runner.epoch) - cur_type, cur_time = 'epoch', runner.epoch - else: - ckpt_filename = self.filename_tmpl.format(runner.iter) - cur_type, cur_time = 'iter', runner.iter - - meta = dict(epoch=runner.epoch, iter=runner.iter) - - # handle auto in self.key_indicators and self.rules before the loop - if 'auto' in self.key_indicators: - self._init_rule(self.rules, [list(metrics.keys())[0]]) - - best_ckpt_updated = False - # save best logic - # get score from messagehub - for key_indicator, rule in zip(self.key_indicators, self.rules): - key_score = metrics[key_indicator] - - if len(self.key_indicators) == 1: - best_score_key = 'best_score' - runtime_best_ckpt_key = 'best_ckpt' - best_ckpt_path = self.best_ckpt_path - else: - best_score_key = f'best_score_{key_indicator}' - runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' - best_ckpt_path = self.best_ckpt_path_dict[key_indicator] - - if best_score_key not in runner.message_hub.runtime_info: - best_score = self.init_value_map[rule] - else: - best_score = runner.message_hub.get_info(best_score_key) - - if key_score is None or not self.is_better_than[key_indicator]( - key_score, best_score): - continue - - best_ckpt_updated = True - - best_score = key_score - runner.message_hub.update_info(best_score_key, best_score) - - if best_ckpt_path and is_main_process(): - is_removed = False - if self.file_backend.isfile(best_ckpt_path): - self.file_backend.remove(best_ckpt_path) - is_removed = True - elif self.file_backend.isdir(best_ckpt_path): - # checkpoints saved by deepspeed are directories - self.file_backend.rmtree(best_ckpt_path) - is_removed = True - - if is_removed: - runner.logger.info( - f'The previous best checkpoint {best_ckpt_path} ' - 'is removed') - - best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' - # Replace illegal characters for filename with `_` - best_ckpt_name = best_ckpt_name.replace('/', '_') - if len(self.key_indicators) == 1: - self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501 - self.out_dir, best_ckpt_name) - runner.message_hub.update_info(runtime_best_ckpt_key, - self.best_ckpt_path) - else: - self.best_ckpt_path_dict[ - key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501 - self.out_dir, best_ckpt_name) - runner.message_hub.update_info( - runtime_best_ckpt_key, - self.best_ckpt_path_dict[key_indicator]) - runner.save_checkpoint( - self.out_dir, - filename=best_ckpt_name, - file_client_args=self.file_client_args, - save_optimizer=False, - save_param_scheduler=False, - meta=meta, - by_epoch=False, - backend_args=self.backend_args) - runner.logger.info( - f'The best checkpoint with {best_score:0.4f} {key_indicator} ' - f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') - - # save checkpoint again to update the best_score and best_ckpt stored - # in message_hub because the checkpoint saved in `after_train_epoch` - # or `after_train_iter` stage only keep the previous best checkpoint - # not the current best checkpoint which causes the current best - # checkpoint can not be removed when resuming training. - if best_ckpt_updated and self.last_ckpt is not None: - self._save_checkpoint_with_step(runner, cur_time, meta) - - def _init_rule(self, rules, key_indicators) -> None: - """Initialize rule, key_indicator, comparison_func, and best score. If - key_indicator is a list of string and rule is a string, all metric in - the key_indicator will share the same rule. - - Here is the rule to determine which rule is used for key indicator when - the rule is not specific (note that the key indicator matching is case- - insensitive): - - 1. If the key indicator is in ``self.greater_keys``, the rule - will be specified as 'greater'. - 2. Or if the key indicator is in ``self.less_keys``, the rule - will be specified as 'less'. - 3. Or if any one item in ``self.greater_keys`` is a substring of - key_indicator, the rule will be specified as 'greater'. - 4. Or if any one item in ``self.less_keys`` is a substring of - key_indicator, the rule will be specified as 'less'. - - Args: - rule (List[Optional[str]]): Comparison rule for best score. - key_indicator (List[str]): Key indicator to determine - the comparison rule. - """ - if len(rules) == 1: - rules = rules * len(key_indicators) - - self.rules = [] - for rule, key_indicator in zip(rules, key_indicators): - - if rule not in self.rule_map and rule is not None: - raise KeyError('rule must be greater, less or None, ' - f'but got {rule}.') - - if rule is None and key_indicator != 'auto': - # `_lc` here means we use the lower case of keys for - # case-insensitive matching - key_indicator_lc = key_indicator.lower() - greater_keys = {key.lower() for key in self.greater_keys} - less_keys = {key.lower() for key in self.less_keys} - - if key_indicator_lc in greater_keys: - rule = 'greater' - elif key_indicator_lc in less_keys: - rule = 'less' - elif any(key in key_indicator_lc for key in greater_keys): - rule = 'greater' - elif any(key in key_indicator_lc for key in less_keys): - rule = 'less' - else: - raise ValueError('Cannot infer the rule for key ' - f'{key_indicator}, thus a specific rule ' - 'must be specified.') - if rule is not None: - self.is_better_than[key_indicator] = self.rule_map[rule] - self.rules.append(rule) - - self.key_indicators = key_indicators - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs=Optional[dict]) -> None: - """Save the checkpoint and synchronize buffers after each iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict, optional): Outputs from model. - """ - if self.by_epoch: - return - - # save checkpoint for following cases: - # 1. every ``self.interval`` iterations - # which start at ``self.save_begin`` - # 2. reach the last iteration of training - if self.every_n_train_iters(runner, self.interval, - self.save_begin) or \ - (self.save_last and - self.is_last_train_iter(runner)): - runner.logger.info( - f'Saving checkpoint at {runner.iter + 1} iterations') - self._save_checkpoint(runner) diff --git a/mmengine/hooks/early_stopping_hook.py b/mmengine/hooks/early_stopping_hook.py deleted file mode 100644 index 5533ebc84c..0000000000 --- a/mmengine/hooks/early_stopping_hook.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from math import inf, isfinite -from typing import Optional, Tuple, Union - -from mmengine.registry import HOOKS -from .hook import Hook - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -@HOOKS.register_module() -class EarlyStoppingHook(Hook): - """Early stop the training when the monitored metric reached a plateau. - - Args: - monitor (str): The monitored metric key to decide early stopping. - rule (str, optional): Comparison rule. Options are 'greater', - 'less'. Defaults to None. - min_delta (float, optional): Minimum difference to continue the - training. Defaults to 0.01. - strict (bool, optional): Whether to crash the training when `monitor` - is not found in the `metrics`. Defaults to False. - check_finite: Whether to stop training when the monitor becomes NaN or - infinite. Defaults to True. - patience (int, optional): The times of validation with no improvement - after which training will be stopped. Defaults to 5. - stopping_threshold (float, optional): Stop training immediately once - the monitored quantity reaches this threshold. Defaults to None. - - Note: - `New in version 0.7.0.` - """ - priority = 'LOWEST' - - rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} - _default_greater_keys = [ - 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', - 'mAcc', 'aAcc' - ] - _default_less_keys = ['loss'] - - def __init__( - self, - monitor: str, - rule: Optional[str] = None, - min_delta: float = 0.1, - strict: bool = False, - check_finite: bool = True, - patience: int = 5, - stopping_threshold: Optional[float] = None, - ): - - self.monitor = monitor - if rule is not None: - if rule not in ['greater', 'less']: - raise ValueError( - '`rule` should be either "greater" or "less", ' - f'but got {rule}') - else: - rule = self._init_rule(monitor) - self.rule = rule - self.min_delta = min_delta if rule == 'greater' else -1 * min_delta - self.strict = strict - self.check_finite = check_finite - self.patience = patience - self.stopping_threshold = stopping_threshold - - self.wait_count = 0 - self.best_score = -inf if rule == 'greater' else inf - - def _init_rule(self, monitor: str) -> str: - greater_keys = {key.lower() for key in self._default_greater_keys} - less_keys = {key.lower() for key in self._default_less_keys} - monitor_lc = monitor.lower() - if monitor_lc in greater_keys: - rule = 'greater' - elif monitor_lc in less_keys: - rule = 'less' - elif any(key in monitor_lc for key in greater_keys): - rule = 'greater' - elif any(key in monitor_lc for key in less_keys): - rule = 'less' - else: - raise ValueError(f'Cannot infer the rule for {monitor}, thus rule ' - 'must be specified.') - return rule - - def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]: - compare = self.rule_map[self.rule] - stop_training = False - reason_message = '' - - if self.check_finite and not isfinite(current_score): - stop_training = True - reason_message = (f'Monitored metric {self.monitor} = ' - f'{current_score} is infinite. ' - f'Previous best value was ' - f'{self.best_score:.3f}.') - - elif self.stopping_threshold is not None and compare( - current_score, self.stopping_threshold): - stop_training = True - self.best_score = current_score - reason_message = (f'Stopping threshold reached: ' - f'`{self.monitor}` = {current_score} is ' - f'{self.rule} than {self.stopping_threshold}.') - elif compare(self.best_score + self.min_delta, current_score): - - self.wait_count += 1 - - if self.wait_count >= self.patience: - reason_message = (f'the monitored metric did not improve ' - f'in the last {self.wait_count} records. ' - f'best score: {self.best_score:.3f}. ') - stop_training = True - else: - self.best_score = current_score - self.wait_count = 0 - - return stop_training, reason_message - - def before_run(self, runner) -> None: - """Check `stop_training` variable in `runner.train_loop`. - - Args: - runner (Runner): The runner of the training process. - """ - - assert hasattr(runner.train_loop, 'stop_training'), \ - '`train_loop` should contain `stop_training` variable.' - - def after_val_epoch(self, runner, metrics): - """Decide whether to stop the training process. - - Args: - runner (Runner): The runner of the training process. - metrics (dict): Evaluation results of all metrics - """ - - if self.monitor not in metrics: - if self.strict: - raise RuntimeError( - 'Early stopping conditioned on metric ' - f'`{self.monitor} is not available. Please check available' - f' metrics {metrics}, or set `strict=False` in ' - '`EarlyStoppingHook`.') - warnings.warn( - 'Skip early stopping process since the evaluation ' - f'results ({metrics.keys()}) do not include `monitor` ' - f'({self.monitor}).') - return - - current_score = metrics[self.monitor] - - stop_training, message = self._check_stop_condition(current_score) - if stop_training: - runner.train_loop.stop_training = True - runner.logger.info(message) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py deleted file mode 100644 index 5bc1051d0b..0000000000 --- a/mmengine/hooks/ema_hook.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import itertools -import logging -from typing import Dict, Optional - -from mmengine.logging import print_log -from mmengine.model import is_model_wrapper -from mmengine.registry import HOOKS, MODELS -from .hook import DATA_BATCH, Hook - - -@HOOKS.register_module() -class EMAHook(Hook): - """A Hook to apply Exponential Moving Average (EMA) on the model during - training. - - Note: - - EMAHook takes priority over CheckpointHook. - - The original model parameters are actually saved in ema field after - train. - - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. - - Args: - ema_type (str): The type of EMA strategy to use. You can find the - supported strategies in :mod:`mmengine.model.averaged_model`. - Defaults to 'ExponentialMovingAverage'. - strict_load (bool): Whether to strictly enforce that the keys of - ``state_dict`` in checkpoint match the keys returned by - ``self.module.state_dict``. Defaults to False. - Changed in v0.3.0. - begin_iter (int): The number of iteration to enable ``EMAHook``. - Defaults to 0. - begin_epoch (int): The number of epoch to enable ``EMAHook``. - Defaults to 0. - **kwargs: Keyword arguments passed to subclasses of - :obj:`BaseAveragedModel` - """ - - priority = 'NORMAL' - - def __init__(self, - ema_type: str = 'ExponentialMovingAverage', - strict_load: bool = False, - begin_iter: int = 0, - begin_epoch: int = 0, - **kwargs): - self.strict_load = strict_load - self.ema_cfg = dict(type=ema_type, **kwargs) - assert not (begin_iter != 0 and begin_epoch != 0), ( - '`begin_iter` and `begin_epoch` should not be both set.') - assert begin_iter >= 0, ( - '`begin_iter` must larger than or equal to 0, ' - f'but got begin_iter: {begin_iter}') - assert begin_epoch >= 0, ( - '`begin_epoch` must larger than or equal to 0, ' - f'but got begin_epoch: {begin_epoch}') - self.begin_iter = begin_iter - self.begin_epoch = begin_epoch - # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be - # enabled at 0 iteration. - self.enabled_by_epoch = self.begin_epoch > 0 - - def before_run(self, runner) -> None: - """Create an ema copy of the model. - - Args: - runner (Runner): The runner of the training process. - """ - model = runner.model - if is_model_wrapper(model): - model = model.module - self.src_model = model - self.ema_model = MODELS.build( - self.ema_cfg, default_args=dict(model=self.src_model)) - - def before_train(self, runner) -> None: - """Check the begin_epoch/iter is smaller than max_epochs/iters. - - Args: - runner (Runner): The runner of the training process. - """ - if self.enabled_by_epoch: - assert self.begin_epoch <= runner.max_epochs, ( - 'self.begin_epoch should be smaller than or equal to ' - f'runner.max_epochs: {runner.max_epochs}, but got ' - f'begin_epoch: {self.begin_epoch}') - else: - assert self.begin_iter <= runner.max_iters, ( - 'self.begin_iter should be smaller than or equal to ' - f'runner.max_iters: {runner.max_iters}, but got ' - f'begin_iter: {self.begin_iter}') - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[dict] = None) -> None: - """Update ema parameter. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. Defaults to None. - """ - if self._ema_started(runner): - self.ema_model.update_parameters(self.src_model) - else: - ema_params = self.ema_model.module.state_dict() - src_params = self.src_model.state_dict() - for k, p in ema_params.items(): - p.data.copy_(src_params[k].data) - - def before_val_epoch(self, runner) -> None: - """We load parameter values from ema model to source model before - validation. - - Args: - runner (Runner): The runner of the training process. - """ - self._swap_ema_parameters() - - def after_val_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """We recover source model's parameter from ema model after validation. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._swap_ema_parameters() - - def before_test_epoch(self, runner) -> None: - """We load parameter values from ema model to source model before test. - - Args: - runner (Runner): The runner of the training process. - """ - self._swap_ema_parameters() - - def after_test_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """We recover source model's parameter from ema model after test. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._swap_ema_parameters() - - def before_save_checkpoint(self, runner, checkpoint: dict) -> None: - """Save ema parameters to checkpoint. - - Args: - runner (Runner): The runner of the testing process. - """ - checkpoint['ema_state_dict'] = self.ema_model.state_dict() - # Save ema parameters to the source model's state dict so that we - # can directly load the averaged model weights for deployment. - # Swapping the state_dict key-values instead of swapping model - # parameters because the state_dict is a shallow copy of model - # parameters. - self._swap_ema_state_dict(checkpoint) - - def after_load_checkpoint(self, runner, checkpoint: dict) -> None: - """Resume ema parameters from checkpoint. - - Args: - runner (Runner): The runner of the testing process. - """ - from mmengine.runner.checkpoint import load_state_dict - if 'ema_state_dict' in checkpoint and runner._resume: - # The original model parameters are actually saved in ema - # field swap the weights back to resume ema state. - self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict( - checkpoint['ema_state_dict'], strict=self.strict_load) - - # Support load checkpoint without ema state dict. - else: - if runner._resume: - print_log( - 'There is no `ema_state_dict` in checkpoint. ' - '`EMAHook` will make a copy of `state_dict` as the ' - 'initial `ema_state_dict`', 'current', logging.WARNING) - load_state_dict( - self.ema_model.module, - copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) - - def _swap_ema_parameters(self) -> None: - """Swap the parameter of model with ema_model.""" - avg_param = ( - itertools.chain(self.ema_model.module.parameters(), - self.ema_model.module.buffers()) - if self.ema_model.update_buffers else - self.ema_model.module.parameters()) - src_param = ( - itertools.chain(self.src_model.parameters(), - self.src_model.buffers()) - if self.ema_model.update_buffers else self.src_model.parameters()) - for p_avg, p_src in zip(avg_param, src_param): - tmp = p_avg.data.clone() - p_avg.data.copy_(p_src.data) - p_src.data.copy_(tmp) - - def _swap_ema_state_dict(self, checkpoint): - """Swap the state dict values of model with ema_model.""" - model_state = checkpoint['state_dict'] - ema_state = checkpoint['ema_state_dict'] - for k in ema_state: - if k[:7] == 'module.': - tmp = ema_state[k] - ema_state[k] = model_state[k[7:]] - model_state[k[7:]] = tmp - - def _ema_started(self, runner) -> bool: - """Whether ``EMAHook`` has been initialized at current iteration or - epoch. - - :attr:`ema_model` will be initialized when ``runner.iter`` or - ``runner.epoch`` is greater than ``self.begin`` for the first time. - - Args: - runner (Runner): Runner of the training, validation process. - - Returns: - bool: Whether ``EMAHook`` has been initialized. - """ - if self.enabled_by_epoch: - return runner.epoch + 1 >= self.begin_epoch - else: - return runner.iter + 1 >= self.begin_iter diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py deleted file mode 100644 index 9a92cdebfe..0000000000 --- a/mmengine/hooks/empty_cache_hook.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Sequence, Union - -import torch - -from mmengine.registry import HOOKS -from ..device import is_cuda_available, is_musa_available -from .hook import Hook - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -@HOOKS.register_module() -class EmptyCacheHook(Hook): - """Releases all unoccupied cached GPU memory during the process of - training. - - Args: - before_epoch (bool): Whether to release cache before an epoch. Defaults - to False. - after_epoch (bool): Whether to release cache after an epoch. Defaults - to True. - after_iter (bool): Whether to release cache after an iteration. - Defaults to False. - """ - - priority = 'NORMAL' - - def __init__(self, - before_epoch: bool = False, - after_epoch: bool = True, - after_iter: bool = False) -> None: - self._do_before_epoch = before_epoch - self._do_after_epoch = after_epoch - self._do_after_iter = after_iter - - def _after_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Union[dict, Sequence]] = None, - mode: str = 'train') -> None: - """Empty cache after an iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict or sequence, optional): Outputs from model. - mode (str): Current mode of runner. Defaults to 'train'. - """ - if self._do_after_iter: - if is_cuda_available(): - torch.cuda.empty_cache() - elif is_musa_available(): - torch.musa.empty_cache() - - def _before_epoch(self, runner, mode: str = 'train') -> None: - """Empty cache before an epoch. - - Args: - runner (Runner): The runner of the training process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - if self._do_before_epoch: - if is_cuda_available(): - torch.cuda.empty_cache() - elif is_musa_available(): - torch.musa.empty_cache() - - def _after_epoch(self, runner, mode: str = 'train') -> None: - """Empty cache after an epoch. - - Args: - runner (Runner): The runner of the training process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - if self._do_after_epoch: - if is_cuda_available(): - torch.cuda.empty_cache() - elif is_musa_available(): - torch.musa.empty_cache() diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py deleted file mode 100644 index 4e1c4ce8bd..0000000000 --- a/mmengine/hooks/hook.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional, Sequence, Union - -from mmengine import is_method_overridden - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -class Hook: - """Base hook class. - - All hooks should inherit from this class. - """ - - priority = 'NORMAL' - stages = ('before_run', 'after_load_checkpoint', 'before_train', - 'before_train_epoch', 'before_train_iter', 'after_train_iter', - 'after_train_epoch', 'before_val', 'before_val_epoch', - 'before_val_iter', 'after_val_iter', 'after_val_epoch', - 'after_val', 'before_save_checkpoint', 'after_train', - 'before_test', 'before_test_epoch', 'before_test_iter', - 'after_test_iter', 'after_test_epoch', 'after_test', 'after_run') - - def before_run(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before the training validation or testing process. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - """ - - def after_run(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before the training validation or testing process. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - """ - - def before_train(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before train. - - Args: - runner (Runner): The runner of the training process. - """ - - def after_train(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after train. - - Args: - runner (Runner): The runner of the training process. - """ - - def before_val(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before validation. - - Args: - runner (Runner): The runner of the validation process. - """ - - def after_val(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after validation. - - Args: - runner (Runner): The runner of the validation process. - """ - - def before_test(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before testing. - - Args: - runner (Runner): The runner of the testing process. - """ - - def after_test(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after testing. - - Args: - runner (Runner): The runner of the testing process. - """ - - def before_save_checkpoint(self, runner, checkpoint: dict) -> None: - """All subclasses should override this method, if they need any - operations before saving the checkpoint. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - checkpoint (dict): Model's checkpoint. - """ - - def after_load_checkpoint(self, runner, checkpoint: dict) -> None: - """All subclasses should override this method, if they need any - operations after loading the checkpoint. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - checkpoint (dict): Model's checkpoint. - """ - - def before_train_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each training epoch. - - Args: - runner (Runner): The runner of the training process. - """ - self._before_epoch(runner, mode='train') - - def before_val_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - """ - self._before_epoch(runner, mode='val') - - def before_test_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each test epoch. - - Args: - runner (Runner): The runner of the testing process. - """ - self._before_epoch(runner, mode='test') - - def after_train_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after each training epoch. - - Args: - runner (Runner): The runner of the training process. - """ - self._after_epoch(runner, mode='train') - - def after_val_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._after_epoch(runner, mode='val') - - def after_test_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """All subclasses should override this method, if they need any - operations after each test epoch. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._after_epoch(runner, mode='test') - - def before_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') - - def before_val_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each validation iteration. - - Args: - runner (Runner): The runner of the validation process. - batch_idx (int): The index of the current batch in the val loop. - data_batch (dict, optional): Data from dataloader. - Defaults to None. - """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') - - def before_test_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each test iteration. - - Args: - runner (Runner): The runner of the testing process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - Defaults to None. - """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[dict] = None) -> None: - """All subclasses should override this method, if they need any - operations after each training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict tuple or list, optional): Data from dataloader. - outputs (dict, optional): Outputs from model. - """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='train') - - def after_val_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence] = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation iteration. - - Args: - runner (Runner): The runner of the validation process. - batch_idx (int): The index of the current batch in the val loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (Sequence, optional): Outputs from model. - """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='val') - - def after_test_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence] = None) -> None: - """All subclasses should override this method, if they need any - operations after each test iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (Sequence, optional): Outputs from model. - """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='test') - - def _before_epoch(self, runner, mode: str = 'train') -> None: - """All subclasses should override this method, if they need any - operations before each epoch. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def _after_epoch(self, runner, mode: str = 'train') -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def _before_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - mode: str = 'train') -> None: - """All subclasses should override this method, if they need any - operations before each iter. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def _after_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Union[Sequence, dict]] = None, - mode: str = 'train') -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict or Sequence, optional): Outputs from model. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: - """Test whether current epoch can be evenly divided by n. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - n (int): Whether current epoch can be evenly divided by n. - start (int): Starting from `start` to check the logic for - every n epochs. Defaults to 0. - - Returns: - bool: Whether current epoch can be evenly divided by n. - """ - dividend = runner.epoch + 1 - start - return dividend % n == 0 if dividend >= 0 and n > 0 else False - - def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: - """Test whether current inner iteration can be evenly divided by n. - - Args: - batch_idx (int): Current batch index of the training, validation - or testing loop. - n (int): Whether current inner iteration can be evenly - divided by n. - - Returns: - bool: Whether current inner iteration can be evenly - divided by n. - """ - return (batch_idx + 1) % n == 0 if n > 0 else False - - def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: - """Test whether current training iteration can be evenly divided by n. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - n (int): Whether current iteration can be evenly divided by n. - start (int): Starting from `start` to check the logic for - every n iterations. Defaults to 0. - - Returns: - bool: Return True if the current iteration can be evenly divided - by n, otherwise False. - """ - dividend = runner.iter + 1 - start - return dividend % n == 0 if dividend >= 0 and n > 0 else False - - def end_of_epoch(self, dataloader, batch_idx: int) -> bool: - """Check whether the current iteration reaches the last iteration of - the dataloader. - - Args: - dataloader (Dataloader): The dataloader of the training, - validation or testing process. - batch_idx (int): The index of the current batch in the loop. - Returns: - bool: Whether reaches the end of current epoch or not. - """ - return batch_idx + 1 == len(dataloader) - - def is_last_train_epoch(self, runner) -> bool: - """Test whether current epoch is the last train epoch. - - Args: - runner (Runner): The runner of the training process. - - Returns: - bool: Whether reaches the end of training epoch. - """ - return runner.epoch + 1 == runner.max_epochs - - def is_last_train_iter(self, runner) -> bool: - """Test whether current iteration is the last train iteration. - - Args: - runner (Runner): The runner of the training process. - - Returns: - bool: Whether current iteration is the last train iteration. - """ - return runner.iter + 1 == runner.max_iters - - def get_triggered_stages(self) -> list: - """Get all triggered stages with method name of the hook. - - Returns: - list: List of triggered stages. - """ - trigger_stages = set() - for stage in Hook.stages: - if is_method_overridden(stage, Hook, self): - trigger_stages.add(stage) - - # some methods will be triggered in multi stages - # use this dict to map method to stages. - method_stages_map = { - '_before_epoch': - ['before_train_epoch', 'before_val_epoch', 'before_test_epoch'], - '_after_epoch': - ['after_train_epoch', 'after_val_epoch', 'after_test_epoch'], - '_before_iter': - ['before_train_iter', 'before_val_iter', 'before_test_iter'], - '_after_iter': - ['after_train_iter', 'after_val_iter', 'after_test_iter'], - } - - for method, map_stages in method_stages_map.items(): - if is_method_overridden(method, Hook, self): - trigger_stages.update(map_stages) - - return list(trigger_stages) diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py deleted file mode 100644 index 5632c2b25e..0000000000 --- a/mmengine/hooks/iter_timer_hook.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import time -from typing import Optional, Sequence, Union - -from mmengine.registry import HOOKS -from .hook import Hook - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -@HOOKS.register_module() -class IterTimerHook(Hook): - """A hook that logs the time spent during iteration. - - E.g. ``data_time`` for loading data and ``time`` for a model train step. - """ - - priority = 'NORMAL' - - def __init__(self): - self.time_sec_tot = 0 - self.time_sec_test_val = 0 - self.start_iter = 0 - - def before_train(self, runner) -> None: - """Synchronize the number of iterations with the runner after resuming - from checkpoints. - - Args: - runner: The runner of the training, validation or testing - process. - """ - self.start_iter = runner.iter - - def _before_epoch(self, runner, mode: str = 'train') -> None: - """Record timestamp before start an epoch. - - Args: - runner (Runner): The runner of the training validation and - testing process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - self.t = time.time() - - def _after_epoch(self, runner, mode: str = 'train') -> None: - self.time_sec_test_val = 0 - - def _before_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - mode: str = 'train') -> None: - """Calculating time for loading data and updating "data_time" - ``HistoryBuffer`` of ``runner.message_hub``. - - Args: - runner (Runner): The runner of the training, validation and - testing process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from - dataloader. - mode (str): Current mode of runner. Defaults to 'train'. - """ - # Update data loading time in `runner.message_hub`. - runner.message_hub.update_scalar(f'{mode}/data_time', - time.time() - self.t) - - def _after_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Union[dict, Sequence]] = None, - mode: str = 'train') -> None: - """Calculating time for an iteration and updating "time" - ``HistoryBuffer`` of ``runner.message_hub``. - - Args: - runner (Runner): The runner of the training validation and - testing process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict or sequence, optional): Outputs from model. - mode (str): Current mode of runner. Defaults to 'train'. - """ - # Update iteration time in `runner.message_hub`. - message_hub = runner.message_hub - message_hub.update_scalar(f'{mode}/time', time.time() - self.t) - self.t = time.time() - iter_time = message_hub.get_scalar(f'{mode}/time') - if mode == 'train': - self.time_sec_tot += iter_time.current() - # Calculate average iterative time. - time_sec_avg = self.time_sec_tot / ( - runner.iter - self.start_iter + 1) - # Calculate eta. - eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) - runner.message_hub.update_info('eta', eta_sec) - else: - if mode == 'val': - cur_dataloader = runner.val_dataloader - else: - cur_dataloader = runner.test_dataloader - - self.time_sec_test_val += iter_time.current() - time_sec_avg = self.time_sec_test_val / (batch_idx + 1) - eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1) - runner.message_hub.update_info('eta', eta_sec) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py deleted file mode 100644 index fa0b79dcf9..0000000000 --- a/mmengine/hooks/logger_hook.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import os -import os.path as osp -from collections import OrderedDict -from pathlib import Path -from typing import Dict, Optional, Sequence, Union - -import numpy as np -import torch - -from mmengine.fileio import FileClient, dump -from mmengine.fileio.io import get_file_backend -from mmengine.hooks import Hook -from mmengine.logging import print_log -from mmengine.registry import HOOKS -from mmengine.utils import is_seq_of, scandir - -DATA_BATCH = Optional[Union[dict, tuple, list]] -SUFFIX_TYPE = Union[Sequence[str], str] - - -@HOOKS.register_module() -class LoggerHook(Hook): - """Collect logs from different components of ``Runner`` and write them to - terminal, JSON file, tensorboard and wandb .etc. - - ``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during - training/validation/testing phase. It is used to control following - behaviors: - - - The frequency of logs update in terminal, local, tensorboad wandb.etc. - - The frequency of show experiment information in terminal. - - The work directory to save logs. - - Args: - interval (int): Logging interval (every k iterations). - Defaults to 10. - ignore_last (bool): Ignore the log of last iterations in each epoch if - the number of remaining iterations is less than :attr:`interval`. - Defaults to True. - interval_exp_name (int): Logging interval for experiment name. This - feature is to help users conveniently get the experiment - information from screen or log file. Defaults to 1000. - out_dir (str or Path, optional): The root directory to save - checkpoints. If not specified, ``runner.work_dir`` will be used - by default. If specified, the ``out_dir`` will be the concatenation - of ``out_dir`` and the last level directory of ``runner.work_dir``. - For example, if the input ``out_dir`` is ``./tmp`` and - ``runner.work_dir`` is ``./work_dir/cur_exp``, then the log will be - saved in ``./tmp/cur_exp``. Defaults to None. - out_suffix (Tuple[str] or str): Those files in ``runner._log_dir`` - ending with ``out_suffix`` will be copied to ``out_dir``. Defaults - to ('json', '.log', '.py'). - keep_local (bool): Whether to keep local logs in the local machine - when :attr:`out_dir` is specified. If False, the local log will be - removed. Defaults to True. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - `backend_args` instead. - log_metric_by_epoch (bool): Whether to output metric in validation step - by epoch. It can be true when running in epoch based runner. - If set to True, `after_val_epoch` will set `step` to self.epoch in - `runner.visualizer.add_scalars`. Otherwise `step` will be - self.iter. Defaults to True. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> # The simplest LoggerHook config. - >>> logger_hook_cfg = dict(interval=20) - """ - priority = 'BELOW_NORMAL' - - def __init__(self, - interval: int = 10, - ignore_last: bool = True, - interval_exp_name: int = 1000, - out_dir: Optional[Union[str, Path]] = None, - out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'), - keep_local: bool = True, - file_client_args: Optional[dict] = None, - log_metric_by_epoch: bool = True, - backend_args: Optional[dict] = None): - - if not isinstance(interval, int): - raise TypeError('interval must be an integer') - if interval <= 0: - raise ValueError('interval must be greater than 0') - - if not isinstance(ignore_last, bool): - raise TypeError('ignore_last must be a boolean') - - if not isinstance(interval_exp_name, int): - raise TypeError('interval_exp_name must be an integer') - if interval_exp_name <= 0: - raise ValueError('interval_exp_name must be greater than 0') - - if out_dir is not None and not isinstance(out_dir, (str, Path)): - raise TypeError('out_dir must be a str or Path object') - - if not isinstance(keep_local, bool): - raise TypeError('keep_local must be a boolean') - - if out_dir is None and file_client_args is not None: - raise ValueError( - 'file_client_args should be "None" when `out_dir` is not' - 'specified.') - - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', - logger='current', - level=logging.WARNING) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set ' - 'at the same time.') - - if not (isinstance(out_suffix, str) or is_seq_of(out_suffix, str)): - raise TypeError('out_suffix should be a string or a sequence of ' - f'string, but got {type(out_suffix)}') - - self.out_suffix = out_suffix - self.out_dir = out_dir - self.interval = interval - self.ignore_last = ignore_last - self.interval_exp_name = interval_exp_name - self.keep_local = keep_local - self.file_client_args = file_client_args - self.json_log_path: Optional[str] = None - - if self.out_dir is not None: - self.file_client = FileClient.infer_client(file_client_args, - self.out_dir) - if file_client_args is None: - self.file_backend = get_file_backend( - self.out_dir, backend_args=backend_args) - else: - self.file_backend = self.file_client - - self.log_metric_by_epoch = log_metric_by_epoch - - def before_run(self, runner) -> None: - """Infer ``self.file_client`` from ``self.out_dir``. Initialize the - ``self.start_iter`` and record the meta information. - - Args: - runner (Runner): The runner of the training process. - """ - if self.out_dir is not None: - # The final `self.out_dir` is the concatenation of `self.out_dir` - # and the last level directory of `runner.work_dir` - basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_backend.join_path(self.out_dir, basename) - runner.logger.info( - f'Text logs will be saved to {self.out_dir} after the ' - 'training process.') - - self.json_log_path = f'{runner.timestamp}.json' - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[dict] = None) -> None: - """Record logs after training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict tuple or list, optional): Data from dataloader. - outputs (dict, optional): Outputs from model. - """ - # Print experiment name every n iterations. - if self.every_n_train_iters( - runner, self.interval_exp_name) or (self.end_of_epoch( - runner.train_dataloader, batch_idx)): - exp_info = f'Exp name: {runner.experiment_name}' - runner.logger.info(exp_info) - if self.every_n_inner_iters(batch_idx, self.interval): - tag, log_str = runner.log_processor.get_log_after_iter( - runner, batch_idx, 'train') - elif (self.end_of_epoch(runner.train_dataloader, batch_idx) - and (not self.ignore_last - or len(runner.train_dataloader) <= self.interval)): - # `runner.max_iters` may not be divisible by `self.interval`. if - # `self.ignore_last==True`, the log of remaining iterations will - # be recorded (Epoch [4][1000/1007], the logs of 998-1007 - # iterations will be recorded). - tag, log_str = runner.log_processor.get_log_after_iter( - runner, batch_idx, 'train') - else: - return - runner.logger.info(log_str) - runner.visualizer.add_scalars( - tag, step=runner.iter + 1, file_path=self.json_log_path) - - def after_val_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence] = None) -> None: - """Record logs after validation iteration. - - Args: - runner (Runner): The runner of the validation process. - batch_idx (int): The index of the current batch in the validation - loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - Defaults to None. - outputs (sequence, optional): Outputs from model. - """ - if self.every_n_inner_iters(batch_idx, self.interval): - _, log_str = runner.log_processor.get_log_after_iter( - runner, batch_idx, 'val') - runner.logger.info(log_str) - - def after_test_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence] = None) -> None: - """Record logs after testing iteration. - - Args: - runner (Runner): The runner of the testing process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (sequence, optional): Outputs from model. - """ - if self.every_n_inner_iters(batch_idx, self.interval): - _, log_str = runner.log_processor.get_log_after_iter( - runner, batch_idx, 'test') - runner.logger.info(log_str) - - def after_val_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - tag, log_str = runner.log_processor.get_log_after_epoch( - runner, len(runner.val_dataloader), 'val') - runner.logger.info(log_str) - if self.log_metric_by_epoch: - # Accessing the epoch attribute of the runner will trigger - # the construction of the train_loop. Therefore, to avoid - # triggering the construction of the train_loop during - # validation, check before accessing the epoch. - if (isinstance(runner._train_loop, dict) - or runner._train_loop is None): - epoch = 0 - else: - epoch = runner.epoch - runner.visualizer.add_scalars( - tag, step=epoch, file_path=self.json_log_path) - else: - if (isinstance(runner._train_loop, dict) - or runner._train_loop is None): - iter = 0 - else: - iter = runner.iter - runner.visualizer.add_scalars( - tag, step=iter, file_path=self.json_log_path) - - def after_test_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """All subclasses should override this method, if they need any - operations after each test epoch. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - tag, log_str = runner.log_processor.get_log_after_epoch( - runner, len(runner.test_dataloader), 'test', with_non_scalar=True) - runner.logger.info(log_str) - dump( - self._process_tags(tag), - osp.join(runner.log_dir, self.json_log_path)) # type: ignore - - @staticmethod - def _process_tags(tags: dict): - """Convert tag values to json-friendly type.""" - - def process_val(value): - if isinstance(value, (list, tuple)): - # Array type of json - return [process_val(item) for item in value] - elif isinstance(value, dict): - # Object type of json - return {k: process_val(v) for k, v in value.items()} - elif isinstance(value, (str, int, float, bool)) or value is None: - # Other supported type of json - return value - elif isinstance(value, (torch.Tensor, np.ndarray)): - return value.tolist() - # Drop unsupported values. - - processed_tags = OrderedDict(process_val(tags)) - - return processed_tags - - def after_run(self, runner) -> None: - """Copy logs to ``self.out_dir`` if ``self.out_dir is not None`` - - Args: - runner (Runner): The runner of the training/testing/validation - process. - """ - # close the visualizer - runner.visualizer.close() - - # copy or upload logs to self.out_dir - if self.out_dir is None: - return - - removed_files = [] - for filename in scandir(runner._log_dir, self.out_suffix, True): - local_filepath = osp.join(runner._log_dir, filename) - removed_files.append(local_filepath) - out_filepath = self.file_backend.join_path(self.out_dir, filename) - with open(local_filepath) as f: - self.file_backend.put_text(f.read(), out_filepath) - - runner.logger.info( - f'The file {local_filepath} has been uploaded to ' - f'{out_filepath}.') - - if not self.keep_local: - runner.logger.info(f'{local_filepath} was removed due to the ' - '`self.keep_local=False`. You can check ' - f'the running logs in {out_filepath}') - - if not self.keep_local: - # Close file handler to avoid PermissionError on Windows. - for handler in runner.logger.handlers: - if isinstance(handler, logging.FileHandler): - handler.close() - - for file in removed_files: - os.remove(file) diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py deleted file mode 100644 index fcb803a20f..0000000000 --- a/mmengine/hooks/naive_visualization_hook.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -from typing import Optional, Sequence, Tuple, Union - -import cv2 -import numpy as np - -from mmengine.hooks import Hook -from mmengine.registry import HOOKS -from mmengine.utils.dl_utils import tensor2imgs - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -# TODO: Due to interface changes, the current class -# functions incorrectly -@HOOKS.register_module() -class NaiveVisualizationHook(Hook): - """Show or Write the predicted results during the process of testing. - - Args: - interval (int): Visualization interval. Defaults to 1. - draw_gt (bool): Whether to draw the ground truth. Defaults to True. - draw_pred (bool): Whether to draw the predicted result. - Defaults to True. - """ - priority = 'NORMAL' - - def __init__(self, - interval: int = 1, - draw_gt: bool = True, - draw_pred: bool = True): - self.draw_gt = draw_gt - self.draw_pred = draw_pred - self._interval = interval - - def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int, - int]) -> np.ndarray: - """Unpad the input image. - - Args: - input (np.ndarray): The image to unpad. - unpad_shape (tuple): The shape of image before padding. - - Returns: - np.ndarray: The image before padding. - """ - unpad_width, unpad_height = unpad_shape - unpad_image = input[:unpad_height, :unpad_width] - return unpad_image - - def before_train(self, runner) -> None: - """Call add_graph method of visualizer. - - Args: - runner (Runner): The runner of the training process. - """ - runner.visualizer.add_graph(runner.model, None) - - def after_test_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence] = None) -> None: - """Show or Write the predicted results. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (Sequence, optional): Outputs from model. - """ - if self.every_n_inner_iters(batch_idx, self._interval): - for data, output in zip(data_batch, outputs): # type: ignore - input = data['inputs'] - data_sample = data['data_sample'] - input = tensor2imgs(input, - **data_sample.get('img_norm_cfg', - dict()))[0] - # TODO We will implement a function to revert the augmentation - # in the future. - ori_shape = (data_sample.ori_width, data_sample.ori_height) - if 'pad_shape' in data_sample: - input = self._unpad(input, - data_sample.get('scale', ori_shape)) - origin_image = cv2.resize(input, ori_shape) - name = osp.basename(data_sample.img_path) - runner.visualizer.add_datasample(name, origin_image, - data_sample, output, - self.draw_gt, self.draw_pred) diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py deleted file mode 100644 index 3b2f1e610a..0000000000 --- a/mmengine/hooks/param_scheduler_hook.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional, Union - -from mmengine.optim import _ParamScheduler -from mmengine.registry import HOOKS -from mmengine.utils import is_list_of -from .hook import Hook - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -@HOOKS.register_module() -class ParamSchedulerHook(Hook): - """A hook to update some hyper-parameters in optimizer, e.g., learning rate - and momentum.""" - - priority = 'LOW' - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[dict] = None) -> None: - """Call step function for each scheduler after each training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - In order to keep this interface consistent with other hooks, - we keep ``data_batch`` here. - outputs (dict, optional): Outputs from model. - In order to keep this interface consistent with other hooks, we - keep ``data_batch`` here. - """ - - if runner.param_schedulers is None: - return - - def step(param_schedulers): - assert isinstance(param_schedulers, list) - for scheduler in param_schedulers: - if not scheduler.by_epoch: - scheduler.step() - - if isinstance(runner.param_schedulers, list): - step(runner.param_schedulers) - elif isinstance(runner.param_schedulers, dict): - for param_schedulers in runner.param_schedulers.values(): - step(param_schedulers) - else: - raise TypeError( - 'runner.param_schedulers should be list of ParamScheduler or ' - 'a dict containing list of ParamScheduler, ' - f'but got {runner.param_schedulers}') - - def after_train_epoch(self, runner) -> None: - """Call step function for each scheduler after each training epoch. - - Args: - runner (Runner): The runner of the training process. - """ - - if runner.param_schedulers is None: - return - - def step(param_schedulers): - assert isinstance(param_schedulers, list) - for scheduler in param_schedulers: - if scheduler.by_epoch: - scheduler.step() - - if isinstance(runner.param_schedulers, list): - step(runner.param_schedulers) - elif isinstance(runner.param_schedulers, dict): - for param_schedulers in runner.param_schedulers.values(): - step(param_schedulers) - else: - raise TypeError( - 'runner.param_schedulers should be list of ParamScheduler or ' - 'a dict containing list of ParamScheduler, ' - f'but got {runner.param_schedulers}') - - def after_val_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """Call step function for each scheduler which has attribute - ``need_val_args`` after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - - Note: - if ``runner.param_schedulers`` is not built before, - the hook ``after_val_epoch`` will be skipped. - """ - - if runner.param_schedulers is None: - return - - # avoid counting scheduler._global_step - # it has counted in after_train_* hook - if metrics is None: - return - - def step(param_schedulers): - # check param_schedulers is list and built - if not is_list_of(param_schedulers, _ParamScheduler): - return - - for scheduler in param_schedulers: - if (scheduler.by_epoch - and getattr(scheduler, 'need_val_args', False)): - scheduler.step(metrics) - - if isinstance(runner.param_schedulers, list): - step(runner.param_schedulers) - elif isinstance(runner.param_schedulers, dict): - for param_schedulers in runner.param_schedulers.values(): - step(param_schedulers) - else: - raise TypeError( - 'runner.param_schedulers should be list of ParamScheduler or ' - 'a dict containing list of ParamScheduler, ' - f'but got {runner.param_schedulers}') diff --git a/mmengine/hooks/profiler_hook.py b/mmengine/hooks/profiler_hook.py deleted file mode 100644 index dae84b85f5..0000000000 --- a/mmengine/hooks/profiler_hook.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import os -import os.path as osp -import sys -from typing import Callable, Optional, Union - -import torch - -from mmengine.dist import master_only -from mmengine.hooks import Hook -from mmengine.logging import print_log -from mmengine.registry import HOOKS - - -def check_kineto() -> bool: # noqa - kineto_exist = False - try: - if torch.autograd.kineto_available(): - kineto_exist = True - except AttributeError: - print_log('NO KINETO', logger='current', level=logging.WARNING) - return kineto_exist - - -@HOOKS.register_module() -class ProfilerHook(Hook): - """A hook to analyze performance during training and inference. - - PyTorch Profiler is a tool that allows the collection of the performance - metrics during the training. More details on Profiler can be found at - `official docs `_ - - Args: - by_epoch (bool): Profile performance by epoch or by iteration. - Defaults to True. - profile_times (int): The period (epoch/iter) recorded by the profiler. - Defaults to 1. For example, profile_iters=10 and by_epoch=False, - indicate that 0-10 iterations are recorded. - activity_with_cpu (bool): Activities to be used in the analysis (CPU) - activity_with_cuda (bool): Activities to be used in the analysis (CUDA) - schedule (dict, optional): Key-word arguments passed to - `torch.profile.schedule `_. - Defaults to None, which means profiling without a schedule - on_trace_ready (callable, dict, optional): Either a handler or a dict - of generating handler. Defaults to None, which means profiling - without an on_trace_ready.The Callable type needs to construct its - own function that can handle 'torch.autograd.profiler.profile'. - Two officially recommended ways are provided: - - - ``schedule=dict(type='log_trace')``: Print the profiling result - in the terminal. See more details in the `PyTorch official tutorial`_. - The configurable arguments are the same as - ``prof.key_averages().table`` - - ``scheduler=dict(type='tb_trace')``: Profile the performance - with tensorboard. See more details in the tutorial - `profile with tensorboard`_. - - record_shapes (bool): Save information about operator's input shapes. - Defaults to False. - profile_memory (bool): Track tensor memory allocation/deallocation. - Defaults to False. - with_stack (bool): Record source information (file and line number) - for the ops. Defaults to False. - with_flops (bool): Use formula to estimate the FLOPS of specific - operators (matrix multiplication and 2D convolution). - Defaults to False. - json_trace_path (str, optional): Exports the collected trace in Chrome - JSON format. Chrome use 'chrome://tracing' view json file. - Defaults to None, which means profiling does not store json files. - - Warnings: - The profiler will be closed after ``profile_times`` iterations - automatically. Please make sure the configuration of your scheduler - will not close the profiler before the iteration reach the value of - ``profile_times`` - - Examples: - >>> # tensorboard trace - >>> trace_config = dict(type='tb_trace') - >>> profiler_hook_cfg = dict(on_trace_ready=trace_config) - - .. _PyTorch official tutorial: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-execution-time - .. _profile with tensorboard: https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard - """ # noqa: E501 - priority = 'VERY_LOW' - - def __init__(self, - *, - by_epoch: bool = True, - profile_times: int = 1, - activity_with_cpu: bool = True, - activity_with_cuda: bool = False, - schedule: Optional[dict] = None, - on_trace_ready: Union[Callable, dict, None] = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - json_trace_path: Optional[str] = None) -> None: - - try: - from torch import profiler - except ImportError: - raise ImportError('please upgrade torch above 1.8.1') - if not check_kineto(): - raise ImportError('Due to Kineto support issues, please upgrade ' - 'pytorch above 1.8.1(windows users above 1.9.1)') - - assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.' - self.by_epoch = by_epoch - - if profile_times < 1: - raise ValueError('profile_iters should be greater than 0, ' - f'but got {profile_times}') - if by_epoch and profile_times > 1: - raise ValueError( - f'Profiler will profile 0-{profile_times} epochs.\n' - 'Since profiler will slow down the training, it is recommended' - ' to train 1 epoch with ProfilerHook and adjust your setting ' - 'according to the profiler summary.\n' - 'During normal training(epoch > 1), ' - 'you may disable the ProfilerHook.') - self.profile_times = profile_times - - assert isinstance(activity_with_cpu, bool), \ - '``activity_with_cpu`` should be a boolean.' - assert isinstance(activity_with_cuda, bool), \ - '``activity_with_cuda`` should be a boolean.' - self.activities = [] - if activity_with_cpu: - self.activities.append(profiler.ProfilerActivity.CPU) - if activity_with_cuda: - self.activities.append(profiler.ProfilerActivity.CUDA) - - if schedule is not None: - assert isinstance(schedule, dict), '``schedule`` should be a dict.' - self.schedule = profiler.schedule(**schedule) - else: - self.schedule = None - - self.on_trace_ready = on_trace_ready - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.with_stack = with_stack - self.with_flops = with_flops - - self.json_trace_path = json_trace_path - self._closed = False - - def before_run(self, runner): - """Initialize the profiler. - - Through the runner parameter, the validity of the parameter is further - determined. - """ - max_times = runner.max_epochs if self.by_epoch else runner.max_iters - if max_times < self.profile_times: - raise ValueError( - f'``profile_times`` should not be greater than {max_times}') - - on_trace_ready = self._parse_trace_config(runner) - - self.profiler = torch.profiler.profile( # noqa - activities=self.activities, - schedule=self.schedule, - on_trace_ready=on_trace_ready, - record_shapes=self.record_shapes, - profile_memory=self.profile_memory, - with_stack=self.with_stack, - with_flops=self.with_flops) - - self.profiler.__enter__() - runner.logger.info('profiler is profiling...') - - def _parse_trace_config(self, runner): - """Used to parse the parameter 'on_trace_ready'.""" - if self.on_trace_ready is None: - _on_trace_ready = None - elif callable(self.on_trace_ready): - _on_trace_ready = self.on_trace_ready - elif isinstance(self.on_trace_ready, dict): - trace_cfg = self.on_trace_ready.copy() - trace_type = trace_cfg.pop('type') - - # Build a log printing handle - if trace_type == 'log_trace': - - def _log_handler(_profile): - print(_profile.key_averages().table(**trace_cfg)) - - _on_trace_ready = _log_handler - - elif trace_type == 'tb_trace': # tensorboard_trace handler - try: - import torch_tb_profiler # noqa: F401 - except ImportError: - raise ImportError( - 'please run ``pip install torch-tb-profiler``') - - if 'dir_name' not in trace_cfg: - trace_cfg['dir_name'] = osp.join(runner.log_dir, - 'tf_tracing_logs') - elif not osp.isabs(trace_cfg['dir_name']): - trace_cfg['dir_name'] = osp.join(runner.log_dir, - trace_cfg['dir_name']) - runner.logger.info('trace_files of ProfilerHook will be ' - f'saved to {trace_cfg["dir_name"]}.') - - if self.json_trace_path is not None: - runner.logger.warn( - 'When using tensorboard_trace, it is recommended to ' - 'save json files by setting ``worker_name`` instead of' - ' setting ``json_trace_path``') - _on_trace_ready = torch.profiler.tensorboard_trace_handler( - **trace_cfg) - else: - raise ValueError('trace_type should be "log_trace" or ' - f'"tb_trace", but got {trace_type}') - else: - raise ValueError( - '``on_trace_ready`` should be a handler, or dict, or None, ' - f'but got {self.on_trace_ready}') - return _on_trace_ready - - def after_train_epoch(self, runner): - """Determine if the content is exported.""" - # `after_train_epoch` will also be called in IterBasedTrainLoop. - # Here we check `self._closed` to avoid exiting twice. - if not self._closed: - self._export_chrome_trace(runner) - - def after_train_iter(self, runner, batch_idx, data_batch, outputs): - """Profiler will call `step` method if it is not closed.""" - if not self._closed: - self.profiler.step() - if runner.iter == self.profile_times - 1 and not self.by_epoch: - self._export_chrome_trace(runner) - - def _export_chrome_trace(self, runner): - """Exporting content.""" - self._closed = True - runner.logger.info('profiler may take a few minutes...') - self.profiler.__exit__(None, None, None) - if self.json_trace_path is not None: - self.profiler.export_chrome_trace(self.json_trace_path) - - -@HOOKS.register_module() -class NPUProfilerHook(Hook): - """NPUProfiler to analyze performance during training. - - NPU Profiling is used to count the device execution time of all operators. - The torch_npu.npu.profile interface is used to complete the profiling data - collection at each stage of the project, and the data is analyzed by the - msprof tool and the data can be dumped to further manually analyze the - key performance bottlenecks. For more details on the torch_npu.npu.profile - interface, please visit - https://gitee.com/ascend/pytorch/blob/master/torch_npu/npu/profiler.py#profile - - Args: - begin (int): Number of start iterations for profiling. Defaults to 0. - end (int): Number of end iterations for profiling. Defaults to 1. - result_path (str): The path to save the profiling results file. - Defaults to 'cann_profiling'. - exit_after_profiling (bool): Whether to exit the program after - profiling. Defaults to True. - use_e2e_profiler (bool): Turn on E2E profiling, E2E profiling combines - performance data at the Pytorch level and the NPU level to analyze - the bottlenecks of model performance end-to-end, and cannot show - detailed content, and only as an auxiliary analysis. - Defaults to False. - ge_profiling_to_std_out (bool): Turn on GE profiling, GE uses to - collect the profiling data of the host side scheduling of the - Assend device. Defaults to False. - - Examples: - >>> cfg = ... - >>> profiler_config = dict(type='NPUProfilerHook', end=2) - >>> cfg.merge_from_dict({'custom_hooks': custom_hooks}) - >>> runner = Runner.from_cfg(cfg) - >>> runner.train() - """ - priority = 'VERY_LOW' - - def __init__(self, - *, - begin: int = 0, - end: int = 1, - result_path: str = 'cann_profiling', - exit_after_profiling: bool = True, - use_e2e_profiler: bool = False, - ge_profiling_to_std_out: bool = False): - - try: - import torch_npu - except ImportError: - raise ImportError('Failed to import torch_npu module') - - if begin >= end: - raise ValueError( - 'The iteration to start profiling should not be greater' - 'than or equal to profile end') - - self.begin = begin - self.end = end - self.result_path = result_path - self.exit_after_profiling = exit_after_profiling - - if ge_profiling_to_std_out: - os.environ['GE_PROFILING_TO_STD_OUT'] = '1' - - if not osp.exists(self.result_path): - os.makedirs(self.result_path, exist_ok=True) - - self.profiler = torch_npu.npu.profile( - self.result_path, use_e2e_profiler=use_e2e_profiler) - - @master_only - def before_run(self, runner): - - if self.end > runner.max_iters: - raise ValueError( - 'The profiling end iteration should not be greater' - 'than the max iteration') - - @master_only - def before_train_iter(self, runner, batch_idx, data_batch=None): - - if runner.iter == self.begin: - self.profiler.__enter__() - runner.logger.info('NPUProfiler starts profiling...') - - @master_only - def after_train_iter(self, - runner, - batch_idx, - data_batch=None, - outputs=None): - - if runner.iter == self.end - 1: - runner.logger.info('profiler may take a few minutes to' - ' save the profiling result.') - self.profiler.__exit__(None, None, None) - if self.exit_after_profiling: - sys.exit() diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py deleted file mode 100644 index 49407e4563..0000000000 --- a/mmengine/hooks/runtime_info_hook.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Optional, Union - -import numpy as np -import torch - -from mmengine.registry import HOOKS -from mmengine.utils import get_git_hash -from mmengine.version import __version__ -from .hook import Hook - -DATA_BATCH = Optional[Union[dict, tuple, list]] - - -def _is_scalar(value: Any) -> bool: - """Determine the value is a scalar type value. - - Args: - value (Any): value of log. - - Returns: - bool: whether the value is a scalar type value. - """ - if isinstance(value, np.ndarray): - return value.size == 1 - elif isinstance(value, (int, float, np.number)): - return True - elif isinstance(value, torch.Tensor): - return value.numel() == 1 - return False - - -@HOOKS.register_module() -class RuntimeInfoHook(Hook): - """A hook that updates runtime information into message hub. - - E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the - training state. Components that cannot access the runner can get runtime - information through the message hub. - """ - - priority = 'VERY_HIGH' - - def before_run(self, runner) -> None: - """Update metainfo. - - Args: - runner (Runner): The runner of the training process. - """ - metainfo = dict( - cfg=runner.cfg.pretty_text, - seed=runner.seed, - experiment_name=runner.experiment_name, - mmengine_version=__version__ + get_git_hash()) - runner.message_hub.update_info_dict(metainfo) - - self.last_loop_stage = None - - def before_train(self, runner) -> None: - """Update resumed training state. - - Args: - runner (Runner): The runner of the training process. - """ - runner.message_hub.update_info('loop_stage', 'train') - runner.message_hub.update_info('epoch', runner.epoch) - runner.message_hub.update_info('iter', runner.iter) - runner.message_hub.update_info('max_epochs', runner.max_epochs) - runner.message_hub.update_info('max_iters', runner.max_iters) - if hasattr(runner.train_dataloader.dataset, 'metainfo'): - runner.message_hub.update_info( - 'dataset_meta', runner.train_dataloader.dataset.metainfo) - - def after_train(self, runner) -> None: - runner.message_hub.pop_info('loop_stage') - - def before_train_epoch(self, runner) -> None: - """Update current epoch information before every epoch. - - Args: - runner (Runner): The runner of the training process. - """ - runner.message_hub.update_info('epoch', runner.epoch) - - def before_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None) -> None: - """Update current iter and learning rate information before every - iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - """ - runner.message_hub.update_info('iter', runner.iter) - lr_dict = runner.optim_wrapper.get_lr() - assert isinstance(lr_dict, dict), ( - '`runner.optim_wrapper.get_lr()` should return a dict ' - 'of learning rate when training with OptimWrapper(single ' - 'optimizer) or OptimWrapperDict(multiple optimizer), ' - f'but got {type(lr_dict)} please check your optimizer ' - 'constructor return an `OptimWrapper` or `OptimWrapperDict` ' - 'instance') - for name, lr in lr_dict.items(): - runner.message_hub.update_scalar(f'train/{name}', lr[0]) - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[dict] = None) -> None: - """Update ``log_vars`` in model outputs every iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. Defaults to None. - """ - if outputs is not None: - for key, value in outputs.items(): - runner.message_hub.update_scalar(f'train/{key}', value) - - def before_val(self, runner) -> None: - self.last_loop_stage = runner.message_hub.get_info('loop_stage') - runner.message_hub.update_info('loop_stage', 'val') - - def after_val_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - if metrics is not None: - for key, value in metrics.items(): - if _is_scalar(value): - runner.message_hub.update_scalar(f'val/{key}', value) - else: - runner.message_hub.update_info(f'val/{key}', value) - - def after_val(self, runner) -> None: - # ValLoop may be called within the TrainLoop, so we need to reset - # the loop_stage - # workflow: before_train -> before_val -> after_val -> after_train - if self.last_loop_stage == 'train': - runner.message_hub.update_info('loop_stage', self.last_loop_stage) - self.last_loop_stage = None - else: - runner.message_hub.pop_info('loop_stage') - - def before_test(self, runner) -> None: - runner.message_hub.update_info('loop_stage', 'test') - - def after_test(self, runner) -> None: - runner.message_hub.pop_info('loop_stage') - - def after_test_epoch(self, - runner, - metrics: Optional[Dict[str, float]] = None) -> None: - """All subclasses should override this method, if they need any - operations after each test epoch. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - if metrics is not None: - for key, value in metrics.items(): - if _is_scalar(value): - runner.message_hub.update_scalar(f'test/{key}', value) - else: - runner.message_hub.update_info(f'test/{key}', value) diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py deleted file mode 100644 index 9aed9dbcf5..0000000000 --- a/mmengine/hooks/sampler_seed_hook.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.registry import HOOKS -from .hook import Hook - - -@HOOKS.register_module() -class DistSamplerSeedHook(Hook): - """Data-loading sampler for distributed training. - - When distributed training, it is only useful in conjunction with - :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same - purpose with :obj:`IterLoader`. - """ - - priority = 'NORMAL' - - def before_train_epoch(self, runner) -> None: - """Set the seed for sampler and batch_sampler. - - Args: - runner (Runner): The runner of the training process. - """ - if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr( - runner.train_loop.dataloader.sampler, 'set_epoch'): - # In case the` _SingleProcessDataLoaderIter` has no sampler, - # or data loader uses `SequentialSampler` in Pytorch. - runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) - - elif hasattr(runner.train_loop.dataloader, - 'batch_sampler') and hasattr( - runner.train_loop.dataloader.batch_sampler.sampler, - 'set_epoch'): - # In case the` _SingleProcessDataLoaderIter` has no batch sampler. - # batch sampler in pytorch warps the sampler as its attributes. - runner.train_loop.dataloader.batch_sampler.sampler.set_epoch( - runner.epoch) diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py deleted file mode 100644 index 7cc75757fe..0000000000 --- a/mmengine/hooks/sync_buffer_hook.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.dist import all_reduce_params, is_distributed -from mmengine.registry import HOOKS -from .hook import Hook - - -@HOOKS.register_module() -class SyncBuffersHook(Hook): - """Synchronize model buffers such as running_mean and running_var in BN at - the end of each epoch.""" - - priority = 'NORMAL' - - def __init__(self) -> None: - self.distributed = is_distributed() - # A flag to mark whether synchronization has been done in - # after_train_epoch - self.called_in_train = False - - def before_val_epoch(self, runner) -> None: - """All-reduce model buffers before each validation epoch. - - Synchronize the buffers before each validation if they have not been - synchronized at the end of the previous training epoch. This method - will be called when using IterBasedTrainLoop. - - Args: - runner (Runner): The runner of the training process. - """ - if self.distributed: - if not self.called_in_train: - all_reduce_params(runner.model.buffers(), op='mean') - self.called_in_train = False - - def after_train_epoch(self, runner) -> None: - """All-reduce model buffers at the end of each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - if self.distributed: - all_reduce_params(runner.model.buffers(), op='mean') - self.called_in_train = True diff --git a/mmengine/hooks/test_time_aug_hook.py b/mmengine/hooks/test_time_aug_hook.py deleted file mode 100644 index 5775736d1f..0000000000 --- a/mmengine/hooks/test_time_aug_hook.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from mmengine.runner import Runner - -from mmengine.hooks import Hook -from mmengine.registry import HOOKS, MODELS, RUNNERS - - -@HOOKS.register_module() -class PrepareTTAHook(Hook): - """Wraps `runner.model` with subclass of :class:`BaseTTAModel` in - `before_test`. - - Note: - This function will only be used with :obj:`MMFullyShardedDataParallel`. - - Args: - tta_cfg (dict): Config dictionary of the test time augmentation model. - """ - - def __init__(self, tta_cfg: dict): - self.tta_cfg = tta_cfg - - def before_test(self, runner: 'Runner') -> None: - """Wraps `runner.model` with the subclass of :class:`BaseTTAModel`. - - Args: - runner (Runner): The runner of the testing process. - """ - self.tta_cfg['module'] = runner.model # type: ignore - model = MODELS.build(self.tta_cfg) - runner.model = model # type: ignore - - -def build_runner_with_tta(cfg: dict) -> 'Runner': - """Builds runner with tta (test time augmentation) transformation and - TTAModel. - - Note: - This function will only be used with :obj:`MMFullyShardedDataParallel`. - - Args: - cfg (dict): cfg with ``tta_pipeline`` and ``tta_model`` - - Notes: - This is only an experimental feature. We may refactor the code in the - future. - - Returns: - Runner: Runner with tta transformation and TTAModel - """ - assert hasattr( - cfg, - 'tta_model'), ('please make sure tta_model is defined in your config.') - assert hasattr(cfg, 'tta_pipeline'), ( - 'please make sure tta_pipeline is defined in your config.') - cfg['test_dataloader']['dataset']['pipeline'] = cfg['tta_pipeline'] - - if 'runner_type' in cfg: - runner = RUNNERS.build(cfg) - else: - from mmengine.runner import Runner - runner = Runner.from_cfg(cfg) - - runner.register_hook(PrepareTTAHook(tta_cfg=cfg['tta_model'])) - return runner diff --git a/mmengine/hub/__init__.py b/mmengine/hub/__init__.py deleted file mode 100644 index e6f2add99c..0000000000 --- a/mmengine/hub/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .hub import get_config, get_model - -__all__ = ['get_config', 'get_model'] diff --git a/mmengine/hub/deprecated.json b/mmengine/hub/deprecated.json deleted file mode 100644 index 473a57c0ee..0000000000 --- a/mmengine/hub/deprecated.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "resnet50_caffe": "detectron/resnet50_caffe", - "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr", - "resnet101_caffe": "detectron/resnet101_caffe", - "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr" - } diff --git a/mmengine/hub/hub.py b/mmengine/hub/hub.py deleted file mode 100644 index b24ac2c125..0000000000 --- a/mmengine/hub/hub.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import importlib -import os.path as osp - -from mmengine.config import Config -from mmengine.config.utils import (_get_cfg_metainfo, - _get_external_cfg_base_path, - _get_package_and_cfg_path) -from mmengine.registry import MODELS, DefaultScope -from mmengine.runner import load_checkpoint -from mmengine.utils import get_installed_path, install_package - - -def get_config(cfg_path: str, pretrained: bool = False) -> Config: - """Get config from external package. - - Args: - cfg_path (str): External relative config path. - pretrained (bool): Whether to save pretrained model path. If - ``pretrained==True``, the url of pretrained model can be accessed - by ``cfg.model_path``. Defaults to False. - - Examples: - >>> cfg = get_config('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) - >>> # Equivalent to - >>> # cfg = Config.fromfile('/path/to/faster-rcnn_r50_fpn_1x_coco.py') - >>> cfg.model_path - https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth - - Returns: - Config: A `Config` parsed from external package. - """ # noqa E301 - # Get package name and relative config path. - package, cfg_path = _get_package_and_cfg_path(cfg_path) - # Install package if it's not installed. - install_package(package) - package_path = get_installed_path(package) - try: - # Use `cfg_path` to search target config file. - cfg_meta = _get_cfg_metainfo(package_path, cfg_path) - cfg_path = osp.join(package_path, '.mim', cfg_meta['Config']) - cfg = Config.fromfile(cfg_path) - if pretrained: - assert 'Weights' in cfg_meta, ('Cannot find `Weights` in cfg_file' - '.metafile.yml, please check the' - 'metafile') - cfg.model_path = cfg_meta['Weights'] - except ValueError: - # Since the base config does not contain a metafile, the absolute - # config is `osp.join(package_path, cfg_path_prefix, cfg_name)` - cfg_path = _get_external_cfg_base_path(package_path, cfg_path) - cfg = Config.fromfile(cfg_path) - except Exception as e: - raise e - return cfg - - -def get_model(cfg_path: str, pretrained: bool = False, **kwargs): - """Get built model from external package. - - Args: - cfg_path (str): External relative config path with prefix - 'package::' and without suffix. - pretrained (bool): Whether to load pretrained model. Defaults to False. - kwargs (dict): Default arguments to build model. - - Examples: - >>> model = get_model('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) - >>> type(model) - - - Returns: - nn.Module: Built model. - """ # noqa E301 - package = cfg_path.split('::')[0] - with DefaultScope.overwrite_default_scope(package): # type: ignore - cfg = get_config(cfg_path, pretrained) - if 'data_preprocessor' in cfg: - cfg.model.data_preprocessor = cfg.data_preprocessor - models_module = importlib.import_module(f'{package}.utils') - models_module.register_all_modules() # type: ignore - model = MODELS.build(cfg.model, default_args=kwargs) - if pretrained: - load_checkpoint(model, cfg.model_path) - # Hack to use pretrained weights. - # If we do not set _is_init here, Runner will call - # `model.init_weights()` to overwrite the pretrained model. - model._is_init = True - return model diff --git a/mmengine/hub/mmcls.json b/mmengine/hub/mmcls.json deleted file mode 100644 index 071db8709c..0000000000 --- a/mmengine/hub/mmcls.json +++ /dev/null @@ -1,59 +0,0 @@ -{ - "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth", - "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth", - "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth", - "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth", - "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth", - "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth", - "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth", - "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth", - "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth", - "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth", - "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth", - "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth", - "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth", - "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth", - "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth", - "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth", - "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth", - "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth", - "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth", - "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth", - "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth", - "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth", - "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth", - "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth", - "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth", - "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth", - "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth", - "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth", - "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth", - "mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth", - "mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth", - "repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth", - "repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth", - "repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth", - "repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth", - "repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth", - "repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth", - "repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth", - "repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth", - "repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth", - "repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth", - "repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth", - "repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth", - "res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth", - "res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth", - "res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth", - "swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth", - "swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth", - "swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth", - "swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth", - "t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth", - "t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth", - "t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth", - "tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth", - "vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth", - "vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth", - "vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth" - } diff --git a/mmengine/hub/openmmlab.json b/mmengine/hub/openmmlab.json deleted file mode 100644 index 0966212ef3..0000000000 --- a/mmengine/hub/openmmlab.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth", - "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth", - "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth", - "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth", - "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth", - "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth", - "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth", - "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth", - "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth", - "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth", - "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth", - "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth", - "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth", - "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth", - "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth", - "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth", - "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth", - "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth", - "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth", - "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth", - "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth", - "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth", - "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth", - "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth", - "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth", - "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth", - "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth", - "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth", - "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth", - "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth", - "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth", - "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth", - "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth", - "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth", - "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth", - "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth", - "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth", - "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth", - "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth", - "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth", - "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth", - "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth", - "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth", - "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth", - "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth", - "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth", - "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth", - "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth" - } diff --git a/mmengine/hub/torchvision_0.12.json b/mmengine/hub/torchvision_0.12.json deleted file mode 100644 index 06defe6748..0000000000 --- a/mmengine/hub/torchvision_0.12.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", - "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", - "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", - "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", - "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", - "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", - "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", - "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - "shufflenetv2_x1.5": null, - "shufflenetv2_x2.0": null, - "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", - "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", - "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", - "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" -} diff --git a/mmengine/infer/__init__.py b/mmengine/infer/__init__.py deleted file mode 100644 index a122481f14..0000000000 --- a/mmengine/infer/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .infer import BaseInferencer - -__all__ = ['BaseInferencer'] diff --git a/mmengine/infer/infer.py b/mmengine/infer/infer.py deleted file mode 100644 index 322d885224..0000000000 --- a/mmengine/infer/infer.py +++ /dev/null @@ -1,692 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import importlib -import os.path as osp -import re -import warnings -from abc import ABCMeta, abstractmethod -from datetime import datetime -from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence, - Tuple, Union) - -import numpy as np -import torch -import torch.nn as nn -from rich.progress import track - -from mmengine.config import Config, ConfigDict -from mmengine.config.utils import MODULE2PACKAGE -from mmengine.dataset import pseudo_collate -from mmengine.device import get_device -from mmengine.fileio import (get_file_backend, isdir, join_path, - list_dir_or_file, load) -from mmengine.logging import print_log -from mmengine.registry import FUNCTIONS, MODELS, VISUALIZERS, DefaultScope -from mmengine.runner.checkpoint import (_load_checkpoint, - _load_checkpoint_to_model) -from mmengine.structures import InstanceData -from mmengine.visualization import Visualizer - -InstanceList = List[InstanceData] -InputType = Union[str, np.ndarray, torch.Tensor] -InputsType = Union[InputType, Sequence[InputType]] -ImgType = Union[np.ndarray, Sequence[np.ndarray]] -ResType = Union[Dict, List[Dict]] -ConfigType = Union[Config, ConfigDict] -ModelType = Union[dict, ConfigType, str] - - -class InferencerMeta(ABCMeta): - """Check the legality of the inferencer. - - All Inferencers should not define duplicated keys for - ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` and - ``postprocess_kwargs``. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert isinstance(self.preprocess_kwargs, set) - assert isinstance(self.forward_kwargs, set) - assert isinstance(self.visualize_kwargs, set) - assert isinstance(self.postprocess_kwargs, set) - - all_kwargs = ( - self.preprocess_kwargs | self.forward_kwargs - | self.visualize_kwargs | self.postprocess_kwargs) - - assert len(all_kwargs) == ( - len(self.preprocess_kwargs) + len(self.forward_kwargs) + - len(self.visualize_kwargs) + len(self.postprocess_kwargs)), ( - f'Class define error! {self.__name__} should not ' - 'define duplicated keys for `preprocess_kwargs`, ' - '`forward_kwargs`, `visualize_kwargs` and ' - '`postprocess_kwargs` are not allowed.') - - -class BaseInferencer(metaclass=InferencerMeta): - """Base inferencer for downstream tasks. - - The BaseInferencer provides the standard workflow for inference as follows: - - 1. Preprocess the input data by :meth:`preprocess`. - 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` - assumes the model inherits from :class:`mmengine.models.BaseModel` and - will call `model.test_step` in :meth:`forward` by default. - 3. Visualize the results by :meth:`visualize`. - 4. Postprocess and return the results by :meth:`postprocess`. - - When we call the subclasses inherited from BaseInferencer (not overriding - ``__call__``), the workflow will be executed in order. - - All subclasses of BaseInferencer could define the following class - attributes for customization: - - - ``preprocess_kwargs``: The keys of the kwargs that will be passed to - :meth:`preprocess`. - - ``forward_kwargs``: The keys of the kwargs that will be passed to - :meth:`forward` - - ``visualize_kwargs``: The keys of the kwargs that will be passed to - :meth:`visualize` - - ``postprocess_kwargs``: The keys of the kwargs that will be passed to - :meth:`postprocess` - - All attributes mentioned above should be a ``set`` of keys (strings), - and each key should not be duplicated. Actually, :meth:`__call__` will - dispatch all the arguments to the corresponding methods according to the - ``xxx_kwargs`` mentioned above, therefore, the key in sets should - be unique to avoid ambiguous dispatching. - - Warning: - If subclasses defined the class attributes mentioned above with - duplicated keys, an ``AssertionError`` will be raised during import - process. - - Subclasses inherited from ``BaseInferencer`` should implement - :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: - - - _init_pipeline: Return a callable object to preprocess the input data. - - visualize: Visualize the results returned by :meth:`forward`. - - postprocess: Postprocess the results returned by :meth:`forward` and - :meth:`visualize`. - - Args: - model (str, optional): Path to the config file or the model name - defined in metafile. Take the `mmdet metafile `_ - as an example, the `model` could be `retinanet_r18_fpn_1x_coco` or - its alias. If model is not specified, user must provide the - `weights` saved by MMEngine which contains the config string. - Defaults to None. - weights (str, optional): Path to the checkpoint. If it is not specified - and model is a model name of metafile, the weights will be loaded - from metafile. Defaults to None. - device (str, optional): Device to run inference. If None, the available - device will be automatically used. Defaults to None. - scope (str, optional): The scope of the model. Defaults to None. - show_progress (bool): Control whether to display the progress bar during - the inference process. Defaults to True. - `New in version 0.7.4.` - - Note: - Since ``Inferencer`` could be used to infer batch data, - `collate_fn` should be defined. If `collate_fn` is not defined in config - file, the `collate_fn` will be `pseudo_collate` by default. - """ # noqa: E501 - - preprocess_kwargs: set = set() - forward_kwargs: set = set() - visualize_kwargs: set = set() - postprocess_kwargs: set = set() - - def __init__(self, - model: Union[ModelType, str, None] = None, - weights: Optional[str] = None, - device: Optional[str] = None, - scope: Optional[str] = None, - show_progress: bool = True) -> None: - if scope is None: - default_scope = DefaultScope.get_current_instance() - if default_scope is not None: - scope = default_scope.scope_name - self.scope = scope - # Load config to cfg - cfg: ConfigType - if isinstance(model, str): - if osp.isfile(model): - cfg = Config.fromfile(model) - else: - # Load config and weights from metafile. If `weights` is - # assigned, the weights defined in metafile will be ignored. - cfg, _weights = self._load_model_from_metafile(model) - if weights is None: - weights = _weights - elif isinstance(model, (Config, ConfigDict)): - cfg = copy.deepcopy(model) - elif isinstance(model, dict): - cfg = copy.deepcopy(ConfigDict(model)) - elif model is None: - if weights is None: - raise ValueError( - 'If model is None, the weights must be specified since ' - 'the config needs to be loaded from the weights') - cfg = ConfigDict() - else: - raise TypeError('model must be a filepath or any ConfigType' - f'object, but got {type(model)}') - - if device is None: - device = get_device() - - self.model = self._init_model(cfg, weights, device) # type: ignore - self.pipeline = self._init_pipeline(cfg) - self.collate_fn = self._init_collate(cfg) - self.visualizer = self._init_visualizer(cfg) - self.cfg = cfg - self.show_progress = show_progress - - def __call__( - self, - inputs: InputsType, - return_datasamples: bool = False, - batch_size: int = 1, - **kwargs, - ) -> dict: - """Call the inferencer. - - Args: - inputs (InputsType): Inputs for the inferencer. - return_datasamples (bool): Whether to return results as - :obj:`BaseDataElement`. Defaults to False. - batch_size (int): Batch size. Defaults to 1. - **kwargs: Key words arguments passed to :meth:`preprocess`, - :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. - Each key in kwargs should be in the corresponding set of - ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` - and ``postprocess_kwargs``. - - Returns: - dict: Inference and visualization results. - """ - ( - preprocess_kwargs, - forward_kwargs, - visualize_kwargs, - postprocess_kwargs, - ) = self._dispatch_kwargs(**kwargs) - - ori_inputs = self._inputs_to_list(inputs) - inputs = self.preprocess( - ori_inputs, batch_size=batch_size, **preprocess_kwargs) - preds = [] - for data in (track(inputs, description='Inference') - if self.show_progress else inputs): - preds.extend(self.forward(data, **forward_kwargs)) - visualization = self.visualize( - ori_inputs, preds, - **visualize_kwargs) # type: ignore # noqa: E501 - results = self.postprocess(preds, visualization, return_datasamples, - **postprocess_kwargs) - return results - - def _inputs_to_list(self, inputs: InputsType) -> list: - """Preprocess the inputs to a list. - - Preprocess inputs to a list according to its type: - - - list or tuple: return inputs - - str: - - Directory path: return all files in the directory - - other cases: return a list containing the string. The string - could be a path to file, a url or other types of string according - to the task. - - Args: - inputs (InputsType): Inputs for the inferencer. - - Returns: - list: List of input for the :meth:`preprocess`. - """ - if isinstance(inputs, str): - backend = get_file_backend(inputs) - if hasattr(backend, 'isdir') and isdir(inputs): - # Backends like HttpsBackend do not implement `isdir`, so only - # those backends that implement `isdir` could accept the inputs - # as a directory - filename_list = list_dir_or_file(inputs, list_dir=False) - inputs = [ - join_path(inputs, filename) for filename in filename_list - ] - - if not isinstance(inputs, (list, tuple)): - inputs = [inputs] - - return list(inputs) - - def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): - """Process the inputs into a model-feedable format. - - Customize your preprocess by overriding this method. Preprocess should - return an iterable object, of which each item will be used as the - input of ``model.test_step``. - - ``BaseInferencer.preprocess`` will return an iterable chunked data, - which will be used in __call__ like this: - - .. code-block:: python - - def __call__(self, inputs, batch_size=1, **kwargs): - chunked_data = self.preprocess(inputs, batch_size, **kwargs) - for batch in chunked_data: - preds = self.forward(batch, **kwargs) - - Args: - inputs (InputsType): Inputs given by user. - batch_size (int): batch size. Defaults to 1. - - Yields: - Any: Data processed by the ``pipeline`` and ``collate_fn``. - """ - chunked_data = self._get_chunk_data( - map(self.pipeline, inputs), batch_size) - yield from map(self.collate_fn, chunked_data) - - @torch.no_grad() - def forward(self, inputs: Union[dict, tuple], **kwargs) -> Any: - """Feed the inputs to the model.""" - return self.model.test_step(inputs) - - @abstractmethod - def visualize(self, - inputs: list, - preds: Any, - show: bool = False, - **kwargs) -> List[np.ndarray]: - """Visualize predictions. - - Customize your visualization by overriding this method. visualize - should return visualization results, which could be np.ndarray or any - other objects. - - Args: - inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. - preds (Any): Predictions of the model. - show (bool): Whether to display the image in a popup window. - Defaults to False. - - Returns: - List[np.ndarray]: Visualization results. - """ - - @abstractmethod - def postprocess( - self, - preds: Any, - visualization: List[np.ndarray], - return_datasample=False, - **kwargs, - ) -> dict: - """Process the predictions and visualization results from ``forward`` - and ``visualize``. - - This method should be responsible for the following tasks: - - 1. Convert datasamples into a json-serializable dict if needed. - 2. Pack the predictions and visualization results and return them. - 3. Dump or log the predictions. - - Customize your postprocess by overriding this method. Make sure - ``postprocess`` will return a dict with visualization results and - inference results. - - Args: - preds (List[Dict]): Predictions of the model. - visualization (np.ndarray): Visualized predictions. - return_datasample (bool): Whether to return results as datasamples. - Defaults to False. - - Returns: - dict: Inference and visualization results with key ``predictions`` - and ``visualization`` - - - ``visualization (Any)``: Returned by :meth:`visualize` - - ``predictions`` (dict or DataSample): Returned by - :meth:`forward` and processed in :meth:`postprocess`. - If ``return_datasample=False``, it usually should be a - json-serializable dict containing only basic data elements such - as strings and numbers. - """ - - def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]: - """Load config and weights from metafile. - - Args: - model (str): model name defined in metafile. - - Returns: - Tuple[Config, str]: Loaded Config and weights path defined in - metafile. - """ - model = model.lower() - - assert self.scope is not None, ( - 'scope should be initialized if you want ' - 'to load config from metafile.') - assert self.scope in MODULE2PACKAGE, ( - f'{self.scope} not in {MODULE2PACKAGE}!,' - 'please pass a valid scope.') - - repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(self.scope) - for model_cfg in BaseInferencer._get_models_from_metafile( - repo_or_mim_dir): - model_name = model_cfg['Name'].lower() - model_aliases = model_cfg.get('Alias', []) - if isinstance(model_aliases, str): - model_aliases = [model_aliases.lower()] - else: - model_aliases = [alias.lower() for alias in model_aliases] - if (model_name == model or model in model_aliases): - cfg = Config.fromfile( - osp.join(repo_or_mim_dir, model_cfg['Config'])) - weights = model_cfg['Weights'] - weights = weights[0] if isinstance(weights, list) else weights - return cfg, weights - raise ValueError(f'Cannot find model: {model} in {self.scope}') - - @staticmethod - def _get_repo_or_mim_dir(scope): - """Get the directory where the ``Configs`` located when the package is - installed or ``PYTHONPATH`` is set. - - Args: - scope (str): The scope of repository. - - Returns: - str: The directory where the ``Configs`` is located. - """ - try: - module = importlib.import_module(scope) - except ImportError: - if scope not in MODULE2PACKAGE: - raise KeyError( - f'{scope} is not a valid scope. The available scopes ' - f'are {MODULE2PACKAGE.keys()}') - else: - project = MODULE2PACKAGE[scope] - raise ImportError( - f'Cannot import {scope} correctly, please try to install ' - f'the {project} by "pip install {project}"') - # Since none of OpenMMLab series packages are namespace packages - # (https://docs.python.org/3/glossary.html#term-namespace-package), - # The first element of module.__path__ means package installation path. - package_path = module.__path__[0] - - if osp.exists(osp.join(osp.dirname(package_path), 'configs')): - repo_dir = osp.dirname(package_path) - return repo_dir - else: - mim_dir = osp.join(package_path, '.mim') - if not osp.exists(osp.join(mim_dir, 'configs')): - raise FileNotFoundError( - f'Cannot find `configs` directory in {package_path}!, ' - f'please check the completeness of the {scope}.') - return mim_dir - - def _init_model( - self, - cfg: ConfigType, - weights: Optional[str], - device: str = 'cpu', - ) -> nn.Module: - """Initialize the model with the given config and checkpoint on the - specific device. - - Args: - cfg (ConfigType): Config containing the model information. - weights (str, optional): Path to the checkpoint. - device (str, optional): Device to run inference. Defaults to 'cpu'. - - Returns: - nn.Module: Model loaded with checkpoint. - """ - checkpoint: Optional[dict] = None - if weights is not None: - checkpoint = _load_checkpoint(weights, map_location='cpu') - - if not cfg: - assert checkpoint is not None - try: - # Prefer to get config from `message_hub` since `message_hub` - # is a more stable module to store all runtime information. - # However, the early version of MMEngine will not save config - # in `message_hub`, so we will try to load config from `meta`. - cfg_string = checkpoint['message_hub']['runtime_info']['cfg'] - except KeyError: - assert 'meta' in checkpoint, ( - 'If model(config) is not provided, the checkpoint must' - 'contain the config string in `meta` or `message_hub`, ' - 'but both `meta` and `message_hub` are not found in the ' - 'checkpoint.') - meta = checkpoint['meta'] - if 'cfg' in meta: - cfg_string = meta['cfg'] - else: - raise ValueError( - 'Cannot find the config in the checkpoint.') - cfg.update( - Config.fromstring(cfg_string, file_format='.py')._cfg_dict) - - # Delete the `pretrained` field to prevent model from loading the - # the pretrained weights unnecessarily. - if cfg.model.get('pretrained') is not None: - del cfg.model.pretrained - - model = MODELS.build(cfg.model) - model.cfg = cfg - self._load_weights_to_model(model, checkpoint, cfg) - model.to(device) - model.eval() - return model - - def _load_weights_to_model(self, model: nn.Module, - checkpoint: Optional[dict], - cfg: Optional[ConfigType]) -> None: - """Loading model weights and meta information from cfg and checkpoint. - - Subclasses could override this method to load extra meta information - from ``checkpoint`` and ``cfg`` to model. - - Args: - model (nn.Module): Model to load weights and meta information. - checkpoint (dict, optional): The loaded checkpoint. - cfg (Config or ConfigDict, optional): The loaded config. - """ - if checkpoint is not None: - _load_checkpoint_to_model(model, checkpoint) - else: - warnings.warn('Checkpoint is not loaded, and the inference ' - 'result is calculated by the randomly initialized ' - 'model!') - - def _init_collate(self, cfg: ConfigType) -> Callable: - """Initialize the ``collate_fn`` with the given config. - - The returned ``collate_fn`` will be used to collate the batch data. - If will be used in :meth:`preprocess` like this - - .. code-block:: python - def preprocess(self, inputs, batch_size, **kwargs): - ... - dataloader = map(self.collate_fn, dataloader) - yield from dataloader - - Args: - cfg (ConfigType): Config which could contained the `collate_fn` - information. If `collate_fn` is not defined in config, it will - be :func:`pseudo_collate`. - - Returns: - Callable: Collate function. - """ - try: - with FUNCTIONS.switch_scope_and_registry(self.scope) as registry: - collate_fn = registry.get(cfg.test_dataloader.collate_fn) - except AttributeError: - collate_fn = pseudo_collate - return collate_fn # type: ignore - - @abstractmethod - def _init_pipeline(self, cfg: ConfigType) -> Callable: - """Initialize the test pipeline. - - Return a pipeline to handle various input data, such as ``str``, - ``np.ndarray``. It is an abstract method in BaseInferencer, and should - be implemented in subclasses. - - The returned pipeline will be used to process a single data. - It will be used in :meth:`preprocess` like this: - - .. code-block:: python - def preprocess(self, inputs, batch_size, **kwargs): - ... - dataset = map(self.pipeline, dataset) - ... - """ - - def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]: - """Initialize visualizers. - - Args: - cfg (ConfigType): Config containing the visualizer information. - - Returns: - Visualizer or None: Visualizer initialized with config. - """ - if 'visualizer' not in cfg: - return None - timestamp = str(datetime.timestamp(datetime.now())) - name = cfg.visualizer.get('name', timestamp) - if Visualizer.check_instance_created(name): - name = f'{name}-{timestamp}' - cfg.visualizer.name = name - return VISUALIZERS.build(cfg.visualizer) - - def _get_chunk_data(self, inputs: Iterable, chunk_size: int): - """Get batch data from dataset. - - Args: - inputs (Iterable): An iterable dataset. - chunk_size (int): Equivalent to batch size. - - Yields: - list: batch data. - """ - inputs_iter = iter(inputs) - while True: - try: - chunk_data = [] - for _ in range(chunk_size): - processed_data = next(inputs_iter) - chunk_data.append(processed_data) - yield chunk_data - except StopIteration: - if chunk_data: - yield chunk_data - break - - def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]: - """Dispatch kwargs to preprocess(), forward(), visualize() and - postprocess() according to the actual demands. - - Returns: - Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, - forward, visualize and postprocess respectively. - """ - # Ensure each argument only matches one function - method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \ - self.visualize_kwargs | self.postprocess_kwargs - - union_kwargs = method_kwargs | set(kwargs.keys()) - if union_kwargs != method_kwargs: - unknown_kwargs = union_kwargs - method_kwargs - raise ValueError( - f'unknown argument {unknown_kwargs} for `preprocess`, ' - '`forward`, `visualize` and `postprocess`') - - preprocess_kwargs = {} - forward_kwargs = {} - visualize_kwargs = {} - postprocess_kwargs = {} - - for key, value in kwargs.items(): - if key in self.preprocess_kwargs: - preprocess_kwargs[key] = value - elif key in self.forward_kwargs: - forward_kwargs[key] = value - elif key in self.visualize_kwargs: - visualize_kwargs[key] = value - else: - postprocess_kwargs[key] = value - - return ( - preprocess_kwargs, - forward_kwargs, - visualize_kwargs, - postprocess_kwargs, - ) - - @staticmethod - def _get_models_from_metafile(dir: str): - """Load model config defined in metafile from package path. - - Args: - dir (str): Path to the directory of Config. It requires the - directory ``Config``, file ``model-index.yml`` exists in the - ``dir``. - - Yields: - dict: Model config defined in metafile. - """ - meta_indexes = load(osp.join(dir, 'model-index.yml')) - for meta_path in meta_indexes['Import']: - # meta_path example: mmcls/.mim/configs/conformer/metafile.yml - meta_path = osp.join(dir, meta_path) - metainfo = load(meta_path) - yield from metainfo['Models'] - - @staticmethod - def list_models(scope: Optional[str] = None, patterns: str = r'.*'): - """List models defined in metafile of corresponding packages. - - Args: - scope (str, optional): The scope to which the model belongs. - Defaults to None. - patterns (str, optional): Regular expressions for the searched - models. Once matched with ``Alias`` or ``Name`` filed in - metafile, corresponding model will be added to the return list. - Defaults to '.*'. - - Returns: - dict: Model dict with model name and its alias. - """ - matched_models = [] - if scope is None: - default_scope = DefaultScope.get_current_instance() - assert default_scope is not None, ( - 'scope should be initialized if you want ' - 'to load config from metafile.') - assert scope in MODULE2PACKAGE, ( - f'{scope} not in {MODULE2PACKAGE}!, please make pass a valid ' - 'scope.') - root_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope) - for model_cfg in BaseInferencer._get_models_from_metafile( - root_or_mim_dir): - model_name = [model_cfg['Name']] - model_name.extend(model_cfg.get('Alias', [])) - for name in model_name: - if re.match(patterns, name) is not None: - matched_models.append(name) - output_str = '' - for name in matched_models: - output_str += f'model_name: {name}\n' - print_log(output_str, logger='current') - return matched_models diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py deleted file mode 100644 index ba5533c236..0000000000 --- a/mmengine/logging/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .history_buffer import HistoryBuffer -from .logger import MMLogger, print_log -from .message_hub import MessageHub - -__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log'] diff --git a/mmengine/logging/history_buffer.py b/mmengine/logging/history_buffer.py deleted file mode 100644 index a50de22c65..0000000000 --- a/mmengine/logging/history_buffer.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import Any, Callable, Optional, Sequence, Tuple, Union - -import numpy as np - - -class HistoryBuffer: - """Unified storage format for different log types. - - ``HistoryBuffer`` records the history of log for further statistics. - - Examples: - >>> history_buffer = HistoryBuffer() - >>> # Update history_buffer. - >>> history_buffer.update(1) - >>> history_buffer.update(2) - >>> history_buffer.min() # minimum of (1, 2) - 1 - >>> history_buffer.max() # maximum of (1, 2) - 2 - >>> history_buffer.mean() # mean of (1, 2) - 1.5 - >>> history_buffer.statistics('mean') # access method by string. - 1.5 - - Args: - log_history (Sequence): History logs. Defaults to []. - count_history (Sequence): Counts of history logs. Defaults to []. - max_length (int): The max length of history logs. Defaults to 1000000. - """ - _statistics_methods: dict = dict() - - def __init__(self, - log_history: Sequence = [], - count_history: Sequence = [], - max_length: int = 1000000): - - self.max_length = max_length - self._set_default_statistics() - assert len(log_history) == len(count_history), \ - 'The lengths of log_history and count_histroy should be equal' - if len(log_history) > max_length: - warnings.warn(f'The length of history buffer({len(log_history)}) ' - f'exceeds the max_length({max_length}), the first ' - 'few elements will be ignored.') - self._log_history = np.array(log_history[-max_length:]) - self._count_history = np.array(count_history[-max_length:]) - else: - self._log_history = np.array(log_history) - self._count_history = np.array(count_history) - - def _set_default_statistics(self) -> None: - """Register default statistic methods: min, max, current and mean.""" - self._statistics_methods.setdefault('min', HistoryBuffer.min) - self._statistics_methods.setdefault('max', HistoryBuffer.max) - self._statistics_methods.setdefault('current', HistoryBuffer.current) - self._statistics_methods.setdefault('mean', HistoryBuffer.mean) - - def update(self, log_val: Union[int, float], count: int = 1) -> None: - """Update the log history. - - If the length of the buffer exceeds ``self._max_length``, the oldest - element will be removed from the buffer. - - Args: - log_val (int or float): The value of log. - count (int): The accumulation times of log, defaults to 1. - ``count`` will be used in smooth statistics. - """ - if (not isinstance(log_val, (int, float)) - or not isinstance(count, (int, float))): - raise TypeError(f'log_val must be int or float but got ' - f'{type(log_val)}, count must be int but got ' - f'{type(count)}') - self._log_history = np.append(self._log_history, log_val) - self._count_history = np.append(self._count_history, count) - if len(self._log_history) > self.max_length: - self._log_history = self._log_history[-self.max_length:] - self._count_history = self._count_history[-self.max_length:] - - @property - def data(self) -> Tuple[np.ndarray, np.ndarray]: - """Get the ``_log_history`` and ``_count_history``. - - Returns: - Tuple[np.ndarray, np.ndarray]: History logs and the counts of - the history logs. - """ - return self._log_history, self._count_history - - @classmethod - def register_statistics(cls, method: Callable) -> Callable: - """Register custom statistics method to ``_statistics_methods``. - - The registered method can be called by ``history_buffer.statistics`` - with corresponding method name and arguments. - - Examples: - >>> @HistoryBuffer.register_statistics - >>> def weighted_mean(self, window_size, weight): - >>> assert len(weight) == window_size - >>> return (self._log_history[-window_size:] * - >>> np.array(weight)).sum() / \ - >>> self._count_history[-window_size:] - - >>> log_buffer = HistoryBuffer([1, 2], [1, 1]) - >>> log_buffer.statistics('weighted_mean', 2, [2, 1]) - 2 - - Args: - method (Callable): Custom statistics method. - Returns: - Callable: Original custom statistics method. - """ - method_name = method.__name__ - assert method_name not in cls._statistics_methods, \ - 'method_name cannot be registered twice!' - cls._statistics_methods[method_name] = method - return method - - def statistics(self, method_name: str, *arg, **kwargs) -> Any: - """Access statistics method by name. - - Args: - method_name (str): Name of method. - - Returns: - Any: Depends on corresponding method. - """ - if method_name not in self._statistics_methods: - raise KeyError(f'{method_name} has not been registered in ' - 'HistoryBuffer._statistics_methods') - method = self._statistics_methods[method_name] - # Provide self arguments for registered functions. - return method(self, *arg, **kwargs) - - def mean(self, window_size: Optional[int] = None) -> np.ndarray: - """Return the mean of the latest ``window_size`` values in log - histories. - - If ``window_size is None`` or ``window_size > len(self._log_history)``, - return the global mean value of history logs. - - Args: - window_size (int, optional): Size of statistics window. - Returns: - np.ndarray: Mean value within the window. - """ - if window_size is not None: - assert isinstance(window_size, int), \ - 'The type of window size should be int, but got ' \ - f'{type(window_size)}' - else: - window_size = len(self._log_history) - logs_sum = self._log_history[-window_size:].sum() - counts_sum = self._count_history[-window_size:].sum() - return logs_sum / counts_sum - - def max(self, window_size: Optional[int] = None) -> np.ndarray: - """Return the maximum value of the latest ``window_size`` values in log - histories. - - If ``window_size is None`` or ``window_size > len(self._log_history)``, - return the global maximum value of history logs. - - Args: - window_size (int, optional): Size of statistics window. - Returns: - np.ndarray: The maximum value within the window. - """ - if window_size is not None: - assert isinstance(window_size, int), \ - 'The type of window size should be int, but got ' \ - f'{type(window_size)}' - else: - window_size = len(self._log_history) - return self._log_history[-window_size:].max() - - def min(self, window_size: Optional[int] = None) -> np.ndarray: - """Return the minimum value of the latest ``window_size`` values in log - histories. - - If ``window_size is None`` or ``window_size > len(self._log_history)``, - return the global minimum value of history logs. - - Args: - window_size (int, optional): Size of statistics window. - Returns: - np.ndarray: The minimum value within the window. - """ - if window_size is not None: - assert isinstance(window_size, int), \ - 'The type of window size should be int, but got ' \ - f'{type(window_size)}' - else: - window_size = len(self._log_history) - return self._log_history[-window_size:].min() - - def current(self) -> np.ndarray: - """Return the recently updated values in log histories. - - Returns: - np.ndarray: Recently updated values in log histories. - """ - if len(self._log_history) == 0: - raise ValueError('HistoryBuffer._log_history is an empty array! ' - 'please call update first') - return self._log_history[-1] - - def __getstate__(self) -> dict: - """Make ``_statistics_methods`` can be resumed. - - Returns: - dict: State dict including statistics_methods. - """ - self.__dict__.update(statistics_methods=self._statistics_methods) - return self.__dict__ - - def __setstate__(self, state): - """Try to load ``_statistics_methods`` from state. - - Args: - state (dict): State dict. - """ - statistics_methods = state.pop('statistics_methods', {}) - self._set_default_statistics() - self._statistics_methods.update(statistics_methods) - self.__dict__.update(state) diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py deleted file mode 100644 index e6cf9fe37d..0000000000 --- a/mmengine/logging/logger.py +++ /dev/null @@ -1,462 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -import os -import os.path as osp -import sys -import warnings -from getpass import getuser -from logging import Logger, LogRecord, handlers -from socket import gethostname -from typing import Dict, Optional, Union - -from termcolor import colored - -from mmengine.utils import ManagerMixin -from mmengine.utils.manager import _accquire_lock, _release_lock - - -class FilterDuplicateWarning(logging.Filter): - """Filter the repeated warning message. - - Args: - name (str): name of the filter. - """ - - def __init__(self, name: str = 'mmengine'): - super().__init__(name) - self.seen: set = set() - - def filter(self, record: LogRecord) -> bool: - """Filter the repeated warning message. - - Args: - record (LogRecord): The log record. - - Returns: - bool: Whether to output the log record. - """ - if record.levelno != logging.WARNING: - return True - - if record.msg not in self.seen: - self.seen.add(record.msg) - return True - return False - - -class MMFormatter(logging.Formatter): - """Colorful format for MMLogger. If the log level is error, the logger will - additionally output the location of the code. - - Args: - color (bool): Whether to use colorful format. filehandler is not - allowed to use color format, otherwise it will be garbled. - blink (bool): Whether to blink the ``INFO`` and ``DEBUG`` logging - level. - **kwargs: Keyword arguments passed to - :meth:`logging.Formatter.__init__`. - """ - _color_mapping: dict = dict( - ERROR='red', WARNING='yellow', INFO='white', DEBUG='green') - - def __init__(self, color: bool = True, blink: bool = False, **kwargs): - super().__init__(**kwargs) - assert not (not color and blink), ( - 'blink should only be available when color is True') - # Get prefix format according to color. - error_prefix = self._get_prefix('ERROR', color, blink=True) - warn_prefix = self._get_prefix('WARNING', color, blink=True) - info_prefix = self._get_prefix('INFO', color, blink) - debug_prefix = self._get_prefix('DEBUG', color, blink) - - # Config output format. - self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - ' - '%(pathname)s - %(funcName)s - %(lineno)d - ' - '%(message)s') - self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %(' - 'message)s') - self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %(' - 'message)s') - self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %(' - 'message)s') - - def _get_prefix(self, level: str, color: bool, blink=False) -> str: - """Get the prefix of the target log level. - - Args: - level (str): log level. - color (bool): Whether to get colorful prefix. - blink (bool): Whether the prefix will blink. - - Returns: - str: The plain or colorful prefix. - """ - if color: - attrs = ['underline'] - if blink: - attrs.append('blink') - prefix = colored(level, self._color_mapping[level], attrs=attrs) - else: - prefix = level - return prefix - - def format(self, record: LogRecord) -> str: - """Override the `logging.Formatter.format`` method `. Output the - message according to the specified log level. - - Args: - record (LogRecord): A LogRecord instance represents an event being - logged. - - Returns: - str: Formatted result. - """ - if record.levelno == logging.ERROR: - self._style._fmt = self.err_format - elif record.levelno == logging.WARNING: - self._style._fmt = self.warn_format - elif record.levelno == logging.INFO: - self._style._fmt = self.info_format - elif record.levelno == logging.DEBUG: - self._style._fmt = self.debug_format - - result = logging.Formatter.format(self, record) - return result - - -class MMLogger(Logger, ManagerMixin): - """Formatted logger used to record messages. - - ``MMLogger`` can create formatted logger to log message with different - log levels and get instance in the same way as ``ManagerMixin``. - ``MMLogger`` has the following features: - - - Distributed log storage, ``MMLogger`` can choose whether to save log of - different ranks according to `log_file`. - - Message with different log levels will have different colors and format - when displayed on terminal. - - Note: - - The `name` of logger and the ``instance_name`` of ``MMLogger`` could - be different. We can only get ``MMLogger`` instance by - ``MMLogger.get_instance`` but not ``logging.getLogger``. This feature - ensures ``MMLogger`` will not be incluenced by third-party logging - config. - - Different from ``logging.Logger``, ``MMLogger`` will not log warning - or error message without ``Handler``. - - Examples: - >>> logger = MMLogger.get_instance(name='MMLogger', - >>> logger_name='Logger') - >>> # Although logger has name attribute just like `logging.Logger` - >>> # We cannot get logger instance by `logging.getLogger`. - >>> assert logger.name == 'Logger' - >>> assert logger.instance_name = 'MMLogger' - >>> assert id(logger) != id(logging.getLogger('Logger')) - >>> # Get logger that do not store logs. - >>> logger1 = MMLogger.get_instance('logger1') - >>> # Get logger only save rank0 logs. - >>> logger2 = MMLogger.get_instance('logger2', log_file='out.log') - >>> # Get logger only save multiple ranks logs. - >>> logger3 = MMLogger.get_instance('logger3', log_file='out.log', - >>> distributed=True) - - Args: - name (str): Global instance name. - logger_name (str): ``name`` attribute of ``Logging.Logger`` instance. - If `logger_name` is not defined, defaults to 'mmengine'. - log_file (str, optional): The log filename. If specified, a - ``FileHandler`` will be added to the logger. Defaults to None. - log_level (str): The log level of the handler. Defaults to - 'INFO'. If log level is 'DEBUG', distributed logs will be saved - during distributed training. - file_mode (str): The file mode used to open log file. Defaults to 'w'. - distributed (bool): Whether to save distributed logs, Defaults to - false. - file_handler_cfg (dict, optional): Configuration of file handler. - Defaults to None. If ``file_handler_cfg`` is not specified, - ``logging.FileHandler`` will be used by default. If it is - specified, the ``type`` key should be set. It can be - ``RotatingFileHandler``, ``TimedRotatingFileHandler``, - ``WatchedFileHandler`` or other file handlers, and the remaining - fields will be used to build the handler. - - Examples: - >>> file_handler_cfg = dict( - >>> type='TimedRotatingFileHandler', - >>> when='MIDNIGHT', - >>> interval=1, - >>> backupCount=365) - - `New in version 0.9.0.` - """ - - def __init__(self, - name: str, - logger_name='mmengine', - log_file: Optional[str] = None, - log_level: Union[int, str] = 'INFO', - file_mode: str = 'w', - distributed=False, - file_handler_cfg: Optional[dict] = None): - Logger.__init__(self, logger_name) - ManagerMixin.__init__(self, name) - # Get rank in DDP mode. - if isinstance(log_level, str): - log_level = logging._nameToLevel[log_level] - global_rank = _get_rank() - device_id = _get_device_id() - - # Config stream_handler. If `rank != 0`. stream_handler can only - # export ERROR logs. - stream_handler = logging.StreamHandler(stream=sys.stdout) - # `StreamHandler` record month, day, hour, minute, and second - # timestamp. - stream_handler.setFormatter( - MMFormatter(color=True, datefmt='%m/%d %H:%M:%S')) - # Only rank0 `StreamHandler` will log messages below error level. - if global_rank == 0: - stream_handler.setLevel(log_level) - else: - stream_handler.setLevel(logging.ERROR) - stream_handler.addFilter(FilterDuplicateWarning(logger_name)) - self.handlers.append(stream_handler) - - if log_file is not None: - world_size = _get_world_size() - is_distributed = (log_level <= logging.DEBUG - or distributed) and world_size > 1 - if is_distributed: - filename, suffix = osp.splitext(osp.basename(log_file)) - hostname = _get_host_info() - if hostname: - filename = (f'{filename}_{hostname}_device{device_id}_' - f'rank{global_rank}{suffix}') - else: - # Omit hostname if it is empty - filename = (f'{filename}_device{device_id}_' - f'rank{global_rank}{suffix}') - log_file = osp.join(osp.dirname(log_file), filename) - # Save multi-ranks logs if distributed is True. The logs of rank0 - # will always be saved. - if global_rank == 0 or is_distributed: - if file_handler_cfg is not None: - assert 'type' in file_handler_cfg - file_handler_type = file_handler_cfg.pop('type') - file_handlers_map = _get_logging_file_handlers() - if file_handler_type in file_handlers_map: - file_handler_cls = file_handlers_map[file_handler_type] - file_handler_cfg.setdefault('filename', log_file) - file_handler = file_handler_cls(**file_handler_cfg) - else: - raise ValueError('`logging.handlers` does not ' - f'contain {file_handler_type}') - else: - # Here, the default behavior of the official - # logger is 'a'. Thus, we provide an interface to - # change the file mode to the default behavior. - # `FileHandler` is not supported to have colors, - # otherwise it will appear garbled. - file_handler = logging.FileHandler(log_file, file_mode) - - # `StreamHandler` record year, month, day hour, minute, - # and second timestamp. file_handler will only record logs - # without color to avoid garbled code saved in files. - file_handler.setFormatter( - MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) - file_handler.setLevel(log_level) - file_handler.addFilter(FilterDuplicateWarning(logger_name)) - self.handlers.append(file_handler) - self._log_file = log_file - - @property - def log_file(self): - return self._log_file - - @classmethod - def get_current_instance(cls) -> 'MMLogger': - """Get latest created ``MMLogger`` instance. - - :obj:`MMLogger` can call :meth:`get_current_instance` before any - instance has been created, and return a logger with the instance name - "mmengine". - - Returns: - MMLogger: Configured logger instance. - """ - if not cls._instance_dict: - cls.get_instance('mmengine') - return super().get_current_instance() - - def callHandlers(self, record: LogRecord) -> None: - """Pass a record to all relevant handlers. - - Override ``callHandlers`` method in ``logging.Logger`` to avoid - multiple warning messages in DDP mode. Loop through all handlers of - the logger instance and its parents in the logger hierarchy. If no - handler was found, the record will not be output. - - Args: - record (LogRecord): A ``LogRecord`` instance contains logged - message. - """ - for handler in self.handlers: - if record.levelno >= handler.level: - handler.handle(record) - - def setLevel(self, level): - """Set the logging level of this logger. - - If ``logging.Logger.selLevel`` is called, all ``logging.Logger`` - instances managed by ``logging.Manager`` will clear the cache. Since - ``MMLogger`` is not managed by ``logging.Manager`` anymore, - ``MMLogger`` should override this method to clear caches of all - ``MMLogger`` instance which is managed by :obj:`ManagerMixin`. - - level must be an int or a str. - """ - self.level = logging._checkLevel(level) - _accquire_lock() - # The same logic as `logging.Manager._clear_cache`. - for logger in MMLogger._instance_dict.values(): - logger._cache.clear() - _release_lock() - - -def print_log(msg, - logger: Optional[Union[Logger, str]] = None, - level=logging.INFO) -> None: - """Print a log message. - - Args: - msg (str): The message to be logged. - logger (Logger or str, optional): If the type of logger is - ``logging.Logger``, we directly use logger to log messages. - Some special loggers are: - - - "silent": No message will be printed. - - "current": Use latest created logger to log message. - - other str: Instance name of logger. The corresponding logger - will log message if it has been created, otherwise ``print_log`` - will raise a `ValueError`. - - None: The `print()` method will be used to print log messages. - level (int): Logging level. Only available when `logger` is a Logger - object, "current", or a created logger instance name. - """ - if logger is None: - print(msg) - elif isinstance(logger, logging.Logger): - logger.log(level, msg) - elif logger == 'silent': - pass - elif logger == 'current': - logger_instance = MMLogger.get_current_instance() - logger_instance.log(level, msg) - elif isinstance(logger, str): - # If the type of `logger` is `str`, but not with value of `current` or - # `silent`, we assume it indicates the name of the logger. If the - # corresponding logger has not been created, `print_log` will raise - # a `ValueError`. - if MMLogger.check_instance_created(logger): - logger_instance = MMLogger.get_instance(logger) - logger_instance.log(level, msg) - else: - raise ValueError(f'MMLogger: {logger} has not been created!') - else: - raise TypeError( - '`logger` should be either a logging.Logger object, str, ' - f'"silent", "current" or None, but got {type(logger)}') - - -def _get_world_size(): - """Support using logging module without torch.""" - try: - # requires torch - from mmengine.dist import get_world_size - except ImportError: - return 1 - else: - return get_world_size() - - -def _get_rank(): - """Support using logging module without torch.""" - try: - # requires torch - from mmengine.dist import get_rank - except ImportError: - return 0 - else: - return get_rank() - - -def _get_device_id(): - """Get device id of current machine.""" - try: - import torch - except ImportError: - return 0 - else: - MUSA_AVAILABLE = False - try: - import torch_musa - MUSA_AVAILABLE = True - except ImportError: - pass - if MUSA_AVAILABLE: - local_rank = int(os.getenv('LOCAL_RANK', '0')) - musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None) - if musa_visible_devices is None: - num_device = torch_musa.device_count() - musa_visible_devices = list(range(num_device)) - else: - musa_visible_devices = musa_visible_devices.split(',') - return int(musa_visible_devices[local_rank]) - else: - local_rank = int(os.getenv('LOCAL_RANK', '0')) - # TODO: return device id of npu and mlu. - if not torch.cuda.is_available(): - return local_rank - cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) - if cuda_visible_devices is None: - num_device = torch.cuda.device_count() - cuda_visible_devices = list(range(num_device)) - else: - cuda_visible_devices = cuda_visible_devices.split(',') - try: - return int(cuda_visible_devices[local_rank]) - except ValueError: - # handle case for Multi-Instance GPUs - # see #1148 for details - return cuda_visible_devices[local_rank] - - -def _get_host_info() -> str: - """Get hostname and username. - - Return empty string if exception raised, e.g. ``getpass.getuser()`` will - lead to error in docker container - """ - host = '' - try: - host = f'{getuser()}@{gethostname()}' - except Exception as e: - warnings.warn(f'Host or user not found: {str(e)}') - return host - - -def _get_logging_file_handlers() -> Dict: - """Get additional file_handlers in ``logging.handlers``. - - Returns: - Dict: A map of file_handlers. - """ - file_handlers_map = {} - for module_name in dir(handlers): - if module_name.startswith('__'): - continue - _fh = getattr(handlers, module_name) - if inspect.isclass(_fh) and issubclass(_fh, logging.FileHandler): - file_handlers_map[module_name] = _fh - return file_handlers_map diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py deleted file mode 100644 index 82565d8832..0000000000 --- a/mmengine/logging/message_hub.py +++ /dev/null @@ -1,470 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Optional, Union - -import numpy as np - -from mmengine.utils import ManagerMixin -from .history_buffer import HistoryBuffer -from .logger import print_log - -if TYPE_CHECKING: - import torch - - -class MessageHub(ManagerMixin): - """Message hub for component interaction. MessageHub is created and - accessed in the same way as ManagerMixin. - - ``MessageHub`` will record log information and runtime information. The - log information refers to the learning rate, loss, etc. of the model - during training phase, which will be stored as ``HistoryBuffer``. The - runtime information refers to the iter times, meta information of - runner etc., which will be overwritten by next update. - - Args: - name (str): Name of message hub used to get corresponding instance - globally. - log_scalars (dict, optional): Each key-value pair in the - dictionary is the name of the log information such as "loss", "lr", - "metric" and their corresponding values. The type of value must be - HistoryBuffer. Defaults to None. - runtime_info (dict, optional): Each key-value pair in the - dictionary is the name of the runtime information and their - corresponding values. Defaults to None. - resumed_keys (dict, optional): Each key-value pair in the - dictionary decides whether the key in :attr:`_log_scalars` and - :attr:`_runtime_info` will be serialized. - - Note: - Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or - :attr:`_runtime_info`. The corresponding value cannot be set - repeatedly. - - Examples: - >>> # create empty `MessageHub`. - >>> message_hub1 = MessageHub('name') - >>> log_scalars = dict(loss=HistoryBuffer()) - >>> runtime_info = dict(task='task') - >>> resumed_keys = dict(loss=True) - >>> # create `MessageHub` from data. - >>> message_hub2 = MessageHub( - >>> name='name', - >>> log_scalars=log_scalars, - >>> runtime_info=runtime_info, - >>> resumed_keys=resumed_keys) - """ - - def __init__(self, - name: str, - log_scalars: Optional[dict] = None, - runtime_info: Optional[dict] = None, - resumed_keys: Optional[dict] = None): - super().__init__(name) - self._log_scalars = self._parse_input('log_scalars', log_scalars) - self._runtime_info = self._parse_input('runtime_info', runtime_info) - self._resumed_keys = self._parse_input('resumed_keys', resumed_keys) - - for value in self._log_scalars.values(): - assert isinstance(value, HistoryBuffer), \ - ("The type of log_scalars'value must be HistoryBuffer, but " - f'got {type(value)}') - - for key in self._resumed_keys.keys(): - assert key in self._log_scalars or key in self._runtime_info, \ - ('Key in `resumed_keys` must contained in `log_scalars` or ' - f'`runtime_info`, but got {key}') - - @classmethod - def get_current_instance(cls) -> 'MessageHub': - """Get latest created ``MessageHub`` instance. - - :obj:`MessageHub` can call :meth:`get_current_instance` before any - instance has been created, and return a message hub with the instance - name "mmengine". - - Returns: - MessageHub: Empty ``MessageHub`` instance. - """ - if not cls._instance_dict: - cls.get_instance('mmengine') - return super().get_current_instance() - - def update_scalar(self, - key: str, - value: Union[int, float, np.ndarray, 'torch.Tensor'], - count: int = 1, - resumed: bool = True) -> None: - """Update :attr:_log_scalars. - - Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key - ``HistoryBuffer`` has been created, ``value`` and ``count`` is the - argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar`` - will create an ``HistoryBuffer`` with value and count via the - constructor of ``HistoryBuffer``. - - Examples: - >>> message_hub = MessageHub(name='name') - >>> # create loss `HistoryBuffer` with value=1, count=1 - >>> message_hub.update_scalar('loss', 1) - >>> # update loss `HistoryBuffer` with value - >>> message_hub.update_scalar('loss', 3) - >>> message_hub.update_scalar('loss', 3, resumed=False) - AssertionError: loss used to be true, but got false now. resumed - keys cannot be modified repeatedly' - - Note: - The ``resumed`` argument needs to be consistent for the same - ``key``. - - Args: - key (str): Key of ``HistoryBuffer``. - value (torch.Tensor or np.ndarray or int or float): Value of log. - count (torch.Tensor or np.ndarray or int or float): Accumulation - times of log, defaults to 1. `count` will be used in smooth - statistics. - resumed (str): Whether the corresponding ``HistoryBuffer`` - could be resumed. Defaults to True. - """ - self._set_resumed_keys(key, resumed) - checked_value = self._get_valid_value(value) - assert isinstance(count, int), ( - f'The type of count must be int. but got {type(count): {count}}') - if key in self._log_scalars: - self._log_scalars[key].update(checked_value, count) - else: - self._log_scalars[key] = HistoryBuffer([checked_value], [count]) - - def update_scalars(self, log_dict: dict, resumed: bool = True) -> None: - """Update :attr:`_log_scalars` with a dict. - - ``update_scalars`` iterates through each pair of log_dict key-value, - and calls ``update_scalar``. If type of value is dict, the value should - be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in - ``log_dict`` has the same resume option. - - Note: - The ``resumed`` argument needs to be consistent for the same - ``log_dict``. - - Args: - log_dict (str): Used for batch updating :attr:`_log_scalars`. - resumed (bool): Whether all ``HistoryBuffer`` referred in - log_dict should be resumed. Defaults to True. - - Examples: - >>> message_hub = MessageHub.get_instance('mmengine') - >>> log_dict = dict(a=1, b=2, c=3) - >>> message_hub.update_scalars(log_dict) - >>> # The default count of `a`, `b` and `c` is 1. - >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) - >>> message_hub.update_scalars(log_dict) - >>> # The count of `c` is 2. - """ - assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, ' - f'but got {type(log_dict)}') - for log_name, log_val in log_dict.items(): - if isinstance(log_val, dict): - assert 'value' in log_val, \ - f'value must be defined in {log_val}' - count = self._get_valid_value(log_val.get('count', 1)) - value = log_val['value'] - else: - count = 1 - value = log_val - assert isinstance(count, - int), ('The type of count must be int. but got ' - f'{type(count): {count}}') - self.update_scalar(log_name, value, count, resumed) - - def update_info(self, key: str, value: Any, resumed: bool = True) -> None: - """Update runtime information. - - The key corresponding runtime information will be overwritten each - time calling ``update_info``. - - Note: - The ``resumed`` argument needs to be consistent for the same - ``key``. - - Examples: - >>> message_hub = MessageHub(name='name') - >>> message_hub.update_info('iter', 100) - - Args: - key (str): Key of runtime information. - value (Any): Value of runtime information. - resumed (bool): Whether the corresponding ``HistoryBuffer`` - could be resumed. - """ - self._set_resumed_keys(key, resumed) - self._runtime_info[key] = value - - def pop_info(self, key: str, default: Optional[Any] = None) -> Any: - """Remove runtime information by key. If the key does not exist, this - method will return the default value. - - Args: - key (str): Key of runtime information. - default (Any, optional): The default returned value for the - given key. - - Returns: - Any: The runtime information if the key exists. - """ - return self._runtime_info.pop(key, default) - - def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None: - """Update runtime information with dictionary. - - The key corresponding runtime information will be overwritten each - time calling ``update_info``. - - Note: - The ``resumed`` argument needs to be consistent for the same - ``info_dict``. - - Examples: - >>> message_hub = MessageHub(name='name') - >>> message_hub.update_info({'iter': 100}) - - Args: - info_dict (str): Runtime information dictionary. - resumed (bool): Whether the corresponding ``HistoryBuffer`` - could be resumed. - """ - assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, ' - f'but got {type(info_dict)}') - for key, value in info_dict.items(): - self.update_info(key, value, resumed=resumed) - - def _set_resumed_keys(self, key: str, resumed: bool) -> None: - """Set corresponding resumed keys. - - This method is called by ``update_scalar``, ``update_scalars`` and - ``update_info`` to set the corresponding key is true or false in - :attr:`_resumed_keys`. - - Args: - key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`. - resumed (bool): Whether the corresponding ``HistoryBuffer`` - could be resumed. - """ - if key not in self._resumed_keys: - self._resumed_keys[key] = resumed - else: - assert self._resumed_keys[key] == resumed, \ - f'{key} used to be {self._resumed_keys[key]}, but got ' \ - '{resumed} now. resumed keys cannot be modified repeatedly.' - - @property - def log_scalars(self) -> OrderedDict: - """Get all ``HistoryBuffer`` instances. - - Note: - Considering the large memory footprint of history buffers in the - post-training, :meth:`get_scalar` will return a reference of - history buffer rather than a copy. - - Returns: - OrderedDict: All ``HistoryBuffer`` instances. - """ - return self._log_scalars - - @property - def runtime_info(self) -> OrderedDict: - """Get all runtime information. - - Returns: - OrderedDict: A copy of all runtime information. - """ - return self._runtime_info - - def get_scalar(self, key: str) -> HistoryBuffer: - """Get ``HistoryBuffer`` instance by key. - - Note: - Considering the large memory footprint of history buffers in the - post-training, :meth:`get_scalar` will not return a reference of - history buffer rather than a copy. - - Args: - key (str): Key of ``HistoryBuffer``. - - Returns: - HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the - key exists. - """ - if key not in self.log_scalars: - raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' - f'instance name is: {MessageHub.instance_name}') - return self.log_scalars[key] - - def get_info(self, key: str, default: Optional[Any] = None) -> Any: - """Get runtime information by key. If the key does not exist, this - method will return default information. - - Args: - key (str): Key of runtime information. - default (Any, optional): The default returned value for the - given key. - - Returns: - Any: A copy of corresponding runtime information if the key exists. - """ - if key not in self.runtime_info: - return default - else: - # TODO: There are restrictions on objects that can be saved - # return copy.deepcopy(self._runtime_info[key]) - return self._runtime_info[key] - - def _get_valid_value( - self, - value: Union['torch.Tensor', np.ndarray, np.number, int, float], - ) -> Union[int, float]: - """Convert value to python built-in type. - - Args: - value (torch.Tensor or np.ndarray or np.number or int or float): - value of log. - - Returns: - float or int: python built-in type value. - """ - if isinstance(value, (np.ndarray, np.number)): - assert value.size == 1 - value = value.item() - elif isinstance(value, (int, float)): - value = value - else: - # check whether value is torch.Tensor but don't want - # to import torch in this file - assert hasattr(value, 'numel') and value.numel() == 1 - value = value.item() - return value # type: ignore - - def state_dict(self) -> dict: - """Returns a dictionary containing log scalars, runtime information and - resumed keys, which should be resumed. - - The returned ``state_dict`` can be loaded by :meth:`load_state_dict`. - - Returns: - dict: A dictionary contains ``log_scalars``, ``runtime_info`` and - ``resumed_keys``. - """ - saved_scalars = OrderedDict() - saved_info = OrderedDict() - - for key, value in self._log_scalars.items(): - if self._resumed_keys.get(key, False): - saved_scalars[key] = copy.deepcopy(value) - - for key, value in self._runtime_info.items(): - if self._resumed_keys.get(key, False): - try: - saved_info[key] = copy.deepcopy(value) - except: # noqa: E722 - print_log( - f'{key} in message_hub cannot be copied, ' - f'just return its reference. ', - logger='current', - level=logging.WARNING) - saved_info[key] = value - return dict( - log_scalars=saved_scalars, - runtime_info=saved_info, - resumed_keys=self._resumed_keys) - - def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None: - """Loads log scalars, runtime information and resumed keys from - ``state_dict`` or ``message_hub``. - - If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it - will only make copies of data which should be resumed from the source - ``message_hub``. - - If ``state_dict`` is a ``message_hub`` instance, it will make copies of - all data from the source message_hub. We suggest to load data from - ``dict`` rather than a ``MessageHub`` instance. - - Args: - state_dict (dict or MessageHub): A dictionary contains key - ``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a - MessageHub instance. - """ - if isinstance(state_dict, dict): - for key in ('log_scalars', 'runtime_info', 'resumed_keys'): - assert key in state_dict, ( - 'The loaded `state_dict` of `MessageHub` must contain ' - f'key: `{key}`') - # The old `MessageHub` could save non-HistoryBuffer `log_scalars`, - # therefore the loaded `log_scalars` needs to be filtered. - for key, value in state_dict['log_scalars'].items(): - if not isinstance(value, HistoryBuffer): - print_log( - f'{key} in message_hub is not HistoryBuffer, ' - f'just skip resuming it.', - logger='current', - level=logging.WARNING) - continue - self.log_scalars[key] = value - - for key, value in state_dict['runtime_info'].items(): - try: - self._runtime_info[key] = copy.deepcopy(value) - except: # noqa: E722 - print_log( - f'{key} in message_hub cannot be copied, ' - f'just return its reference.', - logger='current', - level=logging.WARNING) - self._runtime_info[key] = value - - for key, value in state_dict['resumed_keys'].items(): - if key not in set(self.log_scalars.keys()) | \ - set(self._runtime_info.keys()): - print_log( - f'resumed key: {key} is not defined in message_hub, ' - f'just skip resuming this key.', - logger='current', - level=logging.WARNING) - continue - elif not value: - print_log( - f'Although resumed key: {key} is False, {key} ' - 'will still be loaded this time. This key will ' - 'not be saved by the next calling of ' - '`MessageHub.state_dict()`', - logger='current', - level=logging.WARNING) - self._resumed_keys[key] = value - - # Since some checkpoints saved serialized `message_hub` instance, - # `load_state_dict` support loading `message_hub` instance for - # compatibility - else: - self._log_scalars = copy.deepcopy(state_dict._log_scalars) - self._runtime_info = copy.deepcopy(state_dict._runtime_info) - self._resumed_keys = copy.deepcopy(state_dict._resumed_keys) - - def _parse_input(self, name: str, value: Any) -> OrderedDict: - """Parse input value. - - Args: - name (str): name of input value. - value (Any): Input value. - - Returns: - dict: Parsed input value. - """ - if value is None: - return OrderedDict() - elif isinstance(value, dict): - return OrderedDict(value) - else: - raise TypeError(f'{name} should be a dict or `None`, but ' - f'got {type(name)}') diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py deleted file mode 100644 index 033512a985..0000000000 --- a/mmengine/model/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils.dl_utils import TORCH_VERSION -from mmengine.utils.version_utils import digit_version -from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage, - MomentumAnnealingEMA, StochasticWeightAverage) -from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from .base_module import BaseModule, ModuleDict, ModuleList, Sequential -from .test_time_aug import BaseTTAModel -from .utils import (convert_sync_batchnorm, detect_anomalous_params, - merge_dict, revert_sync_batchnorm, stack_batch) -from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit, - KaimingInit, NormalInit, PretrainedInit, - TruncNormalInit, UniformInit, XavierInit, - bias_init_with_prob, caffe2_xavier_init, - constant_init, initialize, kaiming_init, normal_init, - trunc_normal_init, uniform_init, update_init_info, - xavier_init) -from .wrappers import (MMDistributedDataParallel, - MMSeparateDistributedDataParallel, is_model_wrapper) - -__all__ = [ - 'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel', - 'StochasticWeightAverage', 'ExponentialMovingAverage', - 'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor', - 'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule', - 'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList', - 'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info', - 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', - 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', - 'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit', - 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', - 'Caffe2XavierInit', 'PretrainedInit', 'initialize', - 'convert_sync_batchnorm', 'BaseTTAModel' -] - -if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): - from .wrappers import MMFullyShardedDataParallel # noqa:F401 - __all__.append('MMFullyShardedDataParallel') diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py deleted file mode 100644 index 58457c2a6e..0000000000 --- a/mmengine/model/averaged_model.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from abc import abstractmethod -from copy import deepcopy -from typing import Optional - -import torch -import torch.nn as nn -from torch import Tensor - -from mmengine.logging import print_log -from mmengine.registry import MODELS - - -class BaseAveragedModel(nn.Module): - """A base class for averaging model weights. - - Weight averaging, such as SWA and EMA, is a widely used technique for - training neural networks. This class implements the averaging process - for a model. All subclasses must implement the `avg_func` method. - This class creates a copy of the provided module :attr:`model` - on the :attr:`device` and allows computing running averages of the - parameters of the :attr:`model`. - - The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py. - - Different from the `AveragedModel` in PyTorch, we use in-place operation - to improve the parameter updating speed, which is about 5 times faster - than the non-in-place version. - - In mmengine, we provide two ways to use the model averaging: - - 1. Use the model averaging module in hook: - We provide an :class:`mmengine.hooks.EMAHook` to apply the model - averaging during training. Add ``custom_hooks=[dict(type='EMAHook')]`` - to the config or the runner. - - 2. Use the model averaging module directly in the algorithm. Take the ema - teacher in semi-supervise as an example: - - >>> from mmengine.model import ExponentialMovingAverage - >>> student = ResNet(depth=50) - >>> # use ema model as teacher - >>> ema_teacher = ExponentialMovingAverage(student) - - Args: - model (nn.Module): The model to be averaged. - interval (int): Interval between two updates. Defaults to 1. - device (torch.device, optional): If provided, the averaged model will - be stored on the :attr:`device`. Defaults to None. - update_buffers (bool): if True, it will compute running averages for - both the parameters and the buffers of the model. Defaults to - False. - """ # noqa: E501 - - def __init__(self, - model: nn.Module, - interval: int = 1, - device: Optional[torch.device] = None, - update_buffers: bool = False) -> None: - super().__init__() - self.module = deepcopy(model).requires_grad_(False) - self.interval = interval - if device is not None: - self.module = self.module.to(device) - self.register_buffer('steps', - torch.tensor(0, dtype=torch.long, device=device)) - self.update_buffers = update_buffers - if update_buffers: - self.avg_parameters = self.module.state_dict() - else: - self.avg_parameters = dict(self.module.named_parameters()) - - @abstractmethod - def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> None: - """Use in-place operation to compute the average of the parameters. All - subclasses must implement this method. - - Args: - averaged_param (Tensor): The averaged parameters. - source_param (Tensor): The source parameters. - steps (int): The number of times the parameters have been - updated. - """ - - def forward(self, *args, **kwargs): - """Forward method of the averaged model.""" - return self.module(*args, **kwargs) - - def update_parameters(self, model: nn.Module) -> None: - """Update the parameters of the model. This method will execute the - ``avg_func`` to compute the new parameters and update the model's - parameters. - - Args: - model (nn.Module): The model whose parameters will be averaged. - """ - src_parameters = ( - model.state_dict() - if self.update_buffers else dict(model.named_parameters())) - if self.steps == 0: - for k, p_avg in self.avg_parameters.items(): - p_avg.data.copy_(src_parameters[k].data) - elif self.steps % self.interval == 0: - for k, p_avg in self.avg_parameters.items(): - if p_avg.dtype.is_floating_point: - device = p_avg.device - self.avg_func(p_avg.data, - src_parameters[k].data.to(device), - self.steps) - if not self.update_buffers: - # If not update the buffers, - # keep the buffers in sync with the source model. - for b_avg, b_src in zip(self.module.buffers(), model.buffers()): - b_avg.data.copy_(b_src.data.to(b_avg.device)) - self.steps += 1 - - -@MODELS.register_module() -class StochasticWeightAverage(BaseAveragedModel): - """Implements the stochastic weight averaging (SWA) of the model. - - Stochastic Weight Averaging was proposed in `Averaging Weights Leads to - Wider Optima and Better Generalization, UAI 2018. - `_ by Pavel Izmailov, Dmitrii - Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson. - """ - - def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> None: - """Compute the average of the parameters using stochastic weight - average. - - Args: - averaged_param (Tensor): The averaged parameters. - source_param (Tensor): The source parameters. - steps (int): The number of times the parameters have been - updated. - """ - averaged_param.add_( - source_param - averaged_param, - alpha=1 / float(steps // self.interval + 1)) - - -@MODELS.register_module() -class ExponentialMovingAverage(BaseAveragedModel): - r"""Implements the exponential moving average (EMA) of the model. - - All parameters are updated by the formula as below: - - .. math:: - - Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t - - .. note:: - This :attr:`momentum` argument is different from one used in optimizer - classes and the conventional notion of momentum. Mathematically, - :math:`Xema_{t+1}` is the moving average and :math:`X_t` is the - new observed value. The value of momentum is usually a small number, - allowing observed values to slowly update the ema parameters. - - Args: - model (nn.Module): The model to be averaged. - momentum (float): The momentum used for updating ema parameter. - Defaults to 0.0002. - Ema's parameter are updated with the formula - :math:`averaged\_param = (1-momentum) * averaged\_param + - momentum * source\_param`. - interval (int): Interval between two updates. Defaults to 1. - device (torch.device, optional): If provided, the averaged model will - be stored on the :attr:`device`. Defaults to None. - update_buffers (bool): if True, it will compute running averages for - both the parameters and the buffers of the model. Defaults to - False. - """ # noqa: W605 - - def __init__(self, - model: nn.Module, - momentum: float = 0.0002, - interval: int = 1, - device: Optional[torch.device] = None, - update_buffers: bool = False) -> None: - super().__init__(model, interval, device, update_buffers) - assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\ - f'but got {momentum}' - if momentum > 0.5: - print_log( - 'The value of momentum in EMA is usually a small number,' - 'which is different from the conventional notion of ' - f'momentum but got {momentum}. Please make sure the ' - f'value is correct.', - logger='current', - level=logging.WARNING) - self.momentum = momentum - - def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> None: - """Compute the moving average of the parameters using exponential - moving average. - - Args: - averaged_param (Tensor): The averaged parameters. - source_param (Tensor): The source parameters. - steps (int): The number of times the parameters have been - updated. - """ - averaged_param.lerp_(source_param, self.momentum) - - -@MODELS.register_module() -class MomentumAnnealingEMA(ExponentialMovingAverage): - r"""Exponential moving average (EMA) with momentum annealing strategy. - - Args: - model (nn.Module): The model to be averaged. - momentum (float): The momentum used for updating ema parameter. - Defaults to 0.0002. - Ema's parameter are updated with the formula - :math:`averaged\_param = (1-momentum) * averaged\_param + - momentum * source\_param`. - gamma (int): Use a larger momentum early in training and gradually - annealing to a smaller value to update the ema model smoothly. The - momentum is calculated as max(momentum, gamma / (gamma + steps)) - Defaults to 100. - interval (int): Interval between two updates. Defaults to 1. - device (torch.device, optional): If provided, the averaged model will - be stored on the :attr:`device`. Defaults to None. - update_buffers (bool): if True, it will compute running averages for - both the parameters and the buffers of the model. Defaults to - False. - """ - - def __init__(self, - model: nn.Module, - momentum: float = 0.0002, - gamma: int = 100, - interval: int = 1, - device: Optional[torch.device] = None, - update_buffers: bool = False) -> None: - super().__init__( - model=model, - momentum=momentum, - interval=interval, - device=device, - update_buffers=update_buffers) - assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' - self.gamma = gamma - - def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> None: - """Compute the moving average of the parameters using the linear - momentum strategy. - - Args: - averaged_param (Tensor): The averaged parameters. - source_param (Tensor): The source parameters. - steps (int): The number of times the parameters have been - updated. - """ - momentum = max(self.momentum, - self.gamma / (self.gamma + self.steps.item())) - averaged_param.lerp_(source_param, momentum) diff --git a/mmengine/model/base_model/__init__.py b/mmengine/model/base_model/__init__.py deleted file mode 100644 index 66a3cb89a9..0000000000 --- a/mmengine/model/base_model/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_model import BaseModel -from .data_preprocessor import BaseDataPreprocessor, ImgDataPreprocessor - -__all__ = ['BaseModel', 'ImgDataPreprocessor', 'BaseDataPreprocessor'] diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py deleted file mode 100644 index 299cd67557..0000000000 --- a/mmengine/model/base_model/base_model.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import abstractmethod -from collections import OrderedDict -from typing import Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from mmengine.optim import OptimWrapper -from mmengine.registry import MODELS -from mmengine.utils import is_list_of -from ..base_module import BaseModule -from .data_preprocessor import BaseDataPreprocessor - - -class BaseModel(BaseModule): - """Base class for all algorithmic models. - - BaseModel implements the basic functions of the algorithmic model, such as - weights initialize, batch inputs preprocess(see more information in - :class:`BaseDataPreprocessor`), parse losses, and update model parameters. - - Subclasses inherit from BaseModel only need to implement the forward - method, which implements the logic to calculate loss and predictions, - then can be trained in the runner. - - Examples: - >>> @MODELS.register_module() - >>> class ToyModel(BaseModel): - >>> - >>> def __init__(self): - >>> super().__init__() - >>> self.backbone = nn.Sequential() - >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) - >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) - >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) - >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) - >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) - >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) - >>> - >>> self.criterion = nn.CrossEntropyLoss() - >>> - >>> def forward(self, batch_inputs, data_samples, mode='tensor'): - >>> data_samples = torch.stack(data_samples) - >>> if mode == 'tensor': - >>> return self.backbone(batch_inputs) - >>> elif mode == 'predict': - >>> feats = self.backbone(batch_inputs) - >>> predictions = torch.argmax(feats, 1) - >>> return predictions - >>> elif mode == 'loss': - >>> feats = self.backbone(batch_inputs) - >>> loss = self.criterion(feats, data_samples) - >>> return dict(loss=loss) - - Args: - data_preprocessor (dict, optional): The pre-process config of - :class:`BaseDataPreprocessor`. - init_cfg (dict, optional): The weight initialized config for - :class:`BaseModule`. - - Attributes: - data_preprocessor (:obj:`BaseDataPreprocessor`): Used for - pre-processing data sampled by dataloader to the format accepted by - :meth:`forward`. - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, - data_preprocessor: Optional[Union[dict, nn.Module]] = None, - init_cfg: Optional[dict] = None): - super().__init__(init_cfg) - if data_preprocessor is None: - data_preprocessor = dict(type='BaseDataPreprocessor') - if isinstance(data_preprocessor, nn.Module): - self.data_preprocessor = data_preprocessor - elif isinstance(data_preprocessor, dict): - self.data_preprocessor = MODELS.build(data_preprocessor) - else: - raise TypeError('data_preprocessor should be a `dict` or ' - f'`nn.Module` instance, but got ' - f'{type(data_preprocessor)}') - - def train_step(self, data: Union[dict, tuple, list], - optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: - """Implements the default model training process including - preprocessing, model forward propagation, loss calculation, - optimization, and back-propagation. - - During non-distributed training. If subclasses do not override the - :meth:`train_step`, :class:`EpochBasedTrainLoop` or - :class:`IterBasedTrainLoop` will call this method to update model - parameters. The default parameter update process is as follows: - - 1. Calls ``self.data_processor(data, training=False)`` to collect - batch_inputs and corresponding data_samples(labels). - 2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw - loss - 3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to - backward and dict of loss tensor used to log messages. - 4. Calls ``optim_wrapper.update_params(loss)`` to update model. - - Args: - data (dict or tuple or list): Data sampled from dataset. - optim_wrapper (OptimWrapper): OptimWrapper instance - used to update model parameters. - - Returns: - Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. - """ - # Enable automatic mixed precision training context. - with optim_wrapper.optim_context(self): - data = self.data_preprocessor(data, True) - losses = self._run_forward(data, mode='loss') # type: ignore - parsed_losses, log_vars = self.parse_losses(losses) # type: ignore - optim_wrapper.update_params(parsed_losses) - return log_vars - - def val_step(self, data: Union[tuple, dict, list]) -> list: - """Gets the predictions of given data. - - Calls ``self.data_preprocessor(data, False)`` and - ``self(inputs, data_sample, mode='predict')`` in order. Return the - predictions which will be passed to evaluator. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.data_preprocessor(data, False) - return self._run_forward(data, mode='predict') # type: ignore - - def test_step(self, data: Union[dict, tuple, list]) -> list: - """``BaseModel`` implements ``test_step`` the same as ``val_step``. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.data_preprocessor(data, False) - return self._run_forward(data, mode='predict') # type: ignore - - def parse_losses( - self, losses: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """Parses the raw outputs (losses) of the network. - - Args: - losses (dict): Raw output of the network, which usually contain - losses and other necessary information. - - Returns: - tuple[Tensor, dict]: There are two elements. The first is the - loss tensor passed to optim_wrapper which may be a weighted sum - of all losses, and the second is log_vars which will be sent to - the logger. - """ - log_vars = [] - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - log_vars.append([loss_name, loss_value.mean()]) - elif is_list_of(loss_value, torch.Tensor): - log_vars.append( - [loss_name, - sum(_loss.mean() for _loss in loss_value)]) - else: - raise TypeError( - f'{loss_name} is not a tensor or list of tensors') - - loss = sum(value for key, value in log_vars if 'loss' in key) - log_vars.insert(0, ['loss', loss]) - log_vars = OrderedDict(log_vars) # type: ignore - - return loss, log_vars # type: ignore - - def to(self, *args, **kwargs) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.to` - additionally. - - Returns: - nn.Module: The model itself. - """ - - # Since Torch has not officially merged - # the npu-related fields, using the _parse_to function - # directly will cause the NPU to not be found. - # Here, the input parameters are processed to avoid errors. - if args and isinstance(args[0], str) and 'npu' in args[0]: - import torch_npu - args = tuple([ - list(args)[0].replace( - 'npu', torch_npu.npu.native_device if hasattr( - torch_npu.npu, 'native_device') else 'privateuseone') - ]) - if kwargs and 'npu' in str(kwargs.get('device', '')): - import torch_npu - kwargs['device'] = kwargs['device'].replace( - 'npu', torch_npu.npu.native_device if hasattr( - torch_npu.npu, 'native_device') else 'privateuseone') - - device = torch._C._nn._parse_to(*args, **kwargs)[0] - if device is not None: - self._set_device(torch.device(device)) - return super().to(*args, **kwargs) - - def cuda( - self, - device: Optional[Union[int, str, torch.device]] = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.cuda` - additionally. - - Returns: - nn.Module: The model itself. - """ - if device is None or isinstance(device, int): - device = torch.device('cuda', index=device) - self._set_device(torch.device(device)) - return super().cuda(device) - - def musa( - self, - device: Optional[Union[int, str, torch.device]] = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.musa` - additionally. - - Returns: - nn.Module: The model itself. - """ - if device is None or isinstance(device, int): - device = torch.device('musa', index=device) - self._set_device(torch.device(device)) - return super().musa(device) - - def mlu( - self, - device: Union[int, str, torch.device, None] = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.mlu` - additionally. - - Returns: - nn.Module: The model itself. - """ - device = torch.device('mlu', torch.mlu.current_device()) - self._set_device(device) - return super().mlu() - - def npu( - self, - device: Union[int, str, torch.device, None] = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.npu` - additionally. - - Returns: - nn.Module: The model itself. - - Note: - This generation of NPU(Ascend910) does not support - the use of multiple cards in a single process, - so the index here needs to be consistent with the default device - """ - device = torch.npu.current_device() - self._set_device(device) - return super().npu() - - def cpu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.cpu` - additionally. - - Returns: - nn.Module: The model itself. - """ - self._set_device(torch.device('cpu')) - return super().cpu() - - def _set_device(self, device: torch.device) -> None: - """Recursively set device for `BaseDataPreprocessor` instance. - - Args: - device (torch.device): the desired device of the parameters and - buffers in this module. - """ - - def apply_fn(module): - if not isinstance(module, BaseDataPreprocessor): - return - if device is not None: - module._device = device - - self.apply(apply_fn) - - @abstractmethod - def forward(self, - inputs: torch.Tensor, - data_samples: Optional[list] = None, - mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: - """Returns losses or predictions of training, validation, testing, and - simple inference process. - - ``forward`` method of BaseModel is an abstract method, its subclasses - must implement this method. - - Accepts ``batch_inputs`` and ``data_sample`` processed by - :attr:`data_preprocessor`, and returns results according to mode - arguments. - - During non-distributed training, validation, and testing process, - ``forward`` will be called by ``BaseModel.train_step``, - ``BaseModel.val_step`` and ``BaseModel.test_step`` directly. - - During distributed data parallel training process, - ``MMSeparateDistributedDataParallel.train_step`` will first call - ``DistributedDataParallel.forward`` to enable automatic - gradient synchronization, and then call ``forward`` to get training - loss. - - Args: - inputs (torch.Tensor): batch input tensor collated by - :attr:`data_preprocessor`. - data_samples (list, optional): - data samples collated by :attr:`data_preprocessor`. - mode (str): mode should be one of ``loss``, ``predict`` and - ``tensor`` - - - ``loss``: Called by ``train_step`` and return loss ``dict`` - used for logging - - ``predict``: Called by ``val_step`` and ``test_step`` - and return list of results used for computing metric. - - ``tensor``: Called by custom use to get ``Tensor`` type - results. - - Returns: - dict or list: - - If ``mode == loss``, return a ``dict`` of loss tensor used - for backward and logging. - - If ``mode == predict``, return a ``list`` of inference - results. - - If ``mode == tensor``, return a tensor or ``tuple`` of tensor - or ``dict`` of tensor for custom use. - """ - - def _run_forward(self, data: Union[dict, tuple, list], - mode: str) -> Union[Dict[str, torch.Tensor], list]: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self(**data, mode=mode) - elif isinstance(data, (list, tuple)): - results = self(*data, mode=mode) - else: - raise TypeError('Output of `data_preprocessor` should be ' - f'list, tuple or dict, but got {type(data)}') - return results diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py deleted file mode 100644 index 4d621851b0..0000000000 --- a/mmengine/model/base_model/data_preprocessor.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -from typing import Mapping, Optional, Sequence, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from mmengine.registry import MODELS -from mmengine.structures import BaseDataElement -from mmengine.utils import is_seq_of -from ..utils import stack_batch - -CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, - None] - - -@MODELS.register_module() -class BaseDataPreprocessor(nn.Module): - """Base data pre-processor used for copying data to the target device. - - Subclasses inherit from ``BaseDataPreprocessor`` could override the - forward method to implement custom data pre-processing, such as - batch-resize, MixUp, or CutMix. - - Args: - non_blocking (bool): Whether block current process - when transferring data to device. - New in version 0.3.0. - - Note: - Data dictionary returned by dataloader must be a dict and at least - contain the ``inputs`` key. - """ - - def __init__(self, non_blocking: Optional[bool] = False): - super().__init__() - self._non_blocking = non_blocking - self._device = torch.device('cpu') - - def cast_data(self, data: CastData) -> CastData: - """Copying data to the target device. - - Args: - data (dict): Data returned by ``DataLoader``. - - Returns: - CollatedResult: Inputs and data sample at target device. - """ - if isinstance(data, Mapping): - return {key: self.cast_data(data[key]) for key in data} - elif isinstance(data, (str, bytes)) or data is None: - return data - elif isinstance(data, tuple) and hasattr(data, '_fields'): - # namedtuple - return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable - elif isinstance(data, Sequence): - return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable - elif isinstance(data, (torch.Tensor, BaseDataElement)): - return data.to(self.device, non_blocking=self._non_blocking) - else: - return data - - def forward(self, data: dict, training: bool = False) -> Union[dict, list]: - """Preprocesses the data into the model input format. - - After the data pre-processing of :meth:`cast_data`, ``forward`` - will stack the input tensor list to a batch tensor at the first - dimension. - - Args: - data (dict): Data returned by dataloader - training (bool): Whether to enable training time augmentation. - - Returns: - dict or list: Data in the same format as the model input. - """ - return self.cast_data(data) # type: ignore - - @property - def device(self): - return self._device - - def to(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - - # Since Torch has not officially merged - # the npu-related fields, using the _parse_to function - # directly will cause the NPU to not be found. - # Here, the input parameters are processed to avoid errors. - if args and isinstance(args[0], str) and 'npu' in args[0]: - args = tuple( - [list(args)[0].replace('npu', torch.npu.native_device)]) - if kwargs and 'npu' in str(kwargs.get('device', '')): - kwargs['device'] = kwargs['device'].replace( - 'npu', torch.npu.native_device) - - device = torch._C._nn._parse_to(*args, **kwargs)[0] - if device is not None: - self._device = torch.device(device) - return super().to(*args, **kwargs) - - def cuda(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.cuda.current_device()) - return super().cuda() - - def musa(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.musa.current_device()) - return super().musa() - - def npu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.npu.current_device()) - return super().npu() - - def mlu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.mlu.current_device()) - return super().mlu() - - def cpu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device('cpu') - return super().cpu() - - -@MODELS.register_module() -class ImgDataPreprocessor(BaseDataPreprocessor): - """Image pre-processor for normalization and bgr to rgb conversion. - - Accepts the data sampled by the dataloader, and preprocesses it into the - format of the model input. ``ImgDataPreprocessor`` provides the - basic data pre-processing as follows - - - Collates and moves data to the target device. - - Converts inputs from bgr to rgb if the shape of input is (3, H, W). - - Normalizes image with defined std and mean. - - Pads inputs to the maximum size of current batch with defined - ``pad_value``. The padding size can be divisible by a defined - ``pad_size_divisor`` - - Stack inputs to batch_inputs. - - For ``ImgDataPreprocessor``, the dimension of the single inputs must be - (3, H, W). - - Note: - ``ImgDataPreprocessor`` and its subclass is built in the - constructor of :class:`BaseDataset`. - - Args: - mean (Sequence[float or int], optional): The pixel mean of image - channels. If ``bgr_to_rgb=True`` it means the mean value of R, - G, B channels. If the length of `mean` is 1, it means all - channels have the same mean value, or the input is a gray image. - If it is not specified, images will not be normalized. Defaults - None. - std (Sequence[float or int], optional): The pixel standard deviation of - image channels. If ``bgr_to_rgb=True`` it means the standard - deviation of R, G, B channels. If the length of `std` is 1, - it means all channels have the same standard deviation, or the - input is a gray image. If it is not specified, images will - not be normalized. Defaults None. - pad_size_divisor (int): The size of padded image should be - divisible by ``pad_size_divisor``. Defaults to 1. - pad_value (float or int): The padded pixel value. Defaults to 0. - bgr_to_rgb (bool): whether to convert image from BGR to RGB. - Defaults to False. - rgb_to_bgr (bool): whether to convert image from RGB to RGB. - Defaults to False. - non_blocking (bool): Whether block current process - when transferring data to device. - New in version v0.3.0. - - Note: - if images do not need to be normalized, `std` and `mean` should be - both set to None, otherwise both of them should be set to a tuple of - corresponding values. - """ - - def __init__(self, - mean: Optional[Sequence[Union[float, int]]] = None, - std: Optional[Sequence[Union[float, int]]] = None, - pad_size_divisor: int = 1, - pad_value: Union[float, int] = 0, - bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False, - non_blocking: Optional[bool] = False): - super().__init__(non_blocking) - assert not (bgr_to_rgb and rgb_to_bgr), ( - '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') - assert (mean is None) == (std is None), ( - 'mean and std should be both None or tuple') - if mean is not None: - assert len(mean) == 3 or len(mean) == 1, ( - '`mean` should have 1 or 3 values, to be compatible with ' - f'RGB or gray image, but got {len(mean)} values') - assert len(std) == 3 or len(std) == 1, ( # type: ignore - '`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501 - f'or gray image, but got {len(std)} values') # type: ignore - self._enable_normalize = True - self.register_buffer('mean', - torch.tensor(mean).view(-1, 1, 1), False) - self.register_buffer('std', - torch.tensor(std).view(-1, 1, 1), False) - else: - self._enable_normalize = False - self._channel_conversion = rgb_to_bgr or bgr_to_rgb - self.pad_size_divisor = pad_size_divisor - self.pad_value = pad_value - - def forward(self, data: dict, training: bool = False) -> Union[dict, list]: - """Performs normalization, padding and bgr2rgb conversion based on - ``BaseDataPreprocessor``. - - Args: - data (dict): Data sampled from dataset. If the collate - function of DataLoader is :obj:`pseudo_collate`, data will be a - list of dict. If collate function is :obj:`default_collate`, - data will be a tuple with batch input tensor and list of data - samples. - training (bool): Whether to enable training time augmentation. If - subclasses override this method, they can perform different - preprocessing strategies for training and testing based on the - value of ``training``. - - Returns: - dict or list: Data in the same format as the model input. - """ - data = self.cast_data(data) # type: ignore - _batch_inputs = data['inputs'] # type: ignore - # Process data with `pseudo_collate`. - if is_seq_of(_batch_inputs, torch.Tensor): - batch_inputs = [] - for _batch_input in _batch_inputs: - # channel transform - if self._channel_conversion: - _batch_input = _batch_input[[2, 1, 0], ...] # type: ignore - # Convert to float after channel conversion to ensure - # efficiency - _batch_input = _batch_input.float() # type: ignore - # Normalization. - if self._enable_normalize: - if self.mean.shape[0] == 3: - assert _batch_input.dim( - ) == 3 and _batch_input.shape[0] == 3, ( - 'If the mean has 3 values, the input tensor ' - 'should in shape of (3, H, W), but got the tensor ' - f'with shape {_batch_input.shape}') - _batch_input = (_batch_input - self.mean) / self.std - batch_inputs.append(_batch_input) - # Pad and stack Tensor. - batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor, - self.pad_value) - # Process data with `default_collate`. - elif isinstance(_batch_inputs, torch.Tensor): - assert _batch_inputs.dim() == 4, ( - 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' - 'or a list of tensor, but got a tensor with shape: ' - f'{_batch_inputs.shape}') - if self._channel_conversion: - _batch_inputs = _batch_inputs[:, [2, 1, 0], ...] - # Convert to float after channel conversion to ensure - # efficiency - _batch_inputs = _batch_inputs.float() - if self._enable_normalize: - _batch_inputs = (_batch_inputs - self.mean) / self.std - h, w = _batch_inputs.shape[2:] - target_h = math.ceil( - h / self.pad_size_divisor) * self.pad_size_divisor - target_w = math.ceil( - w / self.pad_size_divisor) * self.pad_size_divisor - pad_h = target_h - h - pad_w = target_w - w - batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h), - 'constant', self.pad_value) - else: - raise TypeError('Output of `cast_data` should be a dict of ' - 'list/tuple with inputs and data_samples, ' - f'but got {type(data)}: {data}') # type: ignore - data['inputs'] = batch_inputs # type: ignore - data.setdefault('data_samples', None) # type: ignore - return data # type: ignore diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py deleted file mode 100644 index 3cfe0b14a8..0000000000 --- a/mmengine/model/base_module.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -from abc import ABCMeta -from collections import defaultdict -from logging import FileHandler -from typing import Iterable, List, Optional, Union - -import torch.nn as nn - -from mmengine.dist import master_only -from mmengine.logging import MMLogger, print_log -from .weight_init import PretrainedInit, initialize, update_init_info -from .wrappers.utils import is_model_wrapper - - -class BaseModule(nn.Module, metaclass=ABCMeta): - """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of - ``torch.nn.Module`` with additional functionality of parameter - initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly - adds three attributes. - - - ``init_cfg``: the config to control the initialization. - - ``init_weights``: The function of parameter initialization and recording - initialization information. - - ``_params_init_info``: Used to track the parameter initialization - information. This attribute only exists during executing the - ``init_weights``. - - Note: - :obj:`PretrainedInit` has a higher priority than any other - initializer. The loaded pretrained weights will overwrite - the previous initialized weights. - - Args: - init_cfg (dict or List[dict], optional): Initialization config dict. - """ - - def __init__(self, init_cfg: Union[dict, List[dict], None] = None): - """Initialize BaseModule, inherited from `torch.nn.Module`""" - - # NOTE init_cfg can be defined in different levels, but init_cfg - # in low levels has a higher priority. - - super().__init__() - # define default value of init_cfg instead of hard code - # in init_weights() function - self._is_init = False - - self.init_cfg = copy.deepcopy(init_cfg) - - # Backward compatibility in derived classes - # if pretrained is not None: - # warnings.warn('DeprecationWarning: pretrained is a deprecated \ - # key, please consider using init_cfg') - # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) - - @property - def is_init(self): - return self._is_init - - @is_init.setter - def is_init(self, value): - self._is_init = value - - def init_weights(self): - """Initialize the weights.""" - - is_top_level_module = False - # check if it is top-level module - if not hasattr(self, '_params_init_info'): - # The `_params_init_info` is used to record the initialization - # information of the parameters - # the key should be the obj:`nn.Parameter` of model and the value - # should be a dict containing - # - init_info (str): The string that describes the initialization. - # - tmp_mean_value (FloatTensor): The mean of the parameter, - # which indicates whether the parameter has been modified. - # this attribute would be deleted after all parameters - # is initialized. - self._params_init_info = defaultdict(dict) - is_top_level_module = True - - # Initialize the `_params_init_info`, - # When detecting the `tmp_mean_value` of - # the corresponding parameter is changed, update related - # initialization information - for name, param in self.named_parameters(): - self._params_init_info[param][ - 'init_info'] = f'The value is the same before and ' \ - f'after calling `init_weights` ' \ - f'of {self.__class__.__name__} ' - self._params_init_info[param][ - 'tmp_mean_value'] = param.data.mean().cpu() - - # pass `params_init_info` to all submodules - # All submodules share the same `params_init_info`, - # so it will be updated when parameters are - # modified at any level of the model. - for sub_module in self.modules(): - sub_module._params_init_info = self._params_init_info - - module_name = self.__class__.__name__ - if not self._is_init: - if self.init_cfg: - print_log( - f'initialize {module_name} with init_cfg {self.init_cfg}', - logger='current', - level=logging.DEBUG) - - init_cfgs = self.init_cfg - if isinstance(self.init_cfg, dict): - init_cfgs = [self.init_cfg] - - # PretrainedInit has higher priority than any other init_cfg. - # Therefore we initialize `pretrained_cfg` last to overwrite - # the previous initialized weights. - # See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501 - other_cfgs = [] - pretrained_cfg = [] - for init_cfg in init_cfgs: - assert isinstance(init_cfg, dict) - if (init_cfg['type'] == 'Pretrained' - or init_cfg['type'] is PretrainedInit): - pretrained_cfg.append(init_cfg) - else: - other_cfgs.append(init_cfg) - - initialize(self, other_cfgs) - - for m in self.children(): - if is_model_wrapper(m) and not hasattr(m, 'init_weights'): - m = m.module - if hasattr(m, 'init_weights') and not getattr( - m, 'is_init', False): - m.init_weights() - # users may overload the `init_weights` - update_init_info( - m, - init_info=f'Initialized by ' - f'user-defined `init_weights`' - f' in {m.__class__.__name__} ') - if self.init_cfg and pretrained_cfg: - initialize(self, pretrained_cfg) - self._is_init = True - else: - print_log( - f'init_weights of {self.__class__.__name__} has ' - f'been called more than once.', - logger='current', - level=logging.WARNING) - - if is_top_level_module: - self._dump_init_info() - - for sub_module in self.modules(): - del sub_module._params_init_info - - @master_only - def _dump_init_info(self): - """Dump the initialization information to a file named - `initialization.log.json` in workdir.""" - - logger = MMLogger.get_current_instance() - with_file_handler = False - # dump the information to the logger file if there is a `FileHandler` - for handler in logger.handlers: - if isinstance(handler, FileHandler): - handler.stream.write( - 'Name of parameter - Initialization information\n') - for name, param in self.named_parameters(): - handler.stream.write( - f'\n{name} - {param.shape}: ' - f"\n{self._params_init_info[param]['init_info']} \n") - handler.stream.flush() - with_file_handler = True - if not with_file_handler: - for name, param in self.named_parameters(): - logger.info( - f'\n{name} - {param.shape}: ' - f"\n{self._params_init_info[param]['init_info']} \n ") - - def __repr__(self): - s = super().__repr__() - if self.init_cfg: - s += f'\ninit_cfg={self.init_cfg}' - return s - - -class Sequential(BaseModule, nn.Sequential): - """Sequential module in openmmlab. - - Ensures that all modules in ``Sequential`` have a different initialization - strategy than the outer model - - Args: - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, *args, init_cfg: Optional[dict] = None): - BaseModule.__init__(self, init_cfg) - nn.Sequential.__init__(self, *args) - - -class ModuleList(BaseModule, nn.ModuleList): - """ModuleList in openmmlab. - - Ensures that all modules in ``ModuleList`` have a different initialization - strategy than the outer model - - Args: - modules (iterable, optional): An iterable of modules to add. - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, - modules: Optional[Iterable] = None, - init_cfg: Optional[dict] = None): - BaseModule.__init__(self, init_cfg) - nn.ModuleList.__init__(self, modules) - - -class ModuleDict(BaseModule, nn.ModuleDict): - """ModuleDict in openmmlab. - - Ensures that all modules in ``ModuleDict`` have a different initialization - strategy than the outer model - - Args: - modules (dict, optional): A mapping (dictionary) of (string: module) - or an iterable of key-value pairs of type (string, module). - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, - modules: Optional[dict] = None, - init_cfg: Optional[dict] = None): - BaseModule.__init__(self, init_cfg) - nn.ModuleDict.__init__(self, modules) diff --git a/mmengine/model/efficient_conv_bn_eval.py b/mmengine/model/efficient_conv_bn_eval.py deleted file mode 100644 index 9cb2ad6199..0000000000 --- a/mmengine/model/efficient_conv_bn_eval.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from operator import attrgetter -from typing import List, Union - -import torch -import torch.nn as nn - - -def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm, - conv: nn.modules.conv._ConvNd, - x: torch.Tensor): - """Code borrowed from mmcv 2.0.1, so that this feature can be used for old - mmcv versions. - - Implementation based on https://arxiv.org/abs/2305.11624 - "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" - It leverages the associative law between convolution and affine transform, - i.e., normalize (weight conv feature) = (normalize weight) conv feature. - It works for Eval mode of ConvBN blocks during validation, and can be used - for training as well. It reduces memory and computation cost. - Args: - bn (_BatchNorm): a BatchNorm module. - conv (nn._ConvNd): a conv module - x (torch.Tensor): Input feature map. - """ - # These lines of code are designed to deal with various cases - # like bn without affine transform, and conv without bias - weight_on_the_fly = conv.weight - if conv.bias is not None: - bias_on_the_fly = conv.bias - else: - bias_on_the_fly = torch.zeros_like(bn.running_var) - - if bn.weight is not None: - bn_weight = bn.weight - else: - bn_weight = torch.ones_like(bn.running_var) - - if bn.bias is not None: - bn_bias = bn.bias - else: - bn_bias = torch.zeros_like(bn.running_var) - - # shape of [C_out, 1, 1, 1] in Conv2d - weight_coeff = torch.rsqrt(bn.running_var + - bn.eps).reshape([-1] + [1] * - (len(conv.weight.shape) - 1)) - # shape of [C_out, 1, 1, 1] in Conv2d - coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff - - # shape of [C_out, C_in, k, k] in Conv2d - weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly - # shape of [C_out] in Conv2d - bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ - (bias_on_the_fly - bn.running_mean) - - return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) - - -def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm, - conv: nn.modules.conv._ConvNd, - x: torch.Tensor): - """This function controls whether to use `efficient_conv_bn_eval_forward`. - - If the following `bn` is in `eval` mode, then we turn on the special - `efficient_conv_bn_eval_forward`. - """ - if not bn.training: - # bn in eval mode - output = efficient_conv_bn_eval_forward(bn, conv, x) - return output - else: - conv_out = conv._conv_forward(x, conv.weight, conv.bias) - return bn(conv_out) - - -def efficient_conv_bn_eval_graph_transform(fx_model): - """Find consecutive conv+bn calls in the graph, inplace modify the graph - with the fused operation.""" - modules = dict(fx_model.named_modules()) - - patterns = [(torch.nn.modules.conv._ConvNd, - torch.nn.modules.batchnorm._BatchNorm)] - - pairs = [] - # Iterate through nodes in the graph to find ConvBN blocks - for node in fx_model.graph.nodes: - # If our current node isn't calling a Module then we can ignore it. - if node.op != 'call_module': - continue - target_module = modules[node.target] - found_pair = False - for conv_class, bn_class in patterns: - if isinstance(target_module, bn_class): - source_module = modules[node.args[0].target] - if isinstance(source_module, conv_class): - found_pair = True - # Not a conv-BN pattern or output of conv is used by other nodes - if not found_pair or len(node.args[0].users) > 1: - continue - - # Find a pair of conv and bn computation nodes to optimize - conv_node = node.args[0] - bn_node = node - pairs.append([conv_node, bn_node]) - - for conv_node, bn_node in pairs: - # set insertion point - fx_model.graph.inserting_before(conv_node) - # create `get_attr` node to access modules - # note that we directly call `create_node` to fill the `name` - # argument. `fx_model.graph.get_attr` and - # `fx_model.graph.call_function` does not allow the `name` argument. - conv_get_node = fx_model.graph.create_node( - op='get_attr', target=conv_node.target, name='get_conv') - bn_get_node = fx_model.graph.create_node( - op='get_attr', target=bn_node.target, name='get_bn') - # prepare args for the fused function - args = (bn_get_node, conv_get_node, conv_node.args[0]) - # create a new node - new_node = fx_model.graph.create_node( - op='call_function', - target=efficient_conv_bn_eval_control, - args=args, - name='efficient_conv_bn_eval') - # this node replaces the original conv + bn, and therefore - # should replace the uses of bn_node - bn_node.replace_all_uses_with(new_node) - # take care of the deletion order: - # delete bn_node first, and then conv_node - fx_model.graph.erase_node(bn_node) - fx_model.graph.erase_node(conv_node) - - # regenerate the code - fx_model.graph.lint() - fx_model.recompile() - - -def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module): - import torch.fx as fx - - # currently we use `fx.symbolic_trace` to trace models. - # in the future, we might turn to pytorch 2.0 compile infrastructure to - # get the `fx.GraphModule` IR. Nonetheless, the graph transform function - # can remain unchanged. We just need to change the way - # we get `fx.GraphModule`. - fx_model: fx.GraphModule = fx.symbolic_trace(model) - efficient_conv_bn_eval_graph_transform(fx_model) - model.forward = fx_model.forward - - -def turn_on_efficient_conv_bn_eval(model: torch.nn.Module, - modules: Union[List[str], str]): - if isinstance(modules, str): - modules = [modules] - for module_name in modules: - module = attrgetter(module_name)(model) - turn_on_efficient_conv_bn_eval_for_single_model(module) diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py deleted file mode 100644 index c623eec8bc..0000000000 --- a/mmengine/model/test_time_aug.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import abstractmethod -from typing import Dict, List, Optional, Union - -import torch -import torch.nn as nn - -from mmengine.registry import MODELS -from mmengine.structures import BaseDataElement -from .base_model import BaseModel - -# multi-batch inputs processed by different augmentations from the same batch. -EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]] -# multi-batch data samples processed by different augmentations from the same -# batch. The inner list stands for different augmentations and the outer list -# stands for batch. -EnhancedBatchDataSamples = List[List[BaseDataElement]] -DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs, - EnhancedBatchDataSamples]], tuple, dict] -MergedDataSamples = List[BaseDataElement] - - -@MODELS.register_module() -class BaseTTAModel(BaseModel): - """Base model for inference with test-time augmentation. - - ``BaseTTAModel`` is a wrapper for inference given multi-batch data. - It implements the :meth:`test_step` for multi-batch data inference. - ``multi-batch`` data means data processed by different augmentation - from the same batch. - - During test time augmentation, the data processed by - :obj:`mmcv.transforms.TestTimeAug`, and then collated by - ``pseudo_collate`` will have the following format: - - .. code-block:: - - result = dict( - inputs=[ - [image1_aug1, image2_aug1], - [image1_aug2, image2_aug2] - ], - data_samples=[ - [data_sample1_aug1, data_sample2_aug1], - [data_sample1_aug2, data_sample2_aug2], - ] - ) - - ``image{i}_aug{j}`` means the i-th image of the batch, which is - augmented by the j-th augmentation. - - ``BaseTTAModel`` will collate the data to: - - .. code-block:: - - data1 = dict( - inputs=[image1_aug1, image2_aug1], - data_samples=[data_sample1_aug1, data_sample2_aug1] - ) - - data2 = dict( - inputs=[image1_aug2, image2_aug2], - data_samples=[data_sample1_aug2, data_sample2_aug2] - ) - - ``data1`` and ``data2`` will be passed to model, and the results will be - merged by :meth:`merge_preds`. - - Note: - :meth:`merge_preds` is an abstract method, all subclasses should - implement it. - - Warning: - If ``data_preprocessor`` is not None, it will overwrite the model's - ``data_preprocessor``. - - Args: - module (dict or nn.Module): Tested model. - data_preprocessor (dict or :obj:`BaseDataPreprocessor`, optional): - If model does not define ``data_preprocessor``, it will be the - default value for model. - """ - - def __init__( - self, - module: Union[dict, nn.Module], - data_preprocessor: Union[dict, nn.Module, None] = None, - ): - super().__init__() - if isinstance(module, nn.Module): - self.module = module - elif isinstance(module, dict): - if data_preprocessor is not None: - module['data_preprocessor'] = data_preprocessor - self.module = MODELS.build(module) - else: - raise TypeError('The type of module should be a `nn.Module` ' - f'instance or a dict, but got {module}') - assert hasattr(self.module, 'test_step'), ( - 'Model wrapped by BaseTTAModel must implement `test_step`!') - - @abstractmethod - def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \ - -> MergedDataSamples: - """Merge predictions of enhanced data to one prediction. - - Args: - data_samples_list (EnhancedBatchDataSamples): List of predictions - of all enhanced data. - - Returns: - List[BaseDataElement]: Merged prediction. - """ - - def test_step(self, data): - """Get predictions of each enhanced data, a multiple predictions. - - Args: - data (DataBatch): Enhanced data batch sampled from dataloader. - - Returns: - MergedDataSamples: Merged prediction. - """ - data_list: Union[List[dict], List[list]] - if isinstance(data, dict): - num_augs = len(data[next(iter(data))]) - data_list = [{key: value[idx] - for key, value in data.items()} - for idx in range(num_augs)] - elif isinstance(data, (tuple, list)): - num_augs = len(data[0]) - data_list = [[_data[idx] for _data in data] - for idx in range(num_augs)] - else: - raise TypeError('data given by dataLoader should be a dict, ' - f'tuple or a list, but got {type(data)}') - - predictions = [] - for data in data_list: # type: ignore - predictions.append(self.module.test_step(data)) - return self.merge_preds(list(zip(*predictions))) # type: ignore - - def forward(self, - inputs: torch.Tensor, - data_samples: Optional[list] = None, - mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: - """``BaseTTAModel.forward`` should not be called.""" - raise NotImplementedError( - '`BaseTTAModel.forward` will not be called during training or' - 'testing. Please call `test_step` instead. If you want to use' - '`BaseTTAModel.forward`, please implement this method') diff --git a/mmengine/model/utils.py b/mmengine/model/utils.py deleted file mode 100644 index c78ea3134d..0000000000 --- a/mmengine/model/utils.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import warnings -from typing import List, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from mmengine.logging import print_log -from mmengine.utils.dl_utils import mmcv_full_available - - -def stack_batch(tensor_list: List[torch.Tensor], - pad_size_divisor: int = 1, - pad_value: Union[int, float] = 0) -> torch.Tensor: - """Stack multiple tensors to form a batch and pad the tensor to the max - shape use the right bottom padding mode in these images. If - ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is - divisible by ``pad_size_divisor``. - - Args: - tensor_list (List[Tensor]): A list of tensors with the same dim. - pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding - to ensure the shape of each dim is divisible by - ``pad_size_divisor``. This depends on the model, and many - models need to be divisible by 32. Defaults to 1 - pad_value (int, float): The padding value. Defaults to 0. - - Returns: - Tensor: The n dim tensor. - """ - assert isinstance( - tensor_list, - list), (f'Expected input type to be list, but got {type(tensor_list)}') - assert tensor_list, '`tensor_list` could not be an empty list' - assert len({ - tensor.ndim - for tensor in tensor_list - }) == 1, (f'Expected the dimensions of all tensors must be the same, ' - f'but got {[tensor.ndim for tensor in tensor_list]}') - - dim = tensor_list[0].dim() - num_img = len(tensor_list) - all_sizes: torch.Tensor = torch.Tensor( - [tensor.shape for tensor in tensor_list]) - max_sizes = torch.ceil( - torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor - padded_sizes = max_sizes - all_sizes - # The first dim normally means channel, which should not be padded. - padded_sizes[:, 0] = 0 - if padded_sizes.sum() == 0: - return torch.stack(tensor_list) - # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4), - # it means that padding the last dim with 1(left) 2(right), padding the - # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite of - # the `padded_sizes`. Therefore, the `padded_sizes` needs to be reversed, - # and only odd index of pad should be assigned to keep padding "right" and - # "bottom". - pad = torch.zeros(num_img, 2 * dim, dtype=torch.int) - pad[:, 1::2] = padded_sizes[:, range(dim - 1, -1, -1)] - batch_tensor = [] - for idx, tensor in enumerate(tensor_list): - batch_tensor.append( - F.pad(tensor, tuple(pad[idx].tolist()), value=pad_value)) - return torch.stack(batch_tensor) - - -def detect_anomalous_params(loss: torch.Tensor, model) -> None: - parameters_in_graph = set() - visited = set() - - def traverse(grad_fn): - if grad_fn is None: - return - if grad_fn not in visited: - visited.add(grad_fn) - if hasattr(grad_fn, 'variable'): - parameters_in_graph.add(grad_fn.variable) - parents = grad_fn.next_functions - if parents is not None: - for parent in parents: - grad_fn = parent[0] - traverse(grad_fn) - - traverse(loss.grad_fn) - for n, p in model.named_parameters(): - if p not in parameters_in_graph and p.requires_grad: - print_log( - f'{n} with shape {p.size()} is not ' - f'in the computational graph \n', - logger='current', - level=logging.ERROR) - - -def merge_dict(*args): - """Merge all dictionaries into one dictionary. - - If pytorch version >= 1.8, ``merge_dict`` will be wrapped - by ``torch.fx.wrap``, which will make ``torch.fx.symbolic_trace`` skip - trace ``merge_dict``. - - Note: - If a function needs to be traced by ``torch.fx.symbolic_trace``, - but inevitably needs to use ``update`` method of ``dict``(``update`` - is not traceable). It should use ``merge_dict`` to replace - ``xxx.update``. - - Args: - *args: dictionary needs to be merged. - - Returns: - dict: Merged dict from args - """ - output = dict() - for item in args: - assert isinstance( - item, - dict), (f'all arguments of merge_dict should be a dict, but got ' - f'{type(item)}') - output.update(item) - return output - - -# torch.fx is only available when pytorch version >= 1.8. -# If the subclass of `BaseModel` has multiple submodules, and each module -# will return a loss dict during training process, i.e., `TwoStageDetector` -# in mmdet. It should use `merge_dict` to get the total loss, rather than -# `loss.update` to keep model traceable. -try: - import torch.fx - - # make torch.fx skip trace `merge_dict`. - merge_dict = torch.fx.wrap(merge_dict) - -except ImportError: - warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function ' - 'to merge multiple dicts') - - -class _BatchNormXd(nn.modules.batchnorm._BatchNorm): - """A general BatchNorm layer without input dimension check. - - Reproduced from @kapily's work: - (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) - The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc - is `_check_input_dim` that is designed for tensor sanity checks. - The check has been bypassed in this class for the convenience of converting - SyncBatchNorm. - """ - - def _check_input_dim(self, input: torch.Tensor): - return - - -def revert_sync_batchnorm(module: nn.Module) -> nn.Module: - """Helper function to convert all `SyncBatchNorm` (SyncBN) and - `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to - `BatchNormXd` layers. - - Adapted from @kapily's work: - (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) - - Args: - module (nn.Module): The module containing `SyncBatchNorm` layers. - - Returns: - module_output: The converted module with `BatchNormXd` layers. - """ - module_output = module - module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] - - if mmcv_full_available(): - from mmcv.ops import SyncBatchNorm - module_checklist.append(SyncBatchNorm) - - if isinstance(module, tuple(module_checklist)): - module_output = _BatchNormXd(module.num_features, module.eps, - module.momentum, module.affine, - module.track_running_stats) - if module.affine: - # no_grad() may not be needed here but - # just to be consistent with `convert_sync_batchnorm()` - with torch.no_grad(): - module_output.weight = module.weight - module_output.bias = module.bias - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - module_output.training = module.training - # qconfig exists in quantized models - if hasattr(module, 'qconfig'): - module_output.qconfig = module.qconfig - for name, child in module.named_children(): - # Some custom modules or 3rd party implemented modules may raise an - # error when calling `add_module`. Therefore, try to catch the error - # and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 # noqa: E501 - # for more details. - try: - module_output.add_module(name, revert_sync_batchnorm(child)) - except Exception: - print_log( - F'Failed to convert {child} from SyncBN to BN!', - logger='current', - level=logging.WARNING) - del module - return module_output - - -def convert_sync_batchnorm(module: nn.Module, - implementation='torch') -> nn.Module: - """Helper function to convert all `BatchNorm` layers in the model to - `SyncBatchNorm` (SyncBN) or `mmcv.ops.sync_bn.SyncBatchNorm` (MMSyncBN) - layers. Adapted from `PyTorch convert sync batchnorm`_. - - Args: - module (nn.Module): The module containing `SyncBatchNorm` layers. - implementation (str): The type of `SyncBatchNorm` to convert to. - - - 'torch': convert to `torch.nn.modules.batchnorm.SyncBatchNorm`. - - 'mmcv': convert to `mmcv.ops.sync_bn.SyncBatchNorm`. - - Returns: - nn.Module: The converted module with `SyncBatchNorm` layers. - - .. _PyTorch convert sync batchnorm: - https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm - """ # noqa: E501 - module_output = module - - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - if implementation == 'torch': - SyncBatchNorm = torch.nn.modules.batchnorm.SyncBatchNorm - elif implementation == 'mmcv': - from mmcv.ops import SyncBatchNorm # type: ignore - else: - raise ValueError('sync_bn should be "torch" or "mmcv", but got ' - f'{implementation}') - - module_output = SyncBatchNorm(module.num_features, module.eps, - module.momentum, module.affine, - module.track_running_stats) - - if module.affine: - with torch.no_grad(): - module_output.weight = module.weight - module_output.bias = module.bias - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - if hasattr(module, 'qconfig'): - module_output.qconfig = module.qconfig - for name, child in module.named_children(): - module_output.add_module(name, - convert_sync_batchnorm(child, implementation)) - del module - return module_output diff --git a/mmengine/model/weight_init.py b/mmengine/model/weight_init.py deleted file mode 100644 index b6d0186ed7..0000000000 --- a/mmengine/model/weight_init.py +++ /dev/null @@ -1,682 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import math -import warnings - -import numpy as np -import torch -import torch.nn as nn -from torch import Tensor - -from mmengine.logging import print_log -from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg - - -def update_init_info(module, init_info): - """Update the `_params_init_info` in the module if the value of parameters - are changed. - - Args: - module (obj:`nn.Module`): The module of PyTorch with a user-defined - attribute `_params_init_info` which records the initialization - information. - init_info (str): The string that describes the initialization. - """ - assert hasattr( - module, - '_params_init_info'), f'Can not find `_params_init_info` in {module}' - for name, param in module.named_parameters(): - - assert param in module._params_init_info, ( - f'Find a new :obj:`Parameter` ' - f'named `{name}` during executing the ' - f'`init_weights` of ' - f'`{module.__class__.__name__}`. ' - f'Please do not add or ' - f'replace parameters during executing ' - f'the `init_weights`. ') - - # The parameter has been changed during executing the - # `init_weights` of module - mean_value = param.data.mean().cpu() - if module._params_init_info[param]['tmp_mean_value'] != mean_value: - module._params_init_info[param]['init_info'] = init_info - module._params_init_info[param]['tmp_mean_value'] = mean_value - - -def constant_init(module, val, bias=0): - if hasattr(module, 'weight') and module.weight is not None: - nn.init.constant_(module.weight, val) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def xavier_init(module, gain=1, bias=0, distribution='normal'): - assert distribution in ['uniform', 'normal'] - if hasattr(module, 'weight') and module.weight is not None: - if distribution == 'uniform': - nn.init.xavier_uniform_(module.weight, gain=gain) - else: - nn.init.xavier_normal_(module.weight, gain=gain) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def normal_init(module, mean=0, std=1, bias=0): - if hasattr(module, 'weight') and module.weight is not None: - nn.init.normal_(module.weight, mean, std) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def trunc_normal_init(module: nn.Module, - mean: float = 0, - std: float = 1, - a: float = -2, - b: float = 2, - bias: float = 0) -> None: - if hasattr(module, 'weight') and module.weight is not None: - trunc_normal_(module.weight, mean, std, a, b) # type: ignore - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) # type: ignore - - -def uniform_init(module, a=0, b=1, bias=0): - if hasattr(module, 'weight') and module.weight is not None: - nn.init.uniform_(module.weight, a, b) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def kaiming_init(module, - a=0, - mode='fan_out', - nonlinearity='relu', - bias=0, - distribution='normal'): - assert distribution in ['uniform', 'normal'] - if hasattr(module, 'weight') and module.weight is not None: - if distribution == 'uniform': - nn.init.kaiming_uniform_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - else: - nn.init.kaiming_normal_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def caffe2_xavier_init(module, bias=0): - # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch - # Acknowledgment to FAIR's internal code - kaiming_init( - module, - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - bias=bias, - distribution='uniform') - - -def bias_init_with_prob(prior_prob): - """Initialize conv/fc bias value according to a given probability value.""" - bias_init = float(-np.log((1 - prior_prob) / prior_prob)) - return bias_init - - -def _get_bases_name(m): - return [b.__name__ for b in m.__class__.__bases__] - - -class BaseInit: - - def __init__(self, *, bias=0, bias_prob=None, layer=None): - self.wholemodule = False - if not isinstance(bias, (int, float)): - raise TypeError(f'bias must be a number, but got a {type(bias)}') - - if bias_prob is not None: - if not isinstance(bias_prob, float): - raise TypeError(f'bias_prob type must be float, \ - but got {type(bias_prob)}') - - if layer is not None: - if not isinstance(layer, (str, list)): - raise TypeError(f'layer must be a str or a list of str, \ - but got a {type(layer)}') - else: - layer = [] - - if bias_prob is not None: - self.bias = bias_init_with_prob(bias_prob) - else: - self.bias = bias - self.layer = [layer] if isinstance(layer, str) else layer - - def _get_init_info(self): - info = f'{self.__class__.__name__}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='Constant') -class ConstantInit(BaseInit): - """Initialize module parameters with constant values. - - Args: - val (int | float): the value to fill the weights in the module with - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, val, **kwargs): - super().__init__(**kwargs) - self.val = val - - def __call__(self, module): - - def init(m): - if self.wholemodule: - constant_init(m, self.val, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & set([layername] + basesname)): - constant_init(m, self.val, self.bias) - - module.apply(init) - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='Xavier') -class XavierInit(BaseInit): - r"""Initialize module parameters with values according to the method - described in the paper below. - - `Understanding the difficulty of training deep feedforward - neural networks - Glorot, X. & Bengio, Y. (2010). - `_ - - Args: - gain (int | float): an optional scaling factor. Defaults to 1. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - distribution (str): distribution either be ``'normal'`` - or ``'uniform'``. Defaults to ``'normal'``. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, gain=1, distribution='normal', **kwargs): - super().__init__(**kwargs) - self.gain = gain - self.distribution = distribution - - def __call__(self, module): - - def init(m): - if self.wholemodule: - xavier_init(m, self.gain, self.bias, self.distribution) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & set([layername] + basesname)): - xavier_init(m, self.gain, self.bias, self.distribution) - - module.apply(init) - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: gain={self.gain}, ' \ - f'distribution={self.distribution}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='Normal') -class NormalInit(BaseInit): - r"""Initialize module parameters with the values drawn from the normal - distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. - - Args: - mean (int | float):the mean of the normal distribution. Defaults to 0. - std (int | float): the standard deviation of the normal distribution. - Defaults to 1. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, mean=0, std=1, **kwargs): - super().__init__(**kwargs) - self.mean = mean - self.std = std - - def __call__(self, module): - - def init(m): - if self.wholemodule: - normal_init(m, self.mean, self.std, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & set([layername] + basesname)): - normal_init(m, self.mean, self.std, self.bias) - - module.apply(init) - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: mean={self.mean},' \ - f' std={self.std}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='TruncNormal') -class TruncNormalInit(BaseInit): - r"""Initialize module parameters with the values drawn from the normal - distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values - outside :math:`[a, b]`. - - Args: - mean (float): the mean of the normal distribution. Defaults to 0. - std (float): the standard deviation of the normal distribution. - Defaults to 1. - a (float): The minimum cutoff value. - b ( float): The maximum cutoff value. - bias (float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, - mean: float = 0, - std: float = 1, - a: float = -2, - b: float = 2, - **kwargs) -> None: - super().__init__(**kwargs) - self.mean = mean - self.std = std - self.a = a - self.b = b - - def __call__(self, module: nn.Module) -> None: - - def init(m): - if self.wholemodule: - trunc_normal_init(m, self.mean, self.std, self.a, self.b, - self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & set([layername] + basesname)): - trunc_normal_init(m, self.mean, self.std, self.a, self.b, - self.bias) - - module.apply(init) - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \ - f' mean={self.mean}, std={self.std}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='Uniform') -class UniformInit(BaseInit): - r"""Initialize module parameters with values drawn from the uniform - distribution :math:`\mathcal{U}(a, b)`. - - Args: - a (int | float): the lower bound of the uniform distribution. - Defaults to 0. - b (int | float): the upper bound of the uniform distribution. - Defaults to 1. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, a=0, b=1, **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __call__(self, module): - - def init(m): - if self.wholemodule: - uniform_init(m, self.a, self.b, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & set([layername] + basesname)): - uniform_init(m, self.a, self.b, self.bias) - - module.apply(init) - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: a={self.a},' \ - f' b={self.b}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='Kaiming') -class KaimingInit(BaseInit): - r"""Initialize module parameters with the values according to the method - described in the paper below. - - `Delving deep into rectifiers: Surpassing human-level - performance on ImageNet classification - He, K. et al. (2015). - `_ - - Args: - a (int | float): the negative slope of the rectifier used after this - layer (only used with ``'leaky_relu'``). Defaults to 0. - mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing - ``'fan_in'`` preserves the magnitude of the variance of the weights - in the forward pass. Choosing ``'fan_out'`` preserves the - magnitudes in the backwards pass. Defaults to ``'fan_out'``. - nonlinearity (str): the non-linear function (`nn.functional` name), - recommended to use only with ``'relu'`` or ``'leaky_relu'`` . - Defaults to 'relu'. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - distribution (str): distribution either be ``'normal'`` or - ``'uniform'``. Defaults to ``'normal'``. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, - a=0, - mode='fan_out', - nonlinearity='relu', - distribution='normal', - **kwargs): - super().__init__(**kwargs) - self.a = a - self.mode = mode - self.nonlinearity = nonlinearity - self.distribution = distribution - - def __call__(self, module): - - def init(m): - if self.wholemodule: - kaiming_init(m, self.a, self.mode, self.nonlinearity, - self.bias, self.distribution) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & set([layername] + basesname)): - kaiming_init(m, self.a, self.mode, self.nonlinearity, - self.bias, self.distribution) - - module.apply(init) - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \ - f'nonlinearity={self.nonlinearity}, ' \ - f'distribution ={self.distribution}, bias={self.bias}' - return info - - -@WEIGHT_INITIALIZERS.register_module(name='Caffe2Xavier') -class Caffe2XavierInit(KaimingInit): - # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch - # Acknowledgment to FAIR's internal code - def __init__(self, **kwargs): - super().__init__( - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - distribution='uniform', - **kwargs) - - def __call__(self, module): - super().__call__(module) - - -@WEIGHT_INITIALIZERS.register_module(name='Pretrained') -class PretrainedInit: - """Initialize module by loading a pretrained model. - - Args: - checkpoint (str): the checkpoint file of the pretrained model should - be load. - prefix (str, optional): the prefix of a sub-module in the pretrained - model. it is for loading a part of the pretrained model to - initialize. For example, if we would like to only load the - backbone of a detector model, we can set ``prefix='backbone.'``. - Defaults to None. - map_location (str): map tensors into proper locations. Defaults to cpu. - """ - - def __init__(self, checkpoint, prefix=None, map_location='cpu'): - self.checkpoint = checkpoint - self.prefix = prefix - self.map_location = map_location - - def __call__(self, module): - from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix, - load_checkpoint, - load_state_dict) - if self.prefix is None: - print_log(f'load model from: {self.checkpoint}', logger='current') - load_checkpoint( - module, - self.checkpoint, - map_location=self.map_location, - strict=False, - logger='current') - else: - print_log( - f'load {self.prefix} in model from: {self.checkpoint}', - logger='current') - state_dict = _load_checkpoint_with_prefix( - self.prefix, self.checkpoint, map_location=self.map_location) - load_state_dict(module, state_dict, strict=False, logger='current') - - if hasattr(module, '_params_init_info'): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f'{self.__class__.__name__}: load from {self.checkpoint}' - return info - - -def _initialize(module, cfg, wholemodule=False): - func = build_from_cfg(cfg, WEIGHT_INITIALIZERS) - # wholemodule flag is for override mode, there is no layer key in override - # and initializer will give init values for the whole module with the name - # in override. - func.wholemodule = wholemodule - func(module) - - -def _initialize_override(module, override, cfg): - if not isinstance(override, (dict, list)): - raise TypeError(f'override must be a dict or a list of dict, \ - but got {type(override)}') - - override = [override] if isinstance(override, dict) else override - - for override_ in override: - - cp_override = copy.deepcopy(override_) - name = cp_override.pop('name', None) - if name is None: - raise ValueError('`override` must contain the key "name",' - f'but got {cp_override}') - # if override only has name key, it means use args in init_cfg - if not cp_override: - cp_override.update(cfg) - # if override has name key and other args except type key, it will - # raise error - elif 'type' not in cp_override.keys(): - raise ValueError( - f'`override` need "type" key, but got {cp_override}') - - if hasattr(module, name): - _initialize(getattr(module, name), cp_override, wholemodule=True) - else: - raise RuntimeError(f'module did not have attribute {name}, ' - f'but init_cfg is {cp_override}.') - - -def initialize(module, init_cfg): - r"""Initialize a module. - - Args: - module (``torch.nn.Module``): the module will be initialized. - init_cfg (dict | list[dict]): initialization configuration dict to - define initializer. OpenMMLab has implemented 6 initializers - including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, - ``Kaiming``, and ``Pretrained``. - - Example: - >>> module = nn.Linear(2, 3, bias=True) - >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) - >>> initialize(module, init_cfg) - >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) - >>> # define key ``'layer'`` for initializing layer with different - >>> # configuration - >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), - dict(type='Constant', layer='Linear', val=2)] - >>> initialize(module, init_cfg) - >>> # define key``'override'`` to initialize some specific part in - >>> # module - >>> class FooNet(nn.Module): - >>> def __init__(self): - >>> super().__init__() - >>> self.feat = nn.Conv2d(3, 16, 3) - >>> self.reg = nn.Conv2d(16, 10, 3) - >>> self.cls = nn.Conv2d(16, 5, 3) - >>> model = FooNet() - >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', - >>> override=dict(type='Constant', name='reg', val=3, bias=4)) - >>> initialize(model, init_cfg) - >>> model = ResNet(depth=50) - >>> # Initialize weights with the pretrained model. - >>> init_cfg = dict(type='Pretrained', - checkpoint='torchvision://resnet50') - >>> initialize(model, init_cfg) - >>> # Initialize weights of a sub-module with the specific part of - >>> # a pretrained model by using "prefix". - >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\ - >>> 'retinanet_r50_fpn_1x_coco/'\ - >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' - >>> init_cfg = dict(type='Pretrained', - checkpoint=url, prefix='backbone.') - """ - if not isinstance(init_cfg, (dict, list)): - raise TypeError(f'init_cfg must be a dict or a list of dict, \ - but got {type(init_cfg)}') - - if isinstance(init_cfg, dict): - init_cfg = [init_cfg] - - for cfg in init_cfg: - # should deeply copy the original config because cfg may be used by - # other modules, e.g., one init_cfg shared by multiple bottleneck - # blocks, the expected cfg will be changed after pop and will change - # the initialization behavior of other modules - cp_cfg = copy.deepcopy(cfg) - override = cp_cfg.pop('override', None) - _initialize(module, cp_cfg) - - if override is not None: - cp_cfg.pop('layer', None) - _initialize_override(module, override, cp_cfg) - else: - # All attributes in module have same initialization. - pass - - -def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, - b: float) -> Tensor: - # Method based on - # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - # Modified from - # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' - 'The distribution of values may be incorrect.', - stacklevel=2) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - lower = norm_cdf((a - mean) / std) - upper = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [lower, upper], then translate - # to [2lower-1, 2upper-1]. - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor: Tensor, - mean: float = 0., - std: float = 1., - a: float = -2., - b: float = 2.) -> Tensor: - r"""Fills the input Tensor with values drawn from a truncated normal - distribution. The values are effectively drawn from the normal distribution - :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside - :math:`[a, b]` redrawn until they are within the bounds. The method used - for generating the random values works best when :math:`a \leq \text{mean} - \leq b`. - - Modified from - https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py - - Args: - tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. - mean (float): the mean of the normal distribution. - std (float): the standard deviation of the normal distribution. - a (float): the minimum cutoff value. - b (float): the maximum cutoff value. - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py deleted file mode 100644 index 90eddabbe1..0000000000 --- a/mmengine/model/wrappers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils.dl_utils import TORCH_VERSION -from mmengine.utils.version_utils import digit_version -from .distributed import MMDistributedDataParallel -from .seperate_distributed import MMSeparateDistributedDataParallel -from .utils import is_model_wrapper - -__all__ = [ - 'MMDistributedDataParallel', 'is_model_wrapper', - 'MMSeparateDistributedDataParallel' -] - -if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): - from .fully_sharded_distributed import \ - MMFullyShardedDataParallel # noqa:F401 - __all__.append('MMFullyShardedDataParallel') diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py deleted file mode 100644 index 4113aebf9e..0000000000 --- a/mmengine/model/wrappers/distributed.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Union - -import torch -from torch.nn.parallel import DataParallel, DistributedDataParallel - -from mmengine.optim import OptimWrapper -from mmengine.registry import MODEL_WRAPPERS -from ..utils import detect_anomalous_params - -MODEL_WRAPPERS.register_module(module=DistributedDataParallel) -MODEL_WRAPPERS.register_module(module=DataParallel) - - -@MODEL_WRAPPERS.register_module() -class MMDistributedDataParallel(DistributedDataParallel): - """A distributed model wrapper used for training,testing and validation in - loop. - - Different from DistributedDataParallel, MMDistributedDataParallel - implements three methods :meth:`train_step`, :meth:`val_step` and - :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` - and ``test_loop``. - - - ``train_step``: Called by ``runner.train_loop``, and implement - default model forward, gradient back propagation, parameter updating - logic. To take advantage of DistributedDataParallel's automatic gradient - synchronization, ``train_step`` calls ``DistributedDataParallel.forward`` - to calculate the losses, and call other methods of :class:`BaseModel` to - pre-process data and parse losses. Finally, update model parameters by - :class:`OptimWrapper` and return the loss dictionary used - for logging. - - - ``val_step``: Called by ``runner.val_loop`` and get the inference - results. Since there is no gradient synchronization requirement, - this procedure is equivalent to ``BaseModel.val_step`` - - - ``test_step``: Called by ``runner.test_loop``, equivalent ``val_step``. - - Args: - detect_anomalous_params (bool): This option is only used for - debugging which will slow down the training speed. - Detect anomalous parameters that are not included in - the computational graph with `loss` as the root. - There are two cases - - - Parameters were not used during forward pass. - - Parameters were not used to produce loss. - - Defaults to False. - - **kwargs: keyword arguments passed to ``DistributedDataParallel``. - - - device_ids (List[int] or torch.device, optional): CUDA devices - for module. - - output_device (int or torch.device, optional): Device location of - output for single-device CUDA modules. - - dim (int): Defaults to 0. - - broadcast_buffers (bool): Flag that enables syncing ( - broadcasting) buffers of the module at beginning of the - ``forward`` function. Defaults to True - - find_unused_parameters (bool): Whether to find parameters of - module, which are not in the forward graph. Defaults to False. - - process_group (ProcessGroup, optional): The process group to be - used for distributed data all-reduction. - - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults - to 25. - - check_reduction (bool): This argument is deprecated. Defaults - to False. - - gradient_as_bucket_view (bool): Defaults to False. - - static_graph (bool): Defaults to False. - - See more information about arguments in - :class:`torch.nn.parallel.DistributedDataParallel`. - - Note: - If model has multiple submodules and each module has - separate optimization strategies, - :class:`MMSeparateDistributedDataParallel` should be used to wrap - the model. - - Note: - If model itself has custom optimization strategy, rather than - simply forward model and update model. A custom model wrapper - inherit from ``MMDistributedDataParallel`` should be defined and - override the ``train_step`` method. - """ - - def __init__(self, - module, - detect_anomalous_params: bool = False, - **kwargs): - super().__init__(module=module, **kwargs) - self.detect_anomalous_params = detect_anomalous_params - - def train_step(self, data: Union[dict, tuple, list], - optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: - """Interface for model forward, backward and parameters updating during - training process. - - :meth:`train_step` will perform the following steps in order: - - - If :attr:`module` defines the preprocess method, - call ``module.preprocess`` to pre-processing data. - - Call ``module.forward(**data)`` and get losses. - - Parse losses. - - Call ``optim_wrapper.optimizer_step`` to update parameters. - - Return log messages of losses. - - Args: - data (dict or tuple or list): Data sampled from dataset. - optim_wrapper (OptimWrapper): A wrapper of optimizer to - update parameters. - - Returns: - Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. - """ - # Enable automatic mixed precision training context. - with optim_wrapper.optim_context(self): - data = self.module.data_preprocessor(data, training=True) - losses = self._run_forward(data, mode='loss') - parsed_loss, log_vars = self.module.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - if self.detect_anomalous_params: - detect_anomalous_params(parsed_loss, model=self) - return log_vars - - def val_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.val_step(data) - - def test_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the predictions of module during testing process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.test_step(data) - - def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self(**data, mode=mode) - elif isinstance(data, (list, tuple)): - results = self(*data, mode=mode) - else: - raise TypeError('Output of `data_preprocessor` should be ' - f'list, tuple or dict, but got {type(data)}') - return results diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py deleted file mode 100644 index df128597b1..0000000000 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from functools import partial -from typing import Any, Callable, Dict, Iterable, List, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup -# yapf: disable -from torch.distributed.fsdp.api import (FullStateDictConfig, - LocalOptimStateDictConfig, - LocalStateDictConfig, - OptimStateDictConfig, - ShardedOptimStateDictConfig, - ShardedStateDictConfig, - ShardingStrategy, StateDictConfig, - StateDictSettings, StateDictType) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, CPUOffload, FullOptimStateDictConfig, - FullyShardedDataParallel, MixedPrecision) - -# yapf: enable -from mmengine.optim import OptimWrapper -from mmengine.registry import FUNCTIONS, MODEL_WRAPPERS -from mmengine.structures import BaseDataElement -from mmengine.utils import digit_version, is_seq_of - - -@MODEL_WRAPPERS.register_module() -class MMFullyShardedDataParallel(FullyShardedDataParallel): - """A wrapper for sharding Module parameters across data parallel workers. - - Different from FullyShardedDataParallel, MMFullyShardedDataParallel - implements three methods :meth:`train_step`, :meth:`val_step` and - :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` - and ``test_loop``. - - - ``train_step``: Called by ``runner.train_loop``, and implement - default model forward, gradient back propagation, parameter updating - logic. - - - ``val_step``: Called by ``runner.val_loop`` and get the inference - results. Specially, since MMFullyShardedDataParallel will wrap model - recursively, it may cause some problem if one just use - ``BaseModel.val_step`` to implement ``val_step`` here. To avoid that, - ``val_step`` will call methods of :obj:`BaseModel` to pre-process - data first, and use ``FullyShardedDataParallel.forward`` to get result. - - - ``test_step``: Called by ``runner.test_loop`` and get the inference - results. Its logic is equivalent to ``val_loop``. - - Args: - module (nn.Module): module to be wrapped with FSDP. - process_group (ProcessGroup, optional): process group for sharding. - cpu_offload (bool, CPUOffload, optional): - CPU offloading config. - Different from FullyShardedDataParallel,Since it can be set by - users' pre-defined config in MMEngine,its type is expected to be - `None`, `bool` or `CPUOffload`. - - Currently, only parameter and gradient CPU offload is supported. - It can be enabled via passing in - ``cpu_offload=CPUOffload(offload_params=True)``. Note that this - currently implicitly enables gradient offloading to CPU in order - for params and grads to be on same device to work with optimizer. - This API is subject to change. Default is ``None`` in which case - there will be no offloading. - auto_wrap_policy (str or Callable, optional): - Specifying a policy to recursively wrap layers with FSDP. - Different from FullyShardedDataParallel, Since it can be set by - users' pre-defined config in MMEngine, its type is expected to be - `None`, `str` or `Callable`. If it's `str`, then - MMFullyShardedDataParallel will try to get specified method in - ``FSDP_WRAP_POLICIES`` registry,and this method will be passed to - FullyShardedDataParallel to finally initialize model. - - Note that this policy currently will only apply to child modules of - the passed in module. The remainder modules are always wrapped in - the returned FSDP root instance. - ``default_auto_wrap_policy`` written in - ``torch.distributed.fsdp.wrap`` is an example of - ``auto_wrap_policy`` callable, this policy wraps layers with - parameter sizes larger than 100M. Users can supply the customized - ``auto_wrap_policy`` callable that should accept following - arguments: ``module: nn.Module``, ``recurse: bool``, - ``unwrapped_params: int``, extra customized arguments could be - added to the customized ``auto_wrap_policy`` callable as well. - - Example:: - - >>> def custom_auto_wrap_policy( - >>> module: nn.Module, - >>> recurse: bool, - >>> unwrapped_params: int, - >>> # These are customizable for this policy function. - >>> min_num_params: int = int(1e8), - >>> ) -> bool: - >>> return unwrapped_params >= min_num_params - - backward_prefetch (str or BackwardPrefetch, optional): - Different from FullyShardedDataParallel, this argument could be a - string or a BackwardPrefetch instance. If it's a string, then - it should be ``BACKWARD_PRE`` or ``BACKWARD_POST`` - mixed_precision (dict or MixedPrecision, optional): - This configures native mixed precision for FSDP. If this is set to - ``None``. Different from the native FSDP, this argument can a dict - like this: - - Examples: - >>> mixed_precision=dict(param_dtype='float16', - >>> buffer_dtype='float32', - >>> reduce_dtype='float32') - - Defaults to None. - use_orig_params (bool): Different from native - ``FullyShardedDataParallel``, it defaults to True. - **kwargs: Keyword arguments passed to - :class:`FullyShardedDataParallel`. - """ - - def __init__( - self, - module: nn.Module, - process_group: Union[dict, ProcessGroup, None] = None, - sharding_strategy: Union[str, ShardingStrategy] = None, - cpu_offload: Union[bool, CPUOffload, None] = None, - auto_wrap_policy: Union[str, Callable, None] = None, - backward_prefetch: Union[str, BackwardPrefetch, None] = None, - mixed_precision: Union[dict, MixedPrecision, None] = None, - param_init_fn: Union[str, Callable[ - [nn.Module], None]] = None, # type: ignore # noqa: E501 - use_orig_params: bool = True, - **kwargs, - ): - if isinstance(sharding_strategy, str): - sharding_strategy = ShardingStrategy[sharding_strategy] - if not (isinstance(sharding_strategy, ShardingStrategy) - or sharding_strategy is None): - raise TypeError( - 'sharding_strategy must be str or enum of `ShardingStrategy` ' - f', but got {sharding_strategy}') - - if isinstance(cpu_offload, bool): - cpu_offload = CPUOffload(offload_params=cpu_offload) - if not (isinstance(cpu_offload, CPUOffload) or cpu_offload is None): - raise TypeError( - '`cpu_offload` should be `None`, `bool`' - f'or `CPUOffload`, but has type {type(cpu_offload)}') - - with FUNCTIONS.switch_scope_and_registry(None): - if isinstance(auto_wrap_policy, str): - auto_wrap_policy = FUNCTIONS.get( # type: ignore - auto_wrap_policy) - if auto_wrap_policy is None: - raise ValueError('`auto_wrap_policy` is not registered!') - elif isinstance(auto_wrap_policy, dict): - policy = auto_wrap_policy.pop('type') - if isinstance(policy, str): - policy = FUNCTIONS.get(policy) # type: ignore - if policy is None: - raise ValueError('`auto_wrap_policy` is not registered!') - auto_wrap_policy = partial(policy, **auto_wrap_policy) - - if not (auto_wrap_policy is None - or callable(auto_wrap_policy)): # type: ignore - raise TypeError('`auto_wrap_policy` should be a str, a ' - 'callable, a dict or None, but has type ' - f'{type(auto_wrap_policy)}') - - if isinstance(backward_prefetch, str): - backward_prefetch = BackwardPrefetch[backward_prefetch] - if not (isinstance(backward_prefetch, BackwardPrefetch) - or backward_prefetch is None): - raise TypeError( - '`backward_prefetch` should be `None`, string of ' - '"BACKWARD_PRE" and "BACKWARD_POST", or ' - f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501 - ) - - if isinstance(param_init_fn, str): - param_init_fn = FUNCTIONS.get( # type: ignore - param_init_fn) - if param_init_fn is None: - raise ValueError('`param_init_fn` is not registered!') - elif isinstance(param_init_fn, dict): - init_fn = param_init_fn.pop('type') - if isinstance(param_init_fn, str): - init_fn = FUNCTIONS.get(init_fn) # type: ignore - if init_fn is None: - raise ValueError('`param_init_fn` is not registered!') - param_init_fn = partial(init_fn, **param_init_fn) - - if not (callable(param_init_fn) or param_init_fn is None): - raise TypeError('`param_init_fn` should be a str, a ' - 'callable, a dict or None, but has type ' - f'{type(param_init_fn)}') - - def parse_dtype(dtype): - if dtype is None: - return None - elif isinstance(dtype, str): - return getattr(torch, dtype) - elif isinstance(dtype, torch.dtype): - return dtype - else: - raise TypeError( - '`dtype` should be `None`, `str` or `torch.dtype`, ' - f'but has type {type(dtype)}') - - if isinstance(mixed_precision, dict): - mixed_precision['param_dtype'] = parse_dtype( - mixed_precision.get('param_dtype', None)) - mixed_precision['reduce_dtype'] = parse_dtype( - mixed_precision.get('reduce_dtype', None)) - mixed_precision['buffer_dtype'] = parse_dtype( - mixed_precision.get('buffer_dtype', None)) - mixed_precision = MixedPrecision(**mixed_precision) - elif isinstance(mixed_precision, MixedPrecision): - mixed_precision = mixed_precision - elif mixed_precision is not None: - raise TypeError( - '`mixed_precision` should be `None`, `dict` or ' - f'`MixedPrecision`, but has type {type(mixed_precision)}') - - # ignored_parameters and ignored_modules will be deprecated by PyTorch. - # Therefore we hide them in **kwargs. - # TODO: Update when PyTorch 2.1.0 released - if 'ignored_parameters' in kwargs: - kwargs['ignored_parameters'] = self._get_ignored_params( - module, kwargs['ignored_parameters']) - - if 'ignored_modules' in kwargs: - kwargs['ignored_modules'] = self._get_ignored_modules( - module, kwargs['ignored_modules']) - - super().__init__( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - auto_wrap_policy=auto_wrap_policy, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - param_init_fn=param_init_fn, - use_orig_params=use_orig_params, - **kwargs) - - def train_step(self, data: dict, - optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: - """Interface for model forward, backward and parameters updating during - training process. - - :meth:`train_step` will perform the following steps in order: - - - If :attr:`module` defines the preprocess method, - call ``module.preprocess`` to pre-processing data. - - Call ``module.forward(**data)`` and get losses. - - Parse losses. - - Call ``optim_wrapper.optimizer_step`` to update parameters. - - Return log messages of losses. - - Args: - data (dict): Data sampled by dataloader. - optim_wrapper (OptimWrapper): A wrapper of optimizer to - update parameters. - - Returns: - Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. - """ - # enable automatic mixed precision training context. - with optim_wrapper.optim_context(self): - data = self.module.data_preprocessor(data, training=True) - if isinstance(data, dict): - losses = self(**data, mode='loss') - elif isinstance(data, (list, tuple)): - losses = self(*data, mode='loss') - else: - raise TypeError('Output of `data_preprocessor` should be ' - f'list tuple or dict, but got {type(data)}') - parsed_loss, log_vars = self.module.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - return log_vars - - def val_step(self, data: dict) -> List[BaseDataElement]: - """Gets the prediction of module during validation process. - - Args: - data (dict): Data sampled by dataloader. - - Returns: - List[BaseDataElement] or dict: The predictions of given data. - """ - data = self.module.data_preprocessor(data, False) - return self._run_forward(data, mode='predict') # type: ignore - - def test_step(self, data: dict) -> List[BaseDataElement]: - """Gets the predictions of module during testing process. - - Args: - data (dict): Data sampled by dataloader. - - Returns: - List[BaseDataElement]: The predictions of given data. - """ - data = self.module.data_preprocessor(data, False) - return self._run_forward(data, mode='predict') # type: ignore - - def _run_forward(self, data: Union[dict, tuple, list], - mode: str) -> Union[Dict[str, torch.Tensor], list]: - """Unpacks data for :meth:`forward` - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self(**data, mode=mode) - elif isinstance(data, (list, tuple)): - results = self(*data, mode=mode) - else: - raise TypeError('Output of `data_preprocessor` should be ' - f'list, tuple or dict, but got {type(data)}') - return results - - def _get_ignored_params(self, module: nn.Module, - ignored_parameters: Union[Iterable[str], - Iterable[nn.Module]]): - """Get params from string.""" - params_dict = dict(module.named_parameters()) - if is_seq_of(ignored_parameters, str): - ignored_parameters = [ - params_dict[name] for name in ignored_parameters - ] - if not is_seq_of(ignored_parameters, - nn.Parameter) and ignored_parameters is not None: - raise TypeError( - '`ignored_modules` should be `None`, `Iterable[str]` or ' - '`Iterable[nn.Parameters]`, but has type ' - f'{type(ignored_parameters)}') - return ignored_parameters - - def _get_ignored_modules(self, module: nn.Module, - ignored_modules: Union[Iterable[str], - Iterable[nn.Module]]): - """Get modules from string.""" - modules_dict = dict(module.named_modules()) - if is_seq_of(ignored_modules, str): - ignored_modules = [modules_dict[name] for name in ignored_modules] - if not is_seq_of(ignored_modules, - nn.Module) and ignored_modules is not None: - raise TypeError( - '`ignored_modules` should be `None`, `Iterable[str]` or ' - '`Iterable[nn.Module]`, but has type ' - f'{type(ignored_modules)}') - return ignored_modules - - if digit_version(torch.__version__) < digit_version('2.0.1'): - - @staticmethod - def optim_state_dict( - model: torch.nn.Module, - optim: torch.optim.Optimizer, - group: Optional[dist.ProcessGroup] = None, - ) -> Dict[str, Any]: - """Copied from pytorch 2.0.1 which has fixed some bugs.""" - state_dict_settings = FullyShardedDataParallel.get_state_dict_type( - model) - return FullyShardedDataParallel._optim_state_dict_impl( - model=model, - optim=optim, - optim_state_dict=optim.state_dict(), - optim_input=None, - rank0_only=getattr(state_dict_settings.optim_state_dict_config, - 'rank0_only', False), - full_state_dict=state_dict_settings.state_dict_type == - StateDictType.FULL_STATE_DICT, - group=group, - ) - - @staticmethod - def set_state_dict_type( - module: nn.Module, - state_dict_type: StateDictType, - state_dict_config: Optional[StateDictConfig] = None, - optim_state_dict_config: Optional[OptimStateDictConfig] = None, - ) -> StateDictSettings: - """Copied from pytorch 2.0.1 which has fixed some bugs.""" - import torch.distributed.fsdp._traversal_utils as traversal_utils - _state_dict_type_to_config = { - StateDictType.FULL_STATE_DICT: FullStateDictConfig, - StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, - StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, - } - _optim_state_dict_type_to_config = { - StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, - StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, - StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, - } - - # Use the default config if a state_dict config is not set. - state_dict_config_type = _state_dict_type_to_config[ - state_dict_type] - optim_state_dict_config_type = _optim_state_dict_type_to_config[ - state_dict_type] - if state_dict_config is None: - state_dict_config = state_dict_config_type() - if optim_state_dict_config is None: - optim_state_dict_config = optim_state_dict_config_type() - if state_dict_config_type != type(state_dict_config): - raise RuntimeError('Expected state_dict_config of type ' - f'{state_dict_config_type} ' - f'but got {type(state_dict_config)}') - if optim_state_dict_config_type != type(optim_state_dict_config): - raise RuntimeError('Expected optim_state_dict_config of type ' - f'{optim_state_dict_config_type} ' - f'but got {type(optim_state_dict_config)}') - - # Set the state_dict type and configurations. - prev_state_dict_type = None - prev_state_dict_config = None - prev_optim_state_dict_config = None - for submodule in traversal_utils._get_fsdp_states(module): - if prev_state_dict_type is None: - prev_state_dict_type = submodule._state_dict_type - else: - assert ( - prev_state_dict_type == submodule._state_dict_type - ), 'All FSDP modules should have the same state_dict_type.' - if prev_state_dict_config is None: - prev_state_dict_config = submodule._state_dict_config - else: - assert isinstance( - submodule._state_dict_config, - type(prev_state_dict_config)), ( - 'All FSDP modules must have the same type of ' - 'state_dict_config.') - if prev_optim_state_dict_config is None: - prev_optim_state_dict_config = \ - submodule._optim_state_dict_config - else: - assert isinstance( - submodule._optim_state_dict_config, - type(prev_optim_state_dict_config), - ), ('All FSDP modules must have the same type of ' - 'optim_state_dict_config.') - - submodule._state_dict_type = state_dict_type - submodule._state_dict_config = state_dict_config - submodule._optim_state_dict_config = optim_state_dict_config - - return StateDictSettings(prev_state_dict_type, - prev_state_dict_config, - prev_optim_state_dict_config) diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py deleted file mode 100644 index ac9c2383c3..0000000000 --- a/mmengine/model/wrappers/seperate_distributed.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from contextlib import ExitStack, contextmanager -from typing import Dict, Union - -import torch -import torch.nn as nn -from torch.nn.parallel.distributed import DistributedDataParallel - -from mmengine.device import get_device -from mmengine.optim import OptimWrapperDict -from mmengine.registry import MODEL_WRAPPERS -from .distributed import MMDistributedDataParallel - - -@MODEL_WRAPPERS.register_module() -class MMSeparateDistributedDataParallel(DistributedDataParallel): - """A DistributedDataParallel wrapper for models in MMGeneration. - - In MMedting and MMGeneration there is a need to wrap different modules in - the models with separate DistributedDataParallel. Otherwise, it will cause - errors for GAN training. For example, the GAN model, usually has two - submodules: generator and discriminator. If we wrap both of them in one - standard DistributedDataParallel, it will cause errors during training, - because when we update the parameters of the generator (or discriminator), - the parameters of the discriminator (or generator) is not updated, which is - not allowed for DistributedDataParallel. So we design this wrapper to - separately wrap DistributedDataParallel for generator and discriminator. - In this wrapper, we perform two operations: - - 1. Wraps each module in the models with separate MMDistributedDataParallel. - Note that only modules with parameters will be wrapped. - 2. Calls ``train_step``, ``val_step`` and ``test_step`` of submodules to - get losses and predictions. - - Args: - module (nn.Module): model contain multiple submodules which have - separately updating strategy. - broadcast_buffers (bool): Same as that in - ``torch.nn.parallel.distributed.DistributedDataParallel``. - Defaults to False. - find_unused_parameters (bool): Same as that in - ``torch.nn.parallel.distributed.DistributedDataParallel``. - Traverse the autograd graph of all tensors contained in returned - value of the wrapped module's forward function. Defaults to False. - **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``. - - - device_ids (List[int] or torch.device, optional): CUDA devices - for module. - - output_device (int or torch.device, optional): Device location of - output for single-device CUDA modules. - - dim (int): Defaults to 0. - - process_group (ProcessGroup, optional): The process group to be - used for distributed data all-reduction. - - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults - to 25. - - check_reduction (bool): This argument is deprecated. Defaults - to False. - - gradient_as_bucket_view (bool): Defaults to False. - - static_graph (bool): Defaults to False. - - See more information about arguments in - :class:`torch.nn.parallel.DistributedDataParallel`. - """ - - def __init__(self, - module: nn.Module, - broadcast_buffers: bool = False, - find_unused_parameters: bool = False, - **kwargs): - super(DistributedDataParallel, self).__init__() - self.module = module - device = get_device() - # Wrap the submodule with parameters of `self.module` to - # `MMDistributedDataParallel` - for name, sub_module in module._modules.items(): - # module without parameters. - if next(sub_module.parameters(), None) is None: - sub_module = sub_module.to(device) - elif all(not p.requires_grad for p in sub_module.parameters()): - sub_module = sub_module.to(device) - else: - sub_module = MMDistributedDataParallel( - module=sub_module.to(device), - broadcast_buffers=broadcast_buffers, - find_unused_parameters=find_unused_parameters, - **kwargs) - module._modules[name] = sub_module - - def train_step(self, data: Union[dict, tuple, list], - optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: - """Interface for model forward, backward and parameters updating during - training process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - optim_wrapper (OptimWrapperDict): A wrapper of optimizer to - update parameters. - - Returns: - Dict[str, torch.Tensor]: A dict of tensor for logging. - """ - return self.module.train_step(data, optim_wrapper) - - def val_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.val_step(data) - - def test_step(self, data: Union[dict, tuple, list]) -> list: - """Gets the predictions of module during testing process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.test_step(data) - - @contextmanager - def no_sync(self): - """Enables ``no_sync`` context of all sub ``MMDistributedDataParallel`` - modules.""" - with ExitStack() as stack: - for sub_ddp_model in self.module._modules.values(): - stack.enter_context(sub_ddp_model.no_sync()) - yield - - def train(self, mode: bool = True) -> 'MMSeparateDistributedDataParallel': - """Sets the module in training mode. - - In order to make the ddp wrapper inheritance hierarchy more uniform, - ``MMSeparateDistributedDataParallel`` inherits from - ``DistributedDataParallel``, but will not call its constructor. - Since the attributes of ``DistributedDataParallel`` have not been - initialized, call the ``train`` method of ``DistributedDataParallel`` - will raise an error if pytorch version <= 1.9. Therefore, override - this method to call the ``train`` method of submodules. - - Args: - mode (bool): whether to set training mode (``True``) or evaluation - mode (``False``). Defaults to ``True``. - - Returns: - Module: self. - """ - self.training = mode - self.module.train(mode) - return self diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py deleted file mode 100644 index 86e1e123b9..0000000000 --- a/mmengine/model/wrappers/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - -from mmengine.registry import MODEL_WRAPPERS, Registry - - -def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS): - """Check if a module is a model wrapper. - - The following 4 model in MMEngine (and their subclasses) are regarded as - model wrappers: DataParallel, DistributedDataParallel, - MMDataParallel, MMDistributedDataParallel. You may add you own - model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``. - - Args: - model (nn.Module): The model to be checked. - registry (Registry): The parent registry to search for model wrappers. - - Returns: - bool: True if the input model is a model wrapper. - """ - module_wrappers = tuple(registry.module_dict.values()) - if isinstance(model, module_wrappers): - return True - - if not registry.children: - return False - - return any( - is_model_wrapper(model, child) for child in registry.children.values()) diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py deleted file mode 100644 index c0a2ec6e37..0000000000 --- a/mmengine/optim/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, - AmpOptimWrapper, ApexOptimWrapper, BaseOptimWrapper, - DefaultOptimWrapperConstructor, OptimWrapper, - OptimWrapperDict, ZeroRedundancyOptimizer, - build_optim_wrapper) -# yapf: disable -from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, - CosineAnnealingLR, CosineAnnealingMomentum, - CosineAnnealingParamScheduler, ExponentialLR, - ExponentialMomentum, ExponentialParamScheduler, - LinearLR, LinearMomentum, LinearParamScheduler, - MultiStepLR, MultiStepMomentum, - MultiStepParamScheduler, OneCycleLR, - OneCycleParamScheduler, PolyLR, PolyMomentum, - PolyParamScheduler, ReduceOnPlateauLR, - ReduceOnPlateauMomentum, ReduceOnPlateauParamScheduler, - StepLR, StepMomentum, StepParamScheduler, - _ParamScheduler) - -# yapf: enable -__all__ = [ - 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optim_wrapper', - 'DefaultOptimWrapperConstructor', 'ConstantLR', 'CosineAnnealingLR', - 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum', - 'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum', - 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler', - 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler', - 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', - '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper', - 'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', - 'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR', - 'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler', - 'ZeroRedundancyOptimizer', 'BaseOptimWrapper' -] diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py deleted file mode 100644 index ebf1f1e3a5..0000000000 --- a/mmengine/optim/optimizer/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .amp_optimizer_wrapper import AmpOptimWrapper -from .apex_optimizer_wrapper import ApexOptimWrapper -from .base import BaseOptimWrapper -from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, - build_optim_wrapper) -from .default_constructor import DefaultOptimWrapperConstructor -from .optimizer_wrapper import OptimWrapper -from .optimizer_wrapper_dict import OptimWrapperDict -from .zero_optimizer import ZeroRedundancyOptimizer - -__all__ = [ - 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', - 'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper', - 'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict', - 'ZeroRedundancyOptimizer', 'BaseOptimWrapper' -] diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py deleted file mode 100644 index 4f3323f2cc..0000000000 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from contextlib import contextmanager -from typing import Union - -import torch -import torch.nn as nn - -from mmengine.device import (is_cuda_available, is_mlu_available, - is_musa_available, is_npu_available) -from mmengine.registry import OPTIM_WRAPPERS -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION -from .optimizer_wrapper import OptimWrapper - -if is_npu_available(): - from torch.npu.amp import GradScaler -elif is_mlu_available(): - from torch.mlu.amp import GradScaler -else: - from torch.cuda.amp import GradScaler - - -@OPTIM_WRAPPERS.register_module() -class AmpOptimWrapper(OptimWrapper): - """A subclass of :class:`OptimWrapper` that supports automatic mixed - precision training based on torch.cuda.amp. - - ``AmpOptimWrapper`` provides a unified interface with - ``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way - as ``OptimWrapper``. - - Warnings: - ``AmpOptimWrapper`` requires PyTorch >= 1.6. - - Args: - loss_scale (float or str or dict): The initial configuration of - `torch.cuda.amp.GradScaler`. See more specific arguments - introduction at `PyTorch AMP `_ # noqa: E501 - Defaults to ``dynamic``. - - - "dynamic": Initialize GradScale without any arguments. - - float: Initialize GradScaler with ``init_scale``. - - dict: Initialize GradScaler with more detail configuration. - - dtype (str or torch.dtype, optional): The data type to autocast in amp. - If a ``str`` is given, it will be converted to ``torch.dtype``. - Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and - `'float64'`. If set to ``None``, the default data type will be used. - Defaults to None. - `New in version 0.6.1.` - use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should - be enabled when using ``FullyShardedDataParallel``. - Defaults to False. - `New in version 0.8.0.` - **kwargs: Keyword arguments passed to OptimWrapper. - - Warnings: - ``dtype`` argument is only available with PyTorch version >= 1.10.0. If - you use PyTorch of an older version, it will be ignored. - - Note: - If you use ``IterBasedRunner`` and enable gradient accumulation, - the original `max_iters` should be multiplied by - ``accumulative_counts``. - """ - - valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') - - def __init__(self, - loss_scale: str = 'dynamic', - dtype: Union[str, torch.dtype] = None, - use_fsdp: bool = False, - **kwargs): - assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( - '`torch.cuda.amp` is only available when pytorch version >= 1.6') - assert is_cuda_available() or is_npu_available() or is_mlu_available( - ) or is_musa_available(), ( - '``AmpOptimizerWrapper`` is only available training ' - 'on gpu, npu, mlu or musa') - super().__init__(**kwargs) - self._scale_update_param = None - - if use_fsdp: - if digit_version(torch.__version__) >= digit_version('2.0.0'): - from torch.distributed.fsdp.sharded_grad_scaler import \ - ShardedGradScaler - scaler_type = ShardedGradScaler - else: - raise RuntimeError( - 'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') - else: - scaler_type = GradScaler - - if loss_scale == 'dynamic': - # If loss_scale is a string, it must be 'dynamic', then dynamic - # loss scaling will be used. - self.loss_scaler = scaler_type() - elif isinstance(loss_scale, float): - # Static loss scaling - self._scale_update_param = loss_scale - self.loss_scaler = scaler_type(init_scale=loss_scale) - elif isinstance(loss_scale, dict): - # More specific configuration. - self.loss_scaler = scaler_type(**loss_scale) - else: - raise TypeError('loss_scale must be of type float, dict, or ' - f'"dynamic", but got {loss_scale}') - - # convert string value to torch.dtype - if isinstance(dtype, str): - assert dtype in self.valid_dtypes, ( - f'dtype should be any of {self.valid_dtypes}, got {dtype}') - dtype = getattr(torch, dtype) - - assert dtype is None or isinstance(dtype, torch.dtype), ( - f'dtype should be None or instance of torch.dtype, got {dtype}') - self.cast_dtype = dtype - - def backward(self, loss: torch.Tensor, **kwargs): - """Perform gradient back propagation with :attr:`loss_scaler`. - - Args: - loss (torch.Tensor): The loss of current iteration. - kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` - """ - self.loss_scaler.scale(loss).backward(**kwargs) - self._inner_count += 1 - - def step(self, **kwargs): - """Update parameters with :attr:`loss_scaler`. - - Args: - kwargs: Keyword arguments passed to - :meth:`torch.optim.Optimizer.step`. - """ - if self.clip_grad_kwargs: - self.loss_scaler.unscale_(self.optimizer) - self._clip_grad() - self.loss_scaler.step(self.optimizer, **kwargs) - self.loss_scaler.update(self._scale_update_param) - - def state_dict(self) -> dict: - """Get the state dictionary of :attr:`optimizer` and - :attr:`loss_scaler`. - - Based on the state dictionary of the optimizer, the returned state - dictionary will add a key named "loss_scaler". - - Returns: - dict: The merged state dict of :attr:`loss_scaler` and - :attr:`optimizer`. - """ - # save state_dict of loss_scaler - state_dict = super().state_dict() - state_dict['loss_scaler'] = self.loss_scaler.state_dict() - return state_dict - - def load_state_dict(self, state_dict: dict): - """Load and parse the state dictionary of :attr:`optimizer` and - :attr:`loss_scaler`. - - If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will - load the corresponding keys. Otherwise, only the :attr:`optimizer` - will load the state dictionary. - - Args: - state_dict (dict): The state dict of :attr:`optimizer` and - :attr:`loss_scaler` - """ - if 'loss_scaler' in state_dict: - self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) - - if 'base_param_settings' in state_dict: - self.base_param_settings = state_dict.pop('base_param_settings') - - # load state_dict of optimizer - self.optimizer.load_state_dict(state_dict) - - @contextmanager - def optim_context(self, model: nn.Module): - """Enables the context for mixed precision training, and enables the - context for disabling gradient synchronization during gradient - accumulation context. - - Args: - model (nn.Module): The training model. - """ - from mmengine.runner.amp import autocast - with super().optim_context(model), autocast(dtype=self.cast_dtype): - yield diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py deleted file mode 100644 index a2e6190460..0000000000 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from contextlib import contextmanager -from typing import Optional, Union - -import torch -import torch.nn as nn - -# a circular import will be caused by -# from mmengine.model.wrappers import is_model_wrapper -import mmengine -from mmengine.registry import OPTIM_WRAPPERS -from .optimizer_wrapper import OptimWrapper - -try: - import apex.amp as apex_amp -except ImportError: - apex_amp = None - - -@OPTIM_WRAPPERS.register_module() -class ApexOptimWrapper(OptimWrapper): - """A subclass of :class:`OptimWrapper` that supports automatic mixed - precision training based on apex.amp. - - ``ApexOptimWrapper`` provides a unified interface with - ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``. - - Warning: - ``ApexOptimWrapper`` requires `nvidia apex `_ - - Args: - opt_level (str): Pure or mixed precision optimization level. Accepted - values are "O0", "O1", "O2", and "O3". Defaults to "O1". - loss_scale (float or str, optional): If passed as a string, must be a - string representing a number, e.g., "128.0", or the string - "dynamic". Defaults to "dynamic". - enabled (bool): If False, renders all Amp calls no-ops, so your script - should run as if Amp were not present. Defaults to True. - cast_model_type (torch.dtype, optional): Model's parameters and - buffers to the desired type. Defaults to None. - patch_torch_functions (bool, optional): Patch all Torch functions - and Tensor methods to perform Tensor Core-friendly ops like GEMMs - and convolutions in FP16, and any ops that benefit from FP32 - precision in FP32. Defaults to None. - keep_batchnorm_fp32 (bool or str, optional): To enhance precision - and enable cudnn batchnorm (which improves performance), - it's often beneficial to keep batchnorm weights in FP32 - even if the rest of the model is FP16. - If passed as a string, must be the string "True" or "False". - Defaults to None. - master_weights (bool, optional): Maintain FP32 master weights to - accompany any FP16 model weights. FP32 master weights are stepped - by the optimizer to enhance precision and capture small gradients. - Defaults to None. - cast_model_outputs (torch.dtype, optional): Option to ensure that - the outputs of your model(s) are always cast to a particular type - regardless of ``opt_level``. Defaults to None. - num_losses (int): Option to tell Amp in advance how many - losses/backward passes you plan to use. Defaults to 1. - verbosity (int): Set to 0 to suppress Amp-related output. - Defaults to 1. - min_loss_scale (float, optional): Sets a floor for the loss scale - values that can be chosen by dynamic loss scaling. - The default value of None means that no floor is imposed. - If dynamic loss scaling is not used, `min_loss_scale` is ignored. - Defaults to None. - max_loss_scale (float, optional): Sets a ceiling for the loss scale - values that can be chosen by dynamic loss scaling. If dynamic - loss scaling is not used, `max_loss_scale` is ignored. - Defaults to 2.**24. - **kwargs: Keyword arguments passed to OptimWrapper. - - Note: - If you use ``IterBasedRunner`` and enable gradient accumulation, - the original `max_iters` should be multiplied by - ``accumulative_counts``. - - Note: - `New in version 0.6.0.` - """ # noqa: E501 - - def __init__(self, - opt_level: str = 'O1', - loss_scale: Union[float, str, None] = 'dynamic', - enabled: Optional[bool] = True, - cast_model_type: Optional[torch.dtype] = None, - patch_torch_functions: Optional[bool] = None, - keep_batchnorm_fp32: Union[bool, str, None] = None, - master_weights: Optional[bool] = None, - cast_model_outputs: Optional[torch.dtype] = None, - num_losses: int = 1, - verbosity: int = 1, - min_loss_scale: Optional[float] = None, - max_loss_scale: Optional[float] = 2.**24, - **kwargs): - assert apex_amp is not None, \ - 'Apex is not installed. Please check ' \ - 'https://github.com/NVIDIA/apex#linux.' - super().__init__(**kwargs) - self.opt_level = opt_level - self.loss_scale = loss_scale - self.enabled = enabled - self.cast_model_type = cast_model_type - self.patch_torch_functions = patch_torch_functions - self.keep_batchnorm_fp32 = keep_batchnorm_fp32 - self.master_weights = master_weights - self.cast_model_outputs = cast_model_outputs - self.num_losses = num_losses - self.verbosity = verbosity - self.min_loss_scale = min_loss_scale - self.max_loss_scale = max_loss_scale - self._apex_amp_state_dict = None - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Perform gradient back propagation with :attr:`loss_scaler`. - - Args: - loss (torch.Tensor): The loss of current iteration. - kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` - """ - with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward(**kwargs) - self._inner_count += 1 - - def state_dict(self) -> dict: - """Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`. - - Based on the state dictionary of the optimizer, the returned state - dictionary will add a key named "apex_amp". - - Returns: - dict: The merged state dict of :attr:`apex_amp` and - :attr:`optimizer`. - """ - state_dict = self.optimizer.state_dict() - state_dict['apex_amp'] = apex_amp.state_dict() - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - """Load and parse the state dictionary of :attr:`optimizer` and - :attr:`apex_amp`. - - If state_dict contains "apex_amp", the :attr:`apex_amp` will - load the corresponding keys. Otherwise, only the :attr:`optimizer` - will load the state dictionary. - - Note: - :meth:`load_state_dict` shuold be called after - `apex_amp.initialize` is called. - Args: - state_dict (dict): The state dict of :attr:`optimizer` and - :attr:`apex_amp` - """ - if 'apex_amp' in state_dict: - # when `apex_amp` is not initialized, calling `load_state_dict` - # will raise an error, so we temporarily cache the apex_amp - # part, and then load it into `apex_amp` after completing - # the `apex_amp` initialization in `optim_context` method - if hasattr(self.optimizer, '_amp_stash'): - apex_amp.load_state_dict(state_dict.pop('apex_amp')) - else: - self._apex_amp_state_dict = state_dict.pop('apex_amp') - self.optimizer.load_state_dict(state_dict) - - @contextmanager - def optim_context(self, model: nn.Module): - """Enables the context for mixed precision training, and enables the - context for disabling gradient synchronization during gradient - accumulation context. - - Args: - model (nn.Module): The training model. - """ - with super().optim_context(model): - # when a given optimizer be passed through apex_amp.initialize, - # the "_amp_stash" property will be added - if not hasattr(self.optimizer, '_amp_stash'): - if mmengine.model.wrappers.is_model_wrapper(model): - model = model.module - model, self.optimizer = apex_amp.initialize( - model, - self.optimizer, - opt_level=self.opt_level, - loss_scale=self.loss_scale, - enabled=self.enabled, - cast_model_type=self.cast_model_type, - patch_torch_functions=self.patch_torch_functions, - keep_batchnorm_fp32=self.keep_batchnorm_fp32, - master_weights=self.master_weights, - cast_model_outputs=self.cast_model_outputs, - num_losses=self.num_losses, - verbosity=self.verbosity, - min_loss_scale=self.min_loss_scale, - max_loss_scale=self.max_loss_scale) - # loading apex_amp state_dict after initialization of apex_amp - if self._apex_amp_state_dict is not None: - apex_amp.load_state_dict(self._apex_amp_state_dict) - self._apex_amp_state_dict = None - yield diff --git a/mmengine/optim/optimizer/base.py b/mmengine/optim/optimizer/base.py deleted file mode 100644 index ee53f508b1..0000000000 --- a/mmengine/optim/optimizer/base.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod -from typing import Dict, List - -import torch - - -class BaseOptimWrapper(metaclass=ABCMeta): - - def __init__(self, optimizer): - self.optimizer = optimizer - - # The Following code is used to initialize `base_param_settings`. - # `base_param_settings` is used to store the parameters that are not - # updated by the optimizer. - # The `base_param_settings` used for tracking the base learning in the - # optimizer. If the optimizer has multiple parameter groups, this - # params will not be scaled by the loss factor. - if len(optimizer.param_groups) > 1: - self.base_param_settings = { - 'params': torch.tensor([0.0], dtype=torch.float) - } - self.base_param_settings.update(**self.optimizer.defaults) - else: - self.base_param_settings = None # type: ignore - - @abstractmethod - def update_params(self, *args, **kwargs): - """Update parameters in :attr:`optimizer`.""" - - @abstractmethod - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Perform gradient back propagation.""" - - @abstractmethod - def zero_grad(self, **kwargs) -> None: - """A wrapper of ``Optimizer.zero_grad``.""" - - @abstractmethod - def step(self, **kwargs): - """Call the step method of optimizer.""" - - def state_dict(self) -> dict: - """A wrapper of ``Optimizer.state_dict``.""" - state_dict = self.optimizer.state_dict() - if self.base_param_settings is not None: - state_dict['base_param_settings'] = self.base_param_settings - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - """A wrapper of ``Optimizer.load_state_dict``. load the state dict of - :attr:`optimizer`. - - Provide unified ``load_state_dict`` interface compatible with automatic - mixed precision training. Subclass can overload this method to - implement the required logic. For example, the state dictionary of - GradScaler should be loaded when training with ``torch.cuda.amp``. - - Args: - state_dict (dict): The state dictionary of :attr:`optimizer`. - """ - base_param_settings = state_dict.pop('base_param_settings', None) - - if base_param_settings is not None: - self.base_param_settings = base_param_settings - - # load state_dict of optimizer - self.optimizer.load_state_dict(state_dict) - - @property - def param_groups(self) -> List[dict]: - """A wrapper of ``Optimizer.param_groups``. - - Make OptimizeWrapper compatible with :class:`_ParamScheduler`. - - Returns: - dict: the ``param_groups`` of :attr:`optimizer`. - """ - if self.base_param_settings is not None: - return self.optimizer.param_groups + [self.base_param_settings] - else: - return self.optimizer.param_groups - - @property - def defaults(self) -> dict: - """A wrapper of ``Optimizer.defaults``. - - Make OptimizeWrapper compatible with :class:`_ParamScheduler`. - - Returns: - dict: the ``param_groups`` of :attr:`optimizer`. - """ - return self.optimizer.defaults - - def get_lr(self): - """Get the learning rate of the optimizer. - - Provide unified interface to get learning rate of optimizer. - - Returns: - Dict[str, List[float]]: - param_groups learning rate of the optimizer. - """ - res = {} - if self.base_param_settings is not None: - res['base_lr'] = [self.base_param_settings['lr']] - - res['lr'] = [group['lr'] for group in self.optimizer.param_groups] - - return res - - def get_momentum(self) -> Dict[str, List[float]]: - """Get the momentum of the optimizer. - - Provide unified interface to get momentum of optimizer. - - Returns: - Dict[str, List[float]]: Momentum of the optimizer. - """ - momentum = [] - for group in self.optimizer.param_groups: - # Get momentum of SGD. - if 'momentum' in group.keys(): - momentum.append(group['momentum']) - # Get momentum of Adam. - elif 'betas' in group.keys(): - momentum.append(group['betas'][0]) - else: - momentum.append(0) - return dict(momentum=momentum) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py deleted file mode 100644 index fef95f729a..0000000000 --- a/mmengine/optim/optimizer/builder.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import inspect -from typing import List, Union - -import torch -import torch.nn as nn - -from mmengine.config import Config, ConfigDict -from mmengine.device import is_npu_available, is_npu_support_full_precision -from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS -from .optimizer_wrapper import OptimWrapper - - -def register_torch_optimizers() -> List[str]: - """Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - torch_optimizers = [] - for module_name in dir(torch.optim): - if module_name.startswith('__'): - continue - _optim = getattr(torch.optim, module_name) - if inspect.isclass(_optim) and issubclass(_optim, - torch.optim.Optimizer): - if module_name == 'Adafactor': - OPTIMIZERS.register_module( - name='TorchAdafactor', module=_optim) - else: - OPTIMIZERS.register_module(module=_optim) - torch_optimizers.append(module_name) - return torch_optimizers - - -TORCH_OPTIMIZERS = register_torch_optimizers() - - -def register_torch_npu_optimizers() -> List[str]: - """Register optimizers in ``torch npu`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - if not is_npu_available(): - return [] - - import torch_npu - if not hasattr(torch_npu, 'optim'): - return [] - - torch_npu_optimizers = [] - for module_name in dir(torch_npu.optim): - if module_name.startswith('__') or module_name in OPTIMIZERS: - continue - _optim = getattr(torch_npu.optim, module_name) - if inspect.isclass(_optim) and issubclass(_optim, - torch.optim.Optimizer): - OPTIMIZERS.register_module(module=_optim) - torch_npu_optimizers.append(module_name) - return torch_npu_optimizers - - -NPU_OPTIMIZERS = register_torch_npu_optimizers() - - -def register_dadaptation_optimizers() -> List[str]: - """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - dadaptation_optimizers = [] - try: - import dadaptation - except ImportError: - pass - else: - for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']: - _optim = getattr(dadaptation, module_name) - if inspect.isclass(_optim) and issubclass(_optim, - torch.optim.Optimizer): - OPTIMIZERS.register_module(module=_optim) - dadaptation_optimizers.append(module_name) - return dadaptation_optimizers - - -DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() - - -def register_lion_optimizers() -> List[str]: - """Register Lion optimizer to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - optimizers = [] - try: - from lion_pytorch import Lion - except ImportError: - pass - else: - OPTIMIZERS.register_module(module=Lion) - optimizers.append('Lion') - return optimizers - - -LION_OPTIMIZERS = register_lion_optimizers() - - -def register_sophia_optimizers() -> List[str]: - """Register Sophia optimizer to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - optimizers = [] - try: - import Sophia - except ImportError: - pass - else: - for module_name in dir(Sophia): - _optim = getattr(Sophia, module_name) - if inspect.isclass(_optim) and issubclass(_optim, - torch.optim.Optimizer): - try: - OPTIMIZERS.register_module(module=_optim) - except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") - return optimizers - - -SOPHIA_OPTIMIZERS = register_sophia_optimizers() - - -def register_bitsandbytes_optimizers() -> List[str]: - """Register optimizers in ``bitsandbytes`` to the ``OPTIMIZERS`` registry. - - In the `bitsandbytes` library, optimizers that have the same name as the - default optimizers in PyTorch are prefixed with ``bnb_``. For example, - ``bnb_Adagrad``. - - Returns: - List[str]: A list of registered optimizers' name. - """ - dadaptation_optimizers = [] - try: - import bitsandbytes as bnb - # import bnb may trigger cuda related error without nvidia gpu resources - except (ImportError, RuntimeError): - pass - else: - optim_classes = inspect.getmembers( - bnb.optim, lambda _optim: (inspect.isclass(_optim) and issubclass( - _optim, torch.optim.Optimizer))) - for name, optim_cls in optim_classes: - if name in OPTIMIZERS: - name = f'bnb_{name}' - try: - OPTIMIZERS.register_module(module=optim_cls, name=name) - except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") - dadaptation_optimizers.append(name) - return dadaptation_optimizers - - -BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() - - -def register_transformers_optimizers(): - transformer_optimizers = [] - try: - from transformers import Adafactor - except ImportError: - pass - else: - try: - OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) - except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") - transformer_optimizers.append('Adafactor') - return transformer_optimizers - - -TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers() - - -def build_optim_wrapper(model: nn.Module, - cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper: - """Build function of OptimWrapper. - - If ``constructor`` is set in the ``cfg``, this method will build an - optimizer wrapper constructor, and use optimizer wrapper constructor to - build the optimizer wrapper. If ``constructor`` is not set, the - ``DefaultOptimWrapperConstructor`` will be used by default. - - Args: - model (nn.Module): Model to be optimized. - cfg (dict): Config of optimizer wrapper, optimizer constructor and - optimizer. - - Returns: - OptimWrapper: The built optimizer wrapper. - """ - optim_wrapper_cfg = copy.deepcopy(cfg) - constructor_type = optim_wrapper_cfg.pop('constructor', - 'DefaultOptimWrapperConstructor') - paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) - - # Since the current generation of NPU(Ascend 910) only supports - # mixed precision training, here we turn on mixed precision - # to make the training normal - if is_npu_available() and not is_npu_support_full_precision(): - optim_wrapper_cfg['type'] = 'AmpOptimWrapper' - - optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - dict( - type=constructor_type, - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg)) - optim_wrapper = optim_wrapper_constructor(model) - return optim_wrapper diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py deleted file mode 100644 index b623a3e70e..0000000000 --- a/mmengine/optim/optimizer/default_constructor.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -from typing import List, Optional, Union - -import torch -import torch.nn as nn -from torch.nn import GroupNorm, LayerNorm - -from mmengine.logging import print_log -from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, - OPTIMIZERS) -from mmengine.utils import is_list_of -from mmengine.utils.dl_utils import mmcv_full_available -from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm -from .optimizer_wrapper import OptimWrapper - - -@OPTIM_WRAPPER_CONSTRUCTORS.register_module() -class DefaultOptimWrapperConstructor: - """Default constructor for optimizers. - - By default, each parameter share the same optimizer settings, and we - provide an argument ``paramwise_cfg`` to specify parameter-wise settings. - It is a dict and may contain the following fields: - - - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If - one of the keys in ``custom_keys`` is a substring of the name of one - parameter, then the setting of the parameter will be specified by - ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will - be ignored. It should be noted that the aforementioned ``key`` is the - longest key that is a substring of the name of the parameter. If there - are multiple matched keys with the same length, then the key with lower - alphabet order will be chosen. - ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` - and ``decay_mult``. See Example 2 below. - - ``bias_lr_mult`` (float): It will be multiplied to the learning - rate for all bias parameters (except for those in normalization - layers and offset layers of DCN). - - ``bias_decay_mult`` (float): It will be multiplied to the weight - decay for all bias parameters (except for those in - normalization layers, depthwise conv layers, offset layers of DCN). - - ``norm_decay_mult`` (float): It will be multiplied to the weight - decay for all weight and bias parameters of normalization - layers. - - ``flat_decay_mult`` (float): It will be multiplied to the weight - decay for all one-dimensional parameters - - ``dwconv_decay_mult`` (float): It will be multiplied to the weight - decay for all weight and bias parameters of depthwise conv - layers. - - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning - rate for parameters of offset layer in the deformable convs - of a model. - - ``bypass_duplicate`` (bool): If true, the duplicate parameters - would not be added into optimizer. Defaults to False. - - Note: - - 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will - override the effect of ``bias_lr_mult`` in the bias of offset layer. - So be careful when using both ``bias_lr_mult`` and - ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset - layer in deformable convs, set ``dcn_offset_lr_mult`` to the original - ``dcn_offset_lr_mult`` * ``bias_lr_mult``. - - 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will - apply it to all the DCN layers in the model. So be careful when the - model contains multiple DCN layers in places other than backbone. - - Args: - optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. - - Required fields of ``optim_wrapper_cfg`` are - - - ``type``: class name of the OptimizerWrapper - - ``optimizer``: The configuration of optimizer. - - Optional fields of ``optim_wrapper_cfg`` are - - - any arguments of the corresponding optimizer wrapper type, - e.g., accumulative_counts, clip_grad, etc. - - Required fields of ``optimizer`` are - - - `type`: class name of the optimizer. - - Optional fields of ``optimizer`` are - - - any arguments of the corresponding optimizer type, e.g., - lr, weight_decay, momentum, etc. - - paramwise_cfg (dict, optional): Parameter-wise options. - - Example 1: - >>> model = torch.nn.modules.Conv1d(1, 1, 1) - >>> optim_wrapper_cfg = dict( - >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, - >>> momentum=0.9, weight_decay=0.0001)) - >>> paramwise_cfg = dict(norm_decay_mult=0.) - >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( - >>> optim_wrapper_cfg, paramwise_cfg) - >>> optim_wrapper = optim_wrapper_builder(model) - - Example 2: - >>> # assume model have attribute model.backbone and model.cls_head - >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( - >>> type='SGD', lr=0.01, weight_decay=0.95)) - >>> paramwise_cfg = dict(custom_keys={ - >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) - >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( - >>> optim_wrapper_cfg, paramwise_cfg) - >>> optim_wrapper = optim_wrapper_builder(model) - >>> # Then the `lr` and `weight_decay` for model.backbone is - >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for - >>> # model.cls_head is (0.01, 0.95). - """ - - def __init__(self, - optim_wrapper_cfg: dict, - paramwise_cfg: Optional[dict] = None): - if not isinstance(optim_wrapper_cfg, dict): - raise TypeError('optimizer_cfg should be a dict', - f'but got {type(optim_wrapper_cfg)}') - assert 'optimizer' in optim_wrapper_cfg, ( - '`optim_wrapper_cfg` must contain "optimizer" config') - self.optim_wrapper_cfg = optim_wrapper_cfg.copy() - self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer') - self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg - self.base_lr = self.optimizer_cfg.get('lr', None) - self.base_wd = self.optimizer_cfg.get('weight_decay', None) - self._validate_cfg() - - def _validate_cfg(self) -> None: - """Verify the correctness of the config.""" - if not isinstance(self.paramwise_cfg, dict): - raise TypeError('paramwise_cfg should be None or a dict, ' - f'but got {type(self.paramwise_cfg)}') - - if 'custom_keys' in self.paramwise_cfg: - if not isinstance(self.paramwise_cfg['custom_keys'], dict): - raise TypeError( - 'If specified, custom_keys must be a dict, ' - f'but got {type(self.paramwise_cfg["custom_keys"])}') - if self.base_wd is None: - for key in self.paramwise_cfg['custom_keys']: - if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]: - raise ValueError('base_wd should not be None') - - # get base lr and weight decay - # weight_decay must be explicitly specified if mult is specified - if ('bias_decay_mult' in self.paramwise_cfg - or 'norm_decay_mult' in self.paramwise_cfg - or 'dwconv_decay_mult' in self.paramwise_cfg): - if self.base_wd is None: - raise ValueError('base_wd should not be None') - - def _is_in(self, param_group: dict, param_group_list: list) -> bool: - """Check whether the `param_group` is in the`param_group_list`""" - assert is_list_of(param_group_list, dict) - param = set(param_group['params']) - param_set = set() - for group in param_group_list: - param_set.update(set(group['params'])) - - return not param.isdisjoint(param_set) - - def add_params(self, - params: List[dict], - module: nn.Module, - prefix: str = '', - is_dcn_module: Optional[Union[int, float]] = None) -> None: - """Add all parameters of module to the params list. - - The parameters of the given module will be added to the list of param - groups, with specific rules defined by paramwise_cfg. - - Args: - params (list[dict]): A list of param groups, it will be modified - in place. - module (nn.Module): The module to be added. - prefix (str): The prefix of the module - is_dcn_module (int|float|None): If the current module is a - submodule of DCN, `is_dcn_module` will be passed to - control conv_offset layer's learning rate. Defaults to None. - """ - # get param-wise options - custom_keys = self.paramwise_cfg.get('custom_keys', {}) - # first sort with alphabet order and then sort with reversed len of str - sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) - - bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None) - bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) - norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) - dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None) - flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) - bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) - dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None) - - # special rules for norm layers and depth-wise conv layers - is_norm = isinstance(module, - (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) - is_dwconv = ( - isinstance(module, torch.nn.Conv2d) - and module.in_channels == module.groups) - - for name, param in module.named_parameters(recurse=False): - param_group = {'params': [param]} - if bypass_duplicate and self._is_in(param_group, params): - print_log( - f'{prefix} is duplicate. It is skipped since ' - f'bypass_duplicate={bypass_duplicate}', - logger='current', - level=logging.WARNING) - continue - if not param.requires_grad: - print_log((f'{prefix}.{name} is skipped since its ' - f'requires_grad={param.requires_grad}'), - logger='current', - level=logging.WARNING) - continue - - # if the parameter match one of the custom keys, ignore other rules - is_custom = False - for key in sorted_keys: - if key in f'{prefix}.{name}': - is_custom = True - lr_mult = custom_keys[key].get('lr_mult', 1.) - param_group['lr'] = self.base_lr * lr_mult - if self.base_wd is not None: - decay_mult = custom_keys[key].get('decay_mult', 1.) - param_group['weight_decay'] = self.base_wd * decay_mult - # add custom settings to param_group - for k, v in custom_keys[key].items(): - param_group[k] = v - break - - if not is_custom: - # bias_lr_mult affects all bias parameters - # except for norm.bias dcn.conv_offset.bias - if name == 'bias' and not ( - is_norm or is_dcn_module) and bias_lr_mult is not None: - param_group['lr'] = self.base_lr * bias_lr_mult - - if (prefix.find('conv_offset') != -1 and is_dcn_module - and dcn_offset_lr_mult is not None - and isinstance(module, torch.nn.Conv2d)): - # deal with both dcn_offset's bias & weight - param_group['lr'] = self.base_lr * dcn_offset_lr_mult - - # apply weight decay policies - if self.base_wd is not None: - # norm decay - if is_norm and norm_decay_mult is not None: - param_group[ - 'weight_decay'] = self.base_wd * norm_decay_mult - # bias lr and decay - elif (name == 'bias' and not is_dcn_module - and bias_decay_mult is not None): - param_group[ - 'weight_decay'] = self.base_wd * bias_decay_mult - # depth-wise conv - elif is_dwconv and dwconv_decay_mult is not None: - param_group[ - 'weight_decay'] = self.base_wd * dwconv_decay_mult - # flatten parameters except dcn offset - elif (param.ndim == 1 and not is_dcn_module - and flat_decay_mult is not None): - param_group[ - 'weight_decay'] = self.base_wd * flat_decay_mult - params.append(param_group) - for key, value in param_group.items(): - if key == 'params': - continue - full_name = f'{prefix}.{name}' if prefix else name - print_log( - f'paramwise_options -- {full_name}:{key}={value}', - logger='current') - - if mmcv_full_available(): - from mmcv.ops import DeformConv2d, ModulatedDeformConv2d - is_dcn_module = isinstance(module, - (DeformConv2d, ModulatedDeformConv2d)) - else: - is_dcn_module = False - for child_name, child_mod in module.named_children(): - child_prefix = f'{prefix}.{child_name}' if prefix else child_name - self.add_params( - params, - child_mod, - prefix=child_prefix, - is_dcn_module=is_dcn_module) - - def __call__(self, model: nn.Module) -> OptimWrapper: - if hasattr(model, 'module'): - model = model.module - - optim_wrapper_cfg = self.optim_wrapper_cfg.copy() - optim_wrapper_cfg.setdefault('type', 'OptimWrapper') - optimizer_cfg = self.optimizer_cfg.copy() - optimizer_cls = self.optimizer_cfg['type'] - # Optimizer like HybridAdam in colossalai requires the argument name - # `model_params` rather than `params`. Here we get the first argument - # name and fill it with the model parameters. - if isinstance(optimizer_cls, str): - with OPTIMIZERS.switch_scope_and_registry(None) as registry: - optimizer_cls = registry.get(self.optimizer_cfg['type']) - fisrt_arg_name = next( - iter(inspect.signature(optimizer_cls).parameters)) - # if no paramwise option is specified, just use the global setting - if not self.paramwise_cfg: - optimizer_cfg[fisrt_arg_name] = model.parameters() - optimizer = OPTIMIZERS.build(optimizer_cfg) - else: - # set param-wise lr and weight decay recursively - params: List = [] - self.add_params(params, model) - optimizer_cfg[fisrt_arg_name] = params - optimizer = OPTIMIZERS.build(optimizer_cfg) - optim_wrapper = OPTIM_WRAPPERS.build( - optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) - return optim_wrapper diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py deleted file mode 100644 index 41218ef768..0000000000 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ /dev/null @@ -1,411 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from contextlib import contextmanager -from typing import Dict, List, Optional - -import torch -import torch.nn as nn -from torch.optim import Optimizer - -from mmengine.logging import MessageHub, print_log -from mmengine.registry import OPTIM_WRAPPERS -from mmengine.utils.dl_utils import has_batch_norm -from .base import BaseOptimWrapper - - -@OPTIM_WRAPPERS.register_module() -class OptimWrapper(BaseOptimWrapper): - """Optimizer wrapper provides a common interface for updating parameters. - - Optimizer wrapper provides a unified interface for single precision - training and automatic mixed precision training with different hardware. - OptimWrapper encapsulates optimizer to provide simplified interfaces - for commonly used training techniques such as gradient accumulative and - grad clips. ``OptimWrapper`` implements the basic logic of gradient - accumulation and gradient clipping based on ``torch.optim.Optimizer``. - The subclasses only need to override some methods to implement the mixed - precision training. See more information in :class:`AmpOptimWrapper`. - - Args: - optimizer (Optimizer): Optimizer used to update model parameters. - accumulative_counts (int): The number of iterations to accumulate - gradients. The parameters will be updated per - ``accumulative_counts``. - clip_grad (dict, optional): If ``clip_grad`` is not None, it will be - the arguments of :func:`torch.nn.utils.clip_grad_norm_` or - :func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a - dict, and the keys could be set as follows: - - If the key ``type`` is not set, or ``type`` is "norm", - the accepted keys are as follows: - - - max_norm (float or int): Max norm of the gradients. - - norm_type (float or int): Type of the used p-norm. Can be - ``'inf'`` for infinity norm. - - error_if_nonfinite (bool): If True, an error is thrown if - the total norm of the gradients from :attr:`parameters` is - ``nan``, ``inf``, or ``-inf``. Defaults to False (will switch - to True in the future) - - If the key ``type`` is set to "value", the accepted keys are as - follows: - - - clip_value (float or int): maximum allowed value of the - gradients. The gradients are clipped in the range - ``(-clip_value, +clip_value)``. - - Note: - If ``accumulative_counts`` is larger than 1, perform - :meth:`update_params` under the context of ``optim_context`` - could avoid unnecessary gradient synchronization. - - Note: - If you use ``IterBasedRunner`` and enable gradient accumulation, - the original `max_iters` should be multiplied by - ``accumulative_counts``. - - Note: - The subclass should ensure that once :meth:`update_params` is called, - ``_inner_count += 1`` is automatically performed. - - Examples: - >>> # Config sample of OptimWrapper and enable clipping gradient by - >>> # norm. - >>> optim_wrapper_cfg = dict( - >>> type='OptimWrapper', - >>> _accumulative_counts=1, - >>> clip_grad=dict(max_norm=0.2)) - >>> # Config sample of OptimWrapper and enable clipping gradient by - >>> # value. - >>> optim_wrapper_cfg = dict( - >>> type='OptimWrapper', - >>> _accumulative_counts=1, - >>> clip_grad=dict(type='value', clip_value=0.2)) - >>> # Use OptimWrapper to update model. - >>> import torch.nn as nn - >>> import torch - >>> from torch.optim import SGD - >>> from torch.utils.data import DataLoader - >>> from mmengine.optim import OptimWrapper - >>> - >>> model = nn.Linear(1, 1) - >>> dataset = torch.randn(10, 1, 1) - >>> dataloader = DataLoader(dataset) - >>> optimizer = SGD(model.parameters(), lr=0.1) - >>> optim_wrapper = OptimWrapper(optimizer) - >>> - >>> for data in dataloader: - >>> loss = model(data) - >>> optim_wrapper.update_params(loss) - >>> # Enable gradient accumulation - >>> optim_wrapper_cfg = dict( - >>> type='OptimWrapper', - >>> _accumulative_counts=3, - >>> clip_grad=dict(max_norm=0.2)) - >>> ddp_model = DistributedDataParallel(model) - >>> optimizer = SGD(ddp_model.parameters(), lr=0.1) - >>> optim_wrapper = OptimWrapper(optimizer) - >>> optim_wrapper.initialize_count_status(0, len(dataloader)) - >>> # If model is a subclass instance of DistributedDataParallel, - >>> # `optim_context` context manager can avoid unnecessary gradient - >>> # synchronize. - >>> for iter, data in enumerate(dataloader): - >>> with optim_wrapper.optim_context(ddp_model): - >>> loss = model(data) - >>> optim_wrapper.update_params(loss) - """ - - def __init__(self, - optimizer: Optimizer, - accumulative_counts: int = 1, - clip_grad: Optional[dict] = None): - assert accumulative_counts > 0, ( - '_accumulative_counts at least greater than or equal to 1') - self._accumulative_counts = accumulative_counts - self.optimizer = optimizer - - if clip_grad is not None: - # clip_grad_kwargs should not be non-empty dict. - assert isinstance(clip_grad, dict) and clip_grad, ( - 'If `clip_grad` is not None, it should be a `dict` ' - 'which is the arguments of `torch.nn.utils.clip_grad_norm_` ' - 'or clip_grad_value_`.') - clip_type = clip_grad.pop('type', 'norm') - if clip_type == 'norm': - self.clip_func = torch.nn.utils.clip_grad_norm_ - self.grad_name = 'grad_norm' - elif clip_type == 'value': - self.clip_func = torch.nn.utils.clip_grad_value_ - self.grad_name = 'grad_value' - else: - raise ValueError('type of clip_grad should be "norm" or ' - f'"value" but got {clip_type}') - assert clip_grad, ('`clip_grad` should contain other arguments ' - 'besides `type`. The arguments should match ' - 'with the `torch.nn.utils.clip_grad_norm_` or ' - 'clip_grad_value_`') - self.clip_grad_kwargs = clip_grad - # Used to update `grad_norm` log message. - self.message_hub = MessageHub.get_current_instance() - self._inner_count = 0 - # `_max_counts` means the total number of parameter updates. It - # ensures that the gradient of the last few iterations will not be - # lost when the `_max_counts` is not divisible by - # `accumulative_counts`. - self._max_counts = -1 - # The `_remainder_iter` is used for calculating loss factor at the - # last few iterations. If `_max_counts` has not been initialized, - # the loss factor will always be the same as `_accumulative_counts`. - self._remainder_counts = -1 - - # The Following code is used to initialize `base_param_settings`. - # `base_param_settings` is used to store the parameters that are not - # updated by the optimizer. - # The `base_param_settings` used for tracking the base learning in the - # optimizer. If the optimizer has multiple parameter groups, this - # params will not be scaled by the loss factor. - if len(optimizer.param_groups) > 1: - self.base_param_settings = { - 'params': torch.tensor([0.0], dtype=torch.float) - } - self.base_param_settings.update(**self.optimizer.defaults) - else: - self.base_param_settings = None # type: ignore - - def update_params( # type: ignore - self, - loss: torch.Tensor, - step_kwargs: Optional[Dict] = None, - zero_kwargs: Optional[Dict] = None) -> None: - """Update parameters in :attr:`optimizer`. - - Args: - loss (torch.Tensor): A tensor for back propagation. - step_kwargs (dict): Arguments for optimizer.step. - Defaults to None. - New in version v0.4.0. - zero_kwargs (dict): Arguments for optimizer.zero_grad. - Defaults to None. - New in version v0.4.0. - """ - if step_kwargs is None: - step_kwargs = {} - if zero_kwargs is None: - zero_kwargs = {} - loss = self.scale_loss(loss) - self.backward(loss) - # Update parameters only if `self._inner_count` is divisible by - # `self._accumulative_counts` or `self._inner_count` equals to - # `self._max_counts` - if self.should_update(): - self.step(**step_kwargs) - self.zero_grad(**zero_kwargs) - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Perform gradient back propagation. - - Provide unified ``backward`` interface compatible with automatic mixed - precision training. Subclass can overload this method to implement the - required logic. For example, ``torch.cuda.amp`` require some extra - operation on GradScaler during backward process. - - Note: - If subclasses inherit from ``OptimWrapper`` override - ``backward``, ``_inner_count +=1`` must be implemented. - - Args: - loss (torch.Tensor): The loss of current iteration. - kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`. - """ - loss.backward(**kwargs) - self._inner_count += 1 - - def zero_grad(self, **kwargs) -> None: - """A wrapper of ``Optimizer.zero_grad``. - - Provide unified ``zero_grad`` interface compatible with automatic mixed - precision training. Subclass can overload this method to implement the - required logic. - - Args: - kwargs: Keyword arguments passed to - :meth:`torch.optim.Optimizer.zero_grad`. - """ - self.optimizer.zero_grad(**kwargs) - - def step(self, **kwargs) -> None: - """A wrapper of ``Optimizer.step``. - - Provide unified ``step`` interface compatible with automatic mixed - precision training. Subclass can overload this method to implement the - required logic. For example, ``torch.cuda.amp`` require some extra - operation on ``GradScaler`` during step process. - - Clip grad if :attr:`clip_grad_kwargs` is not None, and then update - parameters. - - Args: - kwargs: Keyword arguments passed to - :meth:`torch.optim.Optimizer.step`. - """ - if self.clip_grad_kwargs: - self._clip_grad() - self.optimizer.step(**kwargs) - - @contextmanager - def optim_context(self, model: nn.Module): - """A Context for gradient accumulation and automatic mix precision - training. - - If subclasses need to enable the context for mix precision training, - e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be - enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 - training, ``optim_context`` will only enable the context for - blocking the unnecessary gradient synchronization during gradient - accumulation - - If model is an instance with ``no_sync`` method (which means - blocking the gradient synchronization) and - ``self._accumulative_counts != 1``. The model will not automatically - synchronize gradients if ``cur_iter`` is divisible by - ``self._accumulative_counts``. Otherwise, this method will enable an - empty context. - - Args: - model (nn.Module): The training model. - """ - # During gradient accumulation process, the gradient synchronize - # should only happen before updating parameters. - if not self.should_sync() and hasattr(model, 'no_sync'): - with model.no_sync(): - yield - else: - yield - - def _clip_grad(self) -> None: - """Clip the gradients of parameters.""" - params: List[torch.Tensor] = [] - for param_group in self.optimizer.param_groups: - params.extend(param_group['params']) - - params = list( - filter(lambda p: p.requires_grad and p.grad is not None, params)) - if len(params) > 0: - grad = self.clip_func(params, **self.clip_grad_kwargs) - # `torch.nn.utils.clip_grad_value_` will return None. - if grad is not None: - self.message_hub.update_scalar(f'train/{self.grad_name}', - float(grad)) - - def initialize_count_status(self, model: nn.Module, init_counts: int, - max_counts: int) -> None: - """Initialize gradient accumulation related attributes. - - ``OptimWrapper`` can be used without calling - ``initialize_iter_status``. However, Consider the case of ``len( - dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is - not divisible by 3, the last iteration will not trigger - ``optimizer.step()``, resulting in one less parameter updating. - - Args: - model (nn.Module): Training model - init_counts (int): The initial value of the inner count. - max_counts (int): The maximum value of the inner count. - """ - self._inner_count = init_counts - self._max_counts = max_counts - if self._inner_count % self._accumulative_counts != 0: - print_log( - 'Resumed iteration number is not divisible by ' - '`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' - 'which means the gradient of some iterations is lost and the ' - 'result may be influenced slightly.', - logger='current', - level=logging.WARNING) - - if has_batch_norm(model) and self._accumulative_counts > 1: - print_log( - 'Gradient accumulative may slightly decrease ' - 'performance because the model has BatchNorm layers.', - logger='current', - level=logging.WARNING) - # Remainder of `_max_counts` divided by `_accumulative_counts` - self._remainder_counts = self._max_counts % self._accumulative_counts - - def should_update(self) -> bool: - """Decide whether the parameters should be updated at the current - iteration. - - Called by :meth:`update_params` and check whether the optimizer - wrapper should update parameters at current iteration. - - Returns: - bool: Whether to update parameters. - """ - return (self._inner_count % self._accumulative_counts == 0 - or self._inner_count == self._max_counts) - - def should_sync(self) -> bool: - """Decide whether the automatic gradient synchronization should be - allowed at the current iteration. - - It takes effect when gradient accumulation is used to skip - synchronization at the iterations where the parameter is not updated. - - Since ``should_sync`` is called by :meth:`optim_context`, and it is - called before :meth:`backward` which means ``self._inner_count += 1`` - has not happened yet. Therefore, ``self._inner_count += 1`` should be - performed manually here. - - Returns: - bool: Whether to block the automatic gradient synchronization. - """ - return ((self._inner_count + 1) % self._accumulative_counts == 0 - or (self._inner_count + 1) == self._max_counts) - - def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: - """Get scaled loss according to ``_accumulative_counts``, - ``_inner_count`` and max_counts. - - Args: - loss (torch.Tensor): Original loss calculated by model. - - Returns: - loss (torch.Tensor): Scaled loss. - """ - if self._accumulative_counts == 1: - # update parameters without gradient accumulation. The gradient - # should not be rescaled and `loss_factor=1`. - loss_factor = 1 - elif self._max_counts == -1: - loss_factor = self._accumulative_counts - else: - # if `self._accumulative_counts > 1`, the gradient needs to be - # rescaled and accumulated. In most cases, `loss_factor` equals to - # `self._accumulative_counts`. However, `self._max_counts` may not - # be divisible by `self._accumulative_counts`, so the - # `loss_scale` for the last few iterations needs to be - # recalculated. - if self._inner_count < self._max_counts - self._remainder_counts: - loss_factor = self._accumulative_counts - else: - loss_factor = self._remainder_counts - assert loss_factor > 0, ( - 'loss_factor should be larger than zero! This error could ' - 'happened when initialize_iter_status called with an ' - 'error `init_counts` or `max_counts`') - - loss = loss / loss_factor - return loss - - @property - def inner_count(self): - """Get the number of updating parameters of optimizer wrapper.""" - return self._inner_count - - def __repr__(self): - wrapper_info = (f'Type: {type(self).__name__}\n' - f'_accumulative_counts: {self._accumulative_counts}\n' - 'optimizer: \n') - optimizer_str = repr(self.optimizer) + '\n' - return wrapper_info + optimizer_str diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py deleted file mode 100644 index efa7705c9e..0000000000 --- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from contextlib import contextmanager -from typing import Dict, Iterator, List, Optional, Tuple - -import torch -import torch.nn as nn - -from .optimizer_wrapper import OptimWrapper - - -class OptimWrapperDict(OptimWrapper): - """A dictionary container of :obj:`OptimWrapper`. - - If runner is training with multiple optimizers, all optimizer wrappers - should be managed by :obj:`OptimWrapperDict` which is built by - ``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save - the state dictionary of all optimizer wrappers. - - Consider the semantic ambiguity of calling :meth:``update_params``, - :meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not - implement these methods. - - Examples: - >>> import torch.nn as nn - >>> from torch.optim import SGD - >>> from mmengine.optim import OptimWrapperDict, OptimWrapper - >>> model1 = nn.Linear(1, 1) - >>> model2 = nn.Linear(1, 1) - >>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1)) - >>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1)) - >>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1, - >>> model2=optim_wrapper2) - - Note: - The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed - in the same way as `dict`. - - Args: - **optim_wrappers: A dictionary of ``OptimWrapper`` instance. - """ - - def __init__(self, **optim_wrapper_dict: OptimWrapper): - for key, value in optim_wrapper_dict.items(): - assert isinstance(value, OptimWrapper), ( - '`OptimWrapperDict` only accept OptimWrapper instance, ' - f'but got {key}: {type(value)}') - self.optim_wrappers = optim_wrapper_dict - - def update_params( # type: ignore - self, - loss: torch.Tensor, - step_kwargs: Optional[Dict] = None, - zero_kwargs: Optional[Dict] = None) -> None: - """Update all optimizer wrappers would lead to a duplicate backward - errors, and OptimWrapperDict does not know which optimizer wrapper - should be updated. - - Therefore, this method is not implemented. The optimizer wrapper of - OptimWrapperDict should be accessed and call its `update_params`. - """ - raise NotImplementedError('`update_params` should be called by each ' - 'optimizer separately`') - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Since OptimWrapperDict doesn't know which optimizer wrapper's - backward method should be called (``loss_scaler`` maybe different in - different :obj:AmpOptimWrapper), this method is not implemented. - - The optimizer wrapper of OptimWrapperDict should be accessed and call - its `backward`. - """ - raise NotImplementedError('`backward` should be called by each ' - 'optimizer separately`') - - def step(self, **kwargs) -> None: - """Since the backward method is not implemented, the step should not be - implemented either.""" - raise NotImplementedError('`step` should be called by each ' - 'optimizer separately`') - - def zero_grad(self, **kwargs) -> None: - """Set the gradients of all optimizer wrappers to zero.""" - for optim_wrapper in self.optim_wrappers.values(): - optim_wrapper.zero_grad() - - @contextmanager - def optim_context(self, model: nn.Module): - """``optim_context`` should be called by each optimizer separately.""" - raise NotImplementedError( - '`optim_context` should be called by each optimizer separately') - - def initialize_count_status(self, model: nn.Module, cur_iter, - max_iters) -> None: - """Do nothing but provide unified interface for :obj:`OptimWrapper` - - Since ``OptimWrapperDict`` does not know the correspondence between - model and optimizer wrapper. ``initialize_iter_status`` will do nothing - and each optimizer wrapper should call ``initialize_iter_status`` - separately. - """ - return - - @property - def param_groups(self): - """Returns the parameter groups of each OptimWrapper.""" - param_groups = dict() - for key, value in self.optim_wrappers.items(): - param_groups[key] = value.param_groups - return param_groups - - def get_lr(self) -> Dict[str, List[float]]: - """Get the learning rate of all optimizers. - - Returns: - Dict[str, List[float]]: Learning rate of all optimizers. - """ - lr_dict = dict() - for name, optim_wrapper in self.optim_wrappers.items(): - inner_lr_dict = optim_wrapper.get_lr() - if 'base_lr' in inner_lr_dict: - lr_dict[f'{name}.base_lr'] = inner_lr_dict['base_lr'] - lr_dict[f'{name}.lr'] = inner_lr_dict['lr'] - return lr_dict - - def get_momentum(self) -> Dict[str, List[float]]: - """Get the momentum of all optimizers. - - Returns: - Dict[str, List[float]]: momentum of all optimizers. - """ - momentum_dict = dict() - for name, optim_wrapper in self.optim_wrappers.items(): - momentum_dict[f'{name}.momentum'] = optim_wrapper.get_momentum( - )['momentum'] - return momentum_dict - - def state_dict(self) -> dict: - """Get the state dictionary of all optimizer wrappers. - - Returns: - dict: Each key-value pair in the dictionary represents the name - and state dictionary of corresponding :obj:`OptimWrapper`. - """ - state_dict = dict() - for name, optim_wrapper in self.optim_wrappers.items(): - state_dict[name] = optim_wrapper.state_dict() - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - """Load the state dictionary from the ``state_dict``. - - Args: - state_dict (dict): Each key-value pair in `state_dict` represents - the name and the state dictionary of corresponding - :obj:`OptimWrapper`. - """ - for name, _state_dict in state_dict.items(): - assert name in self.optim_wrappers, ( - f'Mismatched `state_dict`! cannot found {name} in ' - 'OptimWrapperDict') - self.optim_wrappers[name].load_state_dict(_state_dict) - - def items(self) -> Iterator[Tuple[str, OptimWrapper]]: - """A generator to get the name and corresponding :obj:`OptimWrapper`""" - yield from self.optim_wrappers.items() - - def values(self) -> Iterator[OptimWrapper]: - """A generator to get :obj:`OptimWrapper`""" - yield from self.optim_wrappers.values() - - def keys(self) -> Iterator[str]: - """A generator to get the name of :obj:`OptimWrapper`""" - yield from self.optim_wrappers.keys() - - def __getitem__(self, key: str) -> OptimWrapper: - assert key in self.optim_wrappers, ( - f'Cannot find {key} in OptimWrapperDict, please check ' - 'your optimizer constructor.') - return self.optim_wrappers[key] - - def __contains__(self, key: str) -> bool: - return key in self.optim_wrappers - - def __len__(self) -> int: - return len(self.optim_wrappers) - - def __repr__(self) -> str: - desc = '' - for name, optim_wrapper in self.optim_wrappers.items(): - desc += f'name: {name}\n' - desc += repr(optim_wrapper) - return desc diff --git a/mmengine/optim/optimizer/zero_optimizer.py b/mmengine/optim/optimizer/zero_optimizer.py deleted file mode 100644 index 0c5630a765..0000000000 --- a/mmengine/optim/optimizer/zero_optimizer.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch -from torch.distributed.rpc import is_available - -from mmengine.dist import is_main_process -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION - -try: - from torch.distributed.optim import \ - ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer -except ImportError: - _ZeroRedundancyOptimizer = object - -from .builder import OPTIMIZERS - - -@OPTIMIZERS.register_module() -class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): - """A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a - optimizer type as string. - - This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its - states across ranks in the group as described by ZeRO_. The local optimizer - instance in each rank is only responsible for updating approximately - ``1 / world_size`` parameters and hence only needs to keep - ``1 / world_size`` optimizer states. After parameters are updated locally, - each rank will broadcast its parameters to all other peers to keep all - model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used - in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to - reduce per-rank peak memory consumption. - - ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number - of parameters at each rank. Each parameter belongs to a single rank and is - not divided among ranks. The partition is arbitrary and might not match the - the parameter registration or usage order. - - Warnings: - ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. - - Warnings: - ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param - groups. - - Args: - params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s - or :class:`dict` s giving all parameters, which will be sharded - across ranks. - optimizer_type (str): the string of the local optimizer class. - - .. _ZeRO: https://arxiv.org/abs/1910.02054 - """ - - def __init__(self, params, optimizer_type: str, **kwargs): - assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( - '`torch.distributed.optim.ZeroReundancyOptimizer` is only ' - 'available when pytorch version >= 1.8.') - assert is_available(), 'torch.distributed.rpc is not available.' - # Avoid the generator becoming empty after the following check - params = list(params) - assert ( - all(isinstance(p, torch.Tensor) for p in params) - or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), ( - 'PyTorch ZeroRedundancyOptimizer started to support param ' - 'groups since 1.12.0. Please update your pytorch version to ' - 'enable this feature, or disable param groups by deleting ' - '`paramwise_cfg` filed in config file.') - optimizer_class = getattr(torch.optim, optimizer_type) - # TODO: Register a DDP communication hook for `overlap_with_ddp=True`. - # Currently only `overlap_with_ddp=False` is supported. For more - # details, please refer to the pytorch's official documentation. - super().__init__(params, optimizer_class, **kwargs) - - def state_dict(self): - """Consolidate `state_dict`s from ranks to save the `state_dict`.""" - self.consolidate_state_dict() - state_dict = super().state_dict() if is_main_process() else dict() - return state_dict diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py deleted file mode 100644 index 48ccc34bc4..0000000000 --- a/mmengine/optim/scheduler/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# yapf: disable -from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR, - ExponentialLR, LinearLR, MultiStepLR, OneCycleLR, - PolyLR, ReduceOnPlateauLR, StepLR) -from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, - CosineRestartMomentum, ExponentialMomentum, - LinearMomentum, MultiStepMomentum, - PolyMomentum, ReduceOnPlateauMomentum, - StepMomentum) -from .param_scheduler import (ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, OneCycleParamScheduler, - PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler, _ParamScheduler) - -# yapf: enable -__all__ = [ - 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', - 'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum', - 'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum', - 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', - 'ExponentialParamScheduler', 'LinearParamScheduler', - 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler', - 'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler', - 'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR', - 'CosineRestartMomentum', 'ReduceOnPlateauParamScheduler', - 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum' -] diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py deleted file mode 100644 index 13bc61d542..0000000000 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.registry import PARAM_SCHEDULERS -# yapf: disable -from .param_scheduler import (ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, OneCycleParamScheduler, - PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler) - -# yapf: enable - - -class LRSchedulerMixin: - """A mixin class for learning rate schedulers.""" - - def __init__(self, optimizer, *args, **kwargs): - super().__init__(optimizer, 'lr', *args, **kwargs) - - -@PARAM_SCHEDULERS.register_module() -class ConstantLR(LRSchedulerMixin, ConstantParamScheduler): - """Decays the learning rate value of each parameter group by a small - constant factor until the number of epoch reaches a pre-defined milestone: - ``end``. Notice that such decay can happen simultaneously with other - changes to the learning rate value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - factor (float): The number we multiply learning rate until the - milestone. Defaults to 1./3. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without state - dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial value and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - Notice that because the schedule - is defined recursively, the learning rate can be simultaneously modified - outside this scheduler by other operators. If the learning rate is set - solely by this scheduler, the learning rate at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this - only implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - T_max (int): Maximum number of iterations. - eta_min (float): Minimum learning rate. Defaults to None. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - eta_min_ratio (float, optional): The ratio of the minimum parameter - value to the base parameter value. Either `eta_min` or - `eta_min_ratio` should be specified. Defaults to None. - New in version 0.3.2. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - -@PARAM_SCHEDULERS.register_module() -class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler): - """Decays the learning rate of each parameter group by gamma every epoch. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - gamma (float): Multiplicative factor of learning rate decay. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class LinearLR(LRSchedulerMixin, LinearParamScheduler): - """Decays the learning rate of each parameter group by linearly changing - small multiplicative factor until the number of epoch reaches a pre-defined - milestone: ``end``. - - Notice that such decay can happen simultaneously with other changes to the - learning rate from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - start_factor (float): The number we multiply learning rate in the - first epoch. The multiplication factor changes towards end_factor - in the following epochs. Defaults to 1./3. - end_factor (float): The number we multiply learning rate at the end - of linear changing process. Defaults to 1.0. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler): - """Decays the specified learning rate in each parameter group by gamma once - the number of epoch reaches one of the milestones. Notice that such decay - can happen simultaneously with other changes to the learning rate from - outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of learning rate decay. - Defaults to 0.1. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class StepLR(LRSchedulerMixin, StepParamScheduler): - """Decays the learning rate of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the learning rate from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - step_size (int): Period of learning rate decay. - gamma (float): Multiplicative factor of learning rate decay. - Defaults to 0.1. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class PolyLR(LRSchedulerMixin, PolyParamScheduler): - """Decays the learning rate of each parameter group in a polynomial decay - scheme. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - eta_min (float): Minimum learning rate at the end of scheduling. - Defaults to 0. - power (float): The power of the polynomial. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler): - r"""Sets the learning rate of each parameter group according to the 1cycle - learning rate policy. The 1cycle policy anneals the learning rate from an - initial learning rate to some maximum learning rate and then from that - maximum learning rate to some minimum learning rate much lower than the - initial learning rate. This policy was initially described in the paper - `Super-Convergence: Very Fast Training of Neural Networks Using Large - Learning Rates`_. - - The 1cycle learning rate policy changes the learning rate after every - batch. `step` should be called after a batch has been used for training. - - This scheduler is not chainable. - - Note also that the total number of steps in the cycle can be determined in - one of two ways (listed in order of precedence): - - #. A value for total_steps is explicitly provided. - #. A number of epochs (epochs) and a number of steps per epoch - (steps_per_epoch) are provided. - In this case, the number of total steps is inferred by - total_steps = epochs * steps_per_epoch - - You must either provide a value for total_steps or provide a value for both - epochs and steps_per_epoch. - - The default behaviour of this scheduler follows the fastai implementation - of 1cycle, which claims that "unpublished work has shown even better - results by using only two phases". To mimic the behaviour of the original - paper instead, set ``three_phase=True``. - - Args: - optimizer (Optimizer): Wrapped optimizer. - eta_max (float or list): Upper parameter value boundaries in the cycle - for each parameter group. - total_steps (int): The total number of steps in the cycle. Note that - if a value is not provided here, then it must be inferred by - providing a value for epochs and steps_per_epoch. - Defaults to None. - pct_start (float): The percentage of the cycle (in number of steps) - spent increasing the learning rate. - Defaults to 0.3 - anneal_strategy (str): {'cos', 'linear'} - Specifies the annealing strategy: "cos" for cosine annealing, - "linear" for linear annealing. - Defaults to 'cos' - div_factor (float): Determines the initial learning rate via - initial_param = eta_max/div_factor - Defaults to 25 - final_div_factor (float): Determines the minimum learning rate via - eta_min = initial_param/final_div_factor - Defaults to 1e4 - three_phase (bool): If ``True``, use a third phase of the schedule to - annihilate the learning rate according to 'final_div_factor' - instead of modifying the second phase (the first two phases will be - symmetrical about the step indicated by 'pct_start'). - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - - .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: - https://arxiv.org/abs/1708.07120 - """# noqa E501 - - -@PARAM_SCHEDULERS.register_module() -class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler): - """Sets the learning rate of each parameter group according to the cosine - annealing with restarts scheme. The cosine restart policy anneals the - learning rate from the initial value to `eta_min` with a cosine annealing - schedule and then restarts another period from the maximum value multiplied - with `restart_weight`. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - periods (list[int]): Periods for each cosine anneling cycle. - restart_weights (list[float]): Restart weights at each - restart iteration. Defaults to [1]. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to None. - eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `min_lr` or `min_lr_ratio` - should be specified. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class ReduceOnPlateauLR(LRSchedulerMixin, ReduceOnPlateauParamScheduler): - """Reduce the learning rate of each parameter group when a metric has - stopped improving. Models often benefit from reducing the learning rate by - a factor of 2-10 once learning stagnates. This scheduler reads a metrics - quantity and if no improvement is seen for a ``patience`` number of epochs, - the learning rate is reduced. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - monitor (str): Key name of the value to monitor in metrics dict. - rule (str): One of `less`, `greater`. In `less` rule, learning rate - will be reduced when the quantity monitored has stopped - decreasing; in `greater` rule it will be reduced when the - quantity monitored has stopped increasing. Defaults to 'less'. - The ``rule`` is the renaming of ``mode`` in pytorch. - factor (float): Factor by which the learning rate will be - reduced. new_param = param * factor. Defaults to 0.1. - patience (int): Number of epochs with no improvement after - which learning rate will be reduced. For example, if - ``patience = 2``, then we will ignore the first 2 epochs - with no improvement, and will only decrease the learning rate after - the 3rd epoch if the monitor value still hasn't improved then. - Defaults to 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Defaults to 1e-4. - threshold_rule (str): One of `rel`, `abs`. In `rel` rule, - dynamic_threshold = best * ( 1 + threshold ) in 'greater' - rule or best * ( 1 - threshold ) in `less` rule. - In `abs` rule, dynamic_threshold = best + threshold in - `greater` rule or best - threshold in `less` rule. - Defaults to 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after learning rate has been reduced. - Defaults to 0. - min_value (float or list[float]): A scalar or a sequence of scalars. - A lower bound on the learning rate of each parameter group - respectively. Defaults to 0. . - eps (float): Minimal decay applied to learning rate. If the difference - between new and old learning rate is smaller than eps, the update - is ignored. Defaults to 1e-8. - begin (int): Step at which to start triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to 0. - end (int): Step at which to stop triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py deleted file mode 100644 index e356e70f7b..0000000000 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.registry import PARAM_SCHEDULERS -# yapf: disable -from .param_scheduler import (ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler) - -# yapf: enable - - -class MomentumSchedulerMixin: - """A mixin class for momentum schedulers. - - It can schedule the momentum in SGD and the beta_0 in Adam series. - """ - - def __init__(self, optimizer, *args, **kwargs): - self.use_betas = False - if 'momentum' in optimizer.defaults: - param_name = 'momentum' - elif 'betas' in optimizer.defaults: - # for Adam series optimizer, the momentum is beta_0 - self.use_betas = True - param_name = 'momentum' - for group in optimizer.param_groups: - # set a reference momentum in the param groups for scheduling - group[param_name] = group['betas'][0] - else: - raise ValueError( - 'optimizer must support momentum when using momentum scheduler' - ) - super().__init__(optimizer, param_name, *args, **kwargs) - - def step(self): - """Adjusts the momentum of each parameter group based on the specified - schedule.""" - super().step() - if self.use_betas: - for group in self.optimizer.param_groups: - _, beta_1 = group['betas'] - # update the betas with the calculated value - group['betas'] = (group['momentum'], beta_1) - - -@PARAM_SCHEDULERS.register_module() -class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler): - """Decays the momentum value of each parameter group by a small constant - factor until the number of epoch reaches a pre-defined milestone: ``end``. - Notice that such decay can happen simultaneously with other changes to the - momentum value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - factor (float): The number we multiply momentum until the milestone. - Defaults to 1./3. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without state - dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by epochs. - Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class CosineAnnealingMomentum(MomentumSchedulerMixin, - CosineAnnealingParamScheduler): - r"""Set the momentum of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial value and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - Notice that because the schedule - is defined recursively, the momentum can be simultaneously modified - outside this scheduler by other operators. If the momentum is set - solely by this scheduler, the momentum at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this - only implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - T_max (int): Maximum number of iterations. - eta_min (float): Minimum momentum value. Defaults to None. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - eta_min_ratio (float, optional): The ratio of the minimum parameter - value to the base parameter value. Either `eta_min` or - `eta_min_ratio` should be specified. Defaults to None. - New in version 0.3.2. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - -@PARAM_SCHEDULERS.register_module() -class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler): - """Decays the momentum of each parameter group by gamma every epoch. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - gamma (float): Multiplicative factor of momentum value decay. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler): - """Decays the momentum of each parameter group by linearly changing - small multiplicative factor until the number of epoch reaches a pre-defined - milestone: ``end``. - - Notice that such decay can happen simultaneously with other changes to the - momentum from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - start_factor (float): The number we multiply momentum in the - first epoch. The multiplication factor changes towards end_factor - in the following epochs. Defaults to 1./3. - end_factor (float): The number we multiply momentum at the end - of linear changing process. Defaults to 1.0. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler): - """Decays the specified momentum in each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the momentum from outside this - scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of momentum value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class StepMomentum(MomentumSchedulerMixin, StepParamScheduler): - """Decays the momentum of each parameter group by gamma every step_size - epochs. Notice that such decay can happen simultaneously with other changes - to the momentum from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - step_size (int): Period of momentum value decay. - gamma (float): Multiplicative factor of momentum value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler): - """Decays the momentum of each parameter group in a polynomial decay - scheme. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - eta_min (float): Minimum momentum at the end of scheduling. - Defaults to 0. - power (float): The power of the polynomial. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class CosineRestartMomentum(MomentumSchedulerMixin, - CosineRestartParamScheduler): - """Sets the momentum of each parameter group according to the cosine - annealing with restarts scheme. The cosine restart policy anneals the - momentum from the initial value to `eta_min` with a cosine annealing - schedule and then restarts another period from the maximum value multiplied - with `restart_weight`. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - periods (list[int]): Periods for each cosine anneling cycle. - restart_weights (list[float]): Restart weights at each - restart iteration. Defaults to [1]. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to None. - eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `min_lr` or `min_lr_ratio` - should be specified. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module() -class ReduceOnPlateauMomentum(MomentumSchedulerMixin, - ReduceOnPlateauParamScheduler): - """Reduce the momentum of each parameter group when a metric has stopped - improving. Models often benefit from reducing the momentum by a factor of - 2-10 once learning stagnates. This scheduler reads a metrics quantity and - if no improvement is seen for a ``patience`` number of epochs, the momentum - is reduced. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - monitor (str): Key name of the value to monitor in metrics dict. - rule (str): One of `less`, `greater`. In `less` rule, momentum will - be reduced when the quantity monitored has stopped - decreasing; in `greater` rule it will be reduced when the - quantity monitored has stopped increasing. Defaults to 'less'. - The ``rule`` is the renaming of ``mode`` in pytorch. - factor (float): Factor by which the momentum will be - reduced. new_param = param * factor. Defaults to 0.1. - patience (int): Number of epochs with no improvement after - which momentum will be reduced. For example, if - ``patience = 2``, then we will ignore the first 2 epochs - with no improvement, and will only decrease the momentum after - the 3rd epoch if the monitor value still hasn't improved then. - Defaults to 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Defaults to 1e-4. - threshold_rule (str): One of `rel`, `abs`. In `rel` rule, - dynamic_threshold = best * ( 1 + threshold ) in 'greater' - rule or best * ( 1 - threshold ) in `less` rule. - In `abs` rule, dynamic_threshold = best + threshold in - `greater` rule or best - threshold in `less` rule. - Defaults to 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after momentum has been reduced. Defaults to 0. - min_value (float or list[float]): A scalar or a sequence of scalars. - A lower bound on the momentum of each parameter group - respectively. Defaults to 0. . - eps (float): Minimal decay applied to momentum. If the difference - between new and old momentum is smaller than eps, the update is - ignored. Defaults to 1e-8. - begin (int): Step at which to start triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to 0. - end (int): Step at which to stop triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def step(self, metrics=None): - """Adjusts the momentum of each parameter group based on the specified - schedule. - - Args: - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - Defaults to None. - """ - super(MomentumSchedulerMixin, self).step(metrics) - if self.use_betas: - for group in self.optimizer.param_groups: - _, beta_1 = group['betas'] - # update the betas with the calculated value - group['betas'] = (group['momentum'], beta_1) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py deleted file mode 100644 index 2dcb1af072..0000000000 --- a/mmengine/optim/scheduler/param_scheduler.py +++ /dev/null @@ -1,1578 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# ------------------------------------------------------------------------ -# Modified from https://github.com/pytorch/pytorch -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# ------------------------------------------------------------------------ - -import math -import warnings -import weakref -from collections import Counter -from functools import wraps -from typing import Callable, List, Optional, Sequence, Union - -from torch.optim import Optimizer - -from mmengine.logging import print_log -from mmengine.optim import BaseOptimWrapper -from mmengine.registry import PARAM_SCHEDULERS - -INF = int(1e9) - -OptimizerType = Union[BaseOptimWrapper, Optimizer] - - -class _ParamScheduler: - """Base class for parameter schedulers. - - It should be inherited by all schedulers that schedule parameters in the - optimizer's ``param_groups``. All subclasses should overwrite the - ``_get_value()`` according to their own schedule strategy. - The implementation is motivated by - https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py. - - Args: - optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resuming without - state dict. Default value ``-1`` means the ``step`` function is - never be called before. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ # noqa: E501 - - def __init__(self, - optimizer: OptimizerType, - param_name: str, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - - # Attach optimizer - if not isinstance(optimizer, (Optimizer, BaseOptimWrapper)): - raise TypeError('``optimizer`` should be an Optimizer,' - 'but got {}'.format(type(optimizer).__name__)) - self.optimizer = optimizer - self.param_name = param_name - - if end <= begin: - raise ValueError('end should be larger than begin, but got' - ' begin={}, end={}'.format(begin, end)) - self.begin = begin - self.end = end - - self.by_epoch = by_epoch - - assert isinstance(last_step, int) and last_step >= -1 - # Initialize valid step count and base values - if last_step == -1: - for group in optimizer.param_groups: - # If the param is never be scheduled, record the current value - # as the initial value. - group.setdefault(f'initial_{param_name}', group[param_name]) - else: - for i, group in enumerate(optimizer.param_groups): - if f'initial_{param_name}' not in group: - raise KeyError( - f"param 'initial_{param_name}' is not specified " - 'in param_groups[{}] when resuming an optimizer'. - format(i)) - self.base_values = [ - group[f'initial_{param_name}'] for group in optimizer.param_groups - ] - self.last_step = last_step - - # Following https://github.com/pytorch/pytorch/issues/20124 - # We would like to ensure that `scheduler.step()` is called after - # `optimizer.step()` - def with_counter(method: Callable): - if getattr(method, '_with_counter', False): - # `optimizer.step()` has already been replaced, return. - return method - - # Keep a weak reference to the optimizer instance to prevent - # cyclic references. - instance_ref = weakref.ref(method.__self__) # type: ignore - # Get the unbound method for the same purpose. - func = method.__func__ # type: ignore - cls = instance_ref().__class__ # type: ignore - del method - - @wraps(func) - def wrapper(*args, **kwargs): - instance = instance_ref() - instance._global_step += 1 - wrapped = func.__get__(instance, cls) - return wrapped(*args, **kwargs) - - # Note that the returned function here is no longer a bound method, - # so attributes like `__func__` and `__self__` no longer exist. - wrapper._with_counter = True # type: ignore - return wrapper - - # add counter to optimizer - self.optimizer.step = with_counter(self.optimizer.step) # type: ignore - self.optimizer._global_step = -1 # type: ignore - - self._global_step = -1 - self.verbose = verbose - - self.step() - - def state_dict(self) -> dict: - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which is not - the optimizer. - - Returns: - dict: scheduler state. - """ - return { - key: value - for key, value in self.__dict__.items() if key != 'optimizer' - } - - def load_state_dict(self, state_dict: dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_value(self): - """Return the last computed value by current scheduler. - - Returns: - list: A list of the last computed value of the optimizer's - ``param_group``. - """ - return self._last_value - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - raise NotImplementedError - - def print_value(self, is_verbose: bool, group: int, value: float): - """Display the current parameter value. - - Args: - is_verbose (bool): Whether to print the value. - group (int): The index of the current ``param_group``. - value (float): The parameter value. - """ - if is_verbose: - print_log( - f'Adjusting parameter value of group {group} to {value:.4e}.', - logger='current') - - def step(self): - """Adjusts the parameter value of each parameter group based on the - specified schedule.""" - # Raise a warning if old pattern is detected - # https://github.com/pytorch/pytorch/issues/20124 - if self._global_step == 0: - if not hasattr(self.optimizer.step, '_with_counter'): - warnings.warn( - 'Seems like `optimizer.step()` has been overridden after ' - 'parameter value scheduler initialization. Please, make ' - 'sure to call `optimizer.step()` before ' - '`scheduler.step()`. See more details at ' - 'https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501 - UserWarning) - - # Just check if there were two first scheduler.step() calls - # before optimizer.step() - elif self.optimizer._global_step < 0: - warnings.warn( - 'Detected call of `scheduler.step()` before ' - '`optimizer.step()`. In PyTorch 1.1.0 and later, you ' - 'should call them in the opposite order: ' - '`optimizer.step()` before `scheduler.step()`. ' - 'Failure to do this will result in PyTorch skipping ' - 'the first value of the parameter value schedule. ' - 'See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501 - UserWarning) - self._global_step += 1 - - # Compute parameter value per param group in the effective range - if self.begin <= self._global_step < self.end: - self.last_step += 1 - values = self._get_value() - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, value = data - param_group[self.param_name] = value - self.print_value(self.verbose, i, value) - - self._last_value = [ - group[self.param_name] for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module() -class StepParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the parameter value from outside this scheduler. - - Args: - optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - step_size (int): Period of parameter value decay. - gamma (float): Multiplicative factor of parameter value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: OptimizerType, - param_name: str, - step_size: int, - gamma: float = 0.1, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - self.step_size = step_size - self.gamma = gamma - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - step_size, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - step_size = step_size * epoch_length - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls( - *args, - step_size=step_size, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if (self.last_step == 0) or (self.last_step % self.step_size != 0): - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - return [ - group[self.param_name] * self.gamma - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module() -class MultiStepParamScheduler(_ParamScheduler): - """Decays the specified parameter in each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the parameter from outside this - scheduler. - - Args: - optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of parameter value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: OptimizerType, - param_name: str, - milestones: List[int], - gamma: float = 0.1, - last_step: int = -1, - begin: int = 0, - end: int = INF, - by_epoch: bool = True, - verbose: bool = False): - self.milestones = Counter(milestones) - self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - milestones, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - milestones = [i * epoch_length for i in milestones] - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls( - *args, - milestones=milestones, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step not in self.milestones: - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - return [ - group[self.param_name] * - self.gamma**self.milestones[self.last_step] - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module() -class ConstantParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by a small constant - factor until the number of epoch reaches a pre-defined milestone: ``end``. - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - factor (float): The number we multiply parameter value until the - milestone. Defaults to 1./3. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: OptimizerType, - param_name: str, - factor: float = 1.0 / 3, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - if factor > 1.0 or factor < 0: - raise ValueError( - 'Constant multiplicative factor should between 0 and 1.') - - self.factor = factor - self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [ - group[self.param_name] * self.factor - for group in self.optimizer.param_groups - ] - - if (self.last_step > self.total_iters - or (self.last_step != self.total_iters)): - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - - if self.last_step == self.total_iters: - return [ - group[self.param_name] * (1.0 / self.factor) - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module() -class ExponentialParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by gamma every epoch. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - gamma (float): Multiplicative factor of parameter value decay. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: OptimizerType, - param_name: str, - gamma: float, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - return [ - group[self.param_name] * self.gamma - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module() -class CosineAnnealingParamScheduler(_ParamScheduler): - r"""Set the parameter value of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial value and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - Notice that because the schedule - is defined recursively, the parameter value can be simultaneously modified - outside this scheduler by other operators. If the parameter value is set - solely by this scheduler, the parameter value at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this - only implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - T_max (int, optional): Maximum number of iterations. If not specified, - use ``end - begin``. Defaults to None. - eta_min (float, optional): Minimum parameter value. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - eta_min_ratio (float, optional): The ratio of the minimum parameter - value to the base parameter value. Either `eta_min` or - `eta_min_ratio` should be specified. Defaults to None. - New in version 0.3.2. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ # noqa: E501 - - def __init__(self, - optimizer: Union[Optimizer, BaseOptimWrapper], - param_name: str, - T_max: Optional[int] = None, - eta_min: Optional[float] = None, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - eta_min_ratio: Optional[float] = None): - # To preserve backwards compatibility - if eta_min is None and eta_min_ratio is None: - eta_min = 0. - assert (eta_min is None) ^ (eta_min_ratio is None), \ - 'Either `eta_min` or `eta_min_ratio should be specified' - self.T_max = T_max or (end - begin) - self.eta_min = eta_min - self.eta_min_ratio = eta_min_ratio - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - T_max=None, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - if T_max is not None: - T_max = T_max * epoch_length - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls( - *args, - T_max=T_max, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) - - def _get_value(self) -> list: - """Compute value using chainable form of the scheduler.""" - - def _get_eta_min(base_value): - if self.eta_min_ratio is None: - return self.eta_min - return base_value * self.eta_min_ratio - - if self.last_step == 0: - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: - return [ - group[self.param_name] + - (base_value - _get_eta_min(base_value)) * - (1 - math.cos(math.pi / self.T_max)) / 2 - for base_value, group in zip(self.base_values, - self.optimizer.param_groups) - ] - return [(1 + math.cos(math.pi * self.last_step / self.T_max)) / - (1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) * - (group[self.param_name] - _get_eta_min(base_value)) + - _get_eta_min(base_value) for base_value, group in zip( - self.base_values, self.optimizer.param_groups)] - - -@PARAM_SCHEDULERS.register_module() -class LinearParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by linearly changing - small multiplicative factor until the number of epoch reaches a pre-defined - milestone: ``end``. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - start_factor (float): The number we multiply parameter value in the - first epoch. The multiplication factor changes towards end_factor - in the following epochs. Defaults to 1./3. - end_factor (float): The number we multiply parameter value at the end - of linear changing process. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: Union[Optimizer, BaseOptimWrapper], - param_name: str, - start_factor: float = 1.0 / 3, - end_factor: float = 1.0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - if start_factor > 1.0 or start_factor < 0: - raise ValueError( - 'Starting multiplicative factor should between 0 and 1.') - - if end_factor > 1.0 or end_factor < 0: - raise ValueError( - 'Ending multiplicative factor should between 0 and 1.') - - self.start_factor = start_factor - self.end_factor = end_factor - self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [ - group[self.param_name] * self.start_factor - for group in self.optimizer.param_groups - ] - - return [ - group[self.param_name] * - (1. + (self.end_factor - self.start_factor) / - (self.total_iters * self.start_factor + (self.last_step - 1) * - (self.end_factor - self.start_factor))) - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module() -class PolyParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group in a polynomial decay - scheme. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to 0. - power (float): The power of the polynomial. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: Union[Optimizer, BaseOptimWrapper], - param_name: str, - eta_min: float = 0, - power: float = 1.0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - - self.eta_min = eta_min - self.power = power - self.total_iters = end - begin - 1 - - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - - return [(group[self.param_name] - self.eta_min) * - (1 - 1 / (self.total_iters - self.last_step + 1))**self.power + - self.eta_min for group in self.optimizer.param_groups] - - -@PARAM_SCHEDULERS.register_module() -class OneCycleParamScheduler(_ParamScheduler): - r"""Sets the parameters of each parameter group according to the 1cycle - learning rate policy. The 1cycle policy anneals the learning rate from an - initial learning rate to some maximum learning rate and then from that - maximum learning rate to some minimum learning rate much lower than the - initial learning rate. This policy was initially described in the paper - `Super-Convergence: Very Fast Training of Neural Networks Using Large - Learning Rates`_. - - The 1cycle learning rate policy changes the learning rate after every - batch. `step` should be called after a batch has been used for training. - - This scheduler is not chainable. - - Note also that the total number of steps in the cycle can be determined in - one of two ways (listed in order of precedence): - - #. A value for total_steps is explicitly provided. - #. If total_steps is not defined, begin and end of the ParamSchedul will - works for it. In this case, the number of total steps is inferred by - total_steps = end - begin - - The default behaviour of this scheduler follows the fastai implementation - of 1cycle, which claims that "unpublished work has shown even better - results by using only two phases". To mimic the behaviour of the original - paper instead, set ``three_phase=True``. - - Args: - optimizer (Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - eta_max (float or list): Upper parameter value boundaries in the cycle - for each parameter group. - total_steps (int): The total number of steps in the cycle. Note that - if a value is not provided here, then it will be equal to - ``end - begin``. Defaults to None - pct_start (float): The percentage of the cycle (in number of steps) - spent increasing the learning rate. - Defaults to 0.3 - anneal_strategy (str): {'cos', 'linear'} - Specifies the annealing strategy: "cos" for cosine annealing, - "linear" for linear annealing. - Defaults to 'cos' - div_factor (float): Determines the initial learning rate via - initial_param = eta_max/div_factor - Defaults to 25 - final_div_factor (float): Determines the minimum learning rate via - eta_min = initial_param/final_div_factor - Defaults to 1e4 - three_phase (bool): If ``True``, use a third phase of the schedule to - annihilate the learning rate according to 'final_div_factor' - instead of modifying the second phase (the first two phases will be - symmetrical about the step indicated by 'pct_start'). - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - - .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: - https://arxiv.org/abs/1708.07120 - """ # noqa E501 - - def __init__(self, - optimizer: Union[Optimizer, BaseOptimWrapper], - param_name: str, - eta_max: float = 0, - total_steps: Optional[int] = None, - pct_start: float = 0.3, - anneal_strategy: str = 'cos', - div_factor: float = 25., - final_div_factor: float = 1e4, - three_phase: bool = False, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - - assert param_name == 'lr', ('OneCycle only works for learning rate ' - 'updating, but got patam_name as ' - f'{param_name}') - - self.eta_max = eta_max - self.div_factor = div_factor - self.final_div_factor = final_div_factor - - # Validate total_steps - if total_steps is not None: - if total_steps <= 0 or not isinstance(total_steps, int): - raise ValueError('Expected positive integer total_steps, ' - f'but got {total_steps}') - self.total_steps = total_steps - else: - self.total_steps = end - begin - - # Validate pct_start - if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): - raise ValueError('Expected float between 0 and 1 pct_start, ' - f'but got {pct_start}') - - # Validate anneal_strategy - if anneal_strategy not in ['cos', 'linear']: - raise ValueError( - 'anneal_strategy must by one of "cos" or "linear", ' - f'instead got {anneal_strategy}') - elif anneal_strategy == 'cos': - self.anneal_func = self._annealing_cos - elif anneal_strategy == 'linear': - self.anneal_func = self._annealing_linear - - if three_phase: - self._schedule_phases = [ - { - 'end_step': float(pct_start * self.total_steps) - 1, - f'start_{param_name}': f'initial_{param_name}', - f'end_{param_name}': f'max_{param_name}' - }, - { - 'end_step': float(2 * pct_start * self.total_steps) - 2, - f'start_{param_name}': f'max_{param_name}', - f'end_{param_name}': f'initial_{param_name}' - }, - { - 'end_step': self.total_steps - 1, - f'start_{param_name}': f'initial_{param_name}', - f'end_{param_name}': f'min_{param_name}' - }, - ] - else: - self._schedule_phases = [ - { - 'end_step': float(pct_start * self.total_steps) - 1, - f'start_{param_name}': f'initial_{param_name}', - f'end_{param_name}': f'max_{param_name}' - }, - { - 'end_step': self.total_steps - 1, - f'start_{param_name}': f'max_{param_name}', - f'end_{param_name}': f'min_{param_name}' - }, - ] - - # Initialize parameters - max_values = self._format_param(f'max_{param_name}', optimizer, - eta_max) - if last_step == -1: - for idx, group in enumerate(optimizer.param_groups): - group[f'initial_{param_name}'] = max_values[idx] / div_factor - group[f'max_{param_name}'] = max_values[idx] - group[f'min_{param_name}'] = \ - group[f'initial_{param_name}'] / final_div_factor - - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError( - f'expected {len(optimizer.param_groups)} values ' - f'for {name}, got {len(param)}') - return param - else: - return [param] * len(optimizer.param_groups) - - @staticmethod - def _annealing_cos(start, end, pct): - """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" - - cos_out = math.cos(math.pi * pct) + 1 - return end + (start - end) / 2.0 * cos_out - - @staticmethod - def _annealing_linear(start, end, pct): - """Linearly anneal from `start` to `end` as pct goes from 0.0 to - 1.0.""" - return (end - start) * pct + start - - @classmethod - def build_iter_from_epoch(cls, - *args, - begin=0, - end=INF, - total_steps=None, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - if total_steps is not None: - total_steps = total_steps * epoch_length - return cls( - *args, - begin=begin, - end=end, - total_steps=total_steps, - by_epoch=by_epoch, - **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - - params = [] - step_num = self.last_step - - if step_num > self.total_steps: - raise ValueError( - f'Tried to step {step_num + 1} times. ' - f'The specified number of total steps is {self.total_steps}') - - for group in self.optimizer.param_groups: - start_step = 0 - for i, phase in enumerate(self._schedule_phases): - end_step = phase['end_step'] - if step_num <= end_step or i == len(self._schedule_phases) - 1: - pct = (step_num - start_step) / (end_step - start_step) - computed_param = self.anneal_func( - group[phase['start_' + self.param_name]], - group[phase['end_' + self.param_name]], pct) - break - start_step = phase['end_step'] - - params.append(computed_param) - - return params - - -@PARAM_SCHEDULERS.register_module() -class CosineRestartParamScheduler(_ParamScheduler): - """Sets the parameters of each parameter group according to the cosine - annealing with restarts scheme. The cosine restart policy anneals the - parameter from the initial value to `eta_min` with a cosine annealing - schedule and then restarts another period from the maximum value multiplied - with `restart_weight`. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - periods (list[int]): Periods for each cosine anneling cycle. - restart_weights (list[float]): Restart weights at each - restart iteration. Defaults to [1]. - eta_min (float, optional): Minimum parameter value at the end of - scheduling. Defaults to None. - eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `eta_min` or `eta_min_ratio` - should be specified. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__(self, - optimizer: Union[Optimizer, BaseOptimWrapper], - param_name: str, - periods: List[int], - restart_weights: Sequence[float] = (1, ), - eta_min: Optional[float] = None, - eta_min_ratio: Optional[float] = None, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - assert (eta_min is None) ^ (eta_min_ratio is None) - self.periods = periods - self.eta_min = eta_min - self.eta_min_ratio = eta_min_ratio - self.restart_weights = restart_weights - assert (len(self.periods) == len(self.restart_weights) - ), 'periods and restart_weights should have the same length.' - self.cumulative_periods = [ - sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) - ] - - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - - @classmethod - def build_iter_from_epoch(cls, - *args, - periods, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ - 'be converted to iter-based.' - assert epoch_length is not None and epoch_length > 0, \ - f'`epoch_length` must be a positive integer, ' \ - f'but got {epoch_length}.' - periods = [p * epoch_length for p in periods] - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls( - *args, - periods=periods, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - idx = self.get_position_from_periods(self.last_step, - self.cumulative_periods) - # if current step is not in the periods, return origin parameters - if idx is None: - return [ - group[self.param_name] for group in self.optimizer.param_groups - ] - current_weight = self.restart_weights[idx] - nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] - current_periods = self.periods[idx] - step = self.last_step - nearest_restart - values = [] - for base_value, group in zip(self.base_values, - self.optimizer.param_groups): - eta_max = base_value * current_weight - if self.eta_min_ratio is None: - eta_min = self.eta_min - else: - eta_min = base_value * self.eta_min_ratio - if step == 0: - values.append(eta_max) - else: - values.append( - (1 + math.cos(math.pi * step / current_periods)) / - (1 + math.cos(math.pi * (step - 1) / current_periods)) * - (group[self.param_name] - eta_min) + eta_min) - - return values - - @staticmethod - def get_position_from_periods( - iteration: int, cumulative_periods: List[int]) -> Optional[int]: - """Get the position from a period list. - - It will return the index of the right-closest number in the period - list. - For example, the cumulative_periods = [100, 200, 300, 400], - if iteration == 50, return 0; - if iteration == 210, return 2; - if iteration == 300, return 3. - - Args: - iteration (int): Current iteration. - cumulative_periods (list[int]): Cumulative period list. - - Returns: - Optional[int]: The position of the right-closest number in the - period list. If not in the period, return None. - """ - for i, period in enumerate(cumulative_periods): - if iteration < period: - return i - return None - - -@PARAM_SCHEDULERS.register_module() -class ReduceOnPlateauParamScheduler(_ParamScheduler): - """Reduce the parameters of each parameter group when a metric has stopped - improving. Models often benefit from reducing the parameters by a factor of - 2-10 once learning stagnates. This scheduler reads a metrics quantity and - if no improvement is seen for a ``patience`` number of epochs, the - parameters are reduced. - - The implementation is motivated by `PyTorch ReduceLROnPlateau`_. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - monitor (str): The name of the metric to measure whether - the performance of the model is improved. - rule (str): One of `less`, `greater`. In `less` rule, parameters will - be reduced when the quantity monitored has stopped - decreasing; in `greater` rule it will be reduced when the - quantity monitored has stopped increasing. Defaults to 'less'. - The ``rule`` is the renaming of ``mode`` in pytorch. - factor (float): Factor by which the parameters will be - reduced. new_param = param * factor. Defaults to 0.1. - patience (int): Number of epochs with no improvement after - which parameters will be reduced. For example, if - ``patience = 2``, then we will ignore the first 2 epochs - with no improvement, and will only decrease the parameters after - the 3rd epoch if the monitor value still hasn't improved then. - Defaults to 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Defaults to 1e-4. - threshold_rule (str): One of `rel`, `abs`. In `rel` rule, - dynamic_threshold = best * ( 1 + threshold ) in 'greater' - rule or best * ( 1 - threshold ) in `less` rule. - In `abs` rule, dynamic_threshold = best + threshold in - `greater` rule or best - threshold in `less` rule. - Defaults to 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after parameters have been reduced. Defaults to 0. - min_value (float or list[float]): A scalar or a sequence of scalars. - A lower bound on the parameters of each parameter group - respectively. Defaults to 0. . - eps (float): Minimal decay applied to parameters. If the difference - between new and old parameters are smaller than eps, the update is - ignored. Defaults to 1e-8. - begin (int): Step at which to start triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to 0. - end (int): Step at which to stop triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - - .. _PyTorch ReduceLROnPlateau: - https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py - """ - - need_val_args = True - - def __init__(self, - optimizer: OptimizerType, - param_name: str, - monitor: str = 'loss', - rule: str = 'less', - factor: float = 0.1, - patience: int = 10, - threshold: float = 1e-4, - threshold_rule: str = 'rel', - cooldown: int = 0, - min_value: Union[float, Sequence[float]] = 0., - eps: float = 1e-8, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - - # Attach optimizer - if not isinstance(optimizer, (Optimizer, BaseOptimWrapper)): - raise TypeError('``optimizer`` should be an Optimizer,' - 'but got {}'.format(type(optimizer).__name__)) - self.optimizer = optimizer - self.param_name = param_name - - if end <= begin: - raise ValueError('end should be larger than begin, but got' - ' begin={}, end={}'.format(begin, end)) - self.begin = begin - self.end = end - - assert by_epoch, \ - f'Now {type(self).__name__} only support by_epoch=True' - self.by_epoch = by_epoch - - assert isinstance(last_step, int) and last_step >= -1 - # Initialize valid step count and base values - if last_step == -1: - for group in optimizer.param_groups: - # If the param is never be scheduled, record the current value - # as the initial value. - group.setdefault(f'initial_{param_name}', group[param_name]) - else: - for i, group in enumerate(optimizer.param_groups): - if f'initial_{param_name}' not in group: - raise KeyError( - f"param 'initial_{param_name}' is not specified " - 'in param_groups[{}] when resuming an optimizer'. - format(i)) - - self.last_step = last_step - - self._global_step = 0 - self.verbose = verbose - - if factor >= 1.0: - raise ValueError('Factor should be < 1.0.') - self.factor = factor - - # This code snippet handles compatibility with the optimizer wrapper. - # The optimizer wrapper includes an additional parameter to record the - # base learning rate (lr) which is not affected by the paramwise_cfg. - # By retrieving the base lr, we can obtain the actual base lr that - # reflects the learning progress. - if isinstance(optimizer, BaseOptimWrapper): - raw_optimizer = optimizer.optimizer - else: - raw_optimizer = optimizer - - if isinstance(min_value, (list, tuple)): - if len(min_value) != len(raw_optimizer.param_groups): - raise ValueError('expected {} min_lrs, got {}'.format( - len(raw_optimizer.param_groups), len(min_value))) - self.min_values = list(min_value) - # Consider the `min_value` of the last param_groups - # as the base setting. And we only add this value when - # the optimizer is OptimWrapper. - if isinstance(optimizer, BaseOptimWrapper) and \ - optimizer.base_param_settings is not None: # type: ignore - self.min_values.append(self.min_values[-1]) - - else: - self.min_values = [min_value] * len( # type: ignore - optimizer.param_groups) - - self.patience = patience - self.cooldown = cooldown - self.cooldown_counter = 0 - self.rule_worse = None # the worse value for the chosen mode - self.best = None - self.num_bad_epochs = 0 - self.eps = eps - - self.monitor = monitor - self._init_is_better( - rule=rule, threshold=threshold, threshold_rule=threshold_rule) - self._reset() - - # remove call self.step() and init self._global_step = 0 - self._last_value = [ - group[self.param_name] for group in self.optimizer.param_groups - ] - - def step(self, metrics=None): - """Adjusts the parameter value of each parameter group based on the - specified schedule. - - Args: - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - Defaults to None. - """ - if metrics is None: - # only to count self._global_step - self._global_step += 1 - return - - if not isinstance(metrics, dict): - raise TypeError('metrics type should be dict,' - f' but got type {type(metrics)}') - - # Compute parameter value per param group in the effective range - if self.begin <= self._global_step < self.end: - self.last_step += 1 - - # convert `metric` to float, in case it's a zero-dim Tensor - metric = metrics.get(self.monitor, None) - if metric is not None: - if self._is_better(metric, self.best): - self.best = metric - self.num_bad_epochs = 0 - else: - self.num_bad_epochs += 1 - - if self._in_cooldown(): - self.cooldown_counter -= 1 - self.num_bad_epochs = 0 # ignore bad epochs in cooldown - - if self.num_bad_epochs > self.patience: - values = self._get_value() - - for i, data in enumerate( - zip(self.optimizer.param_groups, values)): - param_group, value = data - if param_group[self.param_name] - value > self.eps: - param_group[self.param_name] = value - self.print_value(self.verbose, i, value) - self.cooldown_counter = self.cooldown - self.num_bad_epochs = 0 - - else: - raise KeyError(f'Excepted key in {list(metrics.keys())},' - f' but got key {self.monitor} is not in dict') - - self._last_value = [ - group[self.param_name] for group in self.optimizer.param_groups - ] - - def print_value(self, is_verbose: bool, group: int, value: float) -> None: - """Display the current parameter value. - - Args: - is_verbose (bool): Whether to print the value. - group (int): The index of the current ``param_group``. - value (float): The parameter value. - """ - if is_verbose: - step_name = 'epoch' if self.by_epoch else 'iter' - print_log( - f'Adjusting parameter value of group {group} to {value:.4e} ' - f'in {step_name} {self.last_step}.', - logger='current') - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - values = [ - float(group[self.param_name]) * self.factor - for group in self.optimizer.param_groups - ] - return [max(v, min_v) for v, min_v in zip(values, self.min_values)] - - def _in_cooldown(self): - """Judge whether it is in cooldown.""" - return self.cooldown_counter > 0 - - def _is_better(self, a, best): - """Judge whether the monitor value is better.""" - if self.rule == 'less' and self.threshold_rule == 'rel': - rel_epsilon = 1. - self.threshold - return a < best * rel_epsilon - - elif self.rule == 'less' and self.threshold_rule == 'abs': - return a < best - self.threshold - - elif self.rule == 'greater' and self.threshold_rule == 'rel': - rel_epsilon = self.threshold + 1. - return a > best * rel_epsilon - - else: # rule == 'greater' and epsilon_mode == 'abs': - return a > best + self.threshold - - def _init_is_better(self, rule, threshold, threshold_rule): - """Initialize rule and its associated values.""" - if threshold < 0: - raise ValueError(f'threshold {threshold} should be >= 0.') - if rule not in {'less', 'greater'}: - raise ValueError(f'mode {rule} is unknown!') - if threshold_rule not in {'rel', 'abs'}: - raise ValueError(f'threshold mode {threshold_rule}' - ' is unknown!') - - if rule == 'less': - self.rule_worse = INF - else: # rule == 'greater': - self.rule_worse = -INF - - self.rule = rule - self.threshold = threshold - self.threshold_rule = threshold_rule - - def _reset(self): - """Resets num_bad_epochs counter and cooldown counter.""" - self.best = self.rule_worse - self.cooldown_counter = 0 - self.num_bad_epochs = 0 diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py deleted file mode 100644 index cce2737043..0000000000 --- a/mmengine/registry/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .build_functions import (build_from_cfg, build_model_from_cfg, - build_runner_from_cfg, build_scheduler_from_cfg) -from .default_scope import DefaultScope -from .registry import Registry -from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS, - INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, - MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, - OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, - STRATEGIES, TASK_UTILS, TRANSFORMS, VISBACKENDS, - VISUALIZERS, WEIGHT_INITIALIZERS) -from .utils import (count_registered_modules, init_default_scope, - traverse_registry_tree) - -__all__ = [ - 'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', - 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', - 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', - 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS', - 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR', 'INFERENCERS', - 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules', - 'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg', - 'build_scheduler_from_cfg', 'init_default_scope', 'FUNCTIONS', 'STRATEGIES' -] diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py deleted file mode 100644 index 3de6798514..0000000000 --- a/mmengine/registry/build_functions.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -from typing import TYPE_CHECKING, Any, Optional, Union - -from mmengine.config import Config, ConfigDict -from mmengine.utils import ManagerMixin, digit_version -from .registry import Registry - -if TYPE_CHECKING: - import torch.nn as nn - - from mmengine.optim.scheduler import _ParamScheduler - from mmengine.runner import Runner - - -def build_from_cfg( - cfg: Union[dict, ConfigDict, Config], - registry: Registry, - default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: - """Build a module from config dict when it is a class configuration, or - call a function from config dict when it is a function configuration. - - If the global variable default scope (:obj:`DefaultScope`) exists, - :meth:`build` will firstly get the responding registry and then call - its own :meth:`build`. - - At least one of the ``cfg`` and ``default_args`` contains the key "type", - which should be either str or class. If they all contain it, the key - in ``cfg`` will be used because ``cfg`` has a high priority than - ``default_args`` that means if a key exists in both of them, the value of - the key will be ``cfg[key]``. They will be merged first and the key "type" - will be popped up and the remaining keys will be used as initialization - arguments. - - Examples: - >>> from mmengine import Registry, build_from_cfg - >>> MODELS = Registry('models') - >>> @MODELS.register_module() - >>> class ResNet: - >>> def __init__(self, depth, stages=4): - >>> self.depth = depth - >>> self.stages = stages - >>> cfg = dict(type='ResNet', depth=50) - >>> model = build_from_cfg(cfg, MODELS) - >>> # Returns an instantiated object - >>> @MODELS.register_module() - >>> def resnet50(): - >>> pass - >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) - >>> # Return a result of the calling function - - Args: - cfg (dict or ConfigDict or Config): Config dict. It should at least - contain the key "type". - registry (:obj:`Registry`): The registry to search the type from. - default_args (dict or ConfigDict or Config, optional): Default - initialization arguments. Defaults to None. - - Returns: - object: The constructed object. - """ - # Avoid circular import - from ..logging import print_log - - if not isinstance(cfg, (dict, ConfigDict, Config)): - raise TypeError( - f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') - - if 'type' not in cfg: - if default_args is None or 'type' not in default_args: - raise KeyError( - '`cfg` or `default_args` must contain the key "type", ' - f'but got {cfg}\n{default_args}') - - if not isinstance(registry, Registry): - raise TypeError('registry must be a mmengine.Registry object, ' - f'but got {type(registry)}') - - if not (isinstance(default_args, - (dict, ConfigDict, Config)) or default_args is None): - raise TypeError( - 'default_args should be a dict, ConfigDict, Config or None, ' - f'but got {type(default_args)}') - - args = cfg.copy() - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - - # Instance should be built under target scope, if `_scope_` is defined - # in cfg, current default scope should switch to specified scope - # temporarily. - scope = args.pop('_scope_', None) - with registry.switch_scope_and_registry(scope) as registry: - obj_type = args.pop('type') - if isinstance(obj_type, str): - obj_cls = registry.get(obj_type) - if obj_cls is None: - raise KeyError( - f'{obj_type} is not in the {registry.scope}::{registry.name} registry. ' # noqa: E501 - f'Please check whether the value of `{obj_type}` is ' - 'correct or it was registered as expected. More details ' - 'can be found at ' - 'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 - ) - # this will include classes, functions, partial functions and more - elif callable(obj_type): - obj_cls = obj_type - else: - raise TypeError( - f'type must be a str or valid type, but got {type(obj_type)}') - - # If `obj_cls` inherits from `ManagerMixin`, it should be - # instantiated by `ManagerMixin.get_instance` to ensure that it - # can be accessed globally. - if inspect.isclass(obj_cls) and \ - issubclass(obj_cls, ManagerMixin): # type: ignore - obj = obj_cls.get_instance(**args) # type: ignore - else: - obj = obj_cls(**args) # type: ignore - - if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls) - or inspect.ismethod(obj_cls)): - print_log( - f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 - 'registry, and its implementation can be found in ' - f'{obj_cls.__module__}', # type: ignore - logger='current', - level=logging.DEBUG) - else: - print_log( - 'An instance is built from registry, and its constructor ' - f'is {obj_cls}', - logger='current', - level=logging.DEBUG) - return obj - - -def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], - registry: Registry) -> 'Runner': - """Build a Runner object. - - Examples: - >>> from mmengine.registry import Registry, build_runner_from_cfg - >>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg) - >>> @RUNNERS.register_module() - >>> class CustomRunner(Runner): - >>> def setup_env(env_cfg): - >>> pass - >>> cfg = dict(runner_type='CustomRunner', ...) - >>> custom_runner = RUNNERS.build(cfg) - - Args: - cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key - exists, it will be used to build a custom runner. Otherwise, it - will be used to build a default runner. - registry (:obj:`Registry`): The registry to search the type from. - - Returns: - object: The constructed runner object. - """ - from ..config import Config, ConfigDict - from ..logging import print_log - - assert isinstance( - cfg, - (dict, ConfigDict, Config - )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}' - assert isinstance( - registry, Registry), ('registry should be a mmengine.Registry object', - f'but got {type(registry)}') - - args = cfg.copy() - # Runner should be built under target scope, if `_scope_` is defined - # in cfg, current default scope should switch to specified scope - # temporarily. - scope = args.pop('_scope_', None) - with registry.switch_scope_and_registry(scope) as registry: - obj_type = args.get('runner_type', 'Runner') - if isinstance(obj_type, str): - runner_cls = registry.get(obj_type) - if runner_cls is None: - raise KeyError( - f'{obj_type} is not in the {registry.name} registry. ' - f'Please check whether the value of `{obj_type}` is ' - 'correct or it was registered as expected. More details ' - 'can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 - ) - elif inspect.isclass(obj_type): - runner_cls = obj_type - else: - raise TypeError( - f'type must be a str or valid type, but got {type(obj_type)}') - - runner = runner_cls.from_cfg(args) # type: ignore - print_log( - f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 - 'registry, its implementation can be found in' - f'{runner_cls.__module__}', # type: ignore - logger='current', - level=logging.DEBUG) - return runner - - -def build_model_from_cfg( - cfg: Union[dict, ConfigDict, Config], - registry: Registry, - default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None -) -> 'nn.Module': - """Build a PyTorch model from config dict(s). Different from - ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. - - Args: - cfg (dict, list[dict]): The config of modules, which is either a config - dict or a list of config dicts. If cfg is a list, the built - modules will be wrapped with ``nn.Sequential``. - registry (:obj:`Registry`): A registry the module belongs to. - default_args (dict, optional): Default arguments to build the module. - Defaults to None. - - Returns: - nn.Module: A built nn.Module. - """ - from ..model import Sequential - if isinstance(cfg, list): - modules = [ - build_from_cfg(_cfg, registry, default_args) for _cfg in cfg - ] - return Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) - - -def build_optimizer_from_cfg( - cfg: Union[dict, ConfigDict, Config], - registry: Registry, - default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: - import torch - - from ..logging import print_log - if 'type' in cfg \ - and 'Adafactor' == cfg['type'] \ - and digit_version(torch.__version__) >= digit_version('2.5.0'): - print_log( - 'the torch version of Adafactor is registered as TorchAdafactor') - return build_from_cfg(cfg, registry, default_args) - - -def build_scheduler_from_cfg( - cfg: Union[dict, ConfigDict, Config], - registry: Registry, - default_args: Optional[Union[dict, ConfigDict, Config]] = None -) -> '_ParamScheduler': - """Builds a ``ParamScheduler`` instance from config. - - ``ParamScheduler`` supports building instance by its constructor or - method ``build_iter_from_epoch``. Therefore, its registry needs a build - function to handle both cases. - - Args: - cfg (dict or ConfigDict or Config): Config dictionary. If it contains - the key ``convert_to_iter_based``, instance will be built by method - ``convert_to_iter_based``, otherwise instance will be built by its - constructor. - registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry. - default_args (dict or ConfigDict or Config, optional): Default - initialization arguments. It must contain key ``optimizer``. If - ``convert_to_iter_based`` is defined in ``cfg``, it must - additionally contain key ``epoch_length``. Defaults to None. - - Returns: - object: The constructed ``ParamScheduler``. - """ - assert isinstance( - cfg, - (dict, ConfigDict, Config - )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}' - assert isinstance( - registry, Registry), ('registry should be a mmengine.Registry object', - f'but got {type(registry)}') - - args = cfg.copy() - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - scope = args.pop('_scope_', None) - with registry.switch_scope_and_registry(scope) as registry: - convert_to_iter = args.pop('convert_to_iter_based', False) - if convert_to_iter: - scheduler_type = args.pop('type') - assert 'epoch_length' in args and args.get('by_epoch', True), ( - 'Only epoch-based parameter scheduler can be converted to ' - 'iter-based, and `epoch_length` should be set') - if isinstance(scheduler_type, str): - scheduler_cls = registry.get(scheduler_type) - if scheduler_cls is None: - raise KeyError( - f'{scheduler_type} is not in the {registry.name} ' - 'registry. Please check whether the value of ' - f'`{scheduler_type}` is correct or it was registered ' - 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 - ) - elif inspect.isclass(scheduler_type): - scheduler_cls = scheduler_type - else: - raise TypeError('type must be a str or valid type, but got ' - f'{type(scheduler_type)}') - return scheduler_cls.build_iter_from_epoch( # type: ignore - **args) - else: - args.pop('epoch_length', None) - return build_from_cfg(args, registry) diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py deleted file mode 100644 index f1347689e0..0000000000 --- a/mmengine/registry/default_scope.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import time -from contextlib import contextmanager -from typing import Generator, Optional - -from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock - - -class DefaultScope(ManagerMixin): - """Scope of current task used to reset the current registry, which can be - accessed globally. - - Consider the case of resetting the current ``Registry`` by - ``default_scope`` in the internal module which cannot access runner - directly, it is difficult to get the ``default_scope`` defined in - ``Runner``. However, if ``Runner`` created ``DefaultScope`` instance - by given ``default_scope``, the internal module can get - ``default_scope`` by ``DefaultScope.get_current_instance`` everywhere. - - Args: - name (str): Name of default scope for global access. - scope_name (str): Scope of current task. - - Examples: - >>> from mmengine.model import MODELS - >>> # Define default scope in runner. - >>> DefaultScope.get_instance('task', scope_name='mmdet') - >>> # Get default scope globally. - >>> scope_name = DefaultScope.get_instance('task').scope_name - """ - - def __init__(self, name: str, scope_name: str): - super().__init__(name) - assert isinstance( - scope_name, - str), (f'scope_name should be a string, but got {scope_name}') - self._scope_name = scope_name - - @property - def scope_name(self) -> str: - """ - Returns: - str: Get current scope. - """ - return self._scope_name - - @classmethod - def get_current_instance(cls) -> Optional['DefaultScope']: - """Get latest created default scope. - - Since default_scope is an optional argument for ``Registry.build``. - ``get_current_instance`` should return ``None`` if there is no - ``DefaultScope`` created. - - Examples: - >>> default_scope = DefaultScope.get_current_instance() - >>> # There is no `DefaultScope` created yet, - >>> # `get_current_instance` return `None`. - >>> default_scope = DefaultScope.get_instance( - >>> 'instance_name', scope_name='mmengine') - >>> default_scope.scope_name - mmengine - >>> default_scope = DefaultScope.get_current_instance() - >>> default_scope.scope_name - mmengine - - Returns: - Optional[DefaultScope]: Return None If there has not been - ``DefaultScope`` instance created yet, otherwise return the - latest created DefaultScope instance. - """ - _accquire_lock() - if cls._instance_dict: - instance = super().get_current_instance() - else: - instance = None - _release_lock() - return instance - - @classmethod - @contextmanager - def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator: - """Overwrite the current default scope with `scope_name`""" - if scope_name is None: - yield - else: - tmp = copy.deepcopy(cls._instance_dict) - # To avoid create an instance with the same name. - time.sleep(1e-6) - cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name) - try: - yield - finally: - cls._instance_dict = tmp diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py deleted file mode 100644 index e7d8962be4..0000000000 --- a/mmengine/registry/registry.py +++ /dev/null @@ -1,669 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -import sys -from collections.abc import Callable -from contextlib import contextmanager -from importlib import import_module -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union - -from rich.console import Console -from rich.table import Table - -from mmengine.config.utils import MODULE2PACKAGE -from mmengine.utils import get_object_from_string, is_seq_of -from .default_scope import DefaultScope - - -class Registry: - """A registry to map strings to classes or functions. - - Registered object could be built from registry. Meanwhile, registered - functions could be called from registry. - - Args: - name (str): Registry name. - build_func (callable, optional): A function to construct instance - from Registry. :func:`build_from_cfg` is used if neither ``parent`` - or ``build_func`` is specified. If ``parent`` is specified and - ``build_func`` is not given, ``build_func`` will be inherited - from ``parent``. Defaults to None. - parent (:obj:`Registry`, optional): Parent registry. The class - registered in children registry could be built from parent. - Defaults to None. - scope (str, optional): The scope of registry. It is the key to search - for children registry. If not specified, scope will be the name of - the package where class is defined, e.g. mmdet, mmcls, mmseg. - Defaults to None. - locations (list): The locations to import the modules registered - in this registry. Defaults to []. - New in version 0.4.0. - - Examples: - >>> # define a registry - >>> MODELS = Registry('models') - >>> # registry the `ResNet` to `MODELS` - >>> @MODELS.register_module() - >>> class ResNet: - >>> pass - >>> # build model from `MODELS` - >>> resnet = MODELS.build(dict(type='ResNet')) - >>> @MODELS.register_module() - >>> def resnet50(): - >>> pass - >>> resnet = MODELS.build(dict(type='resnet50')) - - >>> # hierarchical registry - >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') - >>> @DETECTORS.register_module() - >>> class FasterRCNN: - >>> pass - >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) - - >>> # add locations to enable auto import - >>> DETECTORS = Registry('detectors', parent=MODELS, - >>> scope='det', locations=['det.models.detectors']) - >>> # define this class in 'det.models.detectors' - >>> @DETECTORS.register_module() - >>> class MaskRCNN: - >>> pass - >>> # The registry will auto import det.models.detectors.MaskRCNN - >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN')) - - More advanced usages can be found at - https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. - """ - - def __init__(self, - name: str, - build_func: Optional[Callable] = None, - parent: Optional['Registry'] = None, - scope: Optional[str] = None, - locations: List = []): - from .build_functions import build_from_cfg - self._name = name - self._module_dict: Dict[str, Type] = dict() - self._children: Dict[str, 'Registry'] = dict() - self._locations = locations - self._imported = False - - if scope is not None: - assert isinstance(scope, str) - self._scope = scope - else: - self._scope = self.infer_scope() - - # See https://mypy.readthedocs.io/en/stable/common_issues.html# - # variables-vs-type-aliases for the use - self.parent: Optional['Registry'] - if parent is not None: - assert isinstance(parent, Registry) - parent._add_child(self) - self.parent = parent - else: - self.parent = None - - # self.build_func will be set with the following priority: - # 1. build_func - # 2. parent.build_func - # 3. build_from_cfg - self.build_func: Callable - if build_func is None: - if self.parent is not None: - self.build_func = self.parent.build_func - else: - self.build_func = build_from_cfg - else: - self.build_func = build_func - - def __len__(self): - return len(self._module_dict) - - def __contains__(self, key): - return self.get(key) is not None - - def __repr__(self): - table = Table(title=f'Registry of {self._name}') - table.add_column('Names', justify='left', style='cyan') - table.add_column('Objects', justify='left', style='green') - - for name, obj in sorted(self._module_dict.items()): - table.add_row(name, str(obj)) - - console = Console() - with console.capture() as capture: - console.print(table, end='') - - return capture.get() - - @staticmethod - def infer_scope() -> str: - """Infer the scope of registry. - - The name of the package where registry is defined will be returned. - - Returns: - str: The inferred scope name. - - Examples: - >>> # in mmdet/models/backbone/resnet.py - >>> MODELS = Registry('models') - >>> @MODELS.register_module() - >>> class ResNet: - >>> pass - >>> # The scope of ``ResNet`` will be ``mmdet``. - """ - from ..logging import print_log - - # `sys._getframe` returns the frame object that many calls below the - # top of the stack. The call stack for `infer_scope` can be listed as - # follow: - # frame-0: `infer_scope` itself - # frame-1: `__init__` of `Registry` which calls the `infer_scope` - # frame-2: Where the `Registry(...)` is called - module = inspect.getmodule(sys._getframe(2)) - if module is not None: - filename = module.__name__ - split_filename = filename.split('.') - scope = split_filename[0] - else: - # use "mmengine" to handle some cases which can not infer the scope - # like initializing Registry in interactive mode - scope = 'mmengine' - print_log( - 'set scope as "mmengine" when scope can not be inferred. You ' - 'can silence this warning by passing a "scope" argument to ' - 'Registry like `Registry(name, scope="toy")`', - logger='current', - level=logging.WARNING) - - return scope - - @staticmethod - def split_scope_key(key: str) -> Tuple[Optional[str], str]: - """Split scope and key. - - The first scope will be split from key. - - Return: - tuple[str | None, str]: The former element is the first scope of - the key, which can be ``None``. The latter is the remaining key. - - Examples: - >>> Registry.split_scope_key('mmdet.ResNet') - 'mmdet', 'ResNet' - >>> Registry.split_scope_key('ResNet') - None, 'ResNet' - """ - split_index = key.find('.') - if split_index != -1: - return key[:split_index], key[split_index + 1:] - else: - return None, key - - @property - def name(self): - return self._name - - @property - def scope(self): - return self._scope - - @property - def module_dict(self): - return self._module_dict - - @property - def children(self): - return self._children - - @property - def root(self): - return self._get_root_registry() - - @contextmanager - def switch_scope_and_registry(self, scope: Optional[str]) -> Generator: - """Temporarily switch default scope to the target scope, and get the - corresponding registry. - - If the registry of the corresponding scope exists, yield the - registry, otherwise yield the current itself. - - Args: - scope (str, optional): The target scope. - - Examples: - >>> from mmengine.registry import Registry, DefaultScope, MODELS - >>> import time - >>> # External Registry - >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', - >>> parent=MODELS) - >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', - >>> parent=MODELS) - >>> # Local Registry - >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', - >>> parent=MODELS) - >>> - >>> # Initiate DefaultScope - >>> DefaultScope.get_instance(f'scope_{time.time()}', - >>> scope_name='custom') - >>> # Check default scope - >>> DefaultScope.get_current_instance().scope_name - custom - >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: - >>> DefaultScope.get_current_instance().scope_name - mmcls - >>> registry.scope - mmcls - >>> # Nested switch scope - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: - >>> DefaultScope.get_current_instance().scope_name - mmdet - >>> mmdet_registry.scope - mmdet - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: - >>> DefaultScope.get_current_instance().scope_name - mmcls - >>> mmcls_registry.scope - mmcls - >>> - >>> # Check switch back to original scope. - >>> DefaultScope.get_current_instance().scope_name - custom - """ # noqa: E501 - from ..logging import print_log - - # Switch to the given scope temporarily. If the corresponding registry - # can be found in root registry, return the registry under the scope, - # otherwise return the registry itself. - with DefaultScope.overwrite_default_scope(scope): - # Get the global default scope - default_scope = DefaultScope.get_current_instance() - # Get registry by scope - if default_scope is not None: - scope_name = default_scope.scope_name - try: - import_module(f'{scope_name}.registry') - except (ImportError, AttributeError, ModuleNotFoundError): - if scope in MODULE2PACKAGE: - print_log( - f'{scope} is not installed and its ' - 'modules will not be registered. If you ' - 'want to use modules defined in ' - f'{scope}, Please install {scope} by ' - f'`pip install {MODULE2PACKAGE[scope]}.', - logger='current', - level=logging.WARNING) - else: - print_log( - f'Failed to import `{scope}.registry` ' - f'make sure the registry.py exists in `{scope}` ' - 'package.', - logger='current', - level=logging.WARNING) - root = self._get_root_registry() - registry = root._search_child(scope_name) - if registry is None: - # if `default_scope` can not be found, fallback to argument - # `registry` - print_log( - f'Failed to search registry with scope "{scope_name}" ' - f'in the "{root.name}" registry tree. ' - f'As a workaround, the current "{self.name}" registry ' - f'in "{self.scope}" is used to build instance. This ' - 'may cause unexpected failure when running the built ' - f'modules. Please check whether "{scope_name}" is a ' - 'correct scope, or whether the registry is ' - 'initialized.', - logger='current', - level=logging.WARNING) - registry = self - # If there is no built default scope, just return current registry. - else: - registry = self - yield registry - - def _get_root_registry(self) -> 'Registry': - """Return the root registry.""" - root = self - while root.parent is not None: - root = root.parent - return root - - def import_from_location(self) -> None: - """Import modules from the pre-defined locations in self._location.""" - if not self._imported: - # Avoid circular import - from ..logging import print_log - - # avoid BC breaking - if len(self._locations) == 0 and self.scope in MODULE2PACKAGE: - print_log( - f'The "{self.name}" registry in {self.scope} did not ' - 'set import location. Fallback to call ' - f'`{self.scope}.utils.register_all_modules` ' - 'instead.', - logger='current', - level=logging.DEBUG) - try: - module = import_module(f'{self.scope}.utils') - except (ImportError, AttributeError, ModuleNotFoundError): - if self.scope in MODULE2PACKAGE: - print_log( - f'{self.scope} is not installed and its ' - 'modules will not be registered. If you ' - 'want to use modules defined in ' - f'{self.scope}, Please install {self.scope} by ' - f'`pip install {MODULE2PACKAGE[self.scope]}.', - logger='current', - level=logging.WARNING) - else: - print_log( - f'Failed to import {self.scope} and register ' - 'its modules, please make sure you ' - 'have registered the module manually.', - logger='current', - level=logging.WARNING) - else: - # The import errors triggered during the registration - # may be more complex, here just throwing - # the error to avoid causing more implicit registry errors - # like `xxx`` not found in `yyy` registry. - module.register_all_modules(False) # type: ignore - - for loc in self._locations: - import_module(loc) - print_log( - f"Modules of {self.scope}'s {self.name} registry have " - f'been automatically imported from {loc}', - logger='current', - level=logging.DEBUG) - self._imported = True - - def get(self, key: str) -> Optional[Type]: - """Get the registry record. - - If `key`` represents the whole object name with its module - information, for example, `mmengine.model.BaseModel`, ``get`` - will directly return the class object :class:`BaseModel`. - - Otherwise, it will first parse ``key`` and check whether it - contains a scope name. The logic to search for ``key``: - - - ``key`` does not contain a scope name, i.e., it is purely a module - name like "ResNet": :meth:`get` will search for ``ResNet`` from the - current registry to its parent or ancestors until finding it. - - - ``key`` contains a scope name and it is equal to the scope of the - current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` - will only search for ``ResNet`` in the current registry. - - - ``key`` contains a scope name and it is not equal to the scope of - the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the - scope exists in its children, :meth:`get` will get "FCNet" from - them. If not, :meth:`get` will first get the root registry and root - registry call its own :meth:`get` method. - - Args: - key (str): Name of the registered item, e.g., the class name in - string format. - - Returns: - Type or None: Return the corresponding class if ``key`` exists, - otherwise return None. - - Examples: - >>> # define a registry - >>> MODELS = Registry('models') - >>> # register `ResNet` to `MODELS` - >>> @MODELS.register_module() - >>> class ResNet: - >>> pass - >>> resnet_cls = MODELS.get('ResNet') - - >>> # hierarchical registry - >>> DETECTORS = Registry('detector', parent=MODELS, scope='det') - >>> # `ResNet` does not exist in `DETECTORS` but `get` method - >>> # will try to search from its parents or ancestors - >>> resnet_cls = DETECTORS.get('ResNet') - >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls') - >>> @CLASSIFIER.register_module() - >>> class MobileNet: - >>> pass - >>> # `get` from its sibling registries - >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') - """ - # Avoid circular import - from ..logging import print_log - - if not isinstance(key, str): - raise TypeError( - 'The key argument of `Registry.get` must be a str, ' - f'got {type(key)}') - - scope, real_key = self.split_scope_key(key) - obj_cls = None - registry_name = self.name - scope_name = self.scope - - # lazy import the modules to register them into the registry - self.import_from_location() - - if scope is None or scope == self._scope: - # get from self - if real_key in self._module_dict: - obj_cls = self._module_dict[real_key] - elif scope is None: - # try to get the target from its parent or ancestors - parent = self.parent - while parent is not None: - if real_key in parent._module_dict: - obj_cls = parent._module_dict[real_key] - registry_name = parent.name - scope_name = parent.scope - break - parent = parent.parent - else: - # import the registry to add the nodes into the registry tree - try: - import_module(f'{scope}.registry') - print_log( - f'Registry node of {scope} has been automatically ' - 'imported.', - logger='current', - level=logging.DEBUG) - except (ImportError, AttributeError, ModuleNotFoundError): - print_log( - f'Cannot auto import {scope}.registry, please check ' - f'whether the package "{scope}" is installed correctly ' - 'or import the registry manually.', - logger='current', - level=logging.DEBUG) - # get from self._children - if scope in self._children: - obj_cls = self._children[scope].get(real_key) - registry_name = self._children[scope].name - scope_name = scope - else: - root = self._get_root_registry() - - if scope != root._scope and scope not in root._children: - # If not skip directly, `root.get(key)` will recursively - # call itself until RecursionError is thrown. - pass - else: - obj_cls = root.get(key) - - if obj_cls is None: - # Actually, it's strange to implement this `try ... except` to - # get the object by its name in `Registry.get`. However, If we - # want to build the model using a configuration like - # `dict(type='mmengine.model.BaseModel')`, which can - # be dumped by lazy import config, we need this code snippet - # for `Registry.get` to work. - try: - obj_cls = get_object_from_string(key) - except Exception: - raise RuntimeError(f'Failed to get {key}') - - if obj_cls is not None: - # For some rare cases (e.g. obj_cls is a partial function), obj_cls - # doesn't have `__name__`. Use default value to prevent error - cls_name = getattr(obj_cls, '__name__', str(obj_cls)) - print_log( - f'Get class `{cls_name}` from "{registry_name}"' - f' registry in "{scope_name}"', - logger='current', - level=logging.DEBUG) - - return obj_cls - - def _search_child(self, scope: str) -> Optional['Registry']: - """Depth-first search for the corresponding registry in its children. - - Note that the method only search for the corresponding registry from - the current registry. Therefore, if we want to search from the root - registry, :meth:`_get_root_registry` should be called to get the - root registry first. - - Args: - scope (str): The scope name used for searching for its - corresponding registry. - - Returns: - Registry or None: Return the corresponding registry if ``scope`` - exists, otherwise return None. - """ - if self._scope == scope: - return self - - for child in self._children.values(): - registry = child._search_child(scope) - if registry is not None: - return registry - - return None - - def build(self, cfg: dict, *args, **kwargs) -> Any: - """Build an instance. - - Build an instance by calling :attr:`build_func`. - - Args: - cfg (dict): Config dict needs to be built. - - Returns: - Any: The constructed object. - - Examples: - >>> from mmengine import Registry - >>> MODELS = Registry('models') - >>> @MODELS.register_module() - >>> class ResNet: - >>> def __init__(self, depth, stages=4): - >>> self.depth = depth - >>> self.stages = stages - >>> cfg = dict(type='ResNet', depth=50) - >>> model = MODELS.build(cfg) - """ - return self.build_func(cfg, *args, **kwargs, registry=self) - - def _add_child(self, registry: 'Registry') -> None: - """Add a child for a registry. - - Args: - registry (:obj:`Registry`): The ``registry`` will be added as a - child of the ``self``. - """ - - assert isinstance(registry, Registry) - assert registry.scope is not None - assert registry.scope not in self.children, \ - f'scope {registry.scope} exists in {self.name} registry' - self.children[registry.scope] = registry - - def _register_module(self, - module: Type, - module_name: Optional[Union[str, List[str]]] = None, - force: bool = False) -> None: - """Register a module. - - Args: - module (type): Module to be registered. Typically a class or a - function, but generally all ``Callable`` are acceptable. - module_name (str or list of str, optional): The module name to be - registered. If not specified, the class name will be used. - Defaults to None. - force (bool): Whether to override an existing class with the same - name. Defaults to False. - """ - if not callable(module): - raise TypeError(f'module must be Callable, but got {type(module)}') - - if module_name is None: - module_name = module.__name__ - if isinstance(module_name, str): - module_name = [module_name] - for name in module_name: - if not force and name in self._module_dict: - existed_module = self.module_dict[name] - raise KeyError(f'{name} is already registered in {self.name} ' - f'at {existed_module.__module__}') - self._module_dict[name] = module - - def register_module( - self, - name: Optional[Union[str, List[str]]] = None, - force: bool = False, - module: Optional[Type] = None) -> Union[type, Callable]: - """Register a module. - - A record will be added to ``self._module_dict``, whose key is the class - name or the specified name, and value is the class itself. - It can be used as a decorator or a normal function. - - Args: - name (str or list of str, optional): The module name to be - registered. If not specified, the class name will be used. - force (bool): Whether to override an existing class with the same - name. Defaults to False. - module (type, optional): Module class or function to be registered. - Defaults to None. - - Examples: - >>> backbones = Registry('backbone') - >>> # as a decorator - >>> @backbones.register_module() - >>> class ResNet: - >>> pass - >>> backbones = Registry('backbone') - >>> @backbones.register_module(name='mnet') - >>> class MobileNet: - >>> pass - - >>> # as a normal function - >>> class ResNet: - >>> pass - >>> backbones.register_module(module=ResNet) - """ - if not isinstance(force, bool): - raise TypeError(f'force must be a boolean, but got {type(force)}') - - # raise the error ahead of time - if not (name is None or isinstance(name, str) or is_seq_of(name, str)): - raise TypeError( - 'name must be None, an instance of str, or a sequence of str, ' - f'but got {type(name)}') - - # use it as a normal method: x.register_module(module=SomeClass) - if module is not None: - self._register_module(module=module, module_name=name, force=force) - return module - - # use it as a decorator: @x.register_module() - def _register(module): - self._register_module(module=module, module_name=name, force=force) - return module - - return _register diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py deleted file mode 100644 index eb9a225a91..0000000000 --- a/mmengine/registry/root.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""MMEngine provides 20 root registries to support using modules across -projects. - -More datails can be found at -https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. -""" - -from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg, - build_runner_from_cfg, build_scheduler_from_cfg) -from .registry import Registry - -# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` -RUNNERS = Registry('runner', build_func=build_runner_from_cfg) -# manage runner constructors that define how to initialize runners -RUNNER_CONSTRUCTORS = Registry('runner constructor') -# manage all kinds of loops like `EpochBasedTrainLoop` -LOOPS = Registry('loop') -# manage all kinds of hooks like `CheckpointHook` -HOOKS = Registry('hook') - -# manage all kinds of strategies like `NativeStrategy` and `DDPStrategy` -STRATEGIES = Registry('strategy') - -# manage data-related modules -DATASETS = Registry('dataset') -DATA_SAMPLERS = Registry('data sampler') -TRANSFORMS = Registry('transform') - -# mangage all kinds of modules inheriting `nn.Module` -MODELS = Registry('model', build_model_from_cfg) -# mangage all kinds of model wrappers like 'MMDistributedDataParallel' -MODEL_WRAPPERS = Registry('model_wrapper') -# mangage all kinds of weight initialization modules like `Uniform` -WEIGHT_INITIALIZERS = Registry('weight initializer') - -# mangage all kinds of optimizers like `SGD` and `Adam` -OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg) -# manage optimizer wrapper -OPTIM_WRAPPERS = Registry('optim_wrapper') -# manage constructors that customize the optimization hyperparameters. -OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor') -# mangage all kinds of parameter schedulers like `MultiStepLR` -PARAM_SCHEDULERS = Registry( - 'parameter scheduler', build_func=build_scheduler_from_cfg) - -# manage all kinds of metrics -METRICS = Registry('metric') -# manage evaluator -EVALUATOR = Registry('evaluator') - -# manage task-specific modules like anchor generators and box coders -TASK_UTILS = Registry('task util') - -# manage visualizer -VISUALIZERS = Registry('visualizer') -# manage visualizer backend -VISBACKENDS = Registry('vis_backend') - -# manage logprocessor -LOG_PROCESSORS = Registry('log_processor') - -# manage inferencer -INFERENCERS = Registry('inferencer') - -# manage function -FUNCTIONS = Registry('function') diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py deleted file mode 100644 index 2737e879a7..0000000000 --- a/mmengine/registry/utils.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import datetime -import logging -import os.path as osp -from typing import Optional - -from mmengine.fileio import dump -from mmengine.logging import print_log -from . import root -from .default_scope import DefaultScope -from .registry import Registry - - -def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list: - """Traverse the whole registry tree from any given node, and collect - information of all registered modules in this registry tree. - - Args: - registry (Registry): a registry node in the registry tree. - verbose (bool): Whether to print log. Defaults to True - - Returns: - list: Statistic results of all modules in each node of the registry - tree. - """ - root_registry = registry.root - modules_info = [] - - def _dfs_registry(_registry): - if isinstance(_registry, Registry): - num_modules = len(_registry.module_dict) - scope = _registry.scope - registry_info = dict(num_modules=num_modules, scope=scope) - for name, registered_class in _registry.module_dict.items(): - folder = '/'.join(registered_class.__module__.split('.')[:-1]) - if folder in registry_info: - registry_info[folder].append(name) - else: - registry_info[folder] = [name] - if verbose: - print_log( - f"Find {num_modules} modules in {scope}'s " - f"'{_registry.name}' registry ", - logger='current') - modules_info.append(registry_info) - else: - return - for _, child in _registry.children.items(): - _dfs_registry(child) - - _dfs_registry(root_registry) - return modules_info - - -def count_registered_modules(save_path: Optional[str] = None, - verbose: bool = True) -> dict: - """Scan all modules in MMEngine's root and child registries and dump to - json. - - Args: - save_path (str, optional): Path to save the json file. - verbose (bool): Whether to print log. Defaults to True. - - Returns: - dict: Statistic results of all registered modules. - """ - # import modules to trigger registering - import mmengine.dataset - import mmengine.evaluator - import mmengine.hooks - import mmengine.model - import mmengine.optim - import mmengine.runner - import mmengine.visualization # noqa: F401 - - registries_info = {} - # traverse all registries in MMEngine - for item in dir(root): - if not item.startswith('__'): - registry = getattr(root, item) - if isinstance(registry, Registry): - registries_info[item] = traverse_registry_tree( - registry, verbose) - scan_data = dict( - scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - registries=registries_info) - if verbose: - print_log( - f'Finish registry analysis, got: {scan_data}', logger='current') - if save_path is not None: - json_path = osp.join(save_path, 'modules_statistic_results.json') - dump(scan_data, json_path, indent=2) - print_log(f'Result has been saved to {json_path}', logger='current') - return scan_data - - -def init_default_scope(scope: str) -> None: - """Initialize the given default scope. - - Args: - scope (str): The name of the default scope. - """ - never_created = DefaultScope.get_current_instance( - ) is None or not DefaultScope.check_instance_created(scope) - if never_created: - DefaultScope.get_instance(scope, scope_name=scope) - return - current_scope = DefaultScope.get_current_instance() # type: ignore - if current_scope.scope_name != scope: # type: ignore - print_log( - 'The current default scope ' # type: ignore - f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore - '`init_default_scope` will force set the current' - f'default scope to "{scope}".', - logger='current', - level=logging.WARNING) - # avoid name conflict - new_instance_name = f'{scope}-{datetime.datetime.now()}' - DefaultScope.get_instance(new_instance_name, scope_name=scope) diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py deleted file mode 100644 index b00f8e8391..0000000000 --- a/mmengine/runner/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from ._flexible_runner import FlexibleRunner -from .activation_checkpointing import turn_on_activation_checkpointing -from .amp import autocast -from .base_loop import BaseLoop -from .checkpoint import (CheckpointLoader, find_latest_checkpoint, - get_deprecated_model_names, get_external_models, - get_mmcls_models, get_state_dict, - get_torchvision_models, load_checkpoint, - load_state_dict, save_checkpoint, weights_to_cpu) -from .log_processor import LogProcessor -from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop -from .priority import Priority, get_priority -from .runner import Runner -from .utils import set_random_seed - -__all__ = [ - 'BaseLoop', 'load_state_dict', 'get_torchvision_models', - 'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names', - 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', - 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', - 'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint', - 'autocast', 'LogProcessor', 'set_random_seed', 'FlexibleRunner', - 'turn_on_activation_checkpointing' -] diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py deleted file mode 100644 index 5160a5cfb0..0000000000 --- a/mmengine/runner/_flexible_runner.py +++ /dev/null @@ -1,1650 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -import os.path as osp -import pickle -import warnings -from functools import partial -from typing import Callable, Dict, List, Optional, Union - -import torch.nn as nn -from torch.utils.data import DataLoader - -import mmengine -from mmengine._strategy import BaseStrategy -from mmengine.config import Config, ConfigDict -from mmengine.dataset import worker_init_fn as default_worker_init_fn -from mmengine.dist import get_rank, infer_launcher, master_only -from mmengine.evaluator import Evaluator -from mmengine.fileio import FileClient, join_path -from mmengine.hooks import Hook -from mmengine.logging import MessageHub, print_log -from mmengine.optim import OptimWrapper, OptimWrapperDict, _ParamScheduler -from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, - HOOKS, LOG_PROCESSORS, LOOPS, RUNNERS, - STRATEGIES, VISUALIZERS, DefaultScope) -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION -from mmengine.visualization import Visualizer -from .base_loop import BaseLoop -from .checkpoint import find_latest_checkpoint -from .log_processor import LogProcessor -from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop -from .priority import Priority, get_priority -from .utils import _get_batch_size - -ConfigType = Union[Dict, Config, ConfigDict] -ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, - List[_ParamScheduler]]] -OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] - - -@RUNNERS.register_module() -class FlexibleRunner: - """A training helper for PyTorch. - - Runner object can be built from config by ``runner = Runner.from_cfg(cfg)`` - where the ``cfg`` usually contains training, validation, and test-related - configurations to build corresponding components. We usually use the - same config to launch training, testing, and validation tasks. However, - only some of these components are necessary at the same time, e.g., - testing a model does not need training or validation-related components. - - To avoid repeatedly modifying config, the construction of ``Runner`` adopts - lazy initialization to only initialize components when they are going to be - used. Therefore, the model is always initialized at the beginning, and - training, validation, and, testing related components are only initialized - when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``, - respectively. - - Warning: - This is an experimental feature, and its interface is subject to - change. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It can be - a dict used for build a model. - - Kwargs: - work_dir (str, optional): The working directory to save checkpoints. - The logs will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dir'. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as ``experiment_name``. - Defaults to None. - train_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping training steps. Defaults to None. - See :meth:`build_dataloader` for more details. - optim_wrapper (OptimWrapper or dict, optional): - Computing gradient of model parameters. If specified, - :attr:`train_dataloader` should also be specified. If automatic - mixed precision or gradient accmulation - training is required. The type of ``optim_wrapper`` should be - AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for - examples. Defaults to None. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optimizer` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - train_cfg (dict, optional): A dict to build a training loop. If it does - not provide "type" key, it should contain "by_epoch" to decide - which type of training loop :class:`EpochBasedTrainLoop` or - :class:`IterBasedTrainLoop` should be used. If ``train_cfg`` - specified, :attr:`train_dataloader` should also be specified. - Defaults to None. See :meth:`build_train_loop` for more details. - val_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping validation steps. Defaults to None. - See :meth:`build_dataloader` for more details. - val_evaluator (Evaluator or dict or list, optional): A evaluator object - used for computing metrics for validation. It can be a dict or a - list of dict to build a evaluator. If specified, - :attr:`val_dataloader` should also be specified. Defaults to None. - val_cfg (dict, optional): A dict to build a validation loop. If it does - not provide "type" key, :class:`ValLoop` will be used by default. - If ``val_cfg`` specified, :attr:`val_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - test_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping test steps. Defaults to None. - See :meth:`build_dataloader` for more details. - Defaults to None. See :meth:`build_val_loop` for more details. - test_evaluator (Evaluator or dict or list, optional): A evaluator - object used for computing metrics for test steps. It can be a dict - or a list of dict to build a evaluator. If specified, - :attr:`test_dataloader` should also be specified. Defaults to None. - test_cfg (dict, optional): A dict to build a test loop. If it does - not provide "type" key, :class:`TestLoop` will be used by default. - If ``test_cfg`` specified, :attr:`test_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - Defaults to None. See :meth:`build_test_loop` for more details. - strategy (BaseStrategy or dict, optional): A strategy object or a dict - to build a strategy. Defaults to None. If not specified, the - strategy will be inferred automatically. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks to - execute default actions like updating model parameters and saving - checkpoints. Default hooks are ``OptimizerHook``, - ``IterTimerHook``, ``LoggerHook``, ``ParamSchedulerHook`` and - ``CheckpointHook``. Defaults to None. - See :meth:`register_default_hooks` for more details. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - data_preprocessor (dict, optional): The pre-process config of - :class:`BaseDataPreprocessor`. If the ``model`` argument is a dict - and doesn't contain the key ``data_preprocessor``, set the argument - as the ``data_preprocessor`` of the ``model`` dict. - Defaults to None. - load_from (str, optional): The checkpoint file to load from. - Defaults to None. - resume (bool): Whether to resume training. Defaults to False. If - ``resume`` is True and ``load_from`` is None, automatically to - find latest checkpoint from ``work_dir``. If not found, resuming - does nothing. - launcher (str, optional): Way to launcher multi-process. Supported - launchers are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' is - provided, non-distributed environment will be launched. - If launcher is None, the launcher will be inferred according some - specified environments. Defaults to None. - env_cfg (dict): A dict used for setting environment. Defaults to - dict(dist_cfg=dict(backend='nccl')). - log_processor (dict, optional): A processor to format logs. Defaults to - None. - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - visualizer (Visualizer or dict, optional): A Visualizer object or a - dict build Visualizer object. Defaults to None. If not - specified, default config will be used. - default_scope (str): Used to reset registries location. - Defaults to "mmengine". - randomness (dict): Some settings to make the experiment as reproducible - as possible like seed and deterministic. - Defaults to ``dict(seed=None)``. If seed is None, a random number - will be generated and it will be broadcasted to all other processes - if in distributed environment. If ``cudnn_benchmark`` is - ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in - ``randomness``, the value of ``torch.backends.cudnn.benchmark`` - will be ``False`` finally. - compile (bool or dict, optional): Whether to enable ``torch.compile``. - Defaults to False. - cfg (dict or Configdict or :obj:`Config`, optional): Full config. - Defaults to None. - - Note: - Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in - `compile = True`. If you want to control compile options, you - can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``. - Refer to `PyTorch API Documentation `_ for more valid - options. - - Examples: - >>> from mmengine.runner import Runner - >>> cfg = dict( - >>> model=dict(type='ToyModel'), - >>> work_dir='path/of/work_dir', - >>> train_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=True), - >>> batch_size=1, - >>> num_workers=0), - >>> val_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> test_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> auto_scale_lr=dict(base_batch_size=16, enable=False), - >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( - >>> type='SGD', lr=0.01)), - >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - >>> val_evaluator=dict(type='ToyEvaluator'), - >>> test_evaluator=dict(type='ToyEvaluator'), - >>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1), - >>> val_cfg=dict(), - >>> test_cfg=dict(), - >>> custom_hooks=[], - >>> default_hooks=dict( - >>> timer=dict(type='IterTimerHook'), - >>> checkpoint=dict(type='CheckpointHook', interval=1), - >>> logger=dict(type='LoggerHook'), - >>> optimizer=dict(type='OptimizerHook', grad_clip=False), - >>> param_scheduler=dict(type='ParamSchedulerHook')), - >>> launcher='none', - >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), - >>> log_processor=dict(window_size=20), - >>> visualizer=dict(type='Visualizer', - >>> vis_backends=[dict(type='LocalVisBackend', - >>> save_dir='temp_dir')]) - >>> ) - >>> runner = Runner.from_cfg(cfg) - >>> runner.train() - >>> runner.test() - """ - cfg: Config - _train_loop: Optional[Union[BaseLoop, Dict]] - _val_loop: Optional[Union[BaseLoop, Dict]] - _test_loop: Optional[Union[BaseLoop, Dict]] - - def __init__( - self, - model: Union[nn.Module, Dict], - *, - work_dir: str = 'work_dirs', - experiment_name: Optional[str] = None, - train_dataloader: Optional[Union[DataLoader, Dict]] = None, - optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None, - param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None, - train_cfg: Optional[Dict] = None, - val_dataloader: Optional[Union[DataLoader, Dict]] = None, - val_evaluator: Optional[Union[Evaluator, Dict, List]] = None, - val_cfg: Optional[Dict] = None, - test_dataloader: Optional[Union[DataLoader, Dict]] = None, - test_evaluator: Optional[Union[Evaluator, Dict, List]] = None, - test_cfg: Optional[Dict] = None, - strategy: Optional[Union[BaseStrategy, Dict]] = None, - auto_scale_lr: Optional[Dict] = None, - default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, - custom_hooks: Optional[List[Union[Hook, Dict]]] = None, - data_preprocessor: Union[nn.Module, Dict, None] = None, - load_from: Optional[str] = None, - resume: Union[str, bool] = False, - launcher: Optional[str] = None, - env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), - log_processor: Optional[Dict] = None, - log_level: str = 'INFO', - visualizer: Optional[Union[Visualizer, Dict]] = None, - default_scope: Optional[str] = 'mmengine', - randomness: Dict = dict(seed=None), - compile: Union[bool, Dict] = False, - cfg: Optional[ConfigType] = None, - ): - if isinstance(model, dict) and data_preprocessor is not None: - # Merge the data_preprocessor to model config. - model.setdefault('data_preprocessor', data_preprocessor) - self.model = model - - self._work_dir = osp.abspath(work_dir) - mmengine.mkdir_or_exist(self._work_dir) - - # recursively copy the `cfg` because `self.cfg` will be modified - # everywhere. - if cfg is not None: - if isinstance(cfg, Config): - self.cfg = copy.deepcopy(cfg) - elif isinstance(cfg, dict): - self.cfg = Config(cfg) - else: - self.cfg = Config(dict()) - - # lazy initialization - training_related = [train_dataloader, train_cfg, optim_wrapper] - if not (all(item is None for item in training_related) - or all(item is not None for item in training_related)): - raise ValueError( - 'train_dataloader, train_cfg, and optim_wrapper should be ' - 'either all None or not None, but got ' - f'train_dataloader={train_dataloader}, ' - f'train_cfg={train_cfg}, ' - f'optim_wrapper={optim_wrapper}.') - self._train_dataloader = train_dataloader - self._train_loop = train_cfg - - self.optim_wrapper: Optional[Union[OptimWrapper, dict]] - self.optim_wrapper = optim_wrapper - - self._auto_scale_lr = auto_scale_lr - - # If there is no need to adjust learning rate, momentum or other - # parameters of optimizer, param_scheduler can be None - if param_scheduler is not None and self.optim_wrapper is None: - raise ValueError( - 'param_scheduler should be None when optim_wrapper is None, ' - f'but got {param_scheduler}') - - self.param_schedulers = param_scheduler - - val_related = [val_dataloader, val_cfg, val_evaluator] - if not (all(item is None - for item in val_related) or all(item is not None - for item in val_related)): - raise ValueError( - 'val_dataloader, val_cfg, and val_evaluator should be either ' - 'all None or not None, but got ' - f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, ' - f'val_evaluator={val_evaluator}') - self._val_dataloader = val_dataloader - self._val_loop = val_cfg - self._val_evaluator = val_evaluator - - test_related = [test_dataloader, test_cfg, test_evaluator] - if not (all(item is None for item in test_related) - or all(item is not None for item in test_related)): - raise ValueError( - 'test_dataloader, test_cfg, and test_evaluator should be ' - 'either all None or not None, but got ' - f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, ' - f'test_evaluator={test_evaluator}') - self._test_dataloader = test_dataloader - self._test_loop = test_cfg - self._test_evaluator = test_evaluator - - if not isinstance(compile, bool) and not isinstance(compile, dict): - raise TypeError( - f'compile should be a bool or dict, but got {type(compile)}') - self._compile = compile - - if isinstance(resume, str) and load_from is not None: - raise ValueError('If resume is a str, load_from should be None.') - self._load_from = load_from - self._resume = resume - # flag to mark whether checkpoint has been loaded or resumed - self._has_loaded = False - - if launcher is None: - launcher = infer_launcher() - - if experiment_name is None and self.cfg.filename is not None: - experiment_name = osp.splitext(osp.basename(self.cfg.filename))[0] - - self._randomness_cfg = randomness - self.strategy = self.build_strategy( - strategy, - launcher=launcher, - randomness=randomness, - env_cfg=env_cfg, - experiment_name=experiment_name, - log_level=log_level, - ) - - # Used to reset registries location. See :meth:`Registry.build` for - # more details. - if default_scope is not None: - default_scope = DefaultScope.get_instance( # type: ignore - self.experiment_name, - scope_name=default_scope) - self.default_scope = default_scope - # Build log processor to format message. - log_processor = dict() if log_processor is None else log_processor - self.log_processor = self.build_log_processor(log_processor) - - # Collect and log environment information. - self._log_env() - - # Build `message_hub` for communication among components. - # `message_hub` can store log scalars (loss, learning rate) and - # runtime information (iter and epoch). Those components that do not - # have access to the runner can get iteration or epoch information - # from `message_hub`. For example, models can get the latest created - # `message_hub` by - # `self.message_hub=MessageHub.get_current_instance()` and then get - # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. - # See `MessageHub` and `ManagerMixin` for more details. - self.message_hub = self.build_message_hub() - # visualizer used for writing log or visualizing all kinds of data - self.visualizer = self.build_visualizer(visualizer) - if self.cfg: - self.visualizer.add_config(self.cfg) - - self._hooks: List[Hook] = [] - # register hooks to `self._hooks` - self.register_hooks(default_hooks, custom_hooks) - # log hooks information - self.logger.info(f'Hooks will be executed in the following ' - f'order:\n{self.get_hooks_info()}') - - # dump `cfg` to `work_dir` - self.dump_config() - - @classmethod - def from_cfg(cls, cfg: ConfigType) -> 'FlexibleRunner': - """Build a runner from config. - - Args: - cfg (ConfigType): A config used for building runner. Keys of - ``cfg`` can see :meth:`__init__`. - - Returns: - Runner: A runner build from ``cfg``. - """ - cfg = copy.deepcopy(cfg) - runner = cls( - model=cfg['model'], - work_dir=cfg.get('work_dir', 'work_dirs'), - experiment_name=cfg.get('experiment_name'), - train_dataloader=cfg.get('train_dataloader'), - optim_wrapper=cfg.get('optim_wrapper'), - param_scheduler=cfg.get('param_scheduler'), - train_cfg=cfg.get('train_cfg'), - val_dataloader=cfg.get('val_dataloader'), - val_evaluator=cfg.get('val_evaluator'), - val_cfg=cfg.get('val_cfg'), - test_dataloader=cfg.get('test_dataloader'), - test_evaluator=cfg.get('test_evaluator'), - test_cfg=cfg.get('test_cfg'), - strategy=cfg.get('strategy'), - auto_scale_lr=cfg.get('auto_scale_lr'), - default_hooks=cfg.get('default_hooks'), - custom_hooks=cfg.get('custom_hooks'), - data_preprocessor=cfg.get('data_preprocessor'), - load_from=cfg.get('load_from'), - resume=cfg.get('resume', False), - launcher=cfg.get('launcher'), - env_cfg=cfg.get('env_cfg'), # type: ignore - log_processor=cfg.get('log_processor'), - log_level=cfg.get('log_level', 'INFO'), - visualizer=cfg.get('visualizer'), - default_scope=cfg.get('default_scope', 'mmengine'), - randomness=cfg.get('randomness', dict(seed=None)), - cfg=cfg, - ) - - return runner - - @property - def experiment_name(self): - """str: Name of experiment.""" - return self.strategy.experiment_name - - @property - def model_name(self): - """str: Name of the model, usually the module class name.""" - return self._model_name - - @property - def work_dir(self): - """str: The working directory to save checkpoints and logs.""" - return self._work_dir - - @property - def log_dir(self): - return self.strategy.log_dir - - @property - def logger(self): - return self.strategy.logger - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_epochs - else: - return 0 - - @property - def max_iters(self): - """int: Total iterations to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_iters - else: - return 0 - - @property - def epoch(self): - """int: Current epoch.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.epoch - else: - return 0 - - @property - def iter(self): - """int: Current iteration.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.iter - else: - return 0 - - @property - def distributed(self): - """bool: Whether current environment is distributed.""" - return self.strategy.distributed - - @property - def rank(self): - """int: Rank of current process.""" - return self.strategy.rank - - @property - def world_size(self): - """int: Number of processes participating in the job.""" - return self.strategy.world_size - - @property - def deterministic(self): - """int: Whether cudnn to select deterministic algorithms.""" - return self._deterministic - - @property - def seed(self): - """int: A number to set random modules.""" - return self.strategy.seed - - @property - def timestamp(self): - """str: Timestamp when creating experiment.""" - return self.strategy.timestamp - - @property - def hooks(self): - """List[:obj:`Hook`]: A list of registered hooks.""" - return self._hooks - - @property - def train_loop(self): - """:obj:`BaseLoop`: A loop to run training.""" - if isinstance(self._train_loop, BaseLoop) or self._train_loop is None: - return self._train_loop - else: - self._train_loop = self.build_train_loop(self._train_loop) - return self._train_loop - - @property - def val_loop(self): - """:obj:`BaseLoop`: A loop to run validation.""" - if isinstance(self._val_loop, BaseLoop) or self._val_loop is None: - return self._val_loop - else: - self._val_loop = self.build_val_loop(self._val_loop) - return self._val_loop - - @property - def test_loop(self): - """:obj:`BaseLoop`: A loop to run testing.""" - if isinstance(self._test_loop, BaseLoop) or self._test_loop is None: - return self._test_loop - else: - self._test_loop = self.build_test_loop(self._test_loop) - return self._test_loop - - @property - def train_dataloader(self): - """The data loader for training.""" - return self.train_loop.dataloader - - @property - def val_dataloader(self): - """The data loader for validation.""" - return self.val_loop.dataloader - - @property - def test_dataloader(self): - """The data loader for testing.""" - return self.test_loop.dataloader - - @property - def val_evaluator(self): - """:obj:`Evaluator`: An evaluator for validation.""" - return self.val_loop.evaluator - - @property - def test_evaluator(self): - """:obj:`Evaluator`: An evaluator for testing.""" - return self.test_loop.evaluator - - @property - def val_interval(self): - """int: Interval to run validation during training.""" - return self.train_loop.val_interval - - @property - def val_begin(self): - """int: The epoch/iteration to start running validation during - training.""" - return self.train_loop.val_begin - - def build_strategy( - self, - strategy: Optional[Union[BaseStrategy, Dict]] = None, - launcher: str = 'none', - randomness: Optional[dict] = None, - env_cfg: dict = dict(dist_cfg=dict(backend='nccl')), - experiment_name: Optional[str] = None, - log_level: Optional[str] = None, - ) -> BaseStrategy: - """Build a strategy. - - Args: - strategy (BaseStrategy, optional): A strategy object or dict to - build the strategy. Defaults to None. - - Returns: - BaseStrategy: A strategy object. - """ - if isinstance(strategy, BaseStrategy): - strategy_obj = strategy - else: - if launcher == 'none': - if strategy is None: - strategy = dict(type='SingleDeviceStrategy') - else: - if strategy is None: - strategy = dict(type='DDPStrategy') - - assert isinstance(strategy, dict) - - # train_micro_batch_size_per_gpu is required by DeepSpeed - if isinstance(strategy['type'], str): - strategy_name = strategy['type'] - else: - strategy_name = strategy['type'].__name__ - if strategy_name == 'DeepSpeedStrategy': - if self._train_dataloader is None: - strategy['train_micro_batch_size_per_gpu'] = 1 - else: - strategy['train_micro_batch_size_per_gpu'] = \ - _get_batch_size(self._train_dataloader) - - strategy.setdefault('work_dir', self._work_dir) - strategy.setdefault('experiment_name', experiment_name) - strategy.setdefault('auto_scale_lr', self._auto_scale_lr) - - env_kwargs = dict( - launcher=launcher, - randomness=randomness, - **env_cfg, - ) - strategy.setdefault('env_kwargs', env_kwargs) - - log_kwargs = dict(log_level=log_level) - strategy.setdefault('log_kwargs', log_kwargs) - - strategy_obj = STRATEGIES.build(strategy) - - return strategy_obj - - def build_message_hub( - self, - message_hub: Optional[Dict] = None, - ) -> MessageHub: - """Build a global asscessable MessageHub. - - Args: - message_hub (dict, optional): A dict to build MessageHub object. - If not specified, default config will be used to build - MessageHub object. Defaults to None. - - Returns: - MessageHub: A MessageHub object build from ``message_hub``. - """ - if message_hub is None: - message_hub = dict(name=self.experiment_name) - elif isinstance(message_hub, dict): - # ensure message_hub containing name key - message_hub.setdefault('name', self.experiment_name) - else: - raise TypeError( - f'message_hub should be dict or None, but got {message_hub}') - - return MessageHub.get_instance(**message_hub) - - def build_visualizer( - self, - visualizer: Optional[Union[Visualizer, Dict]] = None, - ) -> Visualizer: - """Build a global asscessable Visualizer. - - Args: - visualizer (Visualizer or dict, optional): A Visualizer object - or a dict to build Visualizer object. If ``visualizer`` is a - Visualizer object, just returns itself. If not specified, - default config will be used to build Visualizer object. - Defaults to None. - - Returns: - Visualizer: A Visualizer object build from ``visualizer``. - """ - if visualizer is None: - visualizer = dict( - name=self.experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self.log_dir) - return Visualizer.get_instance(**visualizer) - - if isinstance(visualizer, Visualizer): - return visualizer - - if isinstance(visualizer, dict): - # ensure visualizer containing name key - visualizer.setdefault('name', self.experiment_name) - visualizer.setdefault('save_dir', self.log_dir) - return VISUALIZERS.build(visualizer) - else: - raise TypeError( - 'visualizer should be Visualizer object, a dict or None, ' - f'but got {visualizer}') - - def build_evaluator( - self, - evaluator: Union[Dict, List, Evaluator], - ) -> Evaluator: - """Build evaluator. - - Examples of ``evaluator``:: - - # evaluator could be a built Evaluator instance - evaluator = Evaluator(metrics=[ToyMetric()]) - - # evaluator can also be a list of dict - evaluator = [ - dict(type='ToyMetric1'), - dict(type='ToyEvaluator2') - ] - - # evaluator can also be a list of built metric - evaluator = [ToyMetric1(), ToyMetric2()] - - # evaluator can also be a dict with key metrics - evaluator = dict(metrics=ToyMetric()) - # metric is a list - evaluator = dict(metrics=[ToyMetric()]) - - Args: - evaluator (Evaluator or dict or list): An Evaluator object or a - config dict or list of config dict used to build an Evaluator. - - Returns: - Evaluator: Evaluator build from ``evaluator``. - """ - if isinstance(evaluator, Evaluator): - return evaluator - elif isinstance(evaluator, dict): - # if `metrics` in dict keys, it means to build customized evalutor - if 'metrics' in evaluator: - evaluator.setdefault('type', 'Evaluator') - return EVALUATOR.build(evaluator) - # otherwise, default evalutor will be built - else: - return Evaluator(evaluator) # type: ignore - elif isinstance(evaluator, list): - # use the default `Evaluator` - return Evaluator(evaluator) # type: ignore - else: - raise TypeError( - 'evaluator should be one of dict, list of dict, and Evaluator' - f', but got {evaluator}') - - @staticmethod - def build_dataloader( - dataloader: Union[DataLoader, Dict], - seed: Optional[int] = None, - diff_rank_seed: bool = False, - ) -> DataLoader: - """Build dataloader. - - The method builds three components: - - - Dataset - - Sampler - - Dataloader - - An example of ``dataloader``:: - - dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=9 - ) - - Args: - dataloader (DataLoader or dict): A Dataloader object or a dict to - build Dataloader object. If ``dataloader`` is a Dataloader - object, just returns itself. - seed (int, optional): Random seed. Defaults to None. - diff_rank_seed (bool): Whether or not set different seeds to - different ranks. If True, the seed passed to sampler is set - to None, in order to synchronize the seeds used in samplers - across different ranks. Defaults to False. - - Returns: - Dataloader: DataLoader build from ``dataloader_cfg``. - """ - if isinstance(dataloader, DataLoader): - return dataloader - - dataloader_cfg = copy.deepcopy(dataloader) - - # build dataset - dataset_cfg = dataloader_cfg.pop('dataset') - if isinstance(dataset_cfg, dict): - dataset = DATASETS.build(dataset_cfg) - if hasattr(dataset, 'full_init'): - dataset.full_init() - else: - # fallback to raise error in dataloader - # if `dataset_cfg` is not a valid type - dataset = dataset_cfg - - # build sampler - sampler_cfg = dataloader_cfg.pop('sampler') - if isinstance(sampler_cfg, dict): - sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_args=dict(dataset=dataset, seed=sampler_seed)) - else: - # fallback to raise error in dataloader - # if `sampler_cfg` is not a valid type - sampler = sampler_cfg - - # build batch sampler - batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) - if batch_sampler_cfg is None: - batch_sampler = None - elif isinstance(batch_sampler_cfg, dict): - batch_sampler = DATA_SAMPLERS.build( - batch_sampler_cfg, - default_args=dict( - sampler=sampler, - batch_size=dataloader_cfg.pop('batch_size'))) - else: - # fallback to raise error in dataloader - # if `batch_sampler_cfg` is not a valid type - batch_sampler = batch_sampler_cfg - - # build dataloader - init_fn: Optional[partial] - if 'worker_init_fn' in dataloader_cfg: - worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') - worker_init_fn_type = worker_init_fn_cfg.pop('type') - worker_init_fn = FUNCTIONS.get(worker_init_fn_type) - assert callable(worker_init_fn) - init_fn = partial(worker_init_fn, - **worker_init_fn_cfg) # type: ignore - else: - if seed is not None: - disable_subprocess_warning = dataloader_cfg.pop( - 'disable_subprocess_warning', False) - assert isinstance(disable_subprocess_warning, bool), ( - 'disable_subprocess_warning should be a bool, but got ' - f'{type(disable_subprocess_warning)}') - init_fn = partial( - default_worker_init_fn, - num_workers=dataloader_cfg.get('num_workers'), - rank=get_rank(), - seed=seed, - disable_subprocess_warning=disable_subprocess_warning) - else: - init_fn = None - - # `persistent_workers` requires pytorch version >= 1.7 - if ('persistent_workers' in dataloader_cfg - and digit_version(TORCH_VERSION) < digit_version('1.7.0')): - print_log( - '`persistent_workers` is only available when ' - 'pytorch version >= 1.7', - logger='current', - level=logging.WARNING) - dataloader_cfg.pop('persistent_workers') - - # The default behavior of `collat_fn` in dataloader is to - # merge a list of samples to form a mini-batch of Tensor(s). - # However, in mmengine, if `collate_fn` is not defined in - # dataloader_cfg, `pseudo_collate` will only convert the list of - # samples into a dict without stacking the batch tensor. - collate_fn_cfg = dataloader_cfg.pop('collate_fn', - dict(type='pseudo_collate')) - if isinstance(collate_fn_cfg, dict): - collate_fn_type = collate_fn_cfg.pop('type') - if isinstance(collate_fn_type, str): - collate_fn = FUNCTIONS.get(collate_fn_type) - else: - collate_fn = collate_fn_type - collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore - elif callable(collate_fn_cfg): - collate_fn = collate_fn_cfg - else: - raise TypeError( - 'collate_fn should be a dict or callable object, but got ' - f'{collate_fn_cfg}') - data_loader = DataLoader( - dataset=dataset, - sampler=sampler if batch_sampler is None else None, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - worker_init_fn=init_fn, - **dataloader_cfg) - return data_loader - - def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: - """Build training loop. - - Examples of ``loop``:: - - # `EpochBasedTrainLoop` will be used - loop = dict(by_epoch=True, max_epochs=3) - - # `IterBasedTrainLoop` will be used - loop = dict(by_epoch=False, max_epochs=3) - - # custom training loop - loop = dict(type='CustomTrainLoop', max_epochs=3) - - Args: - loop (BaseLoop or dict): A training loop or a dict to build - training loop. If ``loop`` is a training loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Training loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError( - f'loop should be a Loop object or dict, but got {loop}') - - loop_cfg = copy.deepcopy(loop) - - if 'type' in loop_cfg and 'by_epoch' in loop_cfg: - raise RuntimeError( - 'Only one of `type` or `by_epoch` can exist in `loop_cfg`.') - - if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, dataloader=self._train_dataloader)) - else: - by_epoch = loop_cfg.pop('by_epoch') - if by_epoch: - loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) - else: - loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) - return loop # type: ignore - - def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: - """Build validation loop. - - Examples of ``loop``: - - # `ValLoop` will be used - loop = dict() - - # custom validation loop - loop = dict(type='CustomValLoop') - - Args: - loop (BaseLoop or dict): A validation loop or a dict to build - validation loop. If ``loop`` is a validation loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Validation loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError( - f'train_loop should be a Loop object or dict, but got {loop}') - - loop_cfg = copy.deepcopy(loop) - - if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator)) - else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator) # type: ignore - - return loop # type: ignore - - def build_test_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: - """Build test loop. - - Examples of ``loop``:: - - # `TestLoop` will be used - loop = dict() - - # custom test loop - loop = dict(type='CustomTestLoop') - - Args: - loop (BaseLoop or dict): A test loop or a dict to build test loop. - If ``loop`` is a test loop object, just returns itself. - - Returns: - :obj:`BaseLoop`: Test loop object build from ``loop_cfg``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError( - f'train_loop should be a Loop object or dict, but got {loop}') - - loop_cfg = copy.deepcopy(loop) # type: ignore - - if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator)) - else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator) # type: ignore - - return loop # type: ignore - - def build_log_processor( - self, - log_processor: Union[LogProcessor, Dict], - ) -> LogProcessor: - """Build test log_processor. - - Examples of ``log_processor``: - - # `LogProcessor` will be used - log_processor = dict() - - # custom log_processor - log_processor = dict(type='CustomLogProcessor') - - Args: - log_processor (LogProcessor or dict): A log processor or a dict - to build log processor. If ``log_processor`` is a log processor - object, just returns itself. - - Returns: - :obj:`LogProcessor`: Log processor object build from - ``log_processor_cfg``. - """ - if isinstance(log_processor, LogProcessor): - return log_processor - elif not isinstance(log_processor, dict): - raise TypeError( - 'log processor should be a LogProcessor object or dict, but' - f'got {log_processor}') - - log_processor_cfg = copy.deepcopy(log_processor) # type: ignore - - if 'type' in log_processor_cfg: - log_processor = LOG_PROCESSORS.build(log_processor_cfg) - else: - log_processor = LogProcessor(**log_processor_cfg) # type: ignore - - return log_processor # type: ignore - - def get_hooks_info(self) -> str: - # Get hooks info in each stage - stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages} - for hook in self.hooks: - try: - priority = Priority(hook.priority).name # type: ignore - except ValueError: - priority = hook.priority # type: ignore - classname = hook.__class__.__name__ - hook_info = f'({priority:<12}) {classname:<35}' - for trigger_stage in hook.get_triggered_stages(): - stage_hook_map[trigger_stage].append(hook_info) - - stage_hook_infos = [] - for stage in Hook.stages: - hook_infos = stage_hook_map[stage] - if len(hook_infos) > 0: - info = f'{stage}:\n' - info += '\n'.join(hook_infos) - info += '\n -------------------- ' - stage_hook_infos.append(info) - return '\n'.join(stage_hook_infos) - - def load_or_resume(self): - """Load or resume checkpoint.""" - if self._has_loaded: - return None - - if not self._resume and self._load_from is None: - return None - - # decide to load from checkpoint or resume from checkpoint - resume_from = None - if isinstance(self._resume, str): - resume_from = self._resume - elif self._resume and self._load_from is None: - # auto resume from the latest checkpoint - resume_from = find_latest_checkpoint(self.work_dir) - self.logger.info( - f'Auto resumed from the latest checkpoint {resume_from}.') - elif self._resume and self._load_from is not None: - # resume from the specified checkpoint - resume_from = self._load_from - - if resume_from is not None: - self.resume(resume_from) - self._has_loaded = True - elif self._load_from is not None: - self.load_checkpoint(self._load_from) - self._has_loaded = True - - def train(self) -> nn.Module: - """Launch training. - - Returns: - nn.Module: The model after training. - """ - if self._train_loop is None: - raise RuntimeError( - '`self._train_loop` should not be None when calling train ' - 'method. Please provide `train_dataloader`, `train_cfg`, ' - '`optimizer` and `param_scheduler` arguments when ' - 'initializing runner.') - - self._train_loop = self.build_train_loop( - self._train_loop) # type: ignore - - if self._val_loop is not None: - self._val_loop = self.build_val_loop( - self._val_loop) # type: ignore - - compile: Union[dict, bool] = False - if isinstance(self._compile, bool): - if self._compile: - compile = dict(target='train_step') - else: - compile = copy.copy(self._compile) - compile.setdefault('target', 'train_step') - - dispatch_kwargs = dict( - epoch_length=len(self.train_dataloader), - max_epochs=self.max_epochs, - max_iters=self.max_iters, - train_micro_batch_size_per_gpu=_get_batch_size( - self.train_dataloader)) # type: ignore - - self.strategy.prepare( - self.model, - optim_wrapper=self.optim_wrapper, - param_scheduler=self.param_schedulers, - compile=compile, - dispatch_kwargs=dispatch_kwargs, - ) - - self.model = self.strategy.model - self.optim_wrapper = self.strategy.optim_wrapper # type: ignore - if self.param_schedulers is not None: - self.param_schedulers = self.strategy.param_schedulers - - self.load_or_resume() - - # TODO: add a contextmanager to avoid calling `before_run` many times - self.call_hook('before_run') - - model = self.train_loop.run() # type: ignore - self.call_hook('after_run') - return model - - def val(self) -> dict: - """Launch validation. - - Returns: - dict: A dict of metrics on validation set. - """ - if self._val_loop is None: - raise RuntimeError( - '`self._val_loop` should not be None when calling val method.' - 'Please provide `val_dataloader`, `val_cfg` and ' - '`val_evaluator` arguments when initializing runner.') - - self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - - dispatch_kwargs = dict( - init_weights_for_test_or_val=self.cfg.get( - 'init_weights_for_test_or_val', True)) - self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) - self.model = self.strategy.model - - self.load_or_resume() - - self.call_hook('before_run') - metrics = self.val_loop.run() # type: ignore - self.call_hook('after_run') - - return metrics - - def test(self) -> dict: - """Launch test. - - Returns: - dict: A dict of metrics on testing set. - """ - if self._test_loop is None: - raise RuntimeError( - '`self._test_loop` should not be None when calling test ' - 'method. Please provide `test_dataloader`, `test_cfg` and ' - '`test_evaluator` arguments when initializing runner.') - - self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - dispatch_kwargs = dict( - init_weights_for_test_or_val=self.cfg.get( - 'init_weights_for_test_or_val', True)) - self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) - self.model = self.strategy.model - - self.load_or_resume() - - self.call_hook('before_run') - metrics = self.test_loop.run() # type: ignore - self.call_hook('after_run') - - return metrics - - def call_hook(self, fn_name: str, **kwargs) -> None: - """Call all hooks. - - Args: - fn_name (str): The function name in each hook to be called, such as - "before_train_epoch". - **kwargs: Keyword arguments passed to hook. - """ - for hook in self._hooks: - # support adding additional custom hook methods - if hasattr(hook, fn_name): - try: - getattr(hook, fn_name)(self, **kwargs) - except TypeError as e: - raise TypeError(f'{e} in {hook}') from e - - def register_hook( - self, - hook: Union[Hook, Dict], - priority: Optional[Union[str, int, Priority]] = None, - ) -> None: - """Register a hook into the hook list. - - The hook will be inserted into a priority queue, with the specified - priority (See :class:`Priority` for details of priorities). - For hooks with the same priority, they will be triggered in the same - order as they are registered. - - Priority of hook will be decided with the following priority: - - - ``priority`` argument. If ``priority`` is given, it will be priority - of hook. - - If ``hook`` argument is a dict and ``priority`` in it, the priority - will be the value of ``hook['priority']``. - - If ``hook`` argument is a dict but ``priority`` not in it or ``hook`` - is an instance of ``hook``, the priority will be ``hook.priority``. - - Args: - hook (:obj:`Hook` or dict): The hook to be registered. - priority (int or str or :obj:`Priority`, optional): Hook priority. - Lower value means higher priority. - """ - if not isinstance(hook, (Hook, dict)): - raise TypeError( - f'hook should be an instance of Hook or dict, but got {hook}') - - _priority = None - if isinstance(hook, dict): - if 'priority' in hook: - _priority = hook.pop('priority') - - hook_obj = HOOKS.build(hook) - else: - hook_obj = hook - - if priority is not None: - hook_obj.priority = priority - elif _priority is not None: - hook_obj.priority = _priority - - inserted = False - for i in range(len(self._hooks) - 1, -1, -1): - if get_priority(hook_obj.priority) >= get_priority( - self._hooks[i].priority): - self._hooks.insert(i + 1, hook_obj) - inserted = True - break - if not inserted: - self._hooks.insert(0, hook_obj) - - def register_default_hooks( - self, - hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, - ) -> None: - """Register default hooks into hook list. - - ``hooks`` will be registered into runner to execute some default - actions like updating model parameters or saving checkpoints. - - Default hooks and their priorities: - - +----------------------+-------------------------+ - | Hooks | Priority | - +======================+=========================+ - | RuntimeInfoHook | VERY_HIGH (10) | - +----------------------+-------------------------+ - | IterTimerHook | NORMAL (50) | - +----------------------+-------------------------+ - | DistSamplerSeedHook | NORMAL (50) | - +----------------------+-------------------------+ - | LoggerHook | BELOW_NORMAL (60) | - +----------------------+-------------------------+ - | ParamSchedulerHook | LOW (70) | - +----------------------+-------------------------+ - | CheckpointHook | VERY_LOW (90) | - +----------------------+-------------------------+ - - If ``hooks`` is None, above hooks will be registered by - default:: - - default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - sampler_seed=dict(type='DistSamplerSeedHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - ) - - If not None, ``hooks`` will be merged into ``default_hooks``. - If there are None value in default_hooks, the corresponding item will - be popped from ``default_hooks``:: - - hooks = dict(timer=None) - - The final registered default hooks will be :obj:`RuntimeInfoHook`, - :obj:`DistSamplerSeedHook`, :obj:`LoggerHook`, - :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`. - - Args: - hooks (dict[str, Hook or dict], optional): Default hooks or configs - to be registered. - """ - default_hooks: dict = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - sampler_seed=dict(type='DistSamplerSeedHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - ) - if hooks is not None: - for name, hook in hooks.items(): - if name in default_hooks and hook is None: - # remove hook from _default_hooks - default_hooks.pop(name) - else: - assert hook is not None - default_hooks[name] = hook - - for hook in default_hooks.values(): - self.register_hook(hook) - - def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None: - """Register custom hooks into hook list. - - Args: - hooks (list[Hook | dict]): List of hooks or configs to be - registered. - """ - for hook in hooks: - self.register_hook(hook) - - def register_hooks( - self, - default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, - custom_hooks: Optional[List[Union[Hook, Dict]]] = None, - ) -> None: - """Register default hooks and custom hooks into hook list. - - Args: - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks - to execute default actions like updating model parameters and - saving checkpoints. Defaults to None. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - """ - self.register_default_hooks(default_hooks) - - if custom_hooks is not None: - self.register_custom_hooks(custom_hooks) - - def resume( - self, - filename: str, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: Union[str, Callable] = 'default', - ) -> None: - """Resume model from checkpoint. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - """ - - def callback(checkpoint): - self.call_hook('after_load_checkpoint', checkpoint=checkpoint) - - checkpoint = self.strategy.resume( - filename, - resume_optimizer=resume_optimizer, - resume_param_scheduler=resume_param_scheduler, - map_location=map_location, - callback=callback, - ) - - self.train_loop._epoch = checkpoint['meta']['epoch'] - self.train_loop._iter = checkpoint['meta']['iter'] - - # check whether the number of GPU used for current experiment - # is consistent with resuming from checkpoint - if 'config' in checkpoint['meta']: - config = mmengine.Config.fromstring( - checkpoint['meta']['config'], file_format='.py') - previous_gpu_ids = config.get('gpu_ids', None) - if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 - and len(previous_gpu_ids) != self.world_size): - # TODO, should we modify the iteration? - self.logger.info( - 'Number of GPU used for current experiment is not ' - 'consistent with resuming from checkpoint') - if (self._auto_scale_lr is None - or not self._auto_scale_lr.get('enable', False)): - raise RuntimeError( - 'Cannot automatically rescale lr in resuming. Please ' - 'make sure the number of GPU is consistent with the ' - 'previous training state resuming from the checkpoint ' - 'or set `enable` in `auto_scale_lr to False.') - - resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None) - dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None) - - # `resumed_dataset_meta` and `dataset_meta` could be object like - # np.ndarray, which cannot be directly judged as equal or not, - # therefore we just compared their dumped results. - if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta): - self.logger.warning( - 'The dataset metainfo from the resumed checkpoint is ' - 'different from the current training dataset, please ' - 'check the correctness of the checkpoint or the training ' - 'dataset.') - - self.message_hub.load_state_dict(checkpoint['message_hub']) - - self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}') - - def load_checkpoint(self, - filename: str, - map_location: Union[str, Callable] = 'cpu', - strict: bool = False, - revise_keys: list = [(r'^module.', '')]): - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - """ - - def callback(checkpoint): - self.call_hook('after_load_checkpoint', checkpoint=checkpoint) - - self.strategy.load_checkpoint( - filename, - map_location=map_location, - strict=strict, - revise_keys=revise_keys, - callback=callback) - - def save_checkpoint( - self, - out_dir: str, - filename: str, - file_client_args: Optional[dict] = None, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - meta: Optional[dict] = None, - by_epoch: bool = True, - backend_args: Optional[dict] = None, - ): - """Save checkpoints. - - ``CheckpointHook`` invokes this method to save checkpoints - periodically. - - Args: - out_dir (str): The directory that checkpoints are saved. - filename (str): The checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for - details. Defaults to None. It will be deprecated in future. - Please use `backend_args` instead. - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - meta (dict, optional): The meta information to be saved in the - checkpoint. Defaults to None. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - """ - if meta is None: - meta = {} - elif not isinstance(meta, dict): - raise TypeError( - f'meta should be a dict or None, but got {type(meta)}') - - if by_epoch: - # self.epoch increments 1 after - # `self.call_hook('after_train_epoch)` but `save_checkpoint` is - # called by `after_train_epoch`` method of `CheckpointHook` so - # `epoch` should be `self.epoch + 1` - meta.update(epoch=self.epoch + 1, iter=self.iter) - else: - meta.update(epoch=self.epoch, iter=self.iter + 1) - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', DeprecationWarning) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set at ' - 'the same time.') - - file_client = FileClient.infer_client(file_client_args, out_dir) - filepath = file_client.join_path(out_dir, filename) - else: - filepath = join_path( # type: ignore - out_dir, filename, backend_args=backend_args) - - meta.update( - cfg=self.cfg.pretty_text, experiment_name=self.experiment_name) - - if hasattr(self.train_dataloader.dataset, 'metainfo'): - meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) - - checkpoint = { - 'meta': meta, - 'message_hub': self.message_hub.state_dict() - } - - def callback(checkpoint): - self.call_hook('before_save_checkpoint', checkpoint=checkpoint) - - self.strategy.save_checkpoint( - filename=filepath, - save_optimizer=save_optimizer, - save_param_scheduler=save_param_scheduler, - extra_ckpt=checkpoint, - callback=callback, - ) - - @master_only - def dump_config(self) -> None: - """Dump config to `work_dir`.""" - if self.cfg.filename is not None: - filename = osp.basename(self.cfg.filename) - else: - filename = f'{self.timestamp}.py' - self.cfg.dump(osp.join(self.work_dir, filename)) - - def _log_env(self) -> None: - """Logging environment information of the current task. - - Args: - env_cfg (dict): The environment config of the runner. - """ - # Collect and log environment information. - system_env, runtime_env = self.strategy.collect_env() - - env_info = '\n ' + '\n '.join(f'{k}: {v}' - for k, v in system_env.items()) - runtime_env_info = '\n ' + '\n '.join( - f'{k}: {v}' for k, v in runtime_env.items()) - dash_line = '-' * 60 - self.logger.info('\n' + dash_line + '\nSystem environment:' + - env_info + '\n' - '\nRuntime environment:' + runtime_env_info + '\n' + - dash_line + '\n') - - if self.cfg._cfg_dict: - self.logger.info(f'Config:\n{self.cfg.pretty_text}') diff --git a/mmengine/runner/activation_checkpointing.py b/mmengine/runner/activation_checkpointing.py deleted file mode 100644 index 3db67f057c..0000000000 --- a/mmengine/runner/activation_checkpointing.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from functools import wraps -from operator import attrgetter -from typing import List, Union - -import torch -from torch.utils.checkpoint import checkpoint - - -def wrap_forward(forward): - - @wraps(forward) - def wrapper(*args): - return checkpoint(forward, *args) - - return wrapper - - -def turn_on_activation_checkpointing(model: torch.nn.Module, - modules: Union[List[str], str]): - - if isinstance(modules, str): - modules = [modules] - for module_name in modules: - module = attrgetter(module_name)(model) - module.forward = wrap_forward(module.forward) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py deleted file mode 100644 index 198babc582..0000000000 --- a/mmengine/runner/amp.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from contextlib import contextmanager -from typing import Optional - -import torch - -from mmengine.device import (get_device, is_cuda_available, is_mlu_available, - is_npu_available) -from mmengine.logging import print_log -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION - - -@contextmanager -def autocast(device_type: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - enabled: bool = True, - cache_enabled: Optional[bool] = None): - """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``. - - Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in - mixed precision , and update it to ``torch.autocast`` in 1.10.0. - Both interfaces have different arguments, and ``torch.autocast`` - support running with cpu additionally. - - This function provides a unified interface by wrapping - ``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the - compatibility issues that ``torch.cuda.amp.autocast`` does not support - running mixed precision with cpu, and both contexts have different - arguments. We suggest users using this function in the code - to achieve maximized compatibility of different PyTorch versions. - - Note: - ``autocast`` requires pytorch version >= 1.5.0. If pytorch version - <= 1.10.0 and cuda is not available, it will raise an error with - ``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda - mode. - - Examples: - >>> # case1: 1.10 > Pytorch version >= 1.5.0 - >>> with autocast(): - >>> # run in mixed precision context - >>> pass - >>> with autocast(device_type='cpu'):: - >>> # raise error, torch.cuda.amp.autocast only support cuda mode. - >>> pass - >>> # case2: Pytorch version >= 1.10.0 - >>> with autocast(): - >>> # default cuda mixed precision context - >>> pass - >>> with autocast(device_type='cpu'): - >>> # cpu mixed precision context - >>> pass - >>> with autocast( - >>> device_type='cuda', enabled=True, cache_enabled=True): - >>> # enable precision context with more specific arguments. - >>> pass - - Args: - device_type (str, required): Whether to use 'cuda' or 'cpu' device. - enabled(bool): Whether autocasting should be enabled in the region. - Defaults to True - dtype (torch_dtype, optional): Whether to use ``torch.float16`` or - ``torch.bfloat16``. - cache_enabled(bool, optional): Whether the weight cache inside - autocast should be enabled. - """ - # If `enabled` is True, enable an empty context and all calculations - # are performed under fp32. - assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), ( - 'The minimum pytorch version requirements of mmengine is 1.5.0, but ' - f'got {TORCH_VERSION}') - - if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) < - digit_version('1.10.0')): - # If pytorch version is between 1.5.0 and 1.10.0, the default value of - # dtype for `torch.cuda.amp.autocast` is torch.float16. - assert ( - device_type == 'cuda' or device_type == 'mlu' - or device_type is None), ( - 'Pytorch version under 1.10.0 only supports running automatic ' - 'mixed training with cuda or mlu') - if dtype is not None or cache_enabled is not None: - print_log( - f'{dtype} and {device_type} will not work for ' - '`autocast` since your Pytorch version: ' - f'{TORCH_VERSION} <= 1.10.0', - logger='current', - level=logging.WARNING) - - if is_npu_available(): - with torch.npu.amp.autocast(enabled=enabled): - yield - elif is_mlu_available(): - with torch.mlu.amp.autocast(enabled=enabled): - yield - elif is_cuda_available(): - with torch.cuda.amp.autocast(enabled=enabled): - yield - else: - if not enabled: - yield - else: - raise RuntimeError( - 'If pytorch versions is between 1.5.0 and 1.10, ' - '`autocast` is only available in gpu mode') - - else: - # Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501 - # This code should update with the `torch.autocast`. - if cache_enabled is None: - cache_enabled = torch.is_autocast_cache_enabled() - device = get_device() - device_type = device if device_type is None else device_type - - if device_type == 'cuda': - if dtype is None: - dtype = torch.get_autocast_gpu_dtype() - - if dtype == torch.bfloat16 and not \ - torch.cuda.is_bf16_supported(): - raise RuntimeError( - 'Current CUDA Device does not support bfloat16. Please ' - 'switch dtype to float16.') - - elif device_type == 'cpu': - if dtype is None: - dtype = torch.bfloat16 - assert dtype == torch.bfloat16, ( - 'In CPU autocast, only support `torch.bfloat16` dtype') - - elif device_type == 'mlu': - pass - - elif device_type == 'npu': - pass - elif device_type == 'musa': - if dtype is None: - dtype = torch.get_autocast_gpu_dtype() - with torch.musa.amp.autocast( - enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): - yield - return - else: - # Device like MPS does not support fp16 training or testing. - # If an inappropriate device is set and fp16 is enabled, an error - # will be thrown. - if enabled is False: - yield - return - else: - raise ValueError('User specified autocast device_type must be ' - f'cuda or cpu, but got {device_type}') - - with torch.autocast( - device_type=device_type, - enabled=enabled, - dtype=dtype, - cache_enabled=cache_enabled): - yield diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py deleted file mode 100644 index 5bae459a20..0000000000 --- a/mmengine/runner/base_loop.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Union - -from torch.utils.data import DataLoader - - -class BaseLoop(metaclass=ABCMeta): - """Base loop class. - - All subclasses inherited from ``BaseLoop`` should overwrite the - :meth:`run` method. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): An iterator to generate one batch of - dataset each iteration. - """ - - def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: - self._runner = runner - if isinstance(dataloader, dict): - # Determine whether or not different ranks use different seed. - diff_rank_seed = runner._randomness_cfg.get( - 'diff_rank_seed', False) - self.dataloader = runner.build_dataloader( - dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) - else: - self.dataloader = dataloader - - @property - def runner(self): - return self._runner - - @abstractmethod - def run(self) -> Any: - """Execute loop.""" diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py deleted file mode 100644 index 2bf5f50f7c..0000000000 --- a/mmengine/runner/checkpoint.py +++ /dev/null @@ -1,815 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import io -import logging -import os -import os.path as osp -import pkgutil -import re -from collections import OrderedDict, namedtuple -from importlib import import_module -from tempfile import TemporaryDirectory -from typing import Callable, Dict, Optional - -import torch - -import mmengine -from mmengine.dist import get_dist_info -from mmengine.fileio import FileClient, get_file_backend -from mmengine.fileio import load as load_file -from mmengine.logging import print_log -from mmengine.model import BaseTTAModel, is_model_wrapper -from mmengine.utils import (apply_to, deprecated_function, digit_version, - mkdir_or_exist) -from mmengine.utils.dl_utils import load_url - -# `MMENGINE_HOME` is the highest priority directory to save checkpoints -# downloaded from Internet. If it is not set, as a workaround, using -# `XDG_CACHE_HOME`` or `~/.cache` instead. -# Note that `XDG_CACHE_HOME` defines the base directory relative to which -# user-specific non-essential data files should be stored. If `XDG_CACHE_HOME` -# is either not set or empty, a default equal to `~/.cache` should be used. -ENV_MMENGINE_HOME = 'MMENGINE_HOME' -ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' -DEFAULT_CACHE_DIR = '~/.cache' - - -class _IncompatibleKeys( - namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): - - def __repr__(self): - if not self.missing_keys and not self.unexpected_keys: - return '' - return super().__repr__() - - __str__ = __repr__ - - -def _get_mmengine_home(): - mmengine_home = os.path.expanduser( - os.getenv( - ENV_MMENGINE_HOME, - os.path.join( - os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine'))) - - mkdir_or_exist(mmengine_home) - return mmengine_home - - -def load_state_dict(module, state_dict, strict=False, logger=None): - """Load state_dict to a module. - - This method is modified from :meth:`torch.nn.Module.load_state_dict`. - Default value for ``strict`` is set to ``False`` and the message for - param mismatch will be shown even if strict is False. - - Args: - module (Module): Module that receives the state_dict. - state_dict (OrderedDict): Weights. - strict (bool): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Defaults to False. - logger (:obj:`logging.Logger`, optional): Logger to log the error - message. If not specified, print function will be used. - """ - unexpected_keys = [] - missing_keys = [] - err_msg = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - # use _load_from_state_dict to enable checkpoint version control - def load(module, local_state_dict, prefix=''): - # recursively check parallel module in case that the model has a - # complicated structure, e.g., nn.Module(nn.Module(DDP)) - if is_model_wrapper(module) or isinstance(module, BaseTTAModel): - module = module.module - local_metadata = {} if metadata is None else metadata.get( - prefix[:-1], {}) - module._load_from_state_dict(local_state_dict, prefix, local_metadata, - True, missing_keys, unexpected_keys, - err_msg) - for name, child in module._modules.items(): - if child is not None: - child_prefix = prefix + name + '.' - child_state_dict = { - k: v - for k, v in local_state_dict.items() - if k.startswith(child_prefix) - } - load(child, child_state_dict, child_prefix) - - # Note that the hook can modify missing_keys and unexpected_keys. - incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) - if hasattr(module, '_load_state_dict_post_hooks'): - for hook in module._load_state_dict_post_hooks.values(): - out = hook(module, incompatible_keys) - assert out is None, ( - 'Hooks registered with ' - '``register_load_state_dict_post_hook`` are not expected ' - 'to return new values, if incompatible_keys need to be ' - 'modified, it should be done inplace.') - - load(module, state_dict) - load = None # break load->load reference cycle - - # ignore "num_batches_tracked" of BN layers - missing_keys = [ - key for key in missing_keys if 'num_batches_tracked' not in key - ] - - if unexpected_keys: - err_msg.append('unexpected key in source ' - f'state_dict: {", ".join(unexpected_keys)}\n') - if missing_keys: - err_msg.append( - f'missing keys in source state_dict: {", ".join(missing_keys)}\n') - - rank, _ = get_dist_info() - if len(err_msg) > 0 and rank == 0: - err_msg.insert( - 0, 'The model and loaded state dict do not match exactly\n') - err_msg = '\n'.join(err_msg) - if strict: - raise RuntimeError(err_msg) - else: - print_log(err_msg, logger=logger, level=logging.WARNING) - - -def get_torchvision_models(): - import torchvision - if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): - model_urls = dict() - # When the version of torchvision is lower than 0.13, the model url is - # not declared in `torchvision.model.__init__.py`, so we need to - # iterate through `torchvision.models.__path__` to get the url for each - # model. - for _, name, ispkg in pkgutil.walk_packages( - torchvision.models.__path__): - if ispkg: - continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') - model_urls.update(_urls) - else: - # Since torchvision bumps to v0.13, the weight loading logic, - # model keys and model urls have been changed. Here the URLs of old - # version is loaded to avoid breaking back compatibility. If the - # torchvision version>=0.13.0, new URLs will be added. Users can get - # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', - # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. - json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json') - model_urls = mmengine.load(json_path) - if digit_version(torchvision.__version__) < digit_version('0.14.0a0'): - weights_list = [ - cls for cls_name, cls in torchvision.models.__dict__.items() - if cls_name.endswith('_Weights') - ] - else: - weights_list = [ - torchvision.models.get_model_weights(model) - for model in torchvision.models.list_models(torchvision.models) - ] - - for cls in weights_list: - # The name of torchvision model weights classes ends with - # `_Weights` such as `ResNet18_Weights`. However, some model weight - # classes, such as `MNASNet0_75_Weights` does not have any urls in - # torchvision 0.13.0 and cannot be iterated. Here we simply check - # `DEFAULT` attribute to ensure the class is not empty. - if not hasattr(cls, 'DEFAULT'): - continue - # Since `cls.DEFAULT` can not be accessed by iterating cls, we set - # default urls explicitly. - cls_name = cls.__name__ - cls_key = cls_name.replace('_Weights', '').lower() - model_urls[f'{cls_key}.default'] = cls.DEFAULT.url - for weight_enum in cls: - cls_key = cls_name.replace('_Weights', '').lower() - cls_key = f'{cls_key}.{weight_enum.name.lower()}' - model_urls[cls_key] = weight_enum.url - - return model_urls - - -def get_external_models(): - mmengine_home = _get_mmengine_home() - default_json_path = osp.join(mmengine.__path__[0], 'hub/openmmlab.json') - default_urls = load_file(default_json_path) - assert isinstance(default_urls, dict) - external_json_path = osp.join(mmengine_home, 'open_mmlab.json') - if osp.exists(external_json_path): - external_urls = load_file(external_json_path) - assert isinstance(external_urls, dict) - default_urls.update(external_urls) - - return default_urls - - -def get_mmcls_models(): - mmcls_json_path = osp.join(mmengine.__path__[0], 'hub/mmcls.json') - mmcls_urls = load_file(mmcls_json_path) - - return mmcls_urls - - -def get_deprecated_model_names(): - deprecate_json_path = osp.join(mmengine.__path__[0], 'hub/deprecated.json') - deprecate_urls = load_file(deprecate_json_path) - assert isinstance(deprecate_urls, dict) - - return deprecate_urls - - -def _process_mmcls_checkpoint(checkpoint): - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - # Some checkpoints converted from 3rd-party repo don't - # have the "state_dict" key. - state_dict = checkpoint - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - if k.startswith('backbone.'): - new_state_dict[k[9:]] = v - new_checkpoint = dict(state_dict=new_state_dict) - - return new_checkpoint - - -class CheckpointLoader: - """A general checkpoint loader to manage all schemes.""" - - _schemes: Dict[str, Callable] = {} - - @classmethod - def _register_scheme(cls, prefixes, loader, force=False): - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - for prefix in prefixes: - if (prefix not in cls._schemes) or force: - cls._schemes[prefix] = loader - else: - raise KeyError( - f'{prefix} is already registered as a loader backend, ' - 'add "force=True" if you want to override it') - # sort, longer prefixes take priority - cls._schemes = OrderedDict( - sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) - - @classmethod - def register_scheme(cls, prefixes, loader=None, force=False): - """Register a loader to CheckpointLoader. - - This method can be used as a normal class method or a decorator. - - Args: - prefixes (str or list[str] or tuple[str]): - The prefix of the registered loader. - loader (function, optional): The loader function to be registered. - When this method is used as a decorator, loader is None. - Defaults to None. - force (bool, optional): Whether to override the loader - if the prefix has already been registered. Defaults to False. - """ - - if loader is not None: - cls._register_scheme(prefixes, loader, force=force) - return - - def _register(loader_cls): - cls._register_scheme(prefixes, loader_cls, force=force) - return loader_cls - - return _register - - @classmethod - def _get_checkpoint_loader(cls, path): - """Finds a loader that supports the given path. Falls back to the local - loader if no other loader is found. - - Args: - path (str): checkpoint path - - Returns: - callable: checkpoint loader - """ - for p in cls._schemes: - # use regular match to handle some cases that where the prefix of - # loader has a prefix. For example, both 's3://path' and - # 'open-mmlab:s3://path' should return `load_from_ceph` - if re.match(p, path) is not None: - return cls._schemes[p] - - @classmethod - def load_checkpoint(cls, filename, map_location=None, logger='current'): - """Load checkpoint through URL scheme path. - - Args: - filename (str): checkpoint file name with given prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - logger (str): The logger for message. Defaults to 'current'. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - checkpoint_loader = cls._get_checkpoint_loader(filename) - class_name = checkpoint_loader.__name__ - print_log( - f'Loads checkpoint by {class_name[10:]} backend from path: ' - f'{filename}', - logger=logger) - return checkpoint_loader(filename, map_location) - - -@CheckpointLoader.register_scheme(prefixes='') -def load_from_local(filename, map_location): - """Load checkpoint by local file path. - - Args: - filename (str): local checkpoint file path - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - filename = osp.expanduser(filename) - if not osp.isfile(filename): - raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) -def load_from_http(filename, - map_location=None, - model_dir=None, - progress=os.isatty(0)): - """Load checkpoint through HTTP or HTTPS scheme path. In distributed - setting, this function only download checkpoint at local rank 0. - - Args: - filename (str): checkpoint file path with modelzoo or - torchvision prefix - map_location (str, optional): Same as :func:`torch.load`. - model_dir (string, optional): directory in which to save the object, - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - rank, world_size = get_dist_info() - if rank == 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) - if world_size > 1: - torch.distributed.barrier() - if rank > 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes='pavi://') -def load_from_pavi(filename, map_location=None): - """Load checkpoint through the file path prefixed with pavi. In distributed - setting, this function download ckpt at all ranks to different temporary - directories. - - Args: - filename (str): checkpoint file path with pavi prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - assert filename.startswith('pavi://'), \ - f'Expected filename startswith `pavi://`, but get {filename}' - model_path = filename[7:] - - try: - from pavi import modelcloud - except ImportError: - raise ImportError( - 'Please install pavi to load checkpoint from modelcloud.') - - model = modelcloud.get(model_path) - with TemporaryDirectory() as tmp_dir: - downloaded_file = osp.join(tmp_dir, model.name) - model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme( - prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) -def load_from_ceph(filename, map_location=None, backend='petrel'): - """Load checkpoint through the file path prefixed with s3. In distributed - setting, this function download ckpt at all ranks to different temporary - directories. - - Args: - filename (str): checkpoint file path with s3 prefix - map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. - Defaults to 'petrel'. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - file_backend = get_file_backend( - filename, backend_args={'backend': backend}) - with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) -def load_from_torchvision(filename, map_location=None): - """Load checkpoint through the file path prefixed with modelzoo or - torchvision. - - Args: - filename (str): checkpoint file path with modelzoo or - torchvision prefix - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - model_urls = get_torchvision_models() - if filename.startswith('modelzoo://'): - print_log( - 'The URL scheme of "modelzoo://" is deprecated, please ' - 'use "torchvision://" instead', - logger='current', - level=logging.WARNING) - model_name = filename[11:] - else: - model_name = filename[14:] - return load_from_http(model_urls[model_name], map_location=map_location) - - -@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) -def load_from_openmmlab(filename, map_location=None): - """Load checkpoint through the file path prefixed with open-mmlab or - openmmlab. - - Args: - filename (str): checkpoint file path with open-mmlab or - openmmlab prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - model_urls = get_external_models() - prefix_str = 'open-mmlab://' - if filename.startswith(prefix_str): - model_name = filename[13:] - else: - model_name = filename[12:] - prefix_str = 'openmmlab://' - - deprecated_urls = get_deprecated_model_names() - if model_name in deprecated_urls: - print_log( - f'{prefix_str}{model_name} is deprecated in favor ' - f'of {prefix_str}{deprecated_urls[model_name]}', - logger='current', - level=logging.WARNING) - model_name = deprecated_urls[model_name] - model_url = model_urls[model_name] - # check if is url - if model_url.startswith(('http://', 'https://')): - checkpoint = load_from_http(model_url, map_location=map_location) - else: - filename = osp.join(_get_mmengine_home(), model_url) - if not osp.isfile(filename): - raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes='mmcls://') -def load_from_mmcls(filename, map_location=None): - """Load checkpoint through the file path prefixed with mmcls. - - Args: - filename (str): checkpoint file path with mmcls prefix - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - model_urls = get_mmcls_models() - model_name = filename[8:] - checkpoint = load_from_http( - model_urls[model_name], map_location=map_location) - checkpoint = _process_mmcls_checkpoint(checkpoint) - return checkpoint - - -def _load_checkpoint(filename, map_location=None, logger=None): - """Load checkpoint from somewhere (modelzoo, file, url). - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None. - logger (:mod:`logging.Logger`, optional): The logger for error message. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. It can be either an - OrderedDict storing model weights or a dict containing other - information, which depends on the checkpoint. - """ - return CheckpointLoader.load_checkpoint(filename, map_location, logger) - - -def _load_checkpoint_with_prefix(prefix, filename, map_location=None): - """Load partial pretrained model with specific prefix. - - Args: - prefix (str): The prefix of sub-module. - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str | None): Same as :func:`torch.load`. - Defaults to None. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - checkpoint = _load_checkpoint(filename, map_location=map_location) - - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - if not prefix.endswith('.'): - prefix += '.' - prefix_len = len(prefix) - - state_dict = { - k[prefix_len:]: v - for k, v in state_dict.items() if k.startswith(prefix) - } - - assert state_dict, f'{prefix} is not in the pretrained model' - return state_dict - - -def _load_checkpoint_to_model(model, - checkpoint, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): - - # get state_dict from checkpoint - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - - # strip prefix of state_dict - metadata = getattr(state_dict, '_metadata', OrderedDict()) - for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) - # Keep metadata in state_dict - state_dict._metadata = metadata - - # load state_dict - load_state_dict(model, state_dict, strict, logger) - return checkpoint - - -def load_checkpoint(model, - filename, - map_location=None, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): - """Load checkpoint from a file or URI. - - Args: - model (Module): Module to load checkpoint. - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str): Same as :func:`torch.load`. - strict (bool): Whether to allow different params for the model and - checkpoint. - logger (:mod:`logging.Logger` or None): The logger for error message. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - checkpoint = _load_checkpoint(filename, map_location, logger) - # OrderedDict is a subclass of dict - if not isinstance(checkpoint, dict): - raise RuntimeError( - f'No state_dict found in checkpoint file {filename}') - - return _load_checkpoint_to_model(model, checkpoint, strict, logger, - revise_keys) - - -def weights_to_cpu(state_dict): - """Copy a model state_dict to cpu. - - Args: - state_dict (OrderedDict): Model weights on GPU. - - Returns: - OrderedDict: Model weights on GPU. - """ - # stash metadata to put in state_dict later - metadata = getattr(state_dict, '_metadata', OrderedDict()) - state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'), - lambda x: x.cpu()) - state_dict._metadata = metadata - return state_dict - - -@deprecated_function( - since='0.3.0', - removed_in='0.5.0', - instructions='`_save_to_state_dict` will be deprecated in the future, ' - 'please use `nn.Module._save_to_state_dict` directly.') -def _save_to_state_dict(module, destination, prefix, keep_vars): - """Saves module state to `destination` dictionary. - - This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. - - Args: - module (nn.Module): The module to generate state_dict. - destination (dict): A dict where state will be stored. - prefix (str): The prefix for parameters and buffers used in this - module. - keep_vars (bool): Whether to keep the variable property of the - parameters. - """ - for name, param in module._parameters.items(): - if param is not None: - destination[prefix + name] = param if keep_vars else param.detach() - for name, buf in module._buffers.items(): - if buf is not None and name not in module._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() - - -def get_state_dict(module, destination=None, prefix='', keep_vars=False): - """Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - This method is modified from :meth:`torch.nn.Module.state_dict` to - recursively check parallel module in case that the model has a complicated - structure, e.g., nn.Module(nn.Module(DDP)). - - Args: - module (nn.Module): The module to generate state_dict. - destination (OrderedDict): Returned dict for the state of the - module. - prefix (str): Prefix of the key. - keep_vars (bool): Whether to keep the variable property of the - parameters. Defaults to False. - - Returns: - dict: A dictionary containing a whole state of the module. - """ - # recursively check parallel module in case that the model has a - # complicated structure, e.g., nn.Module(nn.Module(DDP)) - if is_model_wrapper(module): - module = module.module - - # below is the same as torch.nn.Module.state_dict() - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = dict( - version=module._version) - module._save_to_state_dict(destination, prefix, keep_vars) - for name, child in module._modules.items(): - if child is not None: - get_state_dict( - child, destination, prefix + name + '.', keep_vars=keep_vars) - for hook in module._state_dict_hooks.values(): - hook_result = hook(module, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - -def save_checkpoint(checkpoint, - filename, - file_client_args=None, - backend_args=None): - """Save checkpoint to file. - - Args: - checkpoint (dict): Module whose params are to be saved. - filename (str): Checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - `backend_args` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - """ - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', - logger='current', - level=logging.WARNING) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set ' - 'at the same time.') - - if filename.startswith('pavi://'): - if file_client_args is not None or backend_args is not None: - raise ValueError( - '"file_client_args" or "backend_args" should be "None" if ' - 'filename starts with "pavi://"') - try: - from pavi import exception, modelcloud - except ImportError: - raise ImportError( - 'Please install pavi to load checkpoint from modelcloud.') - model_path = filename[7:] - root = modelcloud.Folder() - model_dir, model_name = osp.split(model_path) - try: - model = modelcloud.get(model_dir) - except exception.NodeNotFoundError: - model = root.create_training_model(model_dir) - with TemporaryDirectory() as tmp_dir: - checkpoint_file = osp.join(tmp_dir, model_name) - with open(checkpoint_file, 'wb') as f: - torch.save(checkpoint, f) - f.flush() - model.create_file(checkpoint_file, name=model_name) - else: - file_client = FileClient.infer_client(file_client_args, filename) - if file_client_args is None: - file_backend = get_file_backend( - filename, backend_args=backend_args) - else: - file_backend = file_client - - with io.BytesIO() as f: - torch.save(checkpoint, f) - file_backend.put(f.getvalue(), filename) - - -def find_latest_checkpoint(path: str) -> Optional[str]: - """Find the latest checkpoint from the given path. - - Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501 - - Args: - path(str): The path to find checkpoints. - - Returns: - str or None: File path of the latest checkpoint. - """ - save_file = osp.join(path, 'last_checkpoint') - last_saved: Optional[str] - if os.path.exists(save_file): - with open(save_file) as f: - last_saved = f.read().strip() - else: - print_log('Did not find last_checkpoint to be resumed.') - last_saved = None - return last_saved diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py deleted file mode 100644 index 98183ae317..0000000000 --- a/mmengine/runner/log_processor.py +++ /dev/null @@ -1,582 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import datetime -import re -from collections import OrderedDict -from itertools import chain -from typing import List, Optional, Tuple - -import numpy as np -import torch - -from mmengine.device import (get_max_cuda_memory, get_max_musa_memory, - is_cuda_available, is_musa_available) -from mmengine.registry import LOG_PROCESSORS - - -@LOG_PROCESSORS.register_module() -class LogProcessor: - """A log processor used to format log information collected from - ``runner.message_hub.log_scalars``. - - ``LogProcessor`` instance is built by runner and will format - ``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can - directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument - ``custom_cfg`` of constructor can control the statistics method of logs. - - Args: - window_size (int): default smooth interval. Defaults to 10. - by_epoch (bool): Whether to format logs with epoch stype. Defaults to - True. - custom_cfg (list[dict], optional): Contains multiple log config dict, - in which key means the data source name of log and value means the - statistic method and corresponding arguments used to count the - data source. Defaults to None. - - - If custom_cfg is None, all logs will be formatted via default - methods, such as smoothing loss by default window_size. If - custom_cfg is defined as a list of config dict, for example: - [dict(data_src='loss', method='mean', log_name='global_loss', - window_size='global')]. It means the log item ``loss`` will be - counted as global mean and additionally logged as ``global_loss`` - (defined by ``log_name``). If ``log_name`` is not defined in - config dict, the original logged key will be overwritten. - - - The original log item cannot be overwritten twice. Here is - an error example: - [dict(data_src='loss', method='mean', window_size='global'), - dict(data_src='loss', method='mean', window_size='epoch')]. - Both log config dict in custom_cfg do not have ``log_name`` key, - which means the loss item will be overwritten twice. - - - For those statistic methods with the ``window_size`` argument, - if ``by_epoch`` is set to False, ``windows_size`` should not be - `epoch` to statistics log value by epoch. - num_digits (int): The number of significant digit shown in the - logging message. Defaults to 4. - log_with_hierarchy (bool): Whether to log with hierarchy. If it is - True, the information is written to visualizer backend such as - :obj:`LocalVisBackend` and :obj:`TensorboardBackend` - with hierarchy. For example, ``loss`` will be saved as - ``train/loss``, and accuracy will be saved as ``val/accuracy``. - Defaults to False. - `New in version 0.7.0.` - mean_pattern (str): This is a regular expression used to match the log - that need to be included in the smoothing statistics. - `New in version 0.7.3.` - - Examples: - >>> # `log_name` is defined, `loss_large_window` will be an additional - >>> # record. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> log_name='loss_large_window', - >>> method_name='mean', - >>> window_size=100)]) - >>> # `log_name` is not defined. `loss` will be overwritten. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> method_name='mean', - >>> window_size=100)]) - >>> # Record loss with different statistics methods. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> log_name='loss_large_window', - >>> method_name='mean', - >>> window_size=100), - >>> dict(data_src='loss', - >>> method_name='mean', - >>> window_size=100)]) - >>> # Overwrite loss item twice will raise an error. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> method_name='mean', - >>> window_size=100), - >>> dict(data_src='loss', - >>> method_name='max', - >>> window_size=100)]) - AssertionError - """ - - def __init__(self, - window_size=10, - by_epoch=True, - custom_cfg: Optional[List[dict]] = None, - num_digits: int = 4, - log_with_hierarchy: bool = False, - mean_pattern=r'.*(loss|time|data_time|grad_norm).*'): - self.window_size = window_size - self.by_epoch = by_epoch - self.custom_cfg = custom_cfg if custom_cfg else [] - self.num_digits = num_digits - self.log_with_hierarchy = log_with_hierarchy - self.mean_pattern = re.compile(mean_pattern) - self._check_custom_cfg() - - def get_log_after_iter(self, runner, batch_idx: int, - mode: str) -> Tuple[dict, str]: - """Format log string after training, validation or testing iteration. - - Args: - runner (Runner): The runner of training phase. - batch_idx (int): The index of the current batch in the current - loop. - mode (str): Current mode of runner, train, test or val. - - Return: - Tuple[dict, str]: Formatted log dict/string which will be - recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. - """ - assert mode in ['train', 'test', 'val'] - # Overwrite ``window_size`` defined in ``custom_cfg`` to int value. - parsed_cfg = self._parse_windows_size(runner, batch_idx, - self.custom_cfg) - # log_tag is used to write log information to terminal - log_tag = self._collect_scalars(parsed_cfg, runner, mode) - - # If `self.log_with_hierarchy` is False, the tag is the same as - # log_tag. Otherwise, each key in tag starts with prefix `train`, - # `test` or `val` - if not self.log_with_hierarchy: - tag = copy.deepcopy(log_tag) - else: - tag = self._collect_scalars(parsed_cfg, runner, mode, True) - - # Record learning rate. - lr_str_list = [] - for key, value in tag.items(): - if key.endswith('lr'): - key = self._remove_prefix(key, f'{mode}/') - log_tag.pop(key) - lr_str_list.append(f'{key}: ' - f'{value:.{self.num_digits}e}') - lr_str = ' '.join(lr_str_list) - # Format log header. - # by_epoch == True - # train/val: Epoch [5][5/10] ... - # test: Epoch [5/10] - # by_epoch == False - # train: Epoch [5/10000] ... (divided by `max_iter`) - # val/test: Epoch [5/2000] ... (divided by length of dataloader) - if self.by_epoch: - # Align the iteration log: - # Epoch(train) [ 9][010/270] - # ... ||| ||| - # Epoch(train) [ 10][100/270] - dataloader_len = self._get_dataloader_size(runner, mode) - cur_iter = self._get_iter(runner, batch_idx) - cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len))) - if mode in ['train', 'val']: - cur_epoch = self._get_epoch(runner, mode) - if not (isinstance(runner._train_loop, dict) - or runner._train_loop is None): - # Right Align the epoch log: - # Epoch(train) [9][100/270] - # ... || - # Epoch(train) [100][100/270] - max_epochs = runner.max_epochs - # 3 means the three characters: "[", "]", and " " occupied - # in " [{max_epochs}]" - cur_epoch_str = f'[{cur_epoch}]'.rjust( - len(str(max_epochs)) + 3, ' ') - else: - cur_epoch_str = f'[{cur_epoch}]' - tag['epoch'] = cur_epoch - log_str = (f'Epoch({mode}){cur_epoch_str}' - f'[{cur_iter_str}/{dataloader_len}] ') - else: - log_str = (f'Epoch({mode}) ' - f'[{cur_iter_str}/{dataloader_len}] ') - else: - if mode == 'train': - cur_iter = self._get_iter(runner, batch_idx) - cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters))) - log_str = (f'Iter({mode}) ' - f'[{cur_iter_str}/{runner.max_iters}] ') - else: - dataloader_len = self._get_dataloader_size(runner, mode) - cur_iter_str = str(batch_idx + 1).rjust( - len(str(dataloader_len))) - log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ') - # Add global iter. - if isinstance(runner._train_loop, dict) or runner._train_loop is None: - tag['iter'] = 0 - else: - tag['iter'] = runner.iter + 1 - # Concatenate lr, momentum string with log header. - log_str += f'{lr_str} ' - # If IterTimerHook used in runner, eta, time, and data_time should be - # recorded. - if (all(item in log_tag for item in ['time', 'data_time']) - and 'eta' in runner.message_hub.runtime_info): - eta = runner.message_hub.get_info('eta') - eta_str = str(datetime.timedelta(seconds=int(eta))) - log_str += f'eta: {eta_str} ' - log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} ' - f'data_time: ' - f'{log_tag["data_time"]:.{self.num_digits}f} ') - # Pop recorded keys - log_tag.pop('time') - log_tag.pop('data_time') - - # If cuda/musa is available, - # the max memory occupied should be calculated. - if is_cuda_available() or is_musa_available(): - max_memory = self._get_max_memory(runner) - log_str += f'memory: {max_memory} ' - tag['memory'] = max_memory - - # Loop left keys to fill `log_str`. - if mode in ('train', 'val'): - log_items = [] - for name, val in log_tag.items(): - if mode == 'val' and not name.startswith('val/loss'): - continue - if isinstance(val, float): - val = f'{val:.{self.num_digits}f}' - log_items.append(f'{name}: {val}') - log_str += ' '.join(log_items) - return tag, log_str - - def get_log_after_epoch(self, - runner, - batch_idx: int, - mode: str, - with_non_scalar: bool = False) -> Tuple[dict, str]: - """Format log string after validation or testing epoch. - - Args: - runner (Runner): The runner of validation/testing phase. - batch_idx (int): The index of the current batch in the current - loop. - mode (str): Current mode of runner. - with_non_scalar (bool): Whether to include non-scalar infos in the - returned tag. Defaults to False. - - Return: - Tuple[dict, str]: Formatted log dict/string which will be - recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. - """ - assert mode in [ - 'test', 'val' - ], ('`_get_metric_log_str` only accept val or test mode, but got ' - f'{mode}') - dataloader_len = self._get_dataloader_size(runner, mode) - - # By epoch: - # Epoch(val) [10][1000/1000] ... - # Epoch(test) [1000/1000] ... - # By iteration: - # Iteration(val) [1000/1000] ... - # Iteration(test) [1000/1000] ... - if self.by_epoch: - if mode == 'val': - cur_epoch = self._get_epoch(runner, mode) - log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/' - f'{dataloader_len}] ') - else: - log_str = ( - f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ') - - else: - log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ') - - custom_cfg_copy = copy.deepcopy(self.custom_cfg) - # remove prefix - custom_keys = [ - self._remove_prefix(cfg['data_src'], f'{mode}/') - for cfg in custom_cfg_copy - ] - # Count the averaged time and data_time by epoch - if 'time' not in custom_keys: - custom_cfg_copy.append( - dict(data_src='time', window_size='epoch', method_name='mean')) - if 'data_time' not in custom_keys: - custom_cfg_copy.append( - dict( - data_src='data_time', - window_size='epoch', - method_name='mean')) - parsed_cfg = self._parse_windows_size(runner, batch_idx, - custom_cfg_copy) - # tag is used to write log information to different backends. - ori_tag = self._collect_scalars(parsed_cfg, runner, mode, - self.log_with_hierarchy) - non_scalar_tag = self._collect_non_scalars(runner, mode) - # move `time` or `data_time` to the end of the log - tag = OrderedDict() - time_tag = OrderedDict() - for key, value in ori_tag.items(): - if key in (f'{mode}/time', f'{mode}/data_time', 'time', - 'data_time'): - time_tag[key] = value - else: - tag[key] = value - # Log other messages. - log_items = [] - log_str += ' ' - for name, val in chain(tag.items(), non_scalar_tag.items(), - time_tag.items()): - if isinstance(val, float): - val = f'{val:.{self.num_digits}f}' - if isinstance(val, (torch.Tensor, np.ndarray)): - # newline to display tensor and array. - val = f'\n{val}\n' - log_items.append(f'{name}: {val}') - log_str += ' '.join(log_items) - - if with_non_scalar: - tag.update(non_scalar_tag) - tag.update(time_tag) - return tag, log_str - - def _collect_scalars(self, - custom_cfg: List[dict], - runner, - mode: str, - reserve_prefix: bool = False) -> dict: - """Collect log information to compose a dict according to mode. - - Args: - custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int - ``window_size``. - runner (Runner): The runner of the training/testing/validation - process. - mode (str): Current mode of runner. - reserve_prefix (bool): Whether to reserve the prefix of the key. - - Returns: - dict: Statistical values of logs. - """ - custom_cfg = copy.deepcopy(custom_cfg) - tag = OrderedDict() - # history_scalars of train/val/test phase. - history_scalars = runner.message_hub.log_scalars - # corresponding mode history_scalars - mode_history_scalars = OrderedDict() - # extract log scalars and remove prefix to `mode_history_scalars` - # according to mode. - for prefix_key, log_buffer in history_scalars.items(): - if prefix_key.startswith(mode): - if not reserve_prefix: - key = self._remove_prefix(prefix_key, f'{mode}/') - else: - key = prefix_key - mode_history_scalars[key] = log_buffer - for key in mode_history_scalars: - # Update the latest learning rate and smoothed time logs. - if re.search(self.mean_pattern, key) is not None: - tag[key] = mode_history_scalars[key].mean(self.window_size) - else: - # Default statistic method is current. - tag[key] = mode_history_scalars[key].current() - # Update custom keys. - for log_cfg in custom_cfg: - data_src = log_cfg.pop('data_src') - log_name = log_cfg.pop('log_name', data_src) - if reserve_prefix: - data_src = f'{mode}/{data_src}' - log_name = f'{mode}/{log_name}' - # log item in custom_cfg could only exist in train or val - # mode. - if data_src in mode_history_scalars: - tag[log_name] = mode_history_scalars[data_src].statistics( - **log_cfg) - return tag - - def _collect_non_scalars(self, runner, mode: str) -> dict: - """Collect log information to compose a dict according to mode. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - mode (str): Current mode of runner. - - Returns: - dict: non-scalar infos of the specified mode. - """ - # infos of train/val/test phase. - infos = runner.message_hub.runtime_info - # corresponding mode infos - mode_infos = OrderedDict() - # extract log info and remove prefix to `mode_infos` according to mode. - for prefix_key, value in infos.items(): - if prefix_key.startswith(mode): - if self.log_with_hierarchy: - key = prefix_key - else: - key = self._remove_prefix(prefix_key, f'{mode}/') - mode_infos[key] = value - return mode_infos - - def _remove_prefix(self, string: str, prefix: str): - """Remove the prefix ``train``, ``val`` and ``test`` of the key.""" - if string.startswith(prefix): - return string[len(prefix):] - else: - return string - - def _check_custom_cfg(self) -> None: - """Check the legality of ``self.custom_cfg``.""" - - def _check_window_size(): - for log_cfg in self.custom_cfg: - if not self.by_epoch: - assert log_cfg['window_size'] != 'epoch', \ - 'window_size cannot be epoch if LoggerHook.by_epoch' \ - ' is False.' - - def _check_repeated_log_name(): - # The `log_name` of the same data_src should not be repeated. - # If `log_name` is not specified, `data_src` will be overwritten. - # But only allowed to be overwritten once. - check_set = set() - for log_cfg in self.custom_cfg: - assert 'data_src' in log_cfg - data_src = log_cfg['data_src'] - log_name = log_cfg.get('log_name', data_src) - assert log_name not in check_set, ( - f'Found duplicate {log_name} for {data_src}. Please check' - 'your `custom_cfg` for `log_processor`. You should ' - f'neither define duplicate `{log_name}` for {data_src} ' - f'nor do not define any {log_name} for multiple ' - f'{data_src}, See more information in the docstring of ' - 'LogProcessor') - - check_set.add(log_name) - - _check_repeated_log_name() - _check_window_size() - - def _parse_windows_size(self, - runner, - batch_idx: int, - custom_cfg: Optional[list] = None) -> list: - """Parse window_size defined in custom_cfg to int value. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - batch_idx (int): The iteration index of current dataloader. - custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None - to keep backward compatibility. - """ - if custom_cfg is None: - custom_cfg = copy.deepcopy(self.custom_cfg) - else: - custom_cfg = copy.deepcopy(custom_cfg) - for log_cfg in custom_cfg: - window_size = log_cfg.get('window_size', None) - if window_size is None or isinstance(window_size, int): - continue - elif window_size == 'epoch': - log_cfg['window_size'] = batch_idx + 1 - elif window_size == 'global': - log_cfg['window_size'] = runner.iter + 1 - else: - raise TypeError( - 'window_size should be int, epoch or global, but got ' - f'invalid {window_size}') - return custom_cfg - - def _get_max_memory(self, runner) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) - for a given device. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - - Returns: - The maximum GPU memory occupied by tensors in megabytes for a given - device. - """ - - device = getattr(runner.model, 'output_device', None) - - if is_musa_available(): - return get_max_musa_memory(device) - return get_max_cuda_memory(device) - - def _get_iter(self, runner, batch_idx: int) -> int: - """Get current iteration index. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - batch_idx (int): The iteration index of current - dataloader. Defaults to None. - - Returns: - int: The current global iter or inner iter. - """ - if self.by_epoch: - current_iter = batch_idx + 1 - else: - current_iter = runner.iter + 1 - return current_iter - - def _get_epoch(self, runner, mode: str) -> int: - """Get current epoch according to mode. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - mode (str): Current mode of runner. - - Returns: - int: The current epoch. - """ - if mode == 'train': - epoch = runner.epoch + 1 - elif mode == 'val': - if (isinstance(runner._train_loop, dict) - or runner._train_loop is None): - epoch = 0 - else: - # normal val mode - # runner.epoch += 1 has been done before validation - epoch = runner.epoch - else: - raise ValueError( - f"runner mode should be 'train' or 'val', but got {mode}") - return epoch - - def _get_cur_loop(self, runner, mode: str): - """Get current loop according to mode. - - Args: - runner (Runner): The runner of the training/validation/testing - process. - mode (str): Current mode of runner. - - Returns: - BaseLoop: Current loop of runner. - """ - # returns type hint will occur circular import - if mode == 'train': - return runner.train_loop - elif mode == 'val': - return runner.val_loop - else: - return runner.test_loop - - def _get_dataloader_size(self, runner, mode) -> int: - """Get dataloader size of current loop. - - Args: - runner (Runner): The runner of the training/validation/testing - mode (str): Current mode of runner. - - Returns: - int: The dataloader size of current loop. - """ - return len(self._get_cur_loop(runner=runner, mode=mode).dataloader) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py deleted file mode 100644 index 5a678db7b9..0000000000 --- a/mmengine/runner/loops.py +++ /dev/null @@ -1,550 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import bisect -import logging -import time -from typing import Dict, List, Optional, Sequence, Tuple, Union - -import torch -from torch.utils.data import DataLoader - -from mmengine.evaluator import Evaluator -from mmengine.logging import HistoryBuffer, print_log -from mmengine.registry import LOOPS -from mmengine.structures import BaseDataElement -from mmengine.utils import is_list_of -from .amp import autocast -from .base_loop import BaseLoop -from .utils import calc_dynamic_intervals - - -@LOOPS.register_module() -class EpochBasedTrainLoop(BaseLoop): - """Loop for epoch-based training. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - max_epochs (int): Total training epochs. - val_begin (int): The epoch that begins validating. - Defaults to 1. - val_interval (int): Validation interval. Defaults to 1. - dynamic_intervals (List[Tuple[int, int]], optional): The - first element in the tuple is a milestone and the second - element is a interval. The interval is used after the - corresponding milestone. Defaults to None. - """ - - def __init__( - self, - runner, - dataloader: Union[DataLoader, Dict], - max_epochs: int, - val_begin: int = 1, - val_interval: int = 1, - dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: - super().__init__(runner, dataloader) - self._max_epochs = int(max_epochs) - assert self._max_epochs == max_epochs, \ - f'`max_epochs` should be a integer number, but get {max_epochs}.' - self._max_iters = self._max_epochs * len(self.dataloader) - self._epoch = 0 - self._iter = 0 - self.val_begin = val_begin - self.val_interval = val_interval - # This attribute will be updated by `EarlyStoppingHook` - # when it is enabled. - self.stop_training = False - if hasattr(self.dataloader.dataset, 'metainfo'): - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - print_log( - f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' - 'metainfo. ``dataset_meta`` in visualizer will be ' - 'None.', - logger='current', - level=logging.WARNING) - - self.dynamic_milestones, self.dynamic_intervals = \ - calc_dynamic_intervals( - self.val_interval, dynamic_intervals) - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - return self._max_epochs - - @property - def max_iters(self): - """int: Total iterations to train model.""" - return self._max_iters - - @property - def epoch(self): - """int: Current epoch.""" - return self._epoch - - @property - def iter(self): - """int: Current iteration.""" - return self._iter - - def run(self) -> torch.nn.Module: - """Launch training.""" - self.runner.call_hook('before_train') - - while self._epoch < self._max_epochs and not self.stop_training: - self.run_epoch() - - self._decide_current_val_interval() - if (self.runner.val_loop is not None - and self._epoch >= self.val_begin - and (self._epoch % self.val_interval == 0 - or self._epoch == self._max_epochs)): - self.runner.val_loop.run() - - self.runner.call_hook('after_train') - return self.runner.model - - def run_epoch(self) -> None: - """Iterate one epoch.""" - self.runner.call_hook('before_train_epoch') - self.runner.model.train() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - self.runner.call_hook('after_train_epoch') - self._epoch += 1 - - def run_iter(self, idx, data_batch: Sequence[dict]) -> None: - """Iterate one min-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook( - 'before_train_iter', batch_idx=idx, data_batch=data_batch) - # Enable gradient accumulation mode and avoid unnecessary gradient - # synchronization during gradient accumulation process. - # outputs should be a dict of loss. - outputs = self.runner.model.train_step( - data_batch, optim_wrapper=self.runner.optim_wrapper) - - self.runner.call_hook( - 'after_train_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - self._iter += 1 - - def _decide_current_val_interval(self) -> None: - """Dynamically modify the ``val_interval``.""" - step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) - self.val_interval = self.dynamic_intervals[step - 1] - - -class _InfiniteDataloaderIterator: - """An infinite dataloader iterator wrapper for IterBasedTrainLoop. - - It resets the dataloader to continue iterating when the iterator has - iterated over all the data. However, this approach is not efficient, as the - workers need to be restarted every time the dataloader is reset. It is - recommended to use `mmengine.dataset.InfiniteSampler` to enable the - dataloader to iterate infinitely. - """ - - def __init__(self, dataloader: DataLoader) -> None: - self._dataloader = dataloader - self._iterator = iter(self._dataloader) - self._epoch = 0 - - def __iter__(self): - return self - - def __next__(self) -> Sequence[dict]: - try: - data = next(self._iterator) - except StopIteration: - print_log( - 'Reach the end of the dataloader, it will be ' - 'restarted and continue to iterate. It is ' - 'recommended to use ' - '`mmengine.dataset.InfiniteSampler` to enable the ' - 'dataloader to iterate infinitely.', - logger='current', - level=logging.WARNING) - self._epoch += 1 - if hasattr(self._dataloader, 'sampler') and hasattr( - self._dataloader.sampler, 'set_epoch'): - # In case the` _SingleProcessDataLoaderIter` has no sampler, - # or data loader uses `SequentialSampler` in Pytorch. - self._dataloader.sampler.set_epoch(self._epoch) - - elif hasattr(self._dataloader, 'batch_sampler') and hasattr( - self._dataloader.batch_sampler.sampler, 'set_epoch'): - # In case the` _SingleProcessDataLoaderIter` has no batch - # sampler. batch sampler in pytorch warps the sampler as its - # attributes. - self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) - time.sleep(2) # Prevent possible deadlock during epoch transition - self._iterator = iter(self._dataloader) - data = next(self._iterator) - return data - - -@LOOPS.register_module() -class IterBasedTrainLoop(BaseLoop): - """Loop for iter-based training. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - max_iters (int): Total training iterations. - val_begin (int): The iteration that begins validating. - Defaults to 1. - val_interval (int): Validation interval. Defaults to 1000. - dynamic_intervals (List[Tuple[int, int]], optional): The - first element in the tuple is a milestone and the second - element is a interval. The interval is used after the - corresponding milestone. Defaults to None. - """ - - def __init__( - self, - runner, - dataloader: Union[DataLoader, Dict], - max_iters: int, - val_begin: int = 1, - val_interval: int = 1000, - dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: - super().__init__(runner, dataloader) - self._max_iters = int(max_iters) - assert self._max_iters == max_iters, \ - f'`max_iters` should be a integer number, but get {max_iters}' - self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop - self._epoch = 0 - self._iter = 0 - self.val_begin = val_begin - self.val_interval = val_interval - # This attribute will be updated by `EarlyStoppingHook` - # when it is enabled. - self.stop_training = False - if hasattr(self.dataloader.dataset, 'metainfo'): - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - print_log( - f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' - 'metainfo. ``dataset_meta`` in visualizer will be ' - 'None.', - logger='current', - level=logging.WARNING) - # get the iterator of the dataloader - self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) - - self.dynamic_milestones, self.dynamic_intervals = \ - calc_dynamic_intervals( - self.val_interval, dynamic_intervals) - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - return self._max_epochs - - @property - def max_iters(self): - """int: Total iterations to train model.""" - return self._max_iters - - @property - def epoch(self): - """int: Current epoch.""" - return self._epoch - - @property - def iter(self): - """int: Current iteration.""" - return self._iter - - def run(self) -> None: - """Launch training.""" - self.runner.call_hook('before_train') - # In iteration-based training loop, we treat the whole training process - # as a big epoch and execute the corresponding hook. - self.runner.call_hook('before_train_epoch') - if self._iter > 0: - print_log( - f'Advance dataloader {self._iter} steps to skip data ' - 'that has already been trained', - logger='current', - level=logging.WARNING) - for _ in range(self._iter): - next(self.dataloader_iterator) - while self._iter < self._max_iters and not self.stop_training: - self.runner.model.train() - - data_batch = next(self.dataloader_iterator) - self.run_iter(data_batch) - - self._decide_current_val_interval() - if (self.runner.val_loop is not None - and self._iter >= self.val_begin - and (self._iter % self.val_interval == 0 - or self._iter == self._max_iters)): - self.runner.val_loop.run() - - self.runner.call_hook('after_train_epoch') - self.runner.call_hook('after_train') - return self.runner.model - - def run_iter(self, data_batch: Sequence[dict]) -> None: - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook( - 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) - # Enable gradient accumulation mode and avoid unnecessary gradient - # synchronization during gradient accumulation process. - # outputs should be a dict of loss. - outputs = self.runner.model.train_step( - data_batch, optim_wrapper=self.runner.optim_wrapper) - - self.runner.call_hook( - 'after_train_iter', - batch_idx=self._iter, - data_batch=data_batch, - outputs=outputs) - self._iter += 1 - - def _decide_current_val_interval(self) -> None: - """Dynamically modify the ``val_interval``.""" - step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) - self.val_interval = self.dynamic_intervals[step - 1] - - -@LOOPS.register_module() -class ValLoop(BaseLoop): - """Loop for validation. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - fp16 (bool): Whether to enable fp16 validation. Defaults to - False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False) -> None: - super().__init__(runner, dataloader) - - if isinstance(evaluator, (dict, list)): - self.evaluator = runner.build_evaluator(evaluator) # type: ignore - else: - assert isinstance(evaluator, Evaluator), ( - 'evaluator must be one of dict, list or Evaluator instance, ' - f'but got {type(evaluator)}.') - self.evaluator = evaluator # type: ignore - if hasattr(self.dataloader.dataset, 'metainfo'): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - print_log( - f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' - 'metainfo. ``dataset_meta`` in evaluator, metric and ' - 'visualizer will be None.', - logger='current', - level=logging.WARNING) - self.fp16 = fp16 - self.val_loss: Dict[str, HistoryBuffer] = dict() - - def run(self) -> dict: - """Launch validation.""" - self.runner.call_hook('before_val') - self.runner.call_hook('before_val_epoch') - self.runner.model.eval() - - # clear val loss - self.val_loss.clear() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - if self.val_loss: - loss_dict = _parse_losses(self.val_loss, 'val') - metrics.update(loss_dict) - - self.runner.call_hook('after_val_epoch', metrics=metrics) - self.runner.call_hook('after_val') - return metrics - - @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[dict]): - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data - from dataloader. - """ - self.runner.call_hook( - 'before_val_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = self.runner.model.val_step(data_batch) - - outputs, self.val_loss = _update_losses(outputs, self.val_loss) - - self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook( - 'after_val_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -@LOOPS.register_module() -class TestLoop(BaseLoop): - """Loop for test. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - fp16 (bool): Whether to enable fp16 testing. Defaults to - False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False): - super().__init__(runner, dataloader) - - if isinstance(evaluator, dict) or isinstance(evaluator, list): - self.evaluator = runner.build_evaluator(evaluator) # type: ignore - else: - self.evaluator = evaluator # type: ignore - if hasattr(self.dataloader.dataset, 'metainfo'): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - print_log( - f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' - 'metainfo. ``dataset_meta`` in evaluator, metric and ' - 'visualizer will be None.', - logger='current', - level=logging.WARNING) - self.fp16 = fp16 - self.test_loss: Dict[str, HistoryBuffer] = dict() - - def run(self) -> dict: - """Launch test.""" - self.runner.call_hook('before_test') - self.runner.call_hook('before_test_epoch') - self.runner.model.eval() - - # clear test loss - self.test_loss.clear() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - if self.test_loss: - loss_dict = _parse_losses(self.test_loss, 'test') - metrics.update(loss_dict) - - self.runner.call_hook('after_test_epoch', metrics=metrics) - self.runner.call_hook('after_test') - return metrics - - @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[dict]) -> None: - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = self.runner.model.test_step(data_batch) - - outputs, self.test_loss = _update_losses(outputs, self.test_loss) - - self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -def _parse_losses(losses: Dict[str, HistoryBuffer], - stage: str) -> Dict[str, float]: - """Parses the raw losses of the network. - - Args: - losses (dict): raw losses of the network. - stage (str): The stage of loss, e.g., 'val' or 'test'. - - Returns: - dict[str, float]: The key is the loss name, and the value is the - average loss. - """ - all_loss = 0 - loss_dict: Dict[str, float] = dict() - - for loss_name, loss_value in losses.items(): - avg_loss = loss_value.mean() - loss_dict[loss_name] = avg_loss - if 'loss' in loss_name: - all_loss += avg_loss - - loss_dict[f'{stage}_loss'] = all_loss - return loss_dict - - -def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]: - """Update and record the losses of the network. - - Args: - outputs (list): The outputs of the network. - losses (dict): The losses of the network. - - Returns: - list: The updated outputs of the network. - dict: The updated losses of the network. - """ - if isinstance(outputs[-1], - BaseDataElement) and outputs[-1].keys() == ['loss']: - loss = outputs[-1].loss # type: ignore - outputs = outputs[:-1] - else: - loss = dict() - - for loss_name, loss_value in loss.items(): - if loss_name not in losses: - losses[loss_name] = HistoryBuffer() - if isinstance(loss_value, torch.Tensor): - losses[loss_name].update(loss_value.item()) - elif is_list_of(loss_value, torch.Tensor): - for loss_value_i in loss_value: - losses[loss_name].update(loss_value_i.item()) - return outputs, losses diff --git a/mmengine/runner/priority.py b/mmengine/runner/priority.py deleted file mode 100644 index ff644043b8..0000000000 --- a/mmengine/runner/priority.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from enum import Enum -from typing import Union - - -class Priority(Enum): - """Hook priority levels. - - +--------------+------------+ - | Level | Value | - +==============+============+ - | HIGHEST | 0 | - +--------------+------------+ - | VERY_HIGH | 10 | - +--------------+------------+ - | HIGH | 30 | - +--------------+------------+ - | ABOVE_NORMAL | 40 | - +--------------+------------+ - | NORMAL | 50 | - +--------------+------------+ - | BELOW_NORMAL | 60 | - +--------------+------------+ - | LOW | 70 | - +--------------+------------+ - | VERY_LOW | 90 | - +--------------+------------+ - | LOWEST | 100 | - +--------------+------------+ - """ - - HIGHEST = 0 - VERY_HIGH = 10 - HIGH = 30 - ABOVE_NORMAL = 40 - NORMAL = 50 - BELOW_NORMAL = 60 - LOW = 70 - VERY_LOW = 90 - LOWEST = 100 - - -def get_priority(priority: Union[int, str, Priority]) -> int: - """Get priority value. - - Args: - priority (int or str or :obj:`Priority`): Priority. - - Returns: - int: The priority value. - """ - if isinstance(priority, int): - if priority < 0 or priority > 100: - raise ValueError('priority must be between 0 and 100') - return priority - elif isinstance(priority, Priority): - return priority.value - elif isinstance(priority, str): - return Priority[priority.upper()].value - else: - raise TypeError('priority must be an integer or Priority enum value') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py deleted file mode 100644 index 7d1f655aad..0000000000 --- a/mmengine/runner/runner.py +++ /dev/null @@ -1,2413 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -import os -import os.path as osp -import pickle -import platform -import time -import warnings -from collections import OrderedDict -from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Union - -import torch -import torch.nn as nn -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import Optimizer -from torch.utils.data import DataLoader - -import mmengine -from mmengine.config import Config, ConfigDict -from mmengine.dataset import worker_init_fn as default_worker_init_fn -from mmengine.device import get_device -from mmengine.dist import (broadcast, get_dist_info, get_rank, get_world_size, - init_dist, is_distributed, master_only) -from mmengine.evaluator import Evaluator -from mmengine.fileio import FileClient, join_path -from mmengine.hooks import Hook -from mmengine.logging import MessageHub, MMLogger, print_log -from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm, - is_model_wrapper, revert_sync_batchnorm) -from mmengine.model.efficient_conv_bn_eval import \ - turn_on_efficient_conv_bn_eval -from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, - build_optim_wrapper) -from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, - HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, - MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, - RUNNERS, VISUALIZERS, DefaultScope) -from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of -from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, - set_multi_processing) -from mmengine.visualization import Visualizer -from .activation_checkpointing import turn_on_activation_checkpointing -from .base_loop import BaseLoop -from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, - find_latest_checkpoint, save_checkpoint, - weights_to_cpu) -from .log_processor import LogProcessor -from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop -from .priority import Priority, get_priority -from .utils import _get_batch_size, set_random_seed - -ConfigType = Union[Dict, Config, ConfigDict] -ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, - List[_ParamScheduler]]] -OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] - - -class _SlicedDataset: - - def __init__(self, dataset, length) -> None: - self._dataset = dataset - self._length = length - - def __getattr__(self, name): - return getattr(self._dataset, name) - - def __getitem__(self, idx): - return self._dataset[idx] - - def __len__(self): - return self._length - - -@RUNNERS.register_module() -class Runner: - """A training helper for PyTorch. - - Runner object can be built from config by ``runner = Runner.from_cfg(cfg)`` - where the ``cfg`` usually contains training, validation, and test-related - configurations to build corresponding components. We usually use the - same config to launch training, testing, and validation tasks. However, - only some of these components are necessary at the same time, e.g., - testing a model does not need training or validation-related components. - - To avoid repeatedly modifying config, the construction of ``Runner`` adopts - lazy initialization to only initialize components when they are going to be - used. Therefore, the model is always initialized at the beginning, and - training, validation, and, testing related components are only initialized - when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``, - respectively. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It can be - a dict used for build a model. - work_dir (str): The working directory to save checkpoints. The logs - will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. - train_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping training steps. Defaults to None. - See :meth:`build_dataloader` for more details. - val_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping validation steps. Defaults to None. - See :meth:`build_dataloader` for more details. - test_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping test steps. Defaults to None. - See :meth:`build_dataloader` for more details. - train_cfg (dict, optional): A dict to build a training loop. If it does - not provide "type" key, it should contain "by_epoch" to decide - which type of training loop :class:`EpochBasedTrainLoop` or - :class:`IterBasedTrainLoop` should be used. If ``train_cfg`` - specified, :attr:`train_dataloader` should also be specified. - Defaults to None. See :meth:`build_train_loop` for more details. - val_cfg (dict, optional): A dict to build a validation loop. If it does - not provide "type" key, :class:`ValLoop` will be used by default. - If ``val_cfg`` specified, :attr:`val_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - Defaults to None. See :meth:`build_val_loop` for more details. - test_cfg (dict, optional): A dict to build a test loop. If it does - not provide "type" key, :class:`TestLoop` will be used by default. - If ``test_cfg`` specified, :attr:`test_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - Defaults to None. See :meth:`build_test_loop` for more details. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - optim_wrapper (OptimWrapper or dict, optional): - Computing gradient of model parameters. If specified, - :attr:`train_dataloader` should also be specified. If automatic - mixed precision or gradient accmulation - training is required. The type of ``optim_wrapper`` should be - AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for - examples. Defaults to None. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optimizer` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - val_evaluator (Evaluator or dict or list, optional): A evaluator object - used for computing metrics for validation. It can be a dict or a - list of dict to build a evaluator. If specified, - :attr:`val_dataloader` should also be specified. Defaults to None. - test_evaluator (Evaluator or dict or list, optional): A evaluator - object used for computing metrics for test steps. It can be a dict - or a list of dict to build a evaluator. If specified, - :attr:`test_dataloader` should also be specified. Defaults to None. - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks to - execute default actions like updating model parameters and saving - checkpoints. Default hooks are ``OptimizerHook``, - ``IterTimerHook``, ``LoggerHook``, ``ParamSchedulerHook`` and - ``CheckpointHook``. Defaults to None. - See :meth:`register_default_hooks` for more details. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - data_preprocessor (dict, optional): The pre-process config of - :class:`BaseDataPreprocessor`. If the ``model`` argument is a dict - and doesn't contain the key ``data_preprocessor``, set the argument - as the ``data_preprocessor`` of the ``model`` dict. - Defaults to None. - load_from (str, optional): The checkpoint file to load from. - Defaults to None. - resume (bool): Whether to resume training. Defaults to False. If - ``resume`` is True and ``load_from`` is None, automatically to - find latest checkpoint from ``work_dir``. If not found, resuming - does nothing. - launcher (str): Way to launcher multi-process. Supported launchers - are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' is provided, - non-distributed environment will be launched. - env_cfg (dict): A dict used for setting environment. Defaults to - dict(dist_cfg=dict(backend='nccl')). - log_processor (dict, optional): A processor to format logs. Defaults to - None. - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - visualizer (Visualizer or dict, optional): A Visualizer object or a - dict build Visualizer object. Defaults to None. If not - specified, default config will be used. - default_scope (str): Used to reset registries location. - Defaults to "mmengine". - randomness (dict): Some settings to make the experiment as reproducible - as possible like seed and deterministic. - Defaults to ``dict(seed=None)``. If seed is None, a random number - will be generated and it will be broadcasted to all other processes - if in distributed environment. If ``cudnn_benchmark`` is - ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in - ``randomness``, the value of ``torch.backends.cudnn.benchmark`` - will be ``False`` finally. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as ``experiment_name``. - Defaults to None. - cfg (dict or Configdict or :obj:`Config`, optional): Full config. - Defaults to None. - - Note: - Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in - `cfg.compile = True`. If you want to control compile options, you - can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``. - Refer to `PyTorch API Documentation `_ for more valid - options. - - Examples: - >>> from mmengine.runner import Runner - >>> cfg = dict( - >>> model=dict(type='ToyModel'), - >>> work_dir='path/of/work_dir', - >>> train_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=True), - >>> batch_size=1, - >>> num_workers=0), - >>> val_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> test_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> auto_scale_lr=dict(base_batch_size=16, enable=False), - >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( - >>> type='SGD', lr=0.01)), - >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - >>> val_evaluator=dict(type='ToyEvaluator'), - >>> test_evaluator=dict(type='ToyEvaluator'), - >>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1), - >>> val_cfg=dict(), - >>> test_cfg=dict(), - >>> custom_hooks=[], - >>> default_hooks=dict( - >>> timer=dict(type='IterTimerHook'), - >>> checkpoint=dict(type='CheckpointHook', interval=1), - >>> logger=dict(type='LoggerHook'), - >>> optimizer=dict(type='OptimizerHook', grad_clip=False), - >>> param_scheduler=dict(type='ParamSchedulerHook')), - >>> launcher='none', - >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), - >>> log_processor=dict(window_size=20), - >>> visualizer=dict(type='Visualizer', - >>> vis_backends=[dict(type='LocalVisBackend', - >>> save_dir='temp_dir')]) - >>> ) - >>> runner = Runner.from_cfg(cfg) - >>> runner.train() - >>> runner.test() - """ - cfg: Config - _train_loop: Optional[Union[BaseLoop, Dict]] - _val_loop: Optional[Union[BaseLoop, Dict]] - _test_loop: Optional[Union[BaseLoop, Dict]] - - def __init__( - self, - model: Union[nn.Module, Dict], - work_dir: str, - train_dataloader: Optional[Union[DataLoader, Dict]] = None, - val_dataloader: Optional[Union[DataLoader, Dict]] = None, - test_dataloader: Optional[Union[DataLoader, Dict]] = None, - train_cfg: Optional[Dict] = None, - val_cfg: Optional[Dict] = None, - test_cfg: Optional[Dict] = None, - auto_scale_lr: Optional[Dict] = None, - optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None, - param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None, - val_evaluator: Optional[Union[Evaluator, Dict, List]] = None, - test_evaluator: Optional[Union[Evaluator, Dict, List]] = None, - default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, - custom_hooks: Optional[List[Union[Hook, Dict]]] = None, - data_preprocessor: Union[nn.Module, Dict, None] = None, - load_from: Optional[str] = None, - resume: bool = False, - launcher: str = 'none', - env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), - log_processor: Optional[Dict] = None, - log_level: str = 'INFO', - visualizer: Optional[Union[Visualizer, Dict]] = None, - default_scope: str = 'mmengine', - randomness: Dict = dict(seed=None), - experiment_name: Optional[str] = None, - cfg: Optional[ConfigType] = None, - ): - self._work_dir = osp.abspath(work_dir) - mmengine.mkdir_or_exist(self._work_dir) - - # recursively copy the `cfg` because `self.cfg` will be modified - # everywhere. - if cfg is not None: - if isinstance(cfg, Config): - self.cfg = copy.deepcopy(cfg) - elif isinstance(cfg, dict): - self.cfg = Config(cfg) - else: - self.cfg = Config(dict()) - - # lazy initialization - training_related = [train_dataloader, train_cfg, optim_wrapper] - if not (all(item is None for item in training_related) - or all(item is not None for item in training_related)): - raise ValueError( - 'train_dataloader, train_cfg, and optim_wrapper should be ' - 'either all None or not None, but got ' - f'train_dataloader={train_dataloader}, ' - f'train_cfg={train_cfg}, ' - f'optim_wrapper={optim_wrapper}.') - self._train_dataloader = train_dataloader - self._train_loop = train_cfg - - self.optim_wrapper: Optional[Union[OptimWrapper, dict]] - self.optim_wrapper = optim_wrapper - - self.auto_scale_lr = auto_scale_lr - - # If there is no need to adjust learning rate, momentum or other - # parameters of optimizer, param_scheduler can be None - if param_scheduler is not None and self.optim_wrapper is None: - raise ValueError( - 'param_scheduler should be None when optim_wrapper is None, ' - f'but got {param_scheduler}') - - # Parse `param_scheduler` to a list or a dict. If `optim_wrapper` is a - # `dict` with single optimizer, parsed param_scheduler will be a - # list of parameter schedulers. If `optim_wrapper` is - # a `dict` with multiple optimizers, parsed `param_scheduler` will be - # dict with multiple list of parameter schedulers. - self._check_scheduler_cfg(param_scheduler) - self.param_schedulers = param_scheduler - - val_related = [val_dataloader, val_cfg, val_evaluator] - if not (all(item is None - for item in val_related) or all(item is not None - for item in val_related)): - raise ValueError( - 'val_dataloader, val_cfg, and val_evaluator should be either ' - 'all None or not None, but got ' - f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, ' - f'val_evaluator={val_evaluator}') - self._val_dataloader = val_dataloader - self._val_loop = val_cfg - self._val_evaluator = val_evaluator - - test_related = [test_dataloader, test_cfg, test_evaluator] - if not (all(item is None for item in test_related) - or all(item is not None for item in test_related)): - raise ValueError( - 'test_dataloader, test_cfg, and test_evaluator should be ' - 'either all None or not None, but got ' - f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, ' - f'test_evaluator={test_evaluator}') - self._test_dataloader = test_dataloader - self._test_loop = test_cfg - self._test_evaluator = test_evaluator - - self._launcher = launcher - if self._launcher == 'none': - self._distributed = False - else: - self._distributed = True - - # self._timestamp will be set in the `setup_env` method. Besides, - # it also will initialize multi-process and (or) distributed - # environment. - self.setup_env(env_cfg) - # self._deterministic and self._seed will be set in the - # `set_randomness`` method - self._randomness_cfg = randomness - self.set_randomness(**randomness) - - if experiment_name is not None: - self._experiment_name = f'{experiment_name}_{self._timestamp}' - elif self.cfg.filename is not None: - filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0] - self._experiment_name = f'{filename_no_ext}_{self._timestamp}' - else: - self._experiment_name = self.timestamp - self._log_dir = osp.join(self.work_dir, self.timestamp) - mmengine.mkdir_or_exist(self._log_dir) - # Used to reset registries location. See :meth:`Registry.build` for - # more details. - if default_scope is not None: - default_scope = DefaultScope.get_instance( # type: ignore - self._experiment_name, - scope_name=default_scope) - self.default_scope = default_scope - - # Build log processor to format message. - log_processor = dict() if log_processor is None else log_processor - self.log_processor = self.build_log_processor(log_processor) - # Since `get_instance` could return any subclass of ManagerMixin. The - # corresponding attribute needs a type hint. - self.logger = self.build_logger(log_level=log_level) - - # Collect and log environment information. - self._log_env(env_cfg) - - # Build `message_hub` for communication among components. - # `message_hub` can store log scalars (loss, learning rate) and - # runtime information (iter and epoch). Those components that do not - # have access to the runner can get iteration or epoch information - # from `message_hub`. For example, models can get the latest created - # `message_hub` by - # `self.message_hub=MessageHub.get_current_instance()` and then get - # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. - # See `MessageHub` and `ManagerMixin` for more details. - self.message_hub = self.build_message_hub() - # visualizer used for writing log or visualizing all kinds of data - self.visualizer = self.build_visualizer(visualizer) - if self.cfg: - self.visualizer.add_config(self.cfg) - - self._load_from = load_from - self._resume = resume - # flag to mark whether checkpoint has been loaded or resumed - self._has_loaded = False - - # build a model - if isinstance(model, dict) and data_preprocessor is not None: - # Merge the data_preprocessor to model config. - model.setdefault('data_preprocessor', data_preprocessor) - self.model = self.build_model(model) - # wrap model - self.model = self.wrap_model( - self.cfg.get('model_wrapper_cfg'), self.model) - - # get model name from the model class - if hasattr(self.model, 'module'): - self._model_name = self.model.module.__class__.__name__ - else: - self._model_name = self.model.__class__.__name__ - - self._hooks: List[Hook] = [] - # register hooks to `self._hooks` - self.register_hooks(default_hooks, custom_hooks) - # log hooks information - self.logger.info(f'Hooks will be executed in the following ' - f'order:\n{self.get_hooks_info()}') - - # dump `cfg` to `work_dir` - self.dump_config() - - @classmethod - def from_cfg(cls, cfg: ConfigType) -> 'Runner': - """Build a runner from config. - - Args: - cfg (ConfigType): A config used for building runner. Keys of - ``cfg`` can see :meth:`__init__`. - - Returns: - Runner: A runner build from ``cfg``. - """ - cfg = copy.deepcopy(cfg) - runner = cls( - model=cfg['model'], - work_dir=cfg['work_dir'], - train_dataloader=cfg.get('train_dataloader'), - val_dataloader=cfg.get('val_dataloader'), - test_dataloader=cfg.get('test_dataloader'), - train_cfg=cfg.get('train_cfg'), - val_cfg=cfg.get('val_cfg'), - test_cfg=cfg.get('test_cfg'), - auto_scale_lr=cfg.get('auto_scale_lr'), - optim_wrapper=cfg.get('optim_wrapper'), - param_scheduler=cfg.get('param_scheduler'), - val_evaluator=cfg.get('val_evaluator'), - test_evaluator=cfg.get('test_evaluator'), - default_hooks=cfg.get('default_hooks'), - custom_hooks=cfg.get('custom_hooks'), - data_preprocessor=cfg.get('data_preprocessor'), - load_from=cfg.get('load_from'), - resume=cfg.get('resume', False), - launcher=cfg.get('launcher', 'none'), - env_cfg=cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))), - log_processor=cfg.get('log_processor'), - log_level=cfg.get('log_level', 'INFO'), - visualizer=cfg.get('visualizer'), - default_scope=cfg.get('default_scope', 'mmengine'), - randomness=cfg.get('randomness', dict(seed=None)), - experiment_name=cfg.get('experiment_name'), - cfg=cfg, - ) - - return runner - - @property - def experiment_name(self): - """str: Name of experiment.""" - return self._experiment_name - - @property - def model_name(self): - """str: Name of the model, usually the module class name.""" - return self._model_name - - @property - def work_dir(self): - """str: The working directory to save checkpoints and logs.""" - return self._work_dir - - @property - def log_dir(self): - return self._log_dir - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_epochs - else: - return 0 - - @property - def max_iters(self): - """int: Total iterations to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_iters - else: - return 0 - - @property - def epoch(self): - """int: Current epoch.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.epoch - else: - return 0 - - @property - def iter(self): - """int: Current iteration.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.iter - else: - return 0 - - @property - def launcher(self): - """str: Way to launcher multi processes.""" - return self._launcher - - @property - def distributed(self): - """bool: Whether current environment is distributed.""" - return self._distributed - - @property - def rank(self): - """int: Rank of current process.""" - return self._rank - - @property - def world_size(self): - """int: Number of processes participating in the job.""" - return self._world_size - - @property - def deterministic(self): - """int: Whether cudnn to select deterministic algorithms.""" - return self._deterministic - - @property - def seed(self): - """int: A number to set random modules.""" - return self._seed - - @property - def timestamp(self): - """str: Timestamp when creating experiment.""" - return self._timestamp - - @property - def hooks(self): - """List[:obj:`Hook`]: A list of registered hooks.""" - return self._hooks - - @property - def train_loop(self): - """:obj:`BaseLoop`: A loop to run training.""" - if isinstance(self._train_loop, BaseLoop) or self._train_loop is None: - return self._train_loop - else: - self._train_loop = self.build_train_loop(self._train_loop) - return self._train_loop - - @property - def val_loop(self): - """:obj:`BaseLoop`: A loop to run validation.""" - if isinstance(self._val_loop, BaseLoop) or self._val_loop is None: - return self._val_loop - else: - self._val_loop = self.build_val_loop(self._val_loop) - return self._val_loop - - @property - def test_loop(self): - """:obj:`BaseLoop`: A loop to run testing.""" - if isinstance(self._test_loop, BaseLoop) or self._test_loop is None: - return self._test_loop - else: - self._test_loop = self.build_test_loop(self._test_loop) - return self._test_loop - - @property - def train_dataloader(self): - """The data loader for training.""" - return self.train_loop.dataloader - - @property - def val_dataloader(self): - """The data loader for validation.""" - return self.val_loop.dataloader - - @property - def test_dataloader(self): - """The data loader for testing.""" - return self.test_loop.dataloader - - @property - def val_evaluator(self): - """:obj:`Evaluator`: An evaluator for validation.""" - return self.val_loop.evaluator - - @property - def test_evaluator(self): - """:obj:`Evaluator`: An evaluator for testing.""" - return self.test_loop.evaluator - - @property - def val_interval(self): - """int: Interval to run validation during training.""" - return self.train_loop.val_interval - - @property - def val_begin(self): - """int: The epoch/iteration to start running validation during - training.""" - return self.train_loop.val_begin - - def setup_env(self, env_cfg: Dict) -> None: - """Setup environment. - - An example of ``env_cfg``:: - - env_cfg = dict( - cudnn_benchmark=True, - mp_cfg=dict( - mp_start_method='fork', - opencv_num_threads=0 - ), - dist_cfg=dict(backend='nccl', timeout=1800), - resource_limit=4096 - ) - - Args: - env_cfg (dict): Config for setting environment. - """ - if env_cfg.get('cudnn_benchmark'): - torch.backends.cudnn.benchmark = True - - mp_cfg: dict = env_cfg.get('mp_cfg', {}) - set_multi_processing(**mp_cfg, distributed=self.distributed) - - # init distributed env first, since logger depends on the dist info. - if self.distributed and not is_distributed(): - dist_cfg: dict = env_cfg.get('dist_cfg', {}) - init_dist(self.launcher, **dist_cfg) - - self._rank, self._world_size = get_dist_info() - - timestamp = torch.tensor(time.time(), dtype=torch.float64) - # broadcast timestamp from 0 process to other processes - broadcast(timestamp) - self._timestamp = time.strftime('%Y%m%d_%H%M%S', - time.localtime(timestamp.item())) - - # https://github.com/pytorch/pytorch/issues/973 - # set resource limit - if platform.system() != 'Windows': - import resource - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - base_soft_limit = rlimit[0] - hard_limit = rlimit[1] - soft_limit = min( - max(env_cfg.get('resource_limit', 4096), base_soft_limit), - hard_limit) - resource.setrlimit(resource.RLIMIT_NOFILE, - (soft_limit, hard_limit)) - - def set_randomness(self, - seed, - diff_rank_seed: bool = False, - deterministic: bool = False) -> None: - """Set random seed to guarantee reproducible results. - - Args: - seed (int): A number to set random modules. - diff_rank_seed (bool): Whether or not set different seeds according - to global rank. Defaults to False. - deterministic (bool): Whether to set the deterministic option for - CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` - to True and `torch.backends.cudnn.benchmark` to False. - Defaults to False. - See https://pytorch.org/docs/stable/notes/randomness.html for - more details. - """ - self._deterministic = deterministic - self._seed = set_random_seed( - seed=seed, - deterministic=deterministic, - diff_rank_seed=diff_rank_seed) - - def build_logger(self, - log_level: Union[int, str] = 'INFO', - log_file: Optional[str] = None, - **kwargs) -> MMLogger: - """Build a global asscessable MMLogger. - - Args: - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - log_file (str, optional): Path of filename to save log. - Defaults to None. - **kwargs: Remaining parameters passed to ``MMLogger``. - - Returns: - MMLogger: A MMLogger object build from ``logger``. - """ - if log_file is None: - log_file = osp.join(self._log_dir, f'{self.timestamp}.log') - - log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs) - log_cfg.setdefault('name', self._experiment_name) - # `torch.compile` in PyTorch 2.0 could close all user defined handlers - # unexpectedly. Using file mode 'a' can help prevent abnormal - # termination of the FileHandler and ensure that the log file could - # be continuously updated during the lifespan of the runner. - log_cfg.setdefault('file_mode', 'a') - - return MMLogger.get_instance(**log_cfg) # type: ignore - - def build_message_hub(self, - message_hub: Optional[Dict] = None) -> MessageHub: - """Build a global asscessable MessageHub. - - Args: - message_hub (dict, optional): A dict to build MessageHub object. - If not specified, default config will be used to build - MessageHub object. Defaults to None. - - Returns: - MessageHub: A MessageHub object build from ``message_hub``. - """ - if message_hub is None: - message_hub = dict(name=self._experiment_name) - elif isinstance(message_hub, dict): - # ensure message_hub containing name key - message_hub.setdefault('name', self._experiment_name) - else: - raise TypeError( - f'message_hub should be dict or None, but got {message_hub}') - - return MessageHub.get_instance(**message_hub) - - def build_visualizer( - self, - visualizer: Optional[Union[Visualizer, - Dict]] = None) -> Visualizer: - """Build a global asscessable Visualizer. - - Args: - visualizer (Visualizer or dict, optional): A Visualizer object - or a dict to build Visualizer object. If ``visualizer`` is a - Visualizer object, just returns itself. If not specified, - default config will be used to build Visualizer object. - Defaults to None. - - Returns: - Visualizer: A Visualizer object build from ``visualizer``. - """ - if visualizer is None: - visualizer = dict( - name=self._experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self._log_dir) - return Visualizer.get_instance(**visualizer) - - if isinstance(visualizer, Visualizer): - return visualizer - - if isinstance(visualizer, dict): - # ensure visualizer containing name key - visualizer.setdefault('name', self._experiment_name) - visualizer.setdefault('save_dir', self._log_dir) - return VISUALIZERS.build(visualizer) - else: - raise TypeError( - 'visualizer should be Visualizer object, a dict or None, ' - f'but got {visualizer}') - - def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: - """Build model. - - If ``model`` is a dict, it will be used to build a nn.Module object. - Else, if ``model`` is a nn.Module object it will be returned directly. - - An example of ``model``:: - - model = dict(type='ResNet') - - Args: - model (nn.Module or dict): A ``nn.Module`` object or a dict to - build nn.Module object. If ``model`` is a nn.Module object, - just returns itself. - - Note: - The returned model must implement ``train_step``, ``test_step`` - if ``runner.train`` or ``runner.test`` will be called. If - ``runner.val`` will be called or ``val_cfg`` is configured, - model must implement `val_step`. - - Returns: - nn.Module: Model build from ``model``. - """ - if isinstance(model, nn.Module): - return model - elif isinstance(model, dict): - model = MODELS.build(model) - return model # type: ignore - else: - raise TypeError('model should be a nn.Module object or dict, ' - f'but got {model}') - - def wrap_model( - self, model_wrapper_cfg: Optional[Dict], - model: nn.Module) -> Union[DistributedDataParallel, nn.Module]: - """Wrap the model to :obj:`MMDistributedDataParallel` or other custom - distributed data-parallel module wrappers. - - An example of ``model_wrapper_cfg``:: - - model_wrapper_cfg = dict( - broadcast_buffers=False, - find_unused_parameters=False - ) - - Args: - model_wrapper_cfg (dict, optional): Config to wrap model. If not - specified, ``DistributedDataParallel`` will be used in - distributed environment. Defaults to None. - model (nn.Module): Model to be wrapped. - - Returns: - nn.Module or DistributedDataParallel: nn.Module or subclass of - ``DistributedDataParallel``. - """ - if is_model_wrapper(model): - if model_wrapper_cfg is not None: - raise TypeError( - 'model has been wrapped and "model_wrapper_cfg" should be ' - f'None, but got {model_wrapper_cfg}') - - return model - - # Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training. - model = model.to(get_device()) - - if not self.distributed: - self.logger.info( - 'Distributed training is not used, all SyncBatchNorm (SyncBN) ' - 'layers in the model will be automatically reverted to ' - 'BatchNormXd layers if they are used.') - model = revert_sync_batchnorm(model) - return model # type: ignore - else: - sync_bn = self.cfg.get('sync_bn', None) - if sync_bn is not None: - try: - model = convert_sync_batchnorm(model, sync_bn) - except ValueError as e: - self.logger.error('cfg.sync_bn should be "torch" or ' - f'"mmcv", but got {sync_bn}') - raise e - if model_wrapper_cfg is None: - find_unused_parameters = self.cfg.get('find_unused_parameters', - False) - # Sets the `find_unused_parameters` parameter in - # torch.nn.parallel.DistributedDataParallel - # TODO: may use a more elegant way to get local device ID. - model = MMDistributedDataParallel( - module=model, - device_ids=[int(os.environ['LOCAL_RANK'])], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) - else: - model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore - default_args: dict = dict() - if issubclass( - model_wrapper_type, # type: ignore - DistributedDataParallel): - default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])] - default_args['module'] = model - model = MODEL_WRAPPERS.build( - model_wrapper_cfg, default_args=default_args) - return model - - def _init_model_weights(self) -> None: - """Initialize the model weights if the model has - :meth:`init_weights`""" - model = self.model.module if is_model_wrapper( - self.model) else self.model - if hasattr(model, 'init_weights'): - model.init_weights() - # sync params and buffers - for name, params in model.state_dict().items(): - broadcast(params) - - def scale_lr(self, - optim_wrapper: OptimWrapper, - auto_scale_lr: Optional[Dict] = None) -> None: - """Automatically scaling learning rate in training according to the - ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch - size. - - It scales the learning rate linearly according to the - `paper `_. - - Note: - ``scale_lr`` must be called after building optimizer wrappers - and before building parameter schedulers. - - Args: - optim_wrapper (OptimWrapper): An OptimWrapper object whose - parameter groups' learning rate need to be scaled. - auto_scale_lr (Dict, Optional): Config to scale the learning - rate automatically. It includes ``base_batch_size`` and - ``enable``. ``base_batch_size`` is the batch size that the - optimizer lr is based on. ``enable`` is the switch to turn on - and off the feature. - """ - if (auto_scale_lr is None or not auto_scale_lr.get('enable', False)): - return None - - assert 'base_batch_size' in auto_scale_lr, \ - 'Lack of `base_batch_size` in `auto_scale_lr`.' - dataloader: Union[DataLoader, Dict] = self._train_dataloader - bs = dataloader.batch_size if isinstance( - dataloader, DataLoader) else dataloader['batch_size'] - real_bs = self.world_size * bs - base_bs = auto_scale_lr['base_batch_size'] - ratio = float(real_bs) / float(base_bs) - self.logger.info(f'LR is set based on batch size of {base_bs} ' - f'and the current batch size is {real_bs}. ' - f'Scaling the original LR by {ratio}.') - - def _is_built(schedulers): - if isinstance(schedulers, dict): - return False if 'type' in schedulers else any( - _is_built(s) for s in schedulers.values()) - if isinstance(schedulers, list): - return any(_is_built(s) for s in schedulers) - return isinstance(schedulers, _ParamScheduler) - - if _is_built(self.param_schedulers): - raise RuntimeError('`scale_lr` should be called before building ' - 'ParamScheduler because ParamScheduler will ' - 'store initial lr from optimizer wrappers') - - assert isinstance(optim_wrapper, OptimWrapper), \ - '`scale_lr should be called after building OptimWrapper' - wrappers = list(optim_wrapper.values()) if isinstance( - optim_wrapper, OptimWrapperDict) else [optim_wrapper] - for wrapper in wrappers: - for group in wrapper.optimizer.param_groups: - group['lr'] = group['lr'] * ratio - - def build_optim_wrapper( - self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict] - ) -> Union[OptimWrapper, OptimWrapperDict]: - """Build optimizer wrapper. - - If ``optim_wrapper`` is a config dict for only one optimizer, - the keys must contain ``optimizer``, and ``type`` is optional. - It will build a :obj:`OptimWrapper` by default. - - If ``optim_wrapper`` is a config dict for multiple optimizers, i.e., - it has multiple keys and each key is for an optimizer wrapper. The - constructor must be specified since - :obj:`DefaultOptimizerConstructor` cannot handle the building of - training with multiple optimizers. - - If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e., - each value of ``optim_wrapper`` represents an ``OptimWrapper`` - instance. ``build_optim_wrapper`` will directly build the - :obj:`OptimWrapperDict` instance from ``optim_wrapper``. - - Args: - optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a - dict to build OptimWrapper objects. If ``optim_wrapper`` is an - OptimWrapper, just return an ``OptimizeWrapper`` instance. - - Note: - For single optimizer training, if `optim_wrapper` is a config - dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it - must contain `optimizer` to build the corresponding optimizer. - - Examples: - >>> # build an optimizer - >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)) - >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> # is also valid. - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build optimizer without `type` - >>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - maximize: False - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build multiple optimizers - >>> optim_wrapper_cfg = dict( - ... generator=dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)), - ... discriminator=dict(type='OptimWrapper', optimizer=dict( - ... type='Adam', lr=0.001)) - ... # need to customize a multiple optimizer constructor - ... constructor='CustomMultiOptimizerConstructor', - ...) - >>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - name: generator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.1 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - name: discriminator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - 'discriminator': Adam ( - Parameter Group 0 - dampening: 0 - lr: 0.02 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - - Important: - If you need to build multiple optimizers, you should implement a - MultiOptimWrapperConstructor which gets parameters passed to - corresponding optimizers and compose the ``OptimWrapperDict``. - More details about how to customize OptimizerConstructor can be - found at `optimizer-docs`_. - - Returns: - OptimWrapper: Optimizer wrapper build from ``optimizer_cfg``. - """ # noqa: E501 - if isinstance(optim_wrapper, OptimWrapper): - return optim_wrapper - if isinstance(optim_wrapper, (dict, ConfigDict, Config)): - # optimizer must be defined for single optimizer training. - optimizer = optim_wrapper.get('optimizer', None) - - # If optimizer is a built `Optimizer` instance, the optimizer - # wrapper should be built by `OPTIM_WRAPPERS` registry. - if isinstance(optimizer, Optimizer): - optim_wrapper.setdefault('type', 'OptimWrapper') - return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore - - # If `optimizer` is not None or `constructor` is defined, it means, - # optimizer wrapper will be built by optimizer wrapper - # constructor. Therefore, `build_optim_wrapper` should be called. - if optimizer is not None or 'constructor' in optim_wrapper: - return build_optim_wrapper(self.model, optim_wrapper) - else: - # if `optimizer` is not defined, it should be the case of - # training with multiple optimizers. If `constructor` is not - # defined either, each value of `optim_wrapper` must be an - # `OptimWrapper` instance since `DefaultOptimizerConstructor` - # will not handle the case of training with multiple - # optimizers. `build_optim_wrapper` will directly build the - # `OptimWrapperDict` instance from `optim_wrapper.` - optim_wrappers = OrderedDict() - for name, optim in optim_wrapper.items(): - if not isinstance(optim, OptimWrapper): - raise ValueError( - 'each item mush be an optimizer object when ' - '"type" and "constructor" are not in ' - f'optimizer, but got {name}={optim}') - optim_wrappers[name] = optim - return OptimWrapperDict(**optim_wrappers) - else: - raise TypeError('optimizer wrapper should be an OptimWrapper ' - f'object or dict, but got {optim_wrapper}') - - def _build_param_scheduler( - self, scheduler: Union[_ParamScheduler, Dict, List], - optim_wrapper: OptimWrapper) -> List[_ParamScheduler]: - """Build parameter schedulers for a single optimizer. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - optim_wrapper (OptimWrapper): An optimizer wrapper object is - passed to construct ParamScheduler object. - - Returns: - list[_ParamScheduler]: List of parameter schedulers build from - ``scheduler``. - - Note: - If the train loop is built, when building parameter schedulers, - it supports setting the max epochs/iters as the default ``end`` - of schedulers, and supports converting epoch-based schedulers - to iter-based according to the ``convert_to_iter_based`` key. - """ - if not isinstance(scheduler, Sequence): - schedulers = [scheduler] - else: - schedulers = scheduler - - param_schedulers = [] - for scheduler in schedulers: - if isinstance(scheduler, _ParamScheduler): - param_schedulers.append(scheduler) - elif isinstance(scheduler, dict): - _scheduler = copy.deepcopy(scheduler) - - # Set default end - if isinstance(self._train_loop, BaseLoop): - default_end = self.max_epochs if _scheduler.get( - 'by_epoch', True) else self.max_iters - _scheduler.setdefault('end', default_end) - self.logger.debug( - f'The `end` of {_scheduler["type"]} is not set. ' - 'Use the max epochs/iters of train loop as default.') - - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, - epoch_length=len(self.train_dataloader)))) - else: - raise TypeError( - 'scheduler should be a _ParamScheduler object or dict, ' - f'but got {scheduler}') - return param_schedulers - - def build_param_scheduler( - self, scheduler: Union[_ParamScheduler, Dict, - List]) -> ParamSchedulerType: - """Build parameter schedulers. - - ``build_param_scheduler`` should be called after - ``build_optim_wrapper`` because the building logic will change - according to the number of optimizers built by the runner. - The cases are as below: - - - Single optimizer: When only one optimizer is built and used in the - runner, ``build_param_scheduler`` will return a list of - parameter schedulers. - - Multiple optimizers: When two or more optimizers are built and used - in runner, ``build_param_scheduler`` will return a dict containing - the same keys with multiple optimizers and each value is a list of - parameter schedulers. Note that, if you want different optimizers to - use different parameter schedulers to update optimizer's - hyper-parameters, the input parameter ``scheduler`` also needs to be - a dict and its key are consistent with multiple optimizers. - Otherwise, the same parameter schedulers will be used to update - optimizer's hyper-parameters. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - - Examples: - >>> # build one scheduler - >>> optim_cfg = dict(dict(type='SGD', lr=0.01)) - >>> runner.optim_wrapper = runner.build_optim_wrapper( - >>> optim_cfg) - >>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [] # noqa: E501 - - >>> # build multiple schedulers - >>> scheduler_cfg = [ - ... dict(type='MultiStepLR', milestones=[1, 2]), - ... dict(type='StepLR', step_size=1) - ... ] - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [, # noqa: E501 - ] - - Above examples only provide the case of one optimizer and one scheduler - or multiple schedulers. If you want to know how to set parameter - scheduler when using multiple optimizers, you can find more examples - `optimizer-docs`_. - - Returns: - list[_ParamScheduler] or dict[str, list[_ParamScheduler]]: List of - parameter schedulers or a dictionary contains list of parameter - schedulers build from ``scheduler``. - - .. _optimizer-docs: - https://mmengine.readthedocs.io/en/latest/tutorials/optim_wrapper.html - """ - param_schedulers: ParamSchedulerType - if not isinstance(self.optim_wrapper, OptimWrapperDict): - # Since `OptimWrapperDict` inherits from `OptimWrapper`, - # `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell - # whether `self.optim_wrapper` is an `OptimizerWrapper` or - # `OptimWrapperDict` instance. Therefore, here we simply check - # self.optim_wrapper is not an `OptimWrapperDict` instance and - # then assert it is an OptimWrapper instance. - assert isinstance(self.optim_wrapper, OptimWrapper), ( - '`build_optimizer` should be called before' - '`build_param_scheduler` because the latter depends ' - 'on the former') - param_schedulers = self._build_param_scheduler( - scheduler, self.optim_wrapper) # type: ignore - return param_schedulers - else: - param_schedulers = dict() - for name, optimizer in self.optim_wrapper.items(): - if isinstance(scheduler, dict) and 'type' not in scheduler: - # scheduler is a dict and each item is a ParamScheduler - # object or a config to build ParamScheduler objects - param_schedulers[name] = self._build_param_scheduler( - scheduler[name], optimizer) - else: - param_schedulers[name] = self._build_param_scheduler( - scheduler, optimizer) - - return param_schedulers - - def build_evaluator(self, evaluator: Union[Dict, List, - Evaluator]) -> Evaluator: - """Build evaluator. - - Examples of ``evaluator``:: - - # evaluator could be a built Evaluator instance - evaluator = Evaluator(metrics=[ToyMetric()]) - - # evaluator can also be a list of dict - evaluator = [ - dict(type='ToyMetric1'), - dict(type='ToyEvaluator2') - ] - - # evaluator can also be a list of built metric - evaluator = [ToyMetric1(), ToyMetric2()] - - # evaluator can also be a dict with key metrics - evaluator = dict(metrics=ToyMetric()) - # metric is a list - evaluator = dict(metrics=[ToyMetric()]) - - Args: - evaluator (Evaluator or dict or list): An Evaluator object or a - config dict or list of config dict used to build an Evaluator. - - Returns: - Evaluator: Evaluator build from ``evaluator``. - """ - if isinstance(evaluator, Evaluator): - return evaluator - elif isinstance(evaluator, dict): - # if `metrics` in dict keys, it means to build customized evalutor - if 'metrics' in evaluator: - evaluator.setdefault('type', 'Evaluator') - return EVALUATOR.build(evaluator) - # otherwise, default evalutor will be built - else: - return Evaluator(evaluator) # type: ignore - elif isinstance(evaluator, list): - # use the default `Evaluator` - return Evaluator(evaluator) # type: ignore - else: - raise TypeError( - 'evaluator should be one of dict, list of dict, and Evaluator' - f', but got {evaluator}') - - @staticmethod - def build_dataloader(dataloader: Union[DataLoader, Dict], - seed: Optional[int] = None, - diff_rank_seed: bool = False) -> DataLoader: - """Build dataloader. - - The method builds three components: - - - Dataset - - Sampler - - Dataloader - - An example of ``dataloader``:: - - dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=9 - ) - - Args: - dataloader (DataLoader or dict): A Dataloader object or a dict to - build Dataloader object. If ``dataloader`` is a Dataloader - object, just returns itself. - seed (int, optional): Random seed. Defaults to None. - diff_rank_seed (bool): Whether or not set different seeds to - different ranks. If True, the seed passed to sampler is set - to None, in order to synchronize the seeds used in samplers - across different ranks. - - - Returns: - Dataloader: DataLoader build from ``dataloader_cfg``. - """ - if isinstance(dataloader, DataLoader): - return dataloader - - dataloader_cfg = copy.deepcopy(dataloader) - - # build dataset - dataset_cfg = dataloader_cfg.pop('dataset') - if isinstance(dataset_cfg, dict): - dataset = DATASETS.build(dataset_cfg) - if hasattr(dataset, 'full_init'): - dataset.full_init() - else: - # fallback to raise error in dataloader - # if `dataset_cfg` is not a valid type - dataset = dataset_cfg - - num_batch_per_epoch = dataloader_cfg.pop('num_batch_per_epoch', None) - if num_batch_per_epoch is not None: - world_size = get_world_size() - num_samples = ( - num_batch_per_epoch * _get_batch_size(dataloader_cfg) * - world_size) - dataset = _SlicedDataset(dataset, num_samples) - - # build sampler - sampler_cfg = dataloader_cfg.pop('sampler') - if isinstance(sampler_cfg, dict): - sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_args=dict(dataset=dataset, seed=sampler_seed)) - else: - # fallback to raise error in dataloader - # if `sampler_cfg` is not a valid type - sampler = sampler_cfg - - # build batch sampler - batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) - if batch_sampler_cfg is None: - batch_sampler = None - elif isinstance(batch_sampler_cfg, dict): - batch_sampler = DATA_SAMPLERS.build( - batch_sampler_cfg, - default_args=dict( - sampler=sampler, - batch_size=dataloader_cfg.pop('batch_size'))) - else: - # fallback to raise error in dataloader - # if `batch_sampler_cfg` is not a valid type - batch_sampler = batch_sampler_cfg - - # build dataloader - init_fn: Optional[partial] - - if 'worker_init_fn' in dataloader_cfg: - worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') - worker_init_fn_type = worker_init_fn_cfg.pop('type') - if isinstance(worker_init_fn_type, str): - worker_init_fn = FUNCTIONS.get(worker_init_fn_type) - elif callable(worker_init_fn_type): - worker_init_fn = worker_init_fn_type - else: - raise TypeError( - 'type of worker_init_fn should be string or callable ' - f'object, but got {type(worker_init_fn_type)}') - assert callable(worker_init_fn) - init_fn = partial(worker_init_fn, - **worker_init_fn_cfg) # type: ignore - else: - if seed is not None: - disable_subprocess_warning = dataloader_cfg.pop( - 'disable_subprocess_warning', False) - assert isinstance(disable_subprocess_warning, bool), ( - 'disable_subprocess_warning should be a bool, but got ' - f'{type(disable_subprocess_warning)}') - init_fn = partial( - default_worker_init_fn, - num_workers=dataloader_cfg.get('num_workers'), - rank=get_rank(), - seed=seed, - disable_subprocess_warning=disable_subprocess_warning) - else: - init_fn = None - - # `persistent_workers` requires pytorch version >= 1.7 - if ('persistent_workers' in dataloader_cfg - and digit_version(TORCH_VERSION) < digit_version('1.7.0')): - print_log( - '`persistent_workers` is only available when ' - 'pytorch version >= 1.7', - logger='current', - level=logging.WARNING) - dataloader_cfg.pop('persistent_workers') - - # The default behavior of `collat_fn` in dataloader is to - # merge a list of samples to form a mini-batch of Tensor(s). - # However, in mmengine, if `collate_fn` is not defined in - # dataloader_cfg, `pseudo_collate` will only convert the list of - # samples into a dict without stacking the batch tensor. - collate_fn_cfg = dataloader_cfg.pop('collate_fn', - dict(type='pseudo_collate')) - if isinstance(collate_fn_cfg, dict): - collate_fn_type = collate_fn_cfg.pop('type') - if isinstance(collate_fn_type, str): - collate_fn = FUNCTIONS.get(collate_fn_type) - else: - collate_fn = collate_fn_type - collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore - elif callable(collate_fn_cfg): - collate_fn = collate_fn_cfg - else: - raise TypeError( - 'collate_fn should be a dict or callable object, but got ' - f'{collate_fn_cfg}') - data_loader = DataLoader( - dataset=dataset, - sampler=sampler if batch_sampler is None else None, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - worker_init_fn=init_fn, - **dataloader_cfg) - return data_loader - - def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: - """Build training loop. - - Examples of ``loop``:: - - # `EpochBasedTrainLoop` will be used - loop = dict(by_epoch=True, max_epochs=3) - - # `IterBasedTrainLoop` will be used - loop = dict(by_epoch=False, max_epochs=3) - - # custom training loop - loop = dict(type='CustomTrainLoop', max_epochs=3) - - Args: - loop (BaseLoop or dict): A training loop or a dict to build - training loop. If ``loop`` is a training loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Training loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError( - f'train_loop should be a Loop object or dict, but got {loop}') - - loop_cfg = copy.deepcopy(loop) - - if 'type' in loop_cfg and 'by_epoch' in loop_cfg: - raise RuntimeError( - 'Only one of `type` or `by_epoch` can exist in `loop_cfg`.') - - if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, dataloader=self._train_dataloader)) - else: - by_epoch = loop_cfg.pop('by_epoch') - if by_epoch: - loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) - else: - loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) - return loop # type: ignore - - def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: - """Build validation loop. - - Examples of ``loop``: - - # `ValLoop` will be used - loop = dict() - - # custom validation loop - loop = dict(type='CustomValLoop') - - Args: - loop (BaseLoop or dict): A validation loop or a dict to build - validation loop. If ``loop`` is a validation loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Validation loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError( - f'val_loop should be a Loop object or dict, but got {loop}') - - loop_cfg = copy.deepcopy(loop) - - if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator)) - else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator) # type: ignore - - return loop # type: ignore - - def build_test_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: - """Build test loop. - - Examples of ``loop``:: - - # `TestLoop` will be used - loop = dict() - - # custom test loop - loop = dict(type='CustomTestLoop') - - Args: - loop (BaseLoop or dict): A test loop or a dict to build test loop. - If ``loop`` is a test loop object, just returns itself. - - Returns: - :obj:`BaseLoop`: Test loop object build from ``loop_cfg``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError( - f'test_loop should be a Loop object or dict, but got {loop}') - - loop_cfg = copy.deepcopy(loop) # type: ignore - - if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator)) - else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator) # type: ignore - - return loop # type: ignore - - def build_log_processor( - self, log_processor: Union[LogProcessor, Dict]) -> LogProcessor: - """Build test log_processor. - - Examples of ``log_processor``: - - # `LogProcessor` will be used - log_processor = dict() - - # custom log_processor - log_processor = dict(type='CustomLogProcessor') - - Args: - log_processor (LogProcessor or dict): A log processor or a dict - to build log processor. If ``log_processor`` is a log processor - object, just returns itself. - - Returns: - :obj:`LogProcessor`: Log processor object build from - ``log_processor_cfg``. - """ - if isinstance(log_processor, LogProcessor): - return log_processor - elif not isinstance(log_processor, dict): - raise TypeError( - 'log processor should be a LogProcessor object or dict, but' - f'got {log_processor}') - - log_processor_cfg = copy.deepcopy(log_processor) # type: ignore - - if 'type' in log_processor_cfg: - log_processor = LOG_PROCESSORS.build(log_processor_cfg) - else: - log_processor = LogProcessor(**log_processor_cfg) # type: ignore - - return log_processor # type: ignore - - def get_hooks_info(self) -> str: - # Get hooks info in each stage - stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages} - for hook in self.hooks: - try: - priority = Priority(hook.priority).name # type: ignore - except ValueError: - priority = hook.priority # type: ignore - classname = hook.__class__.__name__ - hook_info = f'({priority:<12}) {classname:<35}' - for trigger_stage in hook.get_triggered_stages(): - stage_hook_map[trigger_stage].append(hook_info) - - stage_hook_infos = [] - for stage in Hook.stages: - hook_infos = stage_hook_map[stage] - if len(hook_infos) > 0: - info = f'{stage}:\n' - info += '\n'.join(hook_infos) - info += '\n -------------------- ' - stage_hook_infos.append(info) - return '\n'.join(stage_hook_infos) - - def load_or_resume(self) -> None: - """Load or resume checkpoint.""" - if self._has_loaded: - return None - - # decide to load from checkpoint or resume from checkpoint - resume_from = None - if self._resume and self._load_from is None: - # auto resume from the latest checkpoint - resume_from = find_latest_checkpoint(self.work_dir) - self.logger.info( - f'Auto resumed from the latest checkpoint {resume_from}.') - elif self._resume and self._load_from is not None: - # resume from the specified checkpoint - resume_from = self._load_from - - if resume_from is not None: - self.resume(resume_from) - self._has_loaded = True - elif self._load_from is not None: - self.load_checkpoint(self._load_from) - self._has_loaded = True - - def train(self) -> nn.Module: - """Launch training. - - Returns: - nn.Module: The model after training. - """ - if is_model_wrapper(self.model): - ori_model = self.model.module - else: - ori_model = self.model - assert hasattr(ori_model, 'train_step'), ( - 'If you want to train your model, please make sure your model ' - 'has implemented `train_step`.') - - if self._val_loop is not None: - assert hasattr(ori_model, 'val_step'), ( - 'If you want to validate your model, please make sure your ' - 'model has implemented `val_step`.') - - if self._train_loop is None: - raise RuntimeError( - '`self._train_loop` should not be None when calling train ' - 'method. Please provide `train_dataloader`, `train_cfg`, ' - '`optimizer` and `param_scheduler` arguments when ' - 'initializing runner.') - - self._train_loop = self.build_train_loop( - self._train_loop) # type: ignore - - # `build_optimizer` should be called before `build_param_scheduler` - # because the latter depends on the former - self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) - # Automatically scaling lr by linear scaling rule - self.scale_lr(self.optim_wrapper, self.auto_scale_lr) - - if self.param_schedulers is not None: - self.param_schedulers = self.build_param_scheduler( # type: ignore - self.param_schedulers) # type: ignore - - if self._val_loop is not None: - self._val_loop = self.build_val_loop( - self._val_loop) # type: ignore - # TODO: add a contextmanager to avoid calling `before_run` many times - self.call_hook('before_run') - - # initialize the model weights - self._init_model_weights() - - # try to enable activation_checkpointing feature - modules = self.cfg.get('activation_checkpointing', None) - if modules is not None: - self.logger.info(f'Enabling the "activation_checkpointing" feature' - f' for sub-modules: {modules}') - turn_on_activation_checkpointing(ori_model, modules) - - # try to enable efficient_conv_bn_eval feature - modules = self.cfg.get('efficient_conv_bn_eval', None) - if modules is not None: - self.logger.info(f'Enabling the "efficient_conv_bn_eval" feature' - f' for sub-modules: {modules}') - turn_on_efficient_conv_bn_eval(ori_model, modules) - - # make sure checkpoint-related hooks are triggered after `before_run` - self.load_or_resume() - - # Initiate inner count of `optim_wrapper`. - self.optim_wrapper.initialize_count_status( - self.model, - self._train_loop.iter, # type: ignore - self._train_loop.max_iters) # type: ignore - - # Maybe compile the model according to options in self.cfg.compile - # This must be called **AFTER** model has been wrapped. - self._maybe_compile('train_step') - - model = self.train_loop.run() # type: ignore - self.call_hook('after_run') - return model - - def val(self) -> dict: - """Launch validation. - - Returns: - dict: A dict of metrics on validation set. - """ - if self._val_loop is None: - raise RuntimeError( - '`self._val_loop` should not be None when calling val method.' - 'Please provide `val_dataloader`, `val_cfg` and ' - '`val_evaluator` arguments when initializing runner.') - - self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - - self.call_hook('before_run') - - # make sure checkpoint-related hooks are triggered after `before_run` - self.load_or_resume() - - metrics = self.val_loop.run() # type: ignore - self.call_hook('after_run') - return metrics - - def test(self) -> dict: - """Launch test. - - Returns: - dict: A dict of metrics on testing set. - """ - if self._test_loop is None: - raise RuntimeError( - '`self._test_loop` should not be None when calling test ' - 'method. Please provide `test_dataloader`, `test_cfg` and ' - '`test_evaluator` arguments when initializing runner.') - - self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - - self.call_hook('before_run') - - # make sure checkpoint-related hooks are triggered after `before_run` - self.load_or_resume() - - metrics = self.test_loop.run() # type: ignore - self.call_hook('after_run') - return metrics - - def call_hook(self, fn_name: str, **kwargs) -> None: - """Call all hooks. - - Args: - fn_name (str): The function name in each hook to be called, such as - "before_train_epoch". - **kwargs: Keyword arguments passed to hook. - """ - for hook in self._hooks: - # support adding additional custom hook methods - if hasattr(hook, fn_name): - try: - getattr(hook, fn_name)(self, **kwargs) - except TypeError as e: - raise TypeError(f'{e} in {hook}') from None - - def register_hook( - self, - hook: Union[Hook, Dict], - priority: Optional[Union[str, int, Priority]] = None) -> None: - """Register a hook into the hook list. - - The hook will be inserted into a priority queue, with the specified - priority (See :class:`Priority` for details of priorities). - For hooks with the same priority, they will be triggered in the same - order as they are registered. - - Priority of hook will be decided with the following priority: - - - ``priority`` argument. If ``priority`` is given, it will be priority - of hook. - - If ``hook`` argument is a dict and ``priority`` in it, the priority - will be the value of ``hook['priority']``. - - If ``hook`` argument is a dict but ``priority`` not in it or ``hook`` - is an instance of ``hook``, the priority will be ``hook.priority``. - - Args: - hook (:obj:`Hook` or dict): The hook to be registered. - priority (int or str or :obj:`Priority`, optional): Hook priority. - Lower value means higher priority. - """ - if not isinstance(hook, (Hook, dict)): - raise TypeError( - f'hook should be an instance of Hook or dict, but got {hook}') - - _priority = None - if isinstance(hook, dict): - if 'priority' in hook: - _priority = hook.pop('priority') - - hook_obj = HOOKS.build(hook) - else: - hook_obj = hook - - if priority is not None: - hook_obj.priority = priority - elif _priority is not None: - hook_obj.priority = _priority - - inserted = False - for i in range(len(self._hooks) - 1, -1, -1): - if get_priority(hook_obj.priority) >= get_priority( - self._hooks[i].priority): - self._hooks.insert(i + 1, hook_obj) - inserted = True - break - if not inserted: - self._hooks.insert(0, hook_obj) - - def register_default_hooks( - self, - hooks: Optional[Dict[str, Union[Hook, Dict]]] = None) -> None: - """Register default hooks into hook list. - - ``hooks`` will be registered into runner to execute some default - actions like updating model parameters or saving checkpoints. - - Default hooks and their priorities: - - +----------------------+-------------------------+ - | Hooks | Priority | - +======================+=========================+ - | RuntimeInfoHook | VERY_HIGH (10) | - +----------------------+-------------------------+ - | IterTimerHook | NORMAL (50) | - +----------------------+-------------------------+ - | DistSamplerSeedHook | NORMAL (50) | - +----------------------+-------------------------+ - | LoggerHook | BELOW_NORMAL (60) | - +----------------------+-------------------------+ - | ParamSchedulerHook | LOW (70) | - +----------------------+-------------------------+ - | CheckpointHook | VERY_LOW (90) | - +----------------------+-------------------------+ - - If ``hooks`` is None, above hooks will be registered by - default:: - - default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - sampler_seed=dict(type='DistSamplerSeedHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - ) - - If not None, ``hooks`` will be merged into ``default_hooks``. - If there are None value in default_hooks, the corresponding item will - be popped from ``default_hooks``:: - - hooks = dict(timer=None) - - The final registered default hooks will be :obj:`RuntimeInfoHook`, - :obj:`DistSamplerSeedHook`, :obj:`LoggerHook`, - :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`. - - Args: - hooks (dict[str, Hook or dict], optional): Default hooks or configs - to be registered. - """ - default_hooks: dict = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - sampler_seed=dict(type='DistSamplerSeedHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - ) - if hooks is not None: - for name, hook in hooks.items(): - if name in default_hooks and hook is None: - # remove hook from _default_hooks - default_hooks.pop(name) - else: - assert hook is not None - default_hooks[name] = hook - - for hook in default_hooks.values(): - self.register_hook(hook) - - def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None: - """Register custom hooks into hook list. - - Args: - hooks (list[Hook | dict]): List of hooks or configs to be - registered. - """ - for hook in hooks: - self.register_hook(hook) - - def register_hooks( - self, - default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, - custom_hooks: Optional[List[Union[Hook, Dict]]] = None) -> None: - """Register default hooks and custom hooks into hook list. - - Args: - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks - to execute default actions like updating model parameters and - saving checkpoints. Defaults to None. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - """ - self.register_default_hooks(default_hooks) - - if custom_hooks is not None: - self.register_custom_hooks(custom_hooks) - - def resume(self, - filename: str, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: Union[str, Callable] = 'default') -> None: - """Resume model from checkpoint. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - """ - if map_location == 'default': - device = get_device() - checkpoint = self.load_checkpoint(filename, map_location=device) - else: - checkpoint = self.load_checkpoint( - filename, map_location=map_location) - - self.train_loop._epoch = checkpoint['meta']['epoch'] - self.train_loop._iter = checkpoint['meta']['iter'] - - # check whether the number of GPU used for current experiment - # is consistent with resuming from checkpoint - if 'config' in checkpoint['meta']: - config = mmengine.Config.fromstring( - checkpoint['meta']['config'], file_format='.py') - previous_gpu_ids = config.get('gpu_ids', None) - if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 - and len(previous_gpu_ids) != self._world_size): - # TODO, should we modify the iteration? - if (self.auto_scale_lr is None - or not self.auto_scale_lr.get('enable', False)): - raise RuntimeError( - 'Number of GPUs used for current experiment is not ' - 'consistent with the checkpoint being resumed from. ' - 'This will result in poor performance due to the ' - 'learning rate. You must set the ' - '`auto_scale_lr` parameter for Runner and make ' - '`auto_scale_lr["enable"]=True`.') - else: - self.logger.info( - 'Number of GPU used for current experiment is not ' - 'consistent with resuming from checkpoint but the ' - 'leaning rate will be adjusted according to the ' - f'setting in auto_scale_lr={self.auto_scale_lr}') - - # resume random seed - resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness_cfg.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness_cfg.update(seed=resumed_seed) - self.set_randomness(**self._randomness_cfg) - - resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None) - dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None) - - # `resumed_dataset_meta` and `dataset_meta` could be object like - # np.ndarray, which cannot be directly judged as equal or not, - # therefore we just compared their dumped results. - if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta): - self.logger.warning( - 'The dataset metainfo from the resumed checkpoint is ' - 'different from the current training dataset, please ' - 'check the correctness of the checkpoint or the training ' - 'dataset.') - - self.message_hub.load_state_dict(checkpoint['message_hub']) - - # resume optimizer - if 'optimizer' in checkpoint and resume_optimizer: - self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) - self.optim_wrapper.load_state_dict( # type: ignore - checkpoint['optimizer']) - - # resume param scheduler - if resume_param_scheduler and self.param_schedulers is None: - self.logger.warning( - '`resume_param_scheduler` is True but `self.param_schedulers` ' - 'is None, so skip resuming parameter schedulers') - resume_param_scheduler = False - if 'param_schedulers' in checkpoint and resume_param_scheduler: - self.param_schedulers = self.build_param_scheduler( # type: ignore - self.param_schedulers) # type: ignore - if isinstance(self.param_schedulers, dict): - for name, schedulers in self.param_schedulers.items(): - for scheduler, ckpt_scheduler in zip( - schedulers, checkpoint['param_schedulers'][name]): - scheduler.load_state_dict(ckpt_scheduler) - else: - for scheduler, ckpt_scheduler in zip( - self.param_schedulers, # type: ignore - checkpoint['param_schedulers']): - scheduler.load_state_dict(ckpt_scheduler) - - self._has_loaded = True - - self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}') - - def load_checkpoint(self, - filename: str, - map_location: Union[str, Callable] = 'cpu', - strict: bool = False, - revise_keys: list = [(r'^module.', '')]): - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - """ - checkpoint = _load_checkpoint(filename, map_location=map_location) - - # Add comments to describe the usage of `after_load_ckpt` - self.call_hook('after_load_checkpoint', checkpoint=checkpoint) - - if is_model_wrapper(self.model): - model = self.model.module - else: - model = self.model - - checkpoint = _load_checkpoint_to_model( - model, checkpoint, strict, revise_keys=revise_keys) - - self._has_loaded = True - - self.logger.info(f'Load checkpoint from {filename}') - - return checkpoint - - @master_only - def save_checkpoint( - self, - out_dir: str, - filename: str, - file_client_args: Optional[dict] = None, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - meta: Optional[dict] = None, - by_epoch: bool = True, - backend_args: Optional[dict] = None, - ): - """Save checkpoints. - - ``CheckpointHook`` invokes this method to save checkpoints - periodically. - - Args: - out_dir (str): The directory that checkpoints are saved. - filename (str): The checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for - details. Defaults to None. It will be deprecated in future. - Please use `backend_args` instead. - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - meta (dict, optional): The meta information to be saved in the - checkpoint. Defaults to None. - by_epoch (bool): Decide the number of epoch or iteration saved in - checkpoint. Defaults to True. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - """ - if meta is None: - meta = {} - elif not isinstance(meta, dict): - raise TypeError( - f'meta should be a dict or None, but got {type(meta)}') - - if by_epoch: - # self.epoch increments 1 after - # `self.call_hook('after_train_epoch)` but `save_checkpoint` is - # called by `after_train_epoch`` method of `CheckpointHook` so - # `epoch` should be `self.epoch + 1` - meta.setdefault('epoch', self.epoch + 1) - meta.setdefault('iter', self.iter) - else: - meta.setdefault('epoch', self.epoch) - meta.setdefault('iter', self.iter + 1) - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', DeprecationWarning) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set at ' - 'the same time.') - - file_client = FileClient.infer_client(file_client_args, out_dir) - filepath = file_client.join_path(out_dir, filename) - else: - filepath = join_path( # type: ignore - out_dir, filename, backend_args=backend_args) - - meta.update( - cfg=self.cfg.pretty_text, - seed=self.seed, - experiment_name=self.experiment_name, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine_version=mmengine.__version__ + get_git_hash()) - - if hasattr(self.train_dataloader.dataset, 'metainfo'): - meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) - - if is_model_wrapper(self.model): - model = self.model.module - else: - model = self.model - - checkpoint = { - 'meta': - meta, - 'state_dict': - weights_to_cpu(model.state_dict()), - 'message_hub': - apply_to(self.message_hub.state_dict(), - lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()), - } - # save optimizer state dict to checkpoint - if save_optimizer: - if isinstance(self.optim_wrapper, OptimWrapper): - checkpoint['optimizer'] = apply_to( - self.optim_wrapper.state_dict(), - lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()) - else: - raise TypeError( - 'self.optim_wrapper should be an `OptimWrapper` ' - 'or `OptimWrapperDict` instance, but got ' - f'{self.optim_wrapper}') - - # save param scheduler state dict - if save_param_scheduler and self.param_schedulers is None: - self.logger.warning( - '`save_param_scheduler` is True but `self.param_schedulers` ' - 'is None, so skip saving parameter schedulers') - save_param_scheduler = False - if save_param_scheduler: - if isinstance(self.param_schedulers, dict): - checkpoint['param_schedulers'] = dict() - for name, schedulers in self.param_schedulers.items(): - checkpoint['param_schedulers'][name] = [] - for scheduler in schedulers: - state_dict = scheduler.state_dict() - checkpoint['param_schedulers'][name].append(state_dict) - else: - checkpoint['param_schedulers'] = [] - for scheduler in self.param_schedulers: # type: ignore - state_dict = scheduler.state_dict() # type: ignore - checkpoint['param_schedulers'].append(state_dict) - - self.call_hook('before_save_checkpoint', checkpoint=checkpoint) - save_checkpoint( - checkpoint, - filepath, - file_client_args=file_client_args, - backend_args=backend_args) - - @master_only - def dump_config(self) -> None: - """Dump config to `work_dir`.""" - if self.cfg.filename is not None: - filename = osp.basename(self.cfg.filename) - else: - filename = f'{self.timestamp}.py' - self.cfg.dump(osp.join(self.work_dir, filename)) - - def _check_scheduler_cfg( - self, param_scheduler: Optional[Union[dict, list, - _ParamScheduler]]) -> None: - """Parse `param_scheduler` to a list of parameter schedulers, or a - `dict` of which each value is a list of parameter schedulers. - - If only one optimizer is used, the parsed config should be a - list of parameter scheduler configs or instances. If multiple - optimizers are used, the parsed config should be `dict`. - Its key should be consistent with the optimizer `dict` and its value - should be a list of parameter scheduler configs or instances. See - :meth:`build_param_scheduler` for more details. - - Examples: - >>> # valid scheduler: - >>> # empty scheduler - >>> scheduler = None - >>> # Single scheduler - >>> scheduler = dict(type='MultiStepLR', milestones=[1, 2]) - >>> # Single list schedulers - >>> scheduler = [dict(type='MultiStepLR', milestones=[1, 2]), - >>> dict(type='MultiStepLR', milestones=[2, 3])] - >>> # `dict` of schedulers - >>> scheduler = dict(linear1=dict(type='MultiStepLR', milestones=[1, 2]), - >>> linear2=dict(type='MultiStepLR', milestones=[1, 2])) - >>> # `dict` of `list` of schedulers - >>> scheduler = dict(linear1=[dict(type='MultiStepLR', milestones=[1, 2])], - >>> linear2=[dict(type='MultiStepLR', milestones=[1, 2])]) - >>> # Single built scheduler - >>> from mmengine.optim import MultiStepLR - >>> scheduler = MultiStepLR(milestones=[1, 2], optimizer=optimizer) - >>> # Single built list schedulers - >>> scheduler = [MultiStepLR(milestones=[1, 2], optimizer=optimizer)] - >>> # dict of built scheduler - >>> scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer), - >>> linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer)) - >>> # dict of built list schedulers - >>> scheduler = dict(linear1=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)], - >>> linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)]) - - Args: - param_scheduler (dict or list): The original parameter scheduler. - """ # noqa: E501 - if param_scheduler is None: - return - if isinstance(param_scheduler, _ParamScheduler): - return - if is_seq_of(param_scheduler, _ParamScheduler): - return - - if is_seq_of(param_scheduler, dict): - for _param_scheduler in param_scheduler: - assert 'type' in _param_scheduler, ( - 'Each parameter scheduler should contain the key type, ' - f'but got {_param_scheduler}') - elif isinstance(param_scheduler, dict): - if 'type' not in param_scheduler: - for key, _param_scheduler in param_scheduler.items(): - assert isinstance( - _param_scheduler, - (dict, tuple, list, _ParamScheduler)), ( - 'Each value of `param_scheduler` should be a ' - f'dict or a list, but got {_param_scheduler} with ' - f'type {type(_ParamScheduler)}') - - else: - raise TypeError( - '`param_scheduler` should be a `_ParamScheduler`, `dict`, ' - f'list or a tuple, but got {type(param_scheduler)}. If ' - '`param_scheduler` is a list of dict, it means a list of ' - 'scheduler configs for single optimizer. If it is a dict and ' - 'contains key `type`, it means a scheduler config for a ' - 'single optimizer. If it does not contain key `type`, it ' - 'means multiple lists of schedulers for multiple optimizers.') - - def _log_env(self, env_cfg: dict) -> None: - """Logging environment information of the current task. - - Args: - env_cfg (dict): The environment config of the runner. - """ - # Collect and log environment information. - env = collect_env() - runtime_env = OrderedDict() - runtime_env.update(env_cfg) - runtime_env.update(self._randomness_cfg) - runtime_env['seed'] = self._seed - runtime_env['Distributed launcher'] = self._launcher - runtime_env['Distributed training'] = self._distributed - runtime_env['GPU number'] = self._world_size - - env_info = '\n ' + '\n '.join(f'{k}: {v}' - for k, v in env.items()) - runtime_env_info = '\n ' + '\n '.join( - f'{k}: {v}' for k, v in runtime_env.items()) - dash_line = '-' * 60 - self.logger.info('\n' + dash_line + '\nSystem environment:' + - env_info + '\n' - '\nRuntime environment:' + runtime_env_info + '\n' + - dash_line + '\n') - - if self.cfg._cfg_dict: - self.logger.info(f'Config:\n{self.cfg.pretty_text}') - - def _maybe_compile(self, target: str) -> None: - """Use `torch.compile` to optimize model/wrapped_model.""" - compile_cfg = self.cfg.get('compile', None) - if compile_cfg is None: - # no compile options given, won't compile - return - - if isinstance(compile_cfg, bool): - if not compile_cfg: - # compile=False, compilation is disabled - return - # compile=True, use default configurations - compile_cfg = dict() - - assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( - 'PyTorch >= 2.0.0 is required to enable torch.compile') - assert isinstance(compile_cfg, dict), ( - f'`compile` should be a dict or bool, got {type(compile_cfg)}') - - func = getattr(self.model, target) - compiled_func = torch.compile(func, **compile_cfg) - setattr(self.model, target, compiled_func) - self.logger.info('Model has been "compiled". The first few iterations' - ' will be slow, please be patient.') diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py deleted file mode 100644 index b91025eb07..0000000000 --- a/mmengine/runner/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import random -from typing import List, Optional, Tuple - -import numpy as np -import torch -from torch.utils.data import DataLoader - -from mmengine.device import is_cuda_available, is_musa_available -from mmengine.dist import get_rank, sync_random_seed -from mmengine.logging import print_log -from mmengine.utils import digit_version, is_list_of -from mmengine.utils.dl_utils import TORCH_VERSION - - -def calc_dynamic_intervals( - start_interval: int, - dynamic_interval_list: Optional[List[Tuple[int, int]]] = None -) -> Tuple[List[int], List[int]]: - """Calculate dynamic intervals. - - Args: - start_interval (int): The interval used in the beginning. - dynamic_interval_list (List[Tuple[int, int]], optional): The - first element in the tuple is a milestone and the second - element is a interval. The interval is used after the - corresponding milestone. Defaults to None. - - Returns: - Tuple[List[int], List[int]]: a list of milestone and its corresponding - intervals. - """ - if dynamic_interval_list is None: - return [0], [start_interval] - - assert is_list_of(dynamic_interval_list, tuple) - - dynamic_milestones = [0] - dynamic_milestones.extend( - [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) - dynamic_intervals = [start_interval] - dynamic_intervals.extend( - [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) - return dynamic_milestones, dynamic_intervals - - -def set_random_seed(seed: Optional[int] = None, - deterministic: bool = False, - diff_rank_seed: bool = False) -> int: - """Set random seed. - - Args: - seed (int, optional): Seed to be used. - deterministic (bool): Whether to set the deterministic option for - CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` - to True and `torch.backends.cudnn.benchmark` to False. - Defaults to False. - diff_rank_seed (bool): Whether to add rank number to the random seed to - have different random seed in different threads. Defaults to False. - """ - if seed is None: - seed = sync_random_seed() - - if diff_rank_seed: - rank = get_rank() - seed += rank - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - # torch.cuda.manual_seed(seed) - if is_cuda_available(): - torch.cuda.manual_seed_all(seed) - elif is_musa_available(): - torch.musa.manual_seed_all(seed) - # os.environ['PYTHONHASHSEED'] = str(seed) - if deterministic: - if torch.backends.cudnn.benchmark: - print_log( - 'torch.backends.cudnn.benchmark is going to be set as ' - '`False` to cause cuDNN to deterministically select an ' - 'algorithm', - logger='current', - level=logging.WARNING) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): - torch.use_deterministic_algorithms(True) - return seed - - -def _get_batch_size(dataloader: dict): - if isinstance(dataloader, dict): - if 'batch_size' in dataloader: - return dataloader['batch_size'] - elif ('batch_sampler' in dataloader - and 'batch_size' in dataloader['batch_sampler']): - return dataloader['batch_sampler']['batch_size'] - else: - raise ValueError('Please set batch_size in `Dataloader` or ' - '`batch_sampler`') - elif isinstance(dataloader, DataLoader): - return dataloader.batch_sampler.batch_size - else: - raise ValueError('dataloader should be a dict or a Dataloader ' - f'instance, but got {type(dataloader)}') diff --git a/mmengine/structures/__init__.py b/mmengine/structures/__init__.py deleted file mode 100644 index d4d94fd1f7..0000000000 --- a/mmengine/structures/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_data_element import BaseDataElement -from .instance_data import InstanceData -from .label_data import LabelData -from .pixel_data import PixelData - -__all__ = ['BaseDataElement', 'InstanceData', 'LabelData', 'PixelData'] diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py deleted file mode 100644 index 8ac5a3d27d..0000000000 --- a/mmengine/structures/base_data_element.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Any, Iterator, Optional, Tuple, Type, Union - -import numpy as np -import torch - - -class BaseDataElement: - """A base data interface that supports Tensor-like and dict-like - operations. - - A typical data elements refer to predicted results or ground truth labels - on a task, such as predicted bboxes, instance masks, semantic - segmentation masks, etc. Because groundtruth labels and predicted results - often have similar properties (for example, the predicted bboxes and the - groundtruth bboxes), MMEngine uses the same abstract data interface to - encapsulate predicted results and groundtruth labels, and it is recommended - to use different name conventions to distinguish them, such as using - ``gt_instances`` and ``pred_instances`` to distinguish between labels and - predicted results. Additionally, we distinguish data elements at instance - level, pixel level, and label level. Each of these types has its own - characteristics. Therefore, MMEngine defines the base class - ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and - ``LabelData`` inheriting from ``BaseDataElement`` to represent different - types of ground truth labels or predictions. - - Another common data element is sample data. A sample data consists of input - data (such as an image) and its annotations and predictions. In general, - an image can have multiple types of annotations and/or predictions at the - same time (for example, both pixel-level semantic segmentation annotations - and instance-level detection bboxes annotations). All labels and - predictions of a training sample are often passed between Dataset, Model, - Visualizer, and Evaluator components. In order to simplify the interface - between components, we can treat them as a large data element and - encapsulate them. Such data elements are generally called XXDataSample in - the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` - allows `BaseDataElement` as its attribute. Such a class generally - encapsulates all the data of a sample in the algorithm library, and its - attributes generally are various types of data elements. For example, - MMDetection is assigned by the BaseDataElement to encapsulate all the data - elements of the sample labeling and prediction of a sample in the - algorithm library. - - The attributes in ``BaseDataElement`` are divided into two parts, - the ``metainfo`` and the ``data`` respectively. - - - ``metainfo``: Usually contains the - information about the image such as filename, - image_shape, pad_shape, etc. The attributes can be accessed or - modified by dict-like or object-like operations, such as - ``.`` (for data access and modification), ``in``, ``del``, - ``pop(str)``, ``get(str)``, ``metainfo_keys()``, - ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for - set or change key-value pairs in metainfo). - - - ``data``: Annotations or model predictions are - stored. The attributes can be accessed or modified by - dict-like or object-like operations, such as - ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, - ``values()``, ``items()``. Users can also apply tensor-like - methods to all :obj:`torch.Tensor` in the ``data_fields``, - such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, - ``to_tensor()``, ``.detach()``. - - Args: - metainfo (dict, optional): A dict contains the meta information - of single image, such as ``dict(img_shape=(512, 512, 3), - scale_factor=(1, 1, 1, 1))``. Defaults to None. - kwargs (dict, optional): A dict contains annotations of single image or - model predictions. Defaults to None. - - Examples: - >>> import torch - >>> from mmengine.structures import BaseDataElement - >>> gt_instances = BaseDataElement() - >>> bboxes = torch.rand((5, 4)) - >>> scores = torch.rand((5,)) - >>> img_id = 0 - >>> img_shape = (800, 1333) - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=img_id, img_shape=img_shape), - ... bboxes=bboxes, scores=scores) - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) - - >>> # new - >>> gt_instances1 = gt_instances.new( - ... metainfo=dict(img_id=1, img_shape=(640, 640)), - ... bboxes=torch.rand((5, 4)), - ... scores=torch.rand((5,))) - >>> gt_instances2 = gt_instances1.new() - - >>> # add and process property - >>> gt_instances = BaseDataElement() - >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) - >>> assert 'img_shape' in gt_instances.metainfo_keys() - >>> assert 'img_shape' in gt_instances - >>> assert 'img_shape' not in gt_instances.keys() - >>> assert 'img_shape' in gt_instances.all_keys() - >>> print(gt_instances.img_shape) - (100, 100) - >>> gt_instances.scores = torch.rand((5,)) - >>> assert 'scores' in gt_instances.keys() - >>> assert 'scores' in gt_instances - >>> assert 'scores' in gt_instances.all_keys() - >>> assert 'scores' not in gt_instances.metainfo_keys() - >>> print(gt_instances.scores) - tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) - >>> gt_instances.bboxes = torch.rand((5, 4)) - >>> assert 'bboxes' in gt_instances.keys() - >>> assert 'bboxes' in gt_instances - >>> assert 'bboxes' in gt_instances.all_keys() - >>> assert 'bboxes' not in gt_instances.metainfo_keys() - >>> print(gt_instances.bboxes) - tensor([[0.0900, 0.0424, 0.1755, 0.4469], - [0.8648, 0.0592, 0.3484, 0.0913], - [0.5808, 0.1909, 0.6165, 0.7088], - [0.5490, 0.4209, 0.9416, 0.2374], - [0.3652, 0.1218, 0.8805, 0.7523]]) - - >>> # delete and change property - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=0, img_shape=(640, 640)), - ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) - >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) - >>> gt_instances.img_shape # (1280, 1280) - >>> gt_instances.bboxes = gt_instances.bboxes * 2 - >>> gt_instances.get('img_shape', None) # (1280, 1280) - >>> gt_instances.get('bboxes', None) # 6x4 tensor - >>> del gt_instances.img_shape - >>> del gt_instances.bboxes - >>> assert 'img_shape' not in gt_instances - >>> assert 'bboxes' not in gt_instances - >>> gt_instances.pop('img_shape', None) # None - >>> gt_instances.pop('bboxes', None) # None - - >>> # Tensor-like - >>> cuda_instances = gt_instances.cuda() - >>> cuda_instances = gt_instances.to('cuda:0') - >>> cpu_instances = cuda_instances.cpu() - >>> cpu_instances = cuda_instances.to('cpu') - >>> fp16_instances = cuda_instances.to( - ... device=None, dtype=torch.float16, non_blocking=False, - ... copy=False, memory_format=torch.preserve_format) - >>> cpu_instances = cuda_instances.detach() - >>> np_instances = cpu_instances.numpy() - - >>> # print - >>> metainfo = dict(img_shape=(800, 1196, 3)) - >>> gt_instances = BaseDataElement( - ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) - >>> sample = BaseDataElement(metainfo=metainfo, - ... gt_instances=gt_instances) - >>> print(sample) - - ) at 0x7f0fea49e130> - - >>> # inheritance - >>> class DetDataSample(BaseDataElement): - ... @property - ... def proposals(self): - ... return self._proposals - ... @proposals.setter - ... def proposals(self, value): - ... self.set_field(value, '_proposals', dtype=BaseDataElement) - ... @proposals.deleter - ... def proposals(self): - ... del self._proposals - ... @property - ... def gt_instances(self): - ... return self._gt_instances - ... @gt_instances.setter - ... def gt_instances(self, value): - ... self.set_field(value, '_gt_instances', - ... dtype=BaseDataElement) - ... @gt_instances.deleter - ... def gt_instances(self): - ... del self._gt_instances - ... @property - ... def pred_instances(self): - ... return self._pred_instances - ... @pred_instances.setter - ... def pred_instances(self, value): - ... self.set_field(value, '_pred_instances', - ... dtype=BaseDataElement) - ... @pred_instances.deleter - ... def pred_instances(self): - ... del self._pred_instances - >>> det_sample = DetDataSample() - >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) - >>> det_sample.proposals = proposals - >>> assert 'proposals' in det_sample - >>> assert det_sample.proposals == proposals - >>> del det_sample.proposals - >>> assert 'proposals' not in det_sample - >>> with self.assertRaises(AssertionError): - ... det_sample.proposals = torch.rand((5, 4)) - """ - - def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: - - self._metainfo_fields: set = set() - self._data_fields: set = set() - - if metainfo is not None: - self.set_metainfo(metainfo=metainfo) - if kwargs: - self.set_data(kwargs) - - def set_metainfo(self, metainfo: dict) -> None: - """Set or change key-value pairs in ``metainfo_field`` by parameter - ``metainfo``. - - Args: - metainfo (dict): A dict contains the meta information - of image, such as ``img_shape``, ``scale_factor``, etc. - """ - assert isinstance( - metainfo, - dict), f'metainfo should be a ``dict`` but got {type(metainfo)}' - meta = copy.deepcopy(metainfo) - for k, v in meta.items(): - self.set_field(name=k, value=v, field_type='metainfo', dtype=None) - - def set_data(self, data: dict) -> None: - """Set or change key-value pairs in ``data_field`` by parameter - ``data``. - - Args: - data (dict): A dict contains annotations of image or - model predictions. - """ - assert isinstance(data, - dict), f'data should be a `dict` but got {data}' - for k, v in data.items(): - # Use `setattr()` rather than `self.set_field` to allow `set_data` - # to set property method. - setattr(self, k, v) - - def update(self, instance: 'BaseDataElement') -> None: - """The update() method updates the BaseDataElement with the elements - from another BaseDataElement object. - - Args: - instance (BaseDataElement): Another BaseDataElement object for - update the current object. - """ - assert isinstance( - instance, BaseDataElement - ), f'instance should be a `BaseDataElement` but got {type(instance)}' - self.set_metainfo(dict(instance.metainfo_items())) - self.set_data(dict(instance.items())) - - def new(self, - *, - metainfo: Optional[dict] = None, - **kwargs) -> 'BaseDataElement': - """Return a new data element with same type. If ``metainfo`` and - ``data`` are None, the new data element will have same metainfo and - data. If metainfo or data is not None, the new result will overwrite it - with the input value. - - Args: - metainfo (dict, optional): A dict contains the meta information - of image, such as ``img_shape``, ``scale_factor``, etc. - Defaults to None. - kwargs (dict): A dict contains annotations of image or - model predictions. - - Returns: - BaseDataElement: A new data element with same type. - """ - new_data = self.__class__() - - if metainfo is not None: - new_data.set_metainfo(metainfo) - else: - new_data.set_metainfo(dict(self.metainfo_items())) - if kwargs: - new_data.set_data(kwargs) - else: - new_data.set_data(dict(self.items())) - return new_data - - def clone(self): - """Deep copy the current data element. - - Returns: - BaseDataElement: The copy of current data element. - """ - clone_data = self.__class__() - clone_data.set_metainfo(dict(self.metainfo_items())) - clone_data.set_data(dict(self.items())) - return clone_data - - def keys(self) -> list: - """ - Returns: - list: Contains all keys in data_fields. - """ - # We assume that the name of the attribute related to property is - # '_' + the name of the property. We use this rule to filter out - # private keys. - # TODO: Use a more robust way to solve this problem - private_keys = { - '_' + key - for key in self._data_fields - if isinstance(getattr(type(self), key, None), property) - } - return list(self._data_fields - private_keys) - - def metainfo_keys(self) -> list: - """ - Returns: - list: Contains all keys in metainfo_fields. - """ - return list(self._metainfo_fields) - - def values(self) -> list: - """ - Returns: - list: Contains all values in data. - """ - return [getattr(self, k) for k in self.keys()] - - def metainfo_values(self) -> list: - """ - Returns: - list: Contains all values in metainfo. - """ - return [getattr(self, k) for k in self.metainfo_keys()] - - def all_keys(self) -> list: - """ - Returns: - list: Contains all keys in metainfo and data. - """ - return self.metainfo_keys() + self.keys() - - def all_values(self) -> list: - """ - Returns: - list: Contains all values in metainfo and data. - """ - return self.metainfo_values() + self.values() - - def all_items(self) -> Iterator[Tuple[str, Any]]: - """ - Returns: - iterator: An iterator object whose element is (key, value) tuple - pairs for ``metainfo`` and ``data``. - """ - for k in self.all_keys(): - yield (k, getattr(self, k)) - - def items(self) -> Iterator[Tuple[str, Any]]: - """ - Returns: - iterator: An iterator object whose element is (key, value) tuple - pairs for ``data``. - """ - for k in self.keys(): - yield (k, getattr(self, k)) - - def metainfo_items(self) -> Iterator[Tuple[str, Any]]: - """ - Returns: - iterator: An iterator object whose element is (key, value) tuple - pairs for ``metainfo``. - """ - for k in self.metainfo_keys(): - yield (k, getattr(self, k)) - - @property - def metainfo(self) -> dict: - """dict: A dict contains metainfo of current data element.""" - return dict(self.metainfo_items()) - - def __setattr__(self, name: str, value: Any): - """Setattr is only used to set data.""" - if name in ('_metainfo_fields', '_data_fields'): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError(f'{name} has been used as a ' - 'private attribute, which is immutable.') - else: - self.set_field( - name=name, value=value, field_type='data', dtype=None) - - def __delattr__(self, item: str): - """Delete the item in dataelement. - - Args: - item (str): The key to delete. - """ - if item in ('_metainfo_fields', '_data_fields'): - raise AttributeError(f'{item} has been used as a ' - 'private attribute, which is immutable.') - super().__delattr__(item) - if item in self._metainfo_fields: - self._metainfo_fields.remove(item) - elif item in self._data_fields: - self._data_fields.remove(item) - - # dict-like methods - __delitem__ = __delattr__ - - def get(self, key, default=None) -> Any: - """Get property in data and metainfo as the same as python.""" - # Use `getattr()` rather than `self.__dict__.get()` to allow getting - # properties. - return getattr(self, key, default) - - def pop(self, *args) -> Any: - """Pop property in data and metainfo as the same as python.""" - assert len(args) < 3, '``pop`` get more than 2 arguments' - name = args[0] - if name in self._metainfo_fields: - self._metainfo_fields.remove(args[0]) - return self.__dict__.pop(*args) - - elif name in self._data_fields: - self._data_fields.remove(args[0]) - return self.__dict__.pop(*args) - - # with default value - elif len(args) == 2: - return args[1] - else: - # don't just use 'self.__dict__.pop(*args)' for only popping key in - # metainfo or data - raise KeyError(f'{args[0]} is not contained in metainfo or data') - - def __contains__(self, item: str) -> bool: - """Whether the item is in dataelement. - - Args: - item (str): The key to inquire. - """ - return item in self._data_fields or item in self._metainfo_fields - - def set_field(self, - value: Any, - name: str, - dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, - field_type: str = 'data') -> None: - """Special method for set union field, used as property.setter - functions.""" - assert field_type in ['metainfo', 'data'] - if dtype is not None: - assert isinstance( - value, - dtype), f'{value} should be a {dtype} but got {type(value)}' - - if field_type == 'metainfo': - if name in self._data_fields: - raise AttributeError( - f'Cannot set {name} to be a field of metainfo ' - f'because {name} is already a data field') - self._metainfo_fields.add(name) - else: - if name in self._metainfo_fields: - raise AttributeError( - f'Cannot set {name} to be a field of data ' - f'because {name} is already a metainfo field') - self._data_fields.add(name) - super().__setattr__(name, value) - - # Tensor-like methods - def to(self, *args, **kwargs) -> 'BaseDataElement': - """Apply same name function to all tensors in data_fields.""" - new_data = self.new() - for k, v in self.items(): - if hasattr(v, 'to'): - v = v.to(*args, **kwargs) - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def cpu(self) -> 'BaseDataElement': - """Convert all tensors to CPU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.cpu() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def cuda(self) -> 'BaseDataElement': - """Convert all tensors to GPU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.cuda() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def musa(self) -> 'BaseDataElement': - """Convert all tensors to musa in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.musa() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def npu(self) -> 'BaseDataElement': - """Convert all tensors to NPU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.npu() - data = {k: v} - new_data.set_data(data) - return new_data - - def mlu(self) -> 'BaseDataElement': - """Convert all tensors to MLU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.mlu() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def detach(self) -> 'BaseDataElement': - """Detach all tensors in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.detach() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def numpy(self) -> 'BaseDataElement': - """Convert all tensors to np.ndarray in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.detach().cpu().numpy() - data = {k: v} - new_data.set_data(data) - return new_data - - def to_tensor(self) -> 'BaseDataElement': - """Convert all np.ndarray to tensor in data.""" - new_data = self.new() - for k, v in self.items(): - data = {} - if isinstance(v, np.ndarray): - v = torch.from_numpy(v) - data[k] = v - elif isinstance(v, BaseDataElement): - v = v.to_tensor() - data[k] = v - new_data.set_data(data) - return new_data - - def to_dict(self) -> dict: - """Convert BaseDataElement to dict.""" - return { - k: v.to_dict() if isinstance(v, BaseDataElement) else v - for k, v in self.all_items() - } - - def __repr__(self) -> str: - """Represent the object.""" - - def _addindent(s_: str, num_spaces: int) -> str: - """This func is modified from `pytorch` https://github.com/pytorch/ - pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu - les/module.py#L29. - - Args: - s_ (str): The string to add spaces. - num_spaces (int): The num of space to add. - - Returns: - str: The string after add indent. - """ - s = s_.split('\n') - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * ' ') + line for line in s] - s = '\n'.join(s) # type: ignore - s = first + '\n' + s # type: ignore - return s # type: ignore - - def dump(obj: Any) -> str: - """Represent the object. - - Args: - obj (Any): The obj to represent. - - Returns: - str: The represented str. - """ - _repr = '' - if isinstance(obj, dict): - for k, v in obj.items(): - _repr += f'\n{k}: {_addindent(dump(v), 4)}' - elif isinstance(obj, BaseDataElement): - _repr += '\n\n META INFORMATION' - metainfo_items = dict(obj.metainfo_items()) - _repr += _addindent(dump(metainfo_items), 4) - _repr += '\n\n DATA FIELDS' - items = dict(obj.items()) - _repr += _addindent(dump(items), 4) - classname = obj.__class__.__name__ - _repr = f'<{classname}({_repr}\n) at {hex(id(obj))}>' - else: - _repr += repr(obj) - return _repr - - return dump(self) diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py deleted file mode 100644 index 8633b86037..0000000000 --- a/mmengine/structures/instance_data.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import itertools -from collections.abc import Sized -from typing import Any, List, Union - -import numpy as np -import torch - -from mmengine.device import get_device -from .base_data_element import BaseDataElement - -BoolTypeTensor: Union[Any] -LongTypeTensor: Union[Any] - -if get_device() == 'npu': - BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] -elif get_device() == 'mlu': - BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] -elif get_device() == 'musa': - BoolTypeTensor = Union[torch.BoolTensor, torch.musa.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.musa.LongTensor] -else: - BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] - -IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor, - BoolTypeTensor, np.ndarray] - - -# Modified from -# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa -class InstanceData(BaseDataElement): - """Data structure for instance-level annotations or predictions. - - Subclass of :class:`BaseDataElement`. All value in `data_fields` - should have the same length. This design refer to - https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 - InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value - in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, - and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. - - Examples: - >>> # custom data structure - >>> class TmpObject: - ... def __init__(self, tmp) -> None: - ... assert isinstance(tmp, list) - ... self.tmp = tmp - ... def __len__(self): - ... return len(self.tmp) - ... def __getitem__(self, item): - ... if isinstance(item, int): - ... if item >= len(self) or item < -len(self): # type:ignore - ... raise IndexError(f'Index {item} out of range!') - ... else: - ... # keep the dimension - ... item = slice(item, None, len(self)) - ... return TmpObject(self.tmp[item]) - ... @staticmethod - ... def cat(tmp_objs): - ... assert all(isinstance(results, TmpObject) for results in tmp_objs) - ... if len(tmp_objs) == 1: - ... return tmp_objs[0] - ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] - ... tmp_list = list(itertools.chain(*tmp_list)) - ... new_data = TmpObject(tmp_list) - ... return new_data - ... def __repr__(self): - ... return str(self.tmp) - >>> from mmengine.structures import InstanceData - >>> import numpy as np - >>> import torch - >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) - >>> instance_data = InstanceData(metainfo=img_meta) - >>> 'img_shape' in instance_data - True - >>> instance_data.det_labels = torch.LongTensor([2, 3]) - >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) - >>> instance_data.bboxes = torch.rand((2, 4)) - >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) - >>> len(instance_data) - 2 - >>> print(instance_data) - - >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] - >>> sorted_results.det_scores - tensor([0.7000, 0.8000]) - >>> print(instance_data[instance_data.det_scores > 0.75]) - - >>> print(instance_data[instance_data.det_scores > 1]) - - >>> print(instance_data.cat([instance_data, instance_data])) - - """ - - def __setattr__(self, name: str, value: Sized): - """Setattr is only used to set data. - - The value must have the attribute of `__len__` and have the same length - of `InstanceData`. - """ - if name in ('_metainfo_fields', '_data_fields'): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError(f'{name} has been used as a ' - 'private attribute, which is immutable.') - - else: - assert isinstance(value, - Sized), 'value must contain `__len__` attribute' - - if len(self) > 0: - assert len(value) == len(self), 'The length of ' \ - f'values {len(value)} is ' \ - 'not consistent with ' \ - 'the length of this ' \ - ':obj:`InstanceData` ' \ - f'{len(self)}' - super().__setattr__(name, value) - - __setitem__ = __setattr__ - - def __getitem__(self, item: IndexType) -> 'InstanceData': - """ - Args: - item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, - :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): - Get the corresponding values according to item. - - Returns: - :obj:`InstanceData`: Corresponding values. - """ - assert isinstance(item, IndexType.__args__) - if isinstance(item, list): - item = np.array(item) - if isinstance(item, np.ndarray): - # The default int type of numpy is platform dependent, int32 for - # windows and int64 for linux. `torch.Tensor` requires the index - # should be int64, therefore we simply convert it to int64 here. - # More details in https://github.com/numpy/numpy/issues/9464 - item = item.astype(np.int64) if item.dtype == np.int32 else item - item = torch.from_numpy(item) - - if isinstance(item, str): - return getattr(self, item) - - if isinstance(item, int): - if item >= len(self) or item < -len(self): # type:ignore - raise IndexError(f'Index {item} out of range!') - else: - # keep the dimension - item = slice(item, None, len(self)) - - new_data = self.__class__(metainfo=self.metainfo) - if isinstance(item, torch.Tensor): - assert item.dim() == 1, 'Only support to get the' \ - ' values along the first dimension.' - if isinstance(item, BoolTypeTensor.__args__): - assert len(item) == len(self), 'The shape of the ' \ - 'input(BoolTensor) ' \ - f'{len(item)} ' \ - 'does not match the shape ' \ - 'of the indexed tensor ' \ - 'in results_field ' \ - f'{len(self)} at ' \ - 'first dimension.' - - for k, v in self.items(): - if isinstance(v, torch.Tensor): - new_data[k] = v[item] - elif isinstance(v, np.ndarray): - new_data[k] = v[item.cpu().numpy()] - elif isinstance( - v, (str, list, tuple)) or (hasattr(v, '__getitem__') - and hasattr(v, 'cat')): - # convert to indexes from BoolTensor - if isinstance(item, BoolTypeTensor.__args__): - indexes = torch.nonzero(item).view( - -1).cpu().numpy().tolist() - else: - indexes = item.cpu().numpy().tolist() - slice_list = [] - if indexes: - for index in indexes: - slice_list.append(slice(index, None, len(v))) - else: - slice_list.append(slice(None, 0, None)) - r_list = [v[s] for s in slice_list] - if isinstance(v, (str, list, tuple)): - new_value = r_list[0] - for r in r_list[1:]: - new_value = new_value + r - else: - new_value = v.cat(r_list) - new_data[k] = new_value - else: - raise ValueError( - f'The type of `{k}` is `{type(v)}`, which has no ' - 'attribute of `cat`, so it does not ' - 'support slice with `bool`') - - else: - # item is a slice - for k, v in self.items(): - new_data[k] = v[item] - return new_data # type:ignore - - @staticmethod - def cat(instances_list: List['InstanceData']) -> 'InstanceData': - """Concat the instances of all :obj:`InstanceData` in the list. - - Note: To ensure that cat returns as expected, make sure that - all elements in the list must have exactly the same keys. - - Args: - instances_list (list[:obj:`InstanceData`]): A list - of :obj:`InstanceData`. - - Returns: - :obj:`InstanceData` - """ - assert all( - isinstance(results, InstanceData) for results in instances_list) - assert len(instances_list) > 0 - if len(instances_list) == 1: - return instances_list[0] - - # metainfo and data_fields must be exactly the - # same for each element to avoid exceptions. - field_keys_list = [ - instances.all_keys() for instances in instances_list - ] - assert len({len(field_keys) for field_keys in field_keys_list}) \ - == 1 and len(set(itertools.chain(*field_keys_list))) \ - == len(field_keys_list[0]), 'There are different keys in ' \ - '`instances_list`, which may ' \ - 'cause the cat operation ' \ - 'to fail. Please make sure all ' \ - 'elements in `instances_list` ' \ - 'have the exact same key.' - - new_data = instances_list[0].__class__( - metainfo=instances_list[0].metainfo) - for k in instances_list[0].keys(): - values = [results[k] for results in instances_list] - v0 = values[0] - if isinstance(v0, torch.Tensor): - new_values = torch.cat(values, dim=0) - elif isinstance(v0, np.ndarray): - new_values = np.concatenate(values, axis=0) - elif isinstance(v0, (str, list, tuple)): - new_values = v0[:] - for v in values[1:]: - new_values += v - elif hasattr(v0, 'cat'): - new_values = v0.cat(values) - else: - raise ValueError( - f'The type of `{k}` is `{type(v0)}` which has no ' - 'attribute of `cat`') - new_data[k] = new_values - return new_data # type:ignore - - def __len__(self) -> int: - """int: The length of InstanceData.""" - if len(self._data_fields) > 0: - return len(self.values()[0]) - else: - return 0 diff --git a/mmengine/structures/label_data.py b/mmengine/structures/label_data.py deleted file mode 100644 index de178e07a0..0000000000 --- a/mmengine/structures/label_data.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch - -from .base_data_element import BaseDataElement - - -class LabelData(BaseDataElement): - """Data structure for label-level annotations or predictions.""" - - @staticmethod - def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: - """Convert the one-hot input to label. - - Args: - onehot (torch.Tensor, optional): The one-hot input. The format - of input must be one-hot. - - Returns: - torch.Tensor: The converted results. - """ - assert isinstance(onehot, torch.Tensor) - if (onehot.ndim == 1 and onehot.max().item() <= 1 - and onehot.min().item() >= 0): - return onehot.nonzero().squeeze(-1) - else: - raise ValueError( - 'input is not one-hot and can not convert to label') - - @staticmethod - def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: - """Convert the label-format input to one-hot. - - Args: - label (torch.Tensor): The label-format input. The format - of item must be label-format. - num_classes (int): The number of classes. - - Returns: - torch.Tensor: The converted results. - """ - assert isinstance(label, torch.Tensor) - onehot = label.new_zeros((num_classes, )) - assert max(label, default=torch.tensor(0)).item() < num_classes - onehot[label] = 1 - return onehot diff --git a/mmengine/structures/pixel_data.py b/mmengine/structures/pixel_data.py deleted file mode 100644 index d550f5c0c6..0000000000 --- a/mmengine/structures/pixel_data.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import List, Sequence, Union - -import numpy as np -import torch - -from .base_data_element import BaseDataElement - - -class PixelData(BaseDataElement): - """Data structure for pixel-level annotations or predictions. - - All data items in ``data_fields`` of ``PixelData`` meet the following - requirements: - - - They all have 3 dimensions in orders of channel, height, and width. - - They should have the same height and width. - - Examples: - >>> metainfo = dict( - ... img_id=random.randint(0, 100), - ... img_shape=(random.randint(400, 600), random.randint(400, 600))) - >>> image = np.random.randint(0, 255, (4, 20, 40)) - >>> featmap = torch.randint(0, 255, (10, 20, 40)) - >>> pixel_data = PixelData(metainfo=metainfo, - ... image=image, - ... featmap=featmap) - >>> print(pixel_data.shape) - (20, 40) - - >>> # slice - >>> slice_data = pixel_data[10:20, 20:40] - >>> assert slice_data.shape == (10, 20) - >>> slice_data = pixel_data[10, 20] - >>> assert slice_data.shape == (1, 1) - - >>> # set - >>> pixel_data.map3 = torch.randint(0, 255, (20, 40)) - >>> assert tuple(pixel_data.map3.shape) == (1, 20, 40) - >>> with self.assertRaises(AssertionError): - ... # The dimension must be 3 or 2 - ... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40)) - """ - - def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]): - """Set attributes of ``PixelData``. - - If the dimension of value is 2 and its shape meet the demand, it - will automatically expand its channel-dimension. - - Args: - name (str): The key to access the value, stored in `PixelData`. - value (Union[torch.Tensor, np.ndarray]): The value to store in. - The type of value must be `torch.Tensor` or `np.ndarray`, - and its shape must meet the requirements of `PixelData`. - """ - if name in ('_metainfo_fields', '_data_fields'): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError(f'{name} has been used as a ' - 'private attribute, which is immutable.') - - else: - assert isinstance(value, (torch.Tensor, np.ndarray)), \ - f'Can not set {type(value)}, only support' \ - f' {(torch.Tensor, np.ndarray)}' - - if self.shape: - assert tuple(value.shape[-2:]) == self.shape, ( - 'The height and width of ' - f'values {tuple(value.shape[-2:])} is ' - 'not consistent with ' - 'the shape of this ' - ':obj:`PixelData` ' - f'{self.shape}') - assert value.ndim in [ - 2, 3 - ], f'The dim of value must be 2 or 3, but got {value.ndim}' - if value.ndim == 2: - value = value[None] - warnings.warn('The shape of value will convert from ' - f'{value.shape[-2:]} to {value.shape}') - super().__setattr__(name, value) - - # TODO torch.Long/bool - def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData': - """ - Args: - item (Sequence[Union[int, slice]]): Get the corresponding values - according to item. - - Returns: - :obj:`PixelData`: Corresponding values. - """ - - new_data = self.__class__(metainfo=self.metainfo) - if isinstance(item, tuple): - - assert len(item) == 2, 'Only support to slice height and width' - tmp_item: List[slice] = list() - for index, single_item in enumerate(item[::-1]): - if isinstance(single_item, int): - tmp_item.insert( - 0, slice(single_item, None, self.shape[-index - 1])) - elif isinstance(single_item, slice): - tmp_item.insert(0, single_item) - else: - raise TypeError( - 'The type of element in input must be int or slice, ' - f'but got {type(single_item)}') - tmp_item.insert(0, slice(None, None, None)) - item = tuple(tmp_item) - for k, v in self.items(): - setattr(new_data, k, v[item]) - else: - raise TypeError( - f'Unsupported type {type(item)} for slicing PixelData') - return new_data - - @property - def shape(self): - """The shape of pixel data.""" - if len(self._data_fields) > 0: - return tuple(self.values()[0].shape[-2:]) - else: - return None - - # TODO padding, resize diff --git a/mmengine/testing/__init__.py b/mmengine/testing/__init__.py deleted file mode 100644 index a7e4da3543..0000000000 --- a/mmengine/testing/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .compare import (assert_allclose, assert_attrs_equal, - assert_dict_contains_subset, assert_dict_has_keys, - assert_is_norm_layer, assert_keys_equal, - assert_params_all_zeros, check_python_script) -from .runner_test_case import RunnerTestCase - -__all__ = [ - 'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal', - 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer', - 'assert_params_all_zeros', 'check_python_script', 'RunnerTestCase' -] diff --git a/mmengine/testing/_internal/__init__.py b/mmengine/testing/_internal/__init__.py deleted file mode 100644 index f4528659a8..0000000000 --- a/mmengine/testing/_internal/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .distributed import MultiProcessTestCase - -__all__ = ['MultiProcessTestCase'] diff --git a/mmengine/testing/_internal/distributed.py b/mmengine/testing/_internal/distributed.py deleted file mode 100644 index 56adb52280..0000000000 --- a/mmengine/testing/_internal/distributed.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Copyright (c) https://github.com/pytorch/pytorch -# Modified from https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_distributed.py # noqa: E501 -import faulthandler -import logging -import multiprocessing -import sys -import tempfile -import threading -import time -import traceback -import types -import unittest -from enum import Enum -from functools import wraps -from typing import NamedTuple -from unittest import TestCase - -import torch -from torch.multiprocessing import active_children - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class TestSkip(NamedTuple): - exit_code: int - message: str - - -TEST_SKIPS = { - 'backend_unavailable': - TestSkip(10, 'Skipped because distributed backend is not available.'), - 'no_cuda': - TestSkip(11, 'CUDA is not available.'), - 'multi-gpu-2': - TestSkip(12, 'Need at least 2 CUDA device'), - 'generic': - TestSkip( - 13, 'Test skipped at subprocess level, look at subprocess log for ' - 'skip reason'), -} - -# [How does MultiProcessTestCase work?] -# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by -# default `world_size()` returns 2. Let's take `test_rpc_spawn.py` as an -# example which inherits from this class. Its `Setup()` methods calls into -# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` -# subprocesses. During the spawn, the main process passes the test name to -# subprocesses, and the name is acquired from self.id(). The subprocesses -# then use the provided test function name to retrieve the function attribute -# from the test instance and run it. The main process simply waits for all -# subprocesses to join. - - -class MultiProcessTestCase(TestCase): - MAIN_PROCESS_RANK = -1 - - # This exit code is used to indicate that the test code had an error and - # exited abnormally. There are certain tests that might use sys.exit() to - # simulate failures and in those cases, we can't have an exit code of 0, - # but we still want to ensure we didn't run into any other errors. - TEST_ERROR_EXIT_CODE = 10 - - # do not early terminate for distributed tests. - def _should_stop_test_suite(self) -> bool: - return False - - def prepare_subprocess(self): - pass - - @property - def world_size(self) -> int: - return 2 - - @property - def timeout(self) -> int: - return 1000 - - def join_or_run(self, fn): - - @wraps(fn) - def wrapper(self): - if self.rank == self.MAIN_PROCESS_RANK: - self._join_processes(fn) - else: - fn() - - return types.MethodType(wrapper, self) - - # The main process spawns N subprocesses that run the test. - # Constructor patches current instance test method to - # assume the role of the main process and join its subprocesses, - # or run the underlying test function. - def __init__(self, - method_name: str = 'runTest', - methodName: str = 'runTest') -> None: - # methodName is the correct naming in unittest - # and testslide uses keyword arguments. - # So we need to use both to 1) not break BC and, 2) support testslide. - if methodName != 'runTest': - method_name = methodName - super().__init__(method_name) - try: - fn = getattr(self, method_name) - setattr(self, method_name, self.join_or_run(fn)) - except AttributeError as e: - if methodName != 'runTest': - # we allow instantiation with no explicit method name - # but not an *incorrect* or missing method name - raise ValueError( - f'no such test method in {self.__class__}: {methodName}' - ) from e - - def setUp(self) -> None: - super().setUp() - self.skip_return_code_checks = [] # type: ignore[var-annotated] - self.processes = [] # type: ignore[var-annotated] - self.rank = self.MAIN_PROCESS_RANK - self.file_name = tempfile.NamedTemporaryFile(delete=False).name - # pid to pipe consisting of error message from process. - self.pid_to_pipe = {} # type: ignore[var-annotated] - - def tearDown(self) -> None: - super().tearDown() - for p in self.processes: - p.terminate() - # Each Process instance holds a few open file descriptors. The unittest - # runner creates a new TestCase instance for each test method and keeps - # it alive until the end of the entire suite. We must thus reset the - # processes to prevent an effective file descriptor leak. - self.processes = [] - - def _current_test_name(self) -> str: - # self.id() - # e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' - return self.id().split('.')[-1] - - def _start_processes(self, proc) -> None: - self.processes = [] - for rank in range(int(self.world_size)): - parent_conn, child_conn = torch.multiprocessing.Pipe() - process = proc( - target=self.__class__._run, - name='process ' + str(rank), - args=(rank, self._current_test_name(), self.file_name, - child_conn), - ) - process.start() - self.pid_to_pipe[process.pid] = parent_conn - self.processes.append(process) - - def _spawn_processes(self) -> None: - proc = torch.multiprocessing.get_context('spawn').Process - self._start_processes(proc) - - class Event(Enum): - GET_TRACEBACK = 1 - - @staticmethod - def _event_listener(parent_pipe, signal_pipe, rank: int): - while True: - ready_pipes = multiprocessing.connection.wait( - [parent_pipe, signal_pipe]) - - if parent_pipe in ready_pipes: - - if parent_pipe.closed: - return - - event = parent_pipe.recv() - - if event == MultiProcessTestCase.Event.GET_TRACEBACK: - # Return traceback to the parent process. - with tempfile.NamedTemporaryFile(mode='r+') as tmp_file: - faulthandler.dump_traceback(tmp_file) - # Flush buffers and seek to read from the beginning - tmp_file.flush() - tmp_file.seek(0) - parent_pipe.send(tmp_file.read()) - - if signal_pipe in ready_pipes: - return - - @classmethod - def _run(cls, rank: int, test_name: str, file_name: str, - parent_pipe) -> None: - self = cls(test_name) - try: - self.prepare_subprocess() - except Exception: - raise sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) - self.rank = rank - self.file_name = file_name - self.run_test(test_name, parent_pipe) - - def run_test(self, test_name: str, parent_pipe) -> None: - # Start event listener thread. - signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe( - duplex=False) - event_listener_thread = threading.Thread( - target=MultiProcessTestCase._event_listener, - args=(parent_pipe, signal_recv_pipe, self.rank), - daemon=True, - ) - event_listener_thread.start() - - # self.id() == e.g. '__main__.TestDistributed.test_get_rank' - # We're retrieving a corresponding test and executing it. - try: - getattr(self, test_name)() - except unittest.SkipTest as se: - logger.info(f'Process {self.rank} skipping test {test_name} for ' - f'following reason: {str(se)}') - sys.exit(TEST_SKIPS['generic'].exit_code) - except Exception: - logger.error( - f'Caught exception: \n{traceback.format_exc()} exiting ' - f'process {self.rank} with exit code: ' - f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}') - # Send error to parent process. - parent_pipe.send(traceback.format_exc()) - sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) - finally: - if signal_send_pipe is not None: - signal_send_pipe.send(None) - - assert event_listener_thread is not None - event_listener_thread.join() - # Close pipe after done with test. - parent_pipe.close() - - def _get_timedout_process_traceback(self) -> None: - pipes = [] - for i, process in enumerate(self.processes): - if process.exitcode is None: - pipe = self.pid_to_pipe[process.pid] - try: - pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) - pipes.append((i, pipe)) - except ConnectionError as e: - logger.error( - 'Encountered error while trying to get traceback ' - f'for process {i}: {e}') - - # Wait for results. - for rank, pipe in pipes: - try: - # Wait for traceback - if pipe.poll(5): - if pipe.closed: - logger.info( - f'Pipe closed for process {rank}, cannot retrieve ' - 'traceback') - continue - - traceback = pipe.recv() - logger.error(f'Process {rank} timed out with traceback: ' - f'\n\n{traceback}') - else: - logger.error('Could not retrieve traceback for timed out ' - f'process: {rank}') - except ConnectionError as e: - logger.error( - 'Encountered error while trying to get traceback for ' - f'process {rank}: {e}') - - def _join_processes(self, fn) -> None: - start_time = time.time() - subprocess_error = False - try: - while True: - # check to see if any subprocess exited with an error early. - for (i, p) in enumerate(self.processes): - # This is the exit code processes exit with if they - # encountered an exception. - if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: - print( - f'Process {i} terminated with exit code ' - f'{p.exitcode}, terminating remaining processes.') - _active_children = active_children() - for ac in _active_children: - ac.terminate() - subprocess_error = True - break - if subprocess_error: - break - # All processes have joined cleanly if they all a valid - # exitcode - if all([p.exitcode is not None for p in self.processes]): - break - # Check if we should time out the test. If so, we terminate - # each process. - elapsed = time.time() - start_time - if elapsed > self.timeout: - self._get_timedout_process_traceback() - print(f'Timing out after {self.timeout} seconds and ' - 'killing subprocesses.') - for p in self.processes: - p.terminate() - break - # Sleep to avoid excessive busy polling. - time.sleep(0.1) - - elapsed_time = time.time() - start_time - - if fn in self.skip_return_code_checks: - self._check_no_test_errors(elapsed_time) - else: - self._check_return_codes(elapsed_time) - finally: - # Close all pipes - for pid, pipe in self.pid_to_pipe.items(): - pipe.close() - - def _check_no_test_errors(self, elapsed_time) -> None: - """Checks that we didn't have any errors thrown in the child - processes.""" - for i, p in enumerate(self.processes): - if p.exitcode is None: - raise RuntimeError( - 'Process {} timed out after {} seconds'.format( - i, elapsed_time)) - self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) - - def _check_return_codes(self, elapsed_time) -> None: - """Checks that the return codes of all spawned processes match, and - skips tests if they returned a return code indicating a skipping - condition.""" - first_process = self.processes[0] - # first, we check if there are errors in actual processes - # (via TEST_ERROR_EXIT CODE), and raise an exception for those. - # the reason we do this is to attempt to raise a more helpful error - # message than "Process x terminated/timed out" - # TODO: we should pipe the exception of the failed subprocess here. - # Currently, the actual exception is displayed as a logging output. - errored_processes = [ - (i, p) for i, p in enumerate(self.processes) - if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE - ] - if errored_processes: - error = '' - for i, process in errored_processes: - # Get error from pipe. - error_message = self.pid_to_pipe[process.pid].recv() - error += ( - 'Process {} exited with error code {} and exception:\n{}\n' - .format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, - error_message)) - raise RuntimeError(error) - # If no process exited uncleanly, we check for timeouts, and then - # ensure each process exited cleanly. - for i, p in enumerate(self.processes): - if p.exitcode is None: - raise RuntimeError( - f'Process {i} terminated or timed out after ' - '{elapsed_time} seconds') - - for skip in TEST_SKIPS.values(): - if first_process.exitcode == skip.exit_code: - raise unittest.SkipTest(skip.message) - - # Skip the unittest since the raised error maybe not caused by - # the tested function. For example, in CI environment, the tested - # method could be terminated by system signal for the limited - # resources. - self.skipTest(f'Skip test {self._testMethodName} due to ' - 'the program abort') - - @property - def is_master(self) -> bool: - return self.rank == 0 diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py deleted file mode 100644 index 14c7a97ba7..0000000000 --- a/mmengine/testing/compare.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import sys -from collections.abc import Iterable -from runpy import run_path -from shlex import split -from typing import Any, Callable, Dict, List, Optional, Union -from unittest.mock import patch - -from torch.nn import GroupNorm, LayerNorm -from torch.testing import assert_allclose as _assert_allclose - -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION -from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm - - -def assert_allclose( - actual: Any, - expected: Any, - rtol: Optional[float] = None, - atol: Optional[float] = None, - equal_nan: bool = True, - msg: Optional[Union[str, Callable]] = '', -) -> None: - """Asserts that ``actual`` and ``expected`` are close. A wrapper function - of ``torch.testing.assert_allclose``. - - Args: - actual (Any): Actual input. - expected (Any): Expected input. - rtol (Optional[float]): Relative tolerance. If specified ``atol`` must - also be specified. If omitted, default values based on the - :attr:`~torch.Tensor.dtype` are selected with the below table. - atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` - must also be specified. If omitted, default values based on the - :attr:`~torch.Tensor.dtype` are selected with the below table. - equal_nan (bool): If ``True``, two ``NaN`` values will be considered - equal. - msg (Optional[Union[str, Callable]]): Optional error message to use if - the values of corresponding tensors mismatch. Unused when PyTorch - < 1.6. - """ - if 'parrots' not in TORCH_VERSION and \ - digit_version(TORCH_VERSION) >= digit_version('1.6'): - _assert_allclose( - actual, - expected, - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - msg=msg) - else: - # torch.testing.assert_allclose has no ``msg`` argument - # when PyTorch < 1.6 - _assert_allclose( - actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) - - -def check_python_script(cmd): - """Run the python cmd script with `__main__`. The difference between - `os.system` is that, this function exectues code in the current process, so - that it can be tracked by coverage tools. Currently it supports two forms: - - - ./tests/data/scripts/hello.py zz - - python tests/data/scripts/hello.py zz - """ - args = split(cmd) - if args[0] == 'python': - args = args[1:] - with patch.object(sys, 'argv', args): - run_path(args[0], run_name='__main__') - - -def _any(judge_result): - """Since built-in ``any`` works only when the element of iterable is not - iterable, implement the function.""" - if not isinstance(judge_result, Iterable): - return judge_result - - try: - for element in judge_result: - if _any(element): - return True - except TypeError: - # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) - if judge_result: - return True - return False - - -def assert_dict_contains_subset(dict_obj: Dict[Any, Any], - expected_subset: Dict[Any, Any]) -> bool: - """Check if the dict_obj contains the expected_subset. - - Args: - dict_obj (Dict[Any, Any]): Dict object to be checked. - expected_subset (Dict[Any, Any]): Subset expected to be contained in - dict_obj. - - Returns: - bool: Whether the dict_obj contains the expected_subset. - """ - - for key, value in expected_subset.items(): - if key not in dict_obj.keys() or _any(dict_obj[key] != value): - return False - return True - - -def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: - """Check if attribute of class object is correct. - - Args: - obj (object): Class object to be checked. - expected_attrs (Dict[str, Any]): Dict of the expected attrs. - - Returns: - bool: Whether the attribute of class object is correct. - """ - for attr, value in expected_attrs.items(): - if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): - return False - return True - - -def assert_dict_has_keys(obj: Dict[str, Any], - expected_keys: List[str]) -> bool: - """Check if the obj has all the expected_keys. - - Args: - obj (Dict[str, Any]): Object to be checked. - expected_keys (List[str]): Keys expected to contained in the keys of - the obj. - - Returns: - bool: Whether the obj has the expected keys. - """ - return set(expected_keys).issubset(set(obj.keys())) - - -def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: - """Check if target_keys is equal to result_keys. - - Args: - result_keys (List[str]): Result keys to be checked. - target_keys (List[str]): Target keys to be checked. - - Returns: - bool: Whether target_keys is equal to result_keys. - """ - return set(result_keys) == set(target_keys) - - -def assert_is_norm_layer(module) -> bool: - """Check if the module is a norm layer. - - Args: - module (nn.Module): The module to be checked. - - Returns: - bool: Whether the module is a norm layer. - """ - - norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) - return isinstance(module, norm_layer_candidates) - - -def assert_params_all_zeros(module) -> bool: - """Check if the parameters of the module is all zeros. - - Args: - module (nn.Module): The module to be checked. - - Returns: - bool: Whether the parameters of the module is all zeros. - """ - weight_data = module.weight.data - is_weight_zero = weight_data.allclose( - weight_data.new_zeros(weight_data.size())) - - if hasattr(module, 'bias') and module.bias is not None: - bias_data = module.bias.data - is_bias_zero = bias_data.allclose( - bias_data.new_zeros(bias_data.size())) - else: - is_bias_zero = True - - return is_weight_zero and is_bias_zero diff --git a/mmengine/testing/runner_test_case.py b/mmengine/testing/runner_test_case.py deleted file mode 100644 index f64594acef..0000000000 --- a/mmengine/testing/runner_test_case.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -import os -import shutil -import tempfile -import time -from unittest import TestCase -from uuid import uuid4 - -import torch -import torch.nn as nn -from torch.distributed import destroy_process_group -from torch.utils.data import Dataset - -import mmengine.hooks # noqa F401 -import mmengine.optim # noqa F401 -from mmengine.config import Config -from mmengine.dist import is_distributed -from mmengine.evaluator import BaseMetric -from mmengine.logging import MessageHub, MMLogger -from mmengine.model import BaseModel -from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope -from mmengine.runner import Runner -from mmengine.visualization import Visualizer - - -class ToyModel(BaseModel): - - def __init__(self, data_preprocessor=None): - super().__init__(data_preprocessor=data_preprocessor) - self.linear1 = nn.Linear(2, 2) - self.linear2 = nn.Linear(2, 1) - - def forward(self, inputs, data_samples=None, mode='tensor'): - if isinstance(inputs, list): - inputs = torch.stack(inputs) - if isinstance(data_samples, list): - data_samples = torch.stack(data_samples) - outputs = self.linear1(inputs) - outputs = self.linear2(outputs) - - if mode == 'tensor': - return outputs - elif mode == 'loss': - loss = (data_samples - outputs).sum() - outputs = dict(loss=loss) - return outputs - elif mode == 'predict': - return outputs - - -class ToyDataset(Dataset): - METAINFO = dict() # type: ignore - data = torch.randn(12, 2) - label = torch.ones(12) - - @property - def metainfo(self): - return self.METAINFO - - def __len__(self): - return self.data.size(0) - - def __getitem__(self, index): - return dict(inputs=self.data[index], data_samples=self.label[index]) - - -class ToyMetric(BaseMetric): - - def __init__(self, collect_device='cpu', dummy_metrics=None): - super().__init__(collect_device=collect_device) - self.dummy_metrics = dummy_metrics - - def process(self, data_batch, predictions): - result = {'acc': 1} - self.results.append(result) - - def compute_metrics(self, results): - return dict(acc=1) - - -class RunnerTestCase(TestCase): - """A test case to build runner easily. - - `RunnerTestCase` will do the following things: - - 1. Registers a toy model, a toy metric, and a toy dataset, which can be - used to run the `Runner` successfully. - 2. Provides epoch based and iteration based cfg to build runner. - 3. Provides `build_runner` method to build runner easily. - 4. Clean the global variable used by the runner. - """ - dist_cfg = dict( - MASTER_ADDR='127.0.0.1', - MASTER_PORT=29600, - RANK='0', - WORLD_SIZE='1', - LOCAL_RANK='0') - - def setUp(self) -> None: - self.temp_dir = tempfile.TemporaryDirectory() - # Prevent from registering module with the same name by other unit - # test. These registries will be cleared in `tearDown` - MODELS.register_module(module=ToyModel, force=True) - METRICS.register_module(module=ToyMetric, force=True) - DATASETS.register_module(module=ToyDataset, force=True) - epoch_based_cfg = dict( - work_dir=self.temp_dir.name, - model=dict(type='ToyModel'), - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=[dict(type='ToyMetric')], - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - test_evaluator=[dict(type='ToyMetric')], - optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), - train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), - val_cfg=dict(), - test_cfg=dict(), - default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), - custom_hooks=[], - env_cfg=dict(dist_cfg=dict(backend='nccl')), - experiment_name='test1') - self.epoch_based_cfg = Config(epoch_based_cfg) - - # prepare iter based cfg. - self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg) - self.iter_based_cfg.train_dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='InfiniteSampler', shuffle=True), - batch_size=3, - num_workers=0) - self.iter_based_cfg.log_processor = dict(by_epoch=False) - - self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) - self.iter_based_cfg.default_hooks = dict( - logger=dict(type='LoggerHook', interval=1), - checkpoint=dict( - type='CheckpointHook', interval=12, by_epoch=False)) - - def tearDown(self): - # `FileHandler` should be closed in Windows, otherwise we cannot - # delete the temporary directory - logging.shutdown() - MMLogger._instance_dict.clear() - Visualizer._instance_dict.clear() - DefaultScope._instance_dict.clear() - MessageHub._instance_dict.clear() - MODELS.module_dict.pop('ToyModel', None) - METRICS.module_dict.pop('ToyMetric', None) - DATASETS.module_dict.pop('ToyDataset', None) - self.temp_dir.cleanup() - if is_distributed(): - destroy_process_group() - - def build_runner(self, cfg: Config): - cfg.experiment_name = self.experiment_name - runner = Runner.from_cfg(cfg) - return runner - - @property - def experiment_name(self): - # Since runners could be built too fast to have a unique experiment - # name(timestamp is the same), here we use uuid to make sure each - # runner has the unique experiment name. - return f'{self._testMethodName}_{time.time()} + ' \ - f'{uuid4()}' - - def setup_dist_env(self): - self.dist_cfg['MASTER_PORT'] += 1 - os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT']) - os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR'] - os.environ['RANK'] = self.dist_cfg['RANK'] - os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] - os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] - - def clear_work_dir(self): - logging.shutdown() - for filename in os.listdir(self.temp_dir.name): - filepath = os.path.join(self.temp_dir.name, filename) - if os.path.isfile(filepath): - os.remove(filepath) - else: - shutil.rmtree(filepath) diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py deleted file mode 100644 index 3de9099907..0000000000 --- a/mmengine/utils/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .manager import ManagerMeta, ManagerMixin -from .misc import (apply_to, check_prerequisites, concat_list, - deprecated_api_warning, deprecated_function, - get_object_from_string, has_method, - import_modules_from_strings, is_list_of, - is_method_overridden, is_seq_of, is_str, is_tuple_of, - iter_cast, list_cast, requires_executable, requires_package, - slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, - to_ntuple, tuple_cast) -from .package_utils import (call_command, get_installed_path, install_package, - is_installed) -from .path import (check_file_exist, fopen, is_abs, is_filepath, - mkdir_or_exist, scandir, symlink) -from .progressbar import (ProgressBar, track_iter_progress, - track_parallel_progress, track_progress) -from .progressbar_rich import track_progress_rich -from .timer import Timer, TimerError, check_time -from .version_utils import digit_version, get_git_hash - -__all__ = [ - 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', - 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list', - 'check_prerequisites', 'requires_package', 'requires_executable', - 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink', - 'scandir', 'deprecated_api_warning', 'import_modules_from_strings', - 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', - 'is_installed', 'call_command', 'get_installed_path', 'install_package', - 'is_abs', 'is_method_overridden', 'has_method', 'digit_version', - 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time', - 'TimerError', 'ProgressBar', 'track_iter_progress', - 'track_parallel_progress', 'track_progress', 'deprecated_function', - 'apply_to', 'track_progress_rich', 'get_object_from_string' -] diff --git a/mmengine/utils/dl_utils/__init__.py b/mmengine/utils/dl_utils/__init__.py deleted file mode 100644 index 305ea89890..0000000000 --- a/mmengine/utils/dl_utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from .collect_env import collect_env -from .hub import load_url -from .misc import has_batch_norm, is_norm, mmcv_full_available, tensor2imgs -from .parrots_wrapper import TORCH_VERSION -from .setup_env import set_multi_processing -from .time_counter import TimeCounter -from .torch_ops import torch_meshgrid -from .trace import is_jit_tracing - -__all__ = [ - 'load_url', 'TORCH_VERSION', 'set_multi_processing', 'has_batch_norm', - 'is_norm', 'tensor2imgs', 'mmcv_full_available', 'collect_env', - 'torch_meshgrid', 'is_jit_tracing', 'TimeCounter' -] diff --git a/mmengine/utils/dl_utils/collect_env.py b/mmengine/utils/dl_utils/collect_env.py deleted file mode 100644 index 0ee99abad2..0000000000 --- a/mmengine/utils/dl_utils/collect_env.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""This file holding some environment constant for sharing by other files.""" -import os -import os.path as osp -import subprocess -import sys -from collections import OrderedDict, defaultdict - -import numpy as np -import torch - -import mmengine -from mmengine.device import is_cuda_available, is_musa_available -from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch - - -def _get_cuda_home(): - if TORCH_VERSION == 'parrots': - from parrots.utils.build_extension import CUDA_HOME - else: - if is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - CUDA_HOME = ROCM_HOME - else: - from torch.utils.cpp_extension import CUDA_HOME - return CUDA_HOME - - -def _get_musa_home(): - return os.environ.get('MUSA_HOME') - - -def collect_env(): - """Collect the information of the running environments. - - Returns: - dict: The environment information. The following fields are contained. - - - sys.platform: The variable of ``sys.platform``. - - Python: Python version. - - CUDA available: Bool, indicating if CUDA is available. - - GPU devices: Device type of each GPU. - - CUDA_HOME (optional): The env var ``CUDA_HOME``. - - NVCC (optional): NVCC version. - - GCC: GCC version, "n/a" if GCC is not installed. - - MSVC: Microsoft Virtual C++ Compiler version, Windows only. - - PyTorch: PyTorch version. - - PyTorch compiling details: The output of \ - ``torch.__config__.show()``. - - TorchVision (optional): TorchVision version. - - OpenCV (optional): OpenCV version. - - MMENGINE: MMENGINE version. - """ - from distutils import errors - - env_info = OrderedDict() - env_info['sys.platform'] = sys.platform - env_info['Python'] = sys.version.replace('\n', '') - - cuda_available = is_cuda_available() - musa_available = is_musa_available() - env_info['CUDA available'] = cuda_available - env_info['MUSA available'] = musa_available - env_info['numpy_random_seed'] = np.random.get_state()[1][0] - - if cuda_available: - devices = defaultdict(list) - for k in range(torch.cuda.device_count()): - devices[torch.cuda.get_device_name(k)].append(str(k)) - for name, device_ids in devices.items(): - env_info['GPU ' + ','.join(device_ids)] = name - - CUDA_HOME = _get_cuda_home() - env_info['CUDA_HOME'] = CUDA_HOME - - if CUDA_HOME is not None and osp.isdir(CUDA_HOME): - if CUDA_HOME == '/opt/rocm': - try: - nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc') - nvcc = subprocess.check_output( - f'"{nvcc}" --version', shell=True) - nvcc = nvcc.decode('utf-8').strip() - release = nvcc.rfind('HIP version:') - build = nvcc.rfind('') - nvcc = nvcc[release:build].strip() - except subprocess.SubprocessError: - nvcc = 'Not Available' - else: - try: - nvcc = osp.join(CUDA_HOME, 'bin/nvcc') - nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True) - nvcc = nvcc.decode('utf-8').strip() - release = nvcc.rfind('Cuda compilation tools') - build = nvcc.rfind('Build ') - nvcc = nvcc[release:build].strip() - except subprocess.SubprocessError: - nvcc = 'Not Available' - env_info['NVCC'] = nvcc - elif musa_available: - devices = defaultdict(list) - for k in range(torch.musa.device_count()): - devices[torch.musa.get_device_name(k)].append(str(k)) - for name, device_ids in devices.items(): - env_info['GPU ' + ','.join(device_ids)] = name - - MUSA_HOME = _get_musa_home() - env_info['MUSA_HOME'] = MUSA_HOME - - if MUSA_HOME is not None and osp.isdir(MUSA_HOME): - try: - mcc = osp.join(MUSA_HOME, 'bin/mcc') - subprocess.check_output(f'"{mcc}" -v', shell=True) - except subprocess.SubprocessError: - mcc = 'Not Available' - env_info['mcc'] = mcc - try: - # Check C++ Compiler. - # For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...', - # indicating the compiler used, we use this to get the compiler name - import io - import sysconfig - cc = sysconfig.get_config_var('CC') - if cc: - cc = osp.basename(cc.split()[0]) - cc_info = subprocess.check_output(f'{cc} --version', shell=True) - env_info['GCC'] = cc_info.decode('utf-8').partition( - '\n')[0].strip() - else: - # on Windows, cl.exe is not in PATH. We need to find the path. - # distutils.ccompiler.new_compiler() returns a msvccompiler - # object and after initialization, path to cl.exe is found. - import locale - import os - from distutils.ccompiler import new_compiler - ccompiler = new_compiler() - ccompiler.initialize() - cc = subprocess.check_output( - f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True) - encoding = os.device_encoding( - sys.stdout.fileno()) or locale.getpreferredencoding() - env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip() - env_info['GCC'] = 'n/a' - except (subprocess.CalledProcessError, errors.DistutilsPlatformError): - env_info['GCC'] = 'n/a' - except io.UnsupportedOperation as e: - # JupyterLab on Windows changes sys.stdout, which has no `fileno` attr - # Refer to: https://github.com/open-mmlab/mmengine/issues/931 - # TODO: find a solution to get compiler info in Windows JupyterLab, - # while preserving backward-compatibility in other systems. - env_info['MSVC'] = f'n/a, reason: {str(e)}' - - env_info['PyTorch'] = torch.__version__ - env_info['PyTorch compiling details'] = get_build_config() - - try: - import torchvision - env_info['TorchVision'] = torchvision.__version__ - except ModuleNotFoundError: - pass - - try: - import cv2 - env_info['OpenCV'] = cv2.__version__ - except ImportError: - pass - - env_info['MMEngine'] = mmengine.__version__ - - return env_info diff --git a/mmengine/utils/dl_utils/hub.py b/mmengine/utils/dl_utils/hub.py deleted file mode 100644 index 7f7f1a087d..0000000000 --- a/mmengine/utils/dl_utils/hub.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based -# file format. It will cause RuntimeError when a checkpoint was saved in -# torch >= 1.6.0 but loaded in torch < 1.7.0. -# More details at https://github.com/open-mmlab/mmpose/issues/904 - -from ..path import mkdir_or_exist -from ..version_utils import digit_version -from .parrots_wrapper import TORCH_VERSION - -if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( - '1.7.0'): - # Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py - import os - import sys - import warnings - import zipfile - from urllib.parse import urlparse - - import torch - from torch.hub import HASH_REGEX, _get_torch_home, download_url_to_file - - # Hub used to support automatically extracts from zipfile manually - # compressed by users. The legacy zip format expects only one file from - # torch.save() < 1.6 in the zip. We should remove this support since - # zipfile is now default zipfile format for torch.save(). - def _is_legacy_zip_format(filename): - if zipfile.is_zipfile(filename): - infolist = zipfile.ZipFile(filename).infolist() - return len(infolist) == 1 and not infolist[0].is_dir() - return False - - def _legacy_zip_load(filename, model_dir, map_location): - warnings.warn( - 'Falling back to the old format < 1.6. This support will' - ' be deprecated in favor of default zipfile format ' - 'introduced in 1.6. Please redo torch.save() to save it ' - 'in the new zipfile format.', DeprecationWarning) - # Note: extractall() defaults to overwrite file if exists. No need to - # clean up beforehand. We deliberately don't handle tarfile here - # since our legacy serialization format was in tar. - # E.g. resnet18-5c106cde.pth which is widely used. - with zipfile.ZipFile(filename) as f: - members = f.infolist() - if len(members) != 1: - raise RuntimeError( - 'Only one file(not dir) is allowed in the zipfile') - f.extractall(model_dir) - extraced_name = members[0].filename - extracted_file = os.path.join(model_dir, extraced_name) - return torch.load(extracted_file, map_location=map_location) - - def load_url(url, - model_dir=None, - map_location=None, - progress=True, - check_hash=False, - file_name=None): - r"""Loads the Torch serialized object at the given URL. - - If downloaded file is a zip file, it will be automatically decompressed - If the object is already present in `model_dir`, it's deserialized and - returned. - The default value of ``model_dir`` is ``/checkpoints`` where - ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. - Args: - url (str): URL of the object to download - model_dir (str, optional): directory in which to save the object - map_location (optional): a function or a dict specifying how to - remap storage locations (see torch.load) - progress (bool, optional): whether or not to display a progress bar - to stderr. Defaults to True - check_hash(bool, optional): If True, the filename part of the URL - should follow the naming convention ``filename-.ext`` - where ```` is the first eight or more digits of the - SHA256 hash of the contents of the file. The hash is used to - ensure unique names and to verify the contents of the file. - Defaults to False - file_name (str, optional): name for the downloaded file. Filename - from ``url`` will be used if not set. Defaults to None. - Example: - >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' - ... 'cde.pth') - >>> state_dict = torch.hub.load_state_dict_from_url(url) - """ - # Issue warning to move data if old env is set - if os.getenv('TORCH_MODEL_ZOO'): - warnings.warn( - 'TORCH_MODEL_ZOO is deprecated, please use env ' - 'TORCH_HOME instead', DeprecationWarning) - - if model_dir is None: - torch_home = _get_torch_home() - model_dir = os.path.join(torch_home, 'checkpoints') - - mkdir_or_exist(model_dir) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - if file_name is not None: - filename = file_name - cached_file = os.path.join(model_dir, filename) - if not os.path.exists(cached_file): - sys.stderr.write('Downloading: "{}" to {}\n'.format( - url, cached_file)) - hash_prefix = None - if check_hash: - r = HASH_REGEX.search(filename) # r is Optional[Match[str]] - hash_prefix = r.group(1) if r else None - download_url_to_file( - url, cached_file, hash_prefix, progress=progress) - - if _is_legacy_zip_format(cached_file): - return _legacy_zip_load(cached_file, model_dir, map_location) - - try: - return torch.load(cached_file, map_location=map_location) - except RuntimeError as error: - if digit_version(TORCH_VERSION) < digit_version('1.5.0'): - warnings.warn( - f'If the error is the same as "{cached_file} is a zip ' - 'archive (did you mean to use torch.jit.load()?)", you can' - ' upgrade your torch to 1.5.0 or higher (current torch ' - f'version is {TORCH_VERSION}). The error was raised ' - ' because the checkpoint was saved in torch>=1.6.0 but ' - 'loaded in torch<1.5.') - raise error -else: - from torch.utils.model_zoo import load_url # type: ignore # noqa: F401 diff --git a/mmengine/utils/dl_utils/misc.py b/mmengine/utils/dl_utils/misc.py deleted file mode 100644 index ce52d22c3b..0000000000 --- a/mmengine/utils/dl_utils/misc.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pkgutil -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - -from ..misc import is_tuple_of -from .parrots_wrapper import _BatchNorm, _InstanceNorm - - -def is_norm(layer: nn.Module, - exclude: Optional[Union[type, Tuple[type]]] = None) -> bool: - """Check if a layer is a normalization layer. - - Args: - layer (nn.Module): The layer to be checked. - exclude (type, tuple[type], optional): Types to be excluded. - - Returns: - bool: Whether the layer is a norm layer. - """ - if exclude is not None: - if not isinstance(exclude, tuple): - exclude = (exclude, ) - if not is_tuple_of(exclude, type): - raise TypeError( - f'"exclude" must be either None or type or a tuple of types, ' - f'but got {type(exclude)}: {exclude}') - - if exclude and isinstance(layer, exclude): - return False - - all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) - return isinstance(layer, all_norm_bases) - - -def tensor2imgs(tensor: torch.Tensor, - mean: Optional[Tuple[float, float, float]] = None, - std: Optional[Tuple[float, float, float]] = None, - to_bgr: bool = True): - """Convert tensor to 3-channel images or 1-channel gray images. - - Args: - tensor (torch.Tensor): Tensor that contains multiple images, shape ( - N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format - should be RGB. - mean (tuple[float], optional): Mean of images. If None, - (0, 0, 0) will be used for tensor with 3-channel, - while (0, ) for tensor with 1-channel. Defaults to None. - std (tuple[float], optional): Standard deviation of images. If None, - (1, 1, 1) will be used for tensor with 3-channel, - while (1, ) for tensor with 1-channel. Defaults to None. - to_bgr (bool): For the tensor with 3 channel, convert its format to - BGR. For the tensor with 1 channel, it must be False. Defaults to - True. - - Returns: - list[np.ndarray]: A list that contains multiple images. - """ - - assert torch.is_tensor(tensor) and tensor.ndim == 4 - channels = tensor.size(1) - assert channels in [1, 3] - if mean is None: - mean = (0, ) * channels - if std is None: - std = (1, ) * channels - assert (channels == len(mean) == len(std) == 3) or \ - (channels == len(mean) == len(std) == 1 and not to_bgr) - mean = tensor.new_tensor(mean).view(1, -1) - std = tensor.new_tensor(std).view(1, -1) - tensor = tensor.permute(0, 2, 3, 1) * std + mean - imgs = tensor.detach().cpu().numpy() - if to_bgr and channels == 3: - imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR - imgs = [np.ascontiguousarray(img) for img in imgs] - return imgs - - -def has_batch_norm(model: nn.Module) -> bool: - """Detect whether model has a BatchNormalization layer. - - Args: - model (nn.Module): training model. - - Returns: - bool: whether model has a BatchNormalization layer - """ - if isinstance(model, _BatchNorm): - return True - for m in model.children(): - if has_batch_norm(m): - return True - return False - - -def mmcv_full_available() -> bool: - """Check whether mmcv-full is installed. - - Returns: - bool: True if mmcv-full is installed else False. - """ - try: - import mmcv # noqa: F401 - except ImportError: - return False - ext_loader = pkgutil.find_loader('mmcv._ext') - return ext_loader is not None diff --git a/mmengine/utils/dl_utils/parrots_wrapper.py b/mmengine/utils/dl_utils/parrots_wrapper.py deleted file mode 100644 index 9bd8e5443a..0000000000 --- a/mmengine/utils/dl_utils/parrots_wrapper.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from functools import partial -from typing import Optional - -import torch - -TORCH_VERSION = torch.__version__ - - -def is_rocm_pytorch() -> bool: - """Check whether the PyTorch is compiled on ROCm.""" - is_rocm = False - if TORCH_VERSION != 'parrots': - try: - from torch.utils.cpp_extension import ROCM_HOME - is_rocm = True if ((torch.version.hip is not None) and - (ROCM_HOME is not None)) else False - except ImportError: - pass - return is_rocm - - -def _get_cuda_home() -> Optional[str]: - """Obtain the path of CUDA home.""" - if TORCH_VERSION == 'parrots': - from parrots.utils.build_extension import CUDA_HOME - else: - if is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - CUDA_HOME = ROCM_HOME - else: - from torch.utils.cpp_extension import CUDA_HOME - return CUDA_HOME - - -def get_build_config(): - """Obtain the build information of PyTorch or Parrots.""" - if TORCH_VERSION == 'parrots': - from parrots.config import get_build_info - return get_build_info() - else: - return torch.__config__.show() - - -def _get_conv() -> tuple: - """A wrapper to obtain base classes of Conv layers from PyTorch or - Parrots.""" - if TORCH_VERSION == 'parrots': - from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin - else: - from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin - return _ConvNd, _ConvTransposeMixin - - -def _get_dataloader() -> tuple: - """A wrapper to obtain DataLoader class from PyTorch or Parrots.""" - if TORCH_VERSION == 'parrots': - from torch.utils.data import DataLoader, PoolDataLoader - else: - from torch.utils.data import DataLoader - PoolDataLoader = DataLoader - return DataLoader, PoolDataLoader - - -def _get_extension(): - """A wrapper to obtain extension class from PyTorch or Parrots.""" - if TORCH_VERSION == 'parrots': - from parrots.utils.build_extension import BuildExtension, Extension - CppExtension = partial(Extension, cuda=False) - CUDAExtension = partial(Extension, cuda=True) - else: - from torch.utils.cpp_extension import (BuildExtension, CppExtension, - CUDAExtension) - return BuildExtension, CppExtension, CUDAExtension - - -def _get_pool() -> tuple: - """A wrapper to obtain base classes of pooling layers from PyTorch or - Parrots.""" - if TORCH_VERSION == 'parrots': - from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, _AvgPoolNd, - _MaxPoolNd) - else: - from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, _AvgPoolNd, - _MaxPoolNd) - return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd - - -def _get_norm() -> tuple: - """A wrapper to obtain base classes of normalization layers from PyTorch or - Parrots.""" - if TORCH_VERSION == 'parrots': - from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm - SyncBatchNorm_ = torch.nn.SyncBatchNorm2d - else: - from torch.nn.modules.batchnorm import _BatchNorm - from torch.nn.modules.instancenorm import _InstanceNorm - SyncBatchNorm_ = torch.nn.SyncBatchNorm - return _BatchNorm, _InstanceNorm, SyncBatchNorm_ - - -_ConvNd, _ConvTransposeMixin = _get_conv() -DataLoader, PoolDataLoader = _get_dataloader() -_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() -_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() - - -class SyncBatchNorm(SyncBatchNorm_): # type: ignore - - def _check_input_dim(self, input): - if TORCH_VERSION == 'parrots': - if input.dim() < 2: - raise ValueError( - f'expected at least 2D input (got {input.dim()}D input)') - else: - super()._check_input_dim(input) diff --git a/mmengine/utils/dl_utils/setup_env.py b/mmengine/utils/dl_utils/setup_env.py deleted file mode 100644 index 8c23a56a13..0000000000 --- a/mmengine/utils/dl_utils/setup_env.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import platform -import warnings - -import torch.multiprocessing as mp - - -def set_multi_processing(mp_start_method: str = 'fork', - opencv_num_threads: int = 0, - distributed: bool = False) -> None: - """Set multi-processing related environment. - - Args: - mp_start_method (str): Set the method which should be used to start - child processes. Defaults to 'fork'. - opencv_num_threads (int): Number of threads for opencv. - Defaults to 0. - distributed (bool): True if distributed environment. - Defaults to False. - """ - # set multi-process start method as `fork` to speed up the training - if platform.system() != 'Windows': - current_method = mp.get_start_method(allow_none=True) - if (current_method is not None and current_method != mp_start_method): - warnings.warn( - f'Multi-processing start method `{mp_start_method}` is ' - f'different from the previous setting `{current_method}`.' - f'It will be force set to `{mp_start_method}`. You can ' - 'change this behavior by changing `mp_start_method` in ' - 'your config.') - mp.set_start_method(mp_start_method, force=True) - - try: - import cv2 - - # disable opencv multithreading to avoid system being overloaded - cv2.setNumThreads(opencv_num_threads) - except ImportError: - pass - - # setup OMP threads - # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa - if 'OMP_NUM_THREADS' not in os.environ and distributed: - omp_num_threads = 1 - warnings.warn( - 'Setting OMP_NUM_THREADS environment variable for each process' - f' to be {omp_num_threads} in default, to avoid your system ' - 'being overloaded, please further tune the variable for ' - 'optimal performance in your application as needed.') - os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) - - # setup MKL threads - if 'MKL_NUM_THREADS' not in os.environ and distributed: - mkl_num_threads = 1 - warnings.warn( - 'Setting MKL_NUM_THREADS environment variable for each process' - f' to be {mkl_num_threads} in default, to avoid your system ' - 'being overloaded, please further tune the variable for ' - 'optimal performance in your application as needed.') - os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) diff --git a/mmengine/utils/dl_utils/time_counter.py b/mmengine/utils/dl_utils/time_counter.py deleted file mode 100644 index 05c008da45..0000000000 --- a/mmengine/utils/dl_utils/time_counter.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import time -from typing import Optional, Union - -import torch - -from mmengine.device import is_cuda_available, is_musa_available -from mmengine.dist.utils import master_only -from mmengine.logging import MMLogger, print_log - - -class TimeCounter: - """A tool that counts the average running time of a function or a method. - Users can use it as a decorator or context manager to calculate the average - running time of code blocks. - - Args: - log_interval (int): The interval of logging. Defaults to 1. - warmup_interval (int): The interval of warmup. Defaults to 1. - with_sync (bool): Whether to synchronize cuda. Defaults to True. - tag (str, optional): Function tag. Used to distinguish between - different functions or methods being called. Defaults to None. - logger (MMLogger, optional): Formatted logger used to record messages. - Defaults to None. - - Examples: - >>> import time - >>> from mmengine.utils.dl_utils import TimeCounter - >>> @TimeCounter() - ... def fun1(): - ... time.sleep(0.1) - ... fun1() - [fun1]-time per run averaged in the past 1 runs: 100.0 ms - - >>> @@TimeCounter(log_interval=2, tag='fun') - ... def fun2(): - ... time.sleep(0.2) - >>> for _ in range(3): - ... fun2() - [fun]-time per run averaged in the past 2 runs: 200.0 ms - - >>> with TimeCounter(tag='fun3'): - ... time.sleep(0.3) - [fun3]-time per run averaged in the past 1 runs: 300.0 ms - """ - - instance_dict: dict = dict() - - log_interval: int - warmup_interval: int - logger: Optional[MMLogger] - __count: int - __pure_inf_time: float - - def __new__(cls, - log_interval: int = 1, - warmup_interval: int = 1, - with_sync: bool = True, - tag: Optional[str] = None, - logger: Optional[MMLogger] = None): - assert warmup_interval >= 1 - if tag is not None and tag in cls.instance_dict: - return cls.instance_dict[tag] - - instance = super().__new__(cls) - cls.instance_dict[tag] = instance - - instance.log_interval = log_interval - instance.warmup_interval = warmup_interval - instance.with_sync = with_sync # type: ignore - instance.tag = tag - instance.logger = logger - - instance.__count = 0 - instance.__pure_inf_time = 0. - instance.__start_time = 0. - - return instance - - @master_only - def __call__(self, fn): - if self.tag is None: - self.tag = fn.__name__ - - def wrapper(*args, **kwargs): - self.__count += 1 - - if self.with_sync: - if is_cuda_available(): - torch.cuda.synchronize() - elif is_musa_available(): - torch.musa.synchronize() - start_time = time.perf_counter() - - result = fn(*args, **kwargs) - - if self.with_sync: - if is_cuda_available(): - torch.cuda.synchronize() - elif is_musa_available(): - torch.musa.synchronize() - elapsed = time.perf_counter() - start_time - self.print_time(elapsed) - - return result - - return wrapper - - @master_only - def __enter__(self): - assert self.tag is not None, 'In order to clearly distinguish ' \ - 'printing information in different ' \ - 'contexts, please specify the ' \ - 'tag parameter' - - self.__count += 1 - - if self.with_sync and torch.cuda.is_available(): - torch.cuda.synchronize() - self.__start_time = time.perf_counter() - - @master_only - def __exit__(self, exc_type, exc_val, exc_tb): - if self.with_sync and torch.cuda.is_available(): - torch.cuda.synchronize() - elapsed = time.perf_counter() - self.__start_time - self.print_time(elapsed) - - def print_time(self, elapsed: Union[int, float]) -> None: - """Print times per count.""" - if self.__count >= self.warmup_interval: - self.__pure_inf_time += elapsed - - if self.__count % self.log_interval == 0: - times_per_count = 1000 * self.__pure_inf_time / ( - self.__count - self.warmup_interval + 1) - print_log( - f'[{self.tag}]-time per run averaged in the past ' - f'{self.__count} runs: {times_per_count:.1f} ms', - self.logger) diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py deleted file mode 100644 index 2550ae6986..0000000000 --- a/mmengine/utils/dl_utils/torch_ops.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from ..version_utils import digit_version -from .parrots_wrapper import TORCH_VERSION - -_torch_version_meshgrid_indexing = ( - 'parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) - - -def torch_meshgrid(*tensors): - """A wrapper of torch.meshgrid to compat different PyTorch versions. - - Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. - So we implement a wrapper here to avoid warning when using high-version - PyTorch and avoid compatibility issues when using previous versions of - PyTorch. - - Args: - tensors (List[Tensor]): List of scalars or 1 dimensional tensors. - - Returns: - Sequence[Tensor]: Sequence of meshgrid tensors. - """ - if _torch_version_meshgrid_indexing: - return torch.meshgrid(*tensors, indexing='ij') - else: - return torch.meshgrid(*tensors) # Uses indexing='ij' by default diff --git a/mmengine/utils/dl_utils/trace.py b/mmengine/utils/dl_utils/trace.py deleted file mode 100644 index c12bebf5d1..0000000000 --- a/mmengine/utils/dl_utils/trace.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings - -import torch - -from ..version_utils import digit_version - - -def is_jit_tracing() -> bool: - if (torch.__version__ != 'parrots' - and digit_version(torch.__version__) >= digit_version('1.6.0')): - on_trace = torch.jit.is_tracing() - # In PyTorch 1.6, torch.jit.is_tracing has a bug. - # Refers to https://github.com/pytorch/pytorch/issues/42448 - if isinstance(on_trace, bool): - return on_trace - else: - return torch._C._is_tracing() - else: - warnings.warn( - 'torch.jit.is_tracing is only supported after v1.6.0. ' - 'Therefore is_tracing returns False automatically. Please ' - 'set on_trace manually if you are using trace.', UserWarning) - return False diff --git a/mmengine/utils/dl_utils/visualize.py b/mmengine/utils/dl_utils/visualize.py deleted file mode 100644 index f3361e1d50..0000000000 --- a/mmengine/utils/dl_utils/visualize.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import patch - -import torch -import torch.nn as nn - -from mmengine.model import BaseModel -from mmengine.registry import MODELS - - -@MODELS.register_module() -class ToyModel(BaseModel): - - def __init__(self, *args, **kwargs): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, *args, **kwargs): - return {'loss': torch.tensor(0.0)} - - -def update_params_step(self, loss): - pass - - -def runtimeinfo_step(self, runner, batch_idx, data_batch=None): - runner.message_hub.update_info('iter', runner.iter) - lr_dict = runner.optim_wrapper.get_lr() - for name, lr in lr_dict.items(): - runner.message_hub.update_scalar(f'train/{name}', lr[0]) - - momentum_dict = runner.optim_wrapper.get_momentum() - for name, momentum in momentum_dict.items(): - runner.message_hub.update_scalar(f'train/{name}', momentum[0]) - - -@patch('mmengine.optim.optimizer.OptimWrapper.update_params', - update_params_step) -@patch('mmengine.hooks.RuntimeInfoHook.before_train_iter', runtimeinfo_step) -def fake_run(cfg): - from mmengine.runner import Runner - cfg.pop('model') - cfg.pop('visualizer') - cfg.pop('val_dataloader') - cfg.pop('val_evaluator') - cfg.pop('val_cfg') - cfg.pop('test_dataloader') - cfg.pop('test_evaluator') - cfg.pop('test_cfg') - extra_cfg = dict( - model=dict(type='ToyModel'), - visualizer=dict( - type='Visualizer', - vis_backends=[ - dict(type='TensorboardVisBackend', save_dir='temp_dir') - ]), - ) - cfg.merge_from_dict(extra_cfg) - # build the runner from config - runner = Runner.from_cfg(cfg) - - # start training - runner.train() diff --git a/mmengine/utils/manager.py b/mmengine/utils/manager.py deleted file mode 100644 index 70b45f2d8e..0000000000 --- a/mmengine/utils/manager.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import threading -import warnings -from collections import OrderedDict -from typing import Type, TypeVar - -_lock = threading.RLock() -T = TypeVar('T') - - -def _accquire_lock() -> None: - """Acquire the module-level lock for serializing access to shared data. - - This should be released with _release_lock(). - """ - if _lock: - _lock.acquire() - - -def _release_lock() -> None: - """Release the module-level lock acquired by calling _accquire_lock().""" - if _lock: - _lock.release() - - -class ManagerMeta(type): - """The metaclass for global accessible class. - - The subclasses inheriting from ``ManagerMeta`` will manage their - own ``_instance_dict`` and root instances. The constructors of subclasses - must contain the ``name`` argument. - - Examples: - >>> class SubClass1(metaclass=ManagerMeta): - >>> def __init__(self, *args, **kwargs): - >>> pass - AssertionError: .__init__ must have the - name argument. - >>> class SubClass2(metaclass=ManagerMeta): - >>> def __init__(self, name): - >>> pass - >>> # valid format. - """ - - def __init__(cls, *args): - cls._instance_dict = OrderedDict() - params = inspect.getfullargspec(cls) - params_names = params[0] if params[0] else [] - assert 'name' in params_names, f'{cls} must have the `name` argument' - super().__init__(*args) - - -class ManagerMixin(metaclass=ManagerMeta): - """``ManagerMixin`` is the base class for classes that have global access - requirements. - - The subclasses inheriting from ``ManagerMixin`` can get their - global instances. - - Examples: - >>> class GlobalAccessible(ManagerMixin): - >>> def __init__(self, name=''): - >>> super().__init__(name) - >>> - >>> GlobalAccessible.get_instance('name') - >>> instance_1 = GlobalAccessible.get_instance('name') - >>> instance_2 = GlobalAccessible.get_instance('name') - >>> assert id(instance_1) == id(instance_2) - - Args: - name (str): Name of the instance. Defaults to ''. - """ - - def __init__(self, name: str = '', **kwargs): - assert isinstance(name, str) and name, \ - 'name argument must be an non-empty string.' - self._instance_name = name - - @classmethod - def get_instance(cls: Type[T], name: str, **kwargs) -> T: - """Get subclass instance by name if the name exists. - - If corresponding name instance has not been created, ``get_instance`` - will create an instance, otherwise ``get_instance`` will return the - corresponding instance. - - Examples - >>> instance1 = GlobalAccessible.get_instance('name1') - >>> # Create name1 instance. - >>> instance.instance_name - name1 - >>> instance2 = GlobalAccessible.get_instance('name1') - >>> # Get name1 instance. - >>> assert id(instance1) == id(instance2) - - Args: - name (str): Name of instance. Defaults to ''. - - Returns: - object: Corresponding name instance, the latest instance, or root - instance. - """ - _accquire_lock() - assert isinstance(name, str), \ - f'type of name should be str, but got {type(cls)}' - instance_dict = cls._instance_dict # type: ignore - # Get the instance by name. - if name not in instance_dict: - instance = cls(name=name, **kwargs) # type: ignore - instance_dict[name] = instance # type: ignore - elif kwargs: - warnings.warn( - f'{cls} instance named of {name} has been created, ' - 'the method `get_instance` should not accept any other ' - 'arguments') - # Get latest instantiated instance or root instance. - _release_lock() - return instance_dict[name] - - @classmethod - def get_current_instance(cls): - """Get latest created instance. - - Before calling ``get_current_instance``, The subclass must have called - ``get_instance(xxx)`` at least once. - - Examples - >>> instance = GlobalAccessible.get_current_instance() - AssertionError: At least one of name and current needs to be set - >>> instance = GlobalAccessible.get_instance('name1') - >>> instance.instance_name - name1 - >>> instance = GlobalAccessible.get_current_instance() - >>> instance.instance_name - name1 - - Returns: - object: Latest created instance. - """ - _accquire_lock() - if not cls._instance_dict: - raise RuntimeError( - f'Before calling {cls.__name__}.get_current_instance(), you ' - 'should call get_instance(name=xxx) at least once.') - name = next(iter(reversed(cls._instance_dict))) - _release_lock() - return cls._instance_dict[name] - - @classmethod - def check_instance_created(cls, name: str) -> bool: - """Check whether the name corresponding instance exists. - - Args: - name (str): Name of instance. - - Returns: - bool: Whether the name corresponding instance exists. - """ - return name in cls._instance_dict - - @property - def instance_name(self) -> str: - """Get the name of instance. - - Returns: - str: Name of instance. - """ - return self._instance_name diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py deleted file mode 100644 index 23ae707a56..0000000000 --- a/mmengine/utils/misc.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import collections.abc -import functools -import itertools -import logging -import re -import subprocess -import textwrap -import warnings -from collections import abc -from importlib import import_module -from inspect import getfullargspec, ismodule -from itertools import repeat -from typing import Any, Callable, Optional, Type, Union - - -# From PyTorch internals -def _ntuple(n): - - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = _ntuple - - -def is_str(x): - """Whether the input is an string instance. - - Note: This method is deprecated since python 2 is no longer supported. - """ - return isinstance(x, str) - - -def import_modules_from_strings(imports, allow_failed_imports=False): - """Import modules from the given list of strings. - - Args: - imports (list | str | None): The given module names to be imported. - allow_failed_imports (bool): If True, the failed imports will return - None. Otherwise, an ImportError is raise. Defaults to False. - - Returns: - list[module] | module | None: The imported modules. - - Examples: - >>> osp, sys = import_modules_from_strings( - ... ['os.path', 'sys']) - >>> import os.path as osp_ - >>> import sys as sys_ - >>> assert osp == osp_ - >>> assert sys == sys_ - """ - if not imports: - return - single_import = False - if isinstance(imports, str): - single_import = True - imports = [imports] - if not isinstance(imports, list): - raise TypeError( - f'custom_imports must be a list but got type {type(imports)}') - imported = [] - for imp in imports: - if not isinstance(imp, str): - raise TypeError( - f'{imp} is of type {type(imp)} and cannot be imported.') - try: - imported_tmp = import_module(imp) - except ImportError: - if allow_failed_imports: - warnings.warn(f'{imp} failed to import and is ignored.', - UserWarning) - imported_tmp = None - else: - raise ImportError(f'Failed to import {imp}') - imported.append(imported_tmp) - if single_import: - imported = imported[0] - return imported - - -def iter_cast(inputs, dst_type, return_type=None): - """Cast elements of an iterable object into some type. - - Args: - inputs (Iterable): The input object. - dst_type (type): Destination type. - return_type (type, optional): If specified, the output object will be - converted to this type, otherwise an iterator. - - Returns: - iterator or specified type: The converted object. - """ - if not isinstance(inputs, abc.Iterable): - raise TypeError('inputs must be an iterable object') - if not isinstance(dst_type, type): - raise TypeError('"dst_type" must be a valid type') - - out_iterable = map(dst_type, inputs) - - if return_type is None: - return out_iterable - else: - return return_type(out_iterable) - - -def list_cast(inputs, dst_type): - """Cast elements of an iterable object into a list of some type. - - A partial method of :func:`iter_cast`. - """ - return iter_cast(inputs, dst_type, return_type=list) - - -def tuple_cast(inputs, dst_type): - """Cast elements of an iterable object into a tuple of some type. - - A partial method of :func:`iter_cast`. - """ - return iter_cast(inputs, dst_type, return_type=tuple) - - -def is_seq_of(seq: Any, - expected_type: Union[Type, tuple], - seq_type: Optional[Type] = None) -> bool: - """Check whether it is a sequence of some type. - - Args: - seq (Sequence): The sequence to be checked. - expected_type (type or tuple): Expected type of sequence items. - seq_type (type, optional): Expected sequence type. Defaults to None. - - Returns: - bool: Return True if ``seq`` is valid else False. - - Examples: - >>> from mmengine.utils import is_seq_of - >>> seq = ['a', 'b', 'c'] - >>> is_seq_of(seq, str) - True - >>> is_seq_of(seq, int) - False - """ - if seq_type is None: - exp_seq_type = abc.Sequence - else: - assert isinstance(seq_type, type) - exp_seq_type = seq_type - if not isinstance(seq, exp_seq_type): - return False - for item in seq: - if not isinstance(item, expected_type): - return False - return True - - -def is_list_of(seq, expected_type): - """Check whether it is a list of some type. - - A partial method of :func:`is_seq_of`. - """ - return is_seq_of(seq, expected_type, seq_type=list) - - -def is_tuple_of(seq, expected_type): - """Check whether it is a tuple of some type. - - A partial method of :func:`is_seq_of`. - """ - return is_seq_of(seq, expected_type, seq_type=tuple) - - -def slice_list(in_list, lens): - """Slice a list into several sub lists by a list of given length. - - Args: - in_list (list): The list to be sliced. - lens(int or list): The expected length of each out list. - - Returns: - list: A list of sliced list. - """ - if isinstance(lens, int): - assert len(in_list) % lens == 0 - lens = [lens] * int(len(in_list) / lens) - if not isinstance(lens, list): - raise TypeError('"indices" must be an integer or a list of integers') - elif sum(lens) != len(in_list): - raise ValueError('sum of lens and list length does not ' - f'match: {sum(lens)} != {len(in_list)}') - out_list = [] - idx = 0 - for i in range(len(lens)): - out_list.append(in_list[idx:idx + lens[i]]) - idx += lens[i] - return out_list - - -def concat_list(in_list): - """Concatenate a list of list into a single list. - - Args: - in_list (list): The list of list to be merged. - - Returns: - list: The concatenated flat list. - """ - return list(itertools.chain(*in_list)) - - -def apply_to(data: Any, expr: Callable, apply_func: Callable): - """Apply function to each element in dict, list or tuple that matches with - the expression. - - For examples, if you want to convert each element in a list of dict from - `np.ndarray` to `Tensor`. You can use the following code: - - Examples: - >>> from mmengine.utils import apply_to - >>> import numpy as np - >>> import torch - >>> data = dict(array=[np.array(1)]) # {'array': [array(1)]} - >>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x)) - >>> print(result) # {'array': [tensor(1)]} - - Args: - data (Any): Data to be applied. - expr (Callable): Expression to tell which data should be applied with - the function. It should return a boolean. - apply_func (Callable): Function applied to data. - - Returns: - Any: The data after applying. - """ # noqa: E501 - if isinstance(data, dict): - # Keep the original dict type - res = type(data)() - for key, value in data.items(): - res[key] = apply_to(value, expr, apply_func) - return res - elif isinstance(data, tuple) and hasattr(data, '_fields'): - # namedtuple - return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable - elif isinstance(data, (tuple, list)): - return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable - elif expr(data): - return apply_func(data) - else: - return data - - -def check_prerequisites( - prerequisites, - checker, - msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' - 'found, please install them first.'): # yapf: disable - """A decorator factory to check if prerequisites are satisfied. - - Args: - prerequisites (str of list[str]): Prerequisites to be checked. - checker (callable): The checker method that returns True if a - prerequisite is meet, False otherwise. - msg_tmpl (str): The message template with two variables. - - Returns: - decorator: A specific decorator. - """ - - def wrap(func): - - @functools.wraps(func) - def wrapped_func(*args, **kwargs): - requirements = [prerequisites] if isinstance( - prerequisites, str) else prerequisites - missing = [] - for item in requirements: - if not checker(item): - missing.append(item) - if missing: - print(msg_tmpl.format(', '.join(missing), func.__name__)) - raise RuntimeError('Prerequisites not meet.') - else: - return func(*args, **kwargs) - - return wrapped_func - - return wrap - - -def _check_py_package(package): - try: - import_module(package) - except ImportError: - return False - else: - return True - - -def _check_executable(cmd): - if subprocess.call(f'which {cmd}', shell=True) != 0: - return False - else: - return True - - -def requires_package(prerequisites): - """A decorator to check if some python packages are installed. - - Example: - >>> @requires_package('numpy') - >>> func(arg1, args): - >>> return numpy.zeros(1) - array([0.]) - >>> @requires_package(['numpy', 'non_package']) - >>> func(arg1, args): - >>> return numpy.zeros(1) - ImportError - """ - return check_prerequisites(prerequisites, checker=_check_py_package) - - -def requires_executable(prerequisites): - """A decorator to check if some executable files are installed. - - Example: - >>> @requires_executable('ffmpeg') - >>> func(arg1, args): - >>> print(1) - 1 - """ - return check_prerequisites(prerequisites, checker=_check_executable) - - -def deprecated_api_warning(name_dict: dict, - cls_name: Optional[str] = None) -> Callable: - """A decorator to check if some arguments are deprecate and try to replace - deprecate src_arg_name to dst_arg_name. - - Args: - name_dict(dict): - key (str): Deprecate argument names. - val (str): Expected argument names. - - Returns: - func: New function. - """ - - def api_warning_wrapper(old_func): - - @functools.wraps(old_func) - def new_func(*args, **kwargs): - # get the arg spec of the decorated method - args_info = getfullargspec(old_func) - # get name of the function - func_name = old_func.__name__ - if cls_name is not None: - func_name = f'{cls_name}.{func_name}' - if args: - arg_names = args_info.args[:len(args)] - for src_arg_name, dst_arg_name in name_dict.items(): - if src_arg_name in arg_names: - warnings.warn( - f'"{src_arg_name}" is deprecated in ' - f'`{func_name}`, please use "{dst_arg_name}" ' - 'instead', DeprecationWarning) - arg_names[arg_names.index(src_arg_name)] = dst_arg_name - if kwargs: - for src_arg_name, dst_arg_name in name_dict.items(): - if src_arg_name in kwargs: - assert dst_arg_name not in kwargs, ( - f'The expected behavior is to replace ' - f'the deprecated key `{src_arg_name}` to ' - f'new key `{dst_arg_name}`, but got them ' - f'in the arguments at the same time, which ' - f'is confusing. `{src_arg_name} will be ' - f'deprecated in the future, please ' - f'use `{dst_arg_name}` instead.') - - warnings.warn( - f'"{src_arg_name}" is deprecated in ' - f'`{func_name}`, please use "{dst_arg_name}" ' - 'instead', DeprecationWarning) - kwargs[dst_arg_name] = kwargs.pop(src_arg_name) - - # apply converted arguments to the decorated method - output = old_func(*args, **kwargs) - return output - - return new_func - - return api_warning_wrapper - - -def is_method_overridden(method: str, base_class: type, - derived_class: Union[type, Any]) -> bool: - """Check if a method of base class is overridden in derived class. - - Args: - method (str): the method name to check. - base_class (type): the class of the base class. - derived_class (type | Any): the class or instance of the derived class. - """ - assert isinstance(base_class, type), \ - "base_class doesn't accept instance, Please pass class instead." - - if not isinstance(derived_class, type): - derived_class = derived_class.__class__ - - base_method = getattr(base_class, method) - derived_method = getattr(derived_class, method) - return derived_method != base_method - - -def has_method(obj: object, method: str) -> bool: - """Check whether the object has a method. - - Args: - method (str): The method name to check. - obj (object): The object to check. - - Returns: - bool: True if the object has the method else False. - """ - return hasattr(obj, method) and callable(getattr(obj, method)) - - -def deprecated_function(since: str, removed_in: str, - instructions: str) -> Callable: - """Marks functions as deprecated. - - Throw a warning when a deprecated function is called, and add a note in the - docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py - - Args: - since (str): The version when the function was first deprecated. - removed_in (str): The version when the function will be removed. - instructions (str): The action users should take. - - Returns: - Callable: A new function, which will be deprecated soon. - """ # noqa: E501 - from mmengine import print_log - - def decorator(function): - - @functools.wraps(function) - def wrapper(*args, **kwargs): - print_log( - f"'{function.__module__}.{function.__name__}' " - f'is deprecated in version {since} and will be ' - f'removed in version {removed_in}. Please {instructions}.', - logger='current', - level=logging.WARNING, - ) - return function(*args, **kwargs) - - indent = ' ' - # Add a deprecation note to the docstring. - docstring = function.__doc__ or '' - # Add a note to the docstring. - deprecation_note = textwrap.dedent(f"""\ - .. deprecated:: {since} - Deprecated and will be removed in version {removed_in}. - Please {instructions}. - """) - # Split docstring at first occurrence of newline - pattern = '\n\n' - summary_and_body = re.split(pattern, docstring, 1) - - if len(summary_and_body) > 1: - summary, body = summary_and_body - body = textwrap.indent(textwrap.dedent(body), indent) - summary = '\n'.join( - [textwrap.dedent(string) for string in summary.split('\n')]) - summary = textwrap.indent(summary, prefix=indent) - # Dedent the body. We cannot do this with the presence of the - # summary because the body contains leading whitespaces when the - # summary does not. - new_docstring_parts = [ - deprecation_note, '\n\n', summary, '\n\n', body - ] - else: - summary = summary_and_body[0] - summary = '\n'.join( - [textwrap.dedent(string) for string in summary.split('\n')]) - summary = textwrap.indent(summary, prefix=indent) - new_docstring_parts = [deprecation_note, '\n\n', summary] - - wrapper.__doc__ = ''.join(new_docstring_parts) - - return wrapper - - return decorator - - -def get_object_from_string(obj_name: str): - """Get object from name. - - Args: - obj_name (str): The name of the object. - - Examples: - >>> get_object_from_string('torch.optim.sgd.SGD') - >>> torch.optim.sgd.SGD - """ - parts = iter(obj_name.split('.')) - module_name = next(parts) - # import module - while True: - try: - module = import_module(module_name) - part = next(parts) - # mmcv.ops has nms.py and nms function at the same time. So the - # function will have a higher priority - obj = getattr(module, part, None) - if obj is not None and not ismodule(obj): - break - module_name = f'{module_name}.{part}' - except StopIteration: - # if obj is a module - return module - except ImportError: - return None - - # get class or attribute from module - obj = module - while True: - try: - obj = getattr(obj, part) - part = next(parts) - except StopIteration: - return obj - except AttributeError: - return None diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py deleted file mode 100644 index 1816f47f07..0000000000 --- a/mmengine/utils/package_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -import subprocess - - -def is_installed(package: str) -> bool: - """Check package whether installed. - - Args: - package (str): Name of package to be checked. - """ - # When executing `import mmengine.runner`, - # pkg_resources will be imported and it takes too much time. - # Therefore, import it in function scope to save time. - import importlib.util - - import pkg_resources - from pkg_resources import get_distribution - - # refresh the pkg_resources - # more datails at https://github.com/pypa/setuptools/issues/373 - importlib.reload(pkg_resources) - try: - get_distribution(package) - return True - except pkg_resources.DistributionNotFound: - spec = importlib.util.find_spec(package) - if spec is None: - return False - elif spec.origin is not None: - return True - else: - return False - - -def get_installed_path(package: str) -> str: - """Get installed path of package. - - Args: - package (str): Name of package. - - Example: - >>> get_installed_path('mmcls') - >>> '.../lib/python3.7/site-packages/mmcls' - """ - import importlib.util - - from pkg_resources import DistributionNotFound, get_distribution - - # if the package name is not the same as module name, module name should be - # inferred. For example, mmcv-full is the package name, but mmcv is module - # name. If we want to get the installed path of mmcv-full, we should concat - # the pkg.location and module name - try: - pkg = get_distribution(package) - except DistributionNotFound as e: - # if the package is not installed, package path set in PYTHONPATH - # can be detected by `find_spec` - spec = importlib.util.find_spec(package) - if spec is not None: - if spec.origin is not None: - return osp.dirname(spec.origin) - else: - # `get_installed_path` cannot get the installed path of - # namespace packages - raise RuntimeError( - f'{package} is a namespace package, which is invalid ' - 'for `get_install_path`') - else: - raise e - - possible_path = osp.join(pkg.location, package) # type: ignore - if osp.exists(possible_path): - return possible_path - else: - return osp.join(pkg.location, package2module(package)) # type: ignore - - -def package2module(package: str): - """Infer module name from package. - - Args: - package (str): Package to infer module name. - """ - from pkg_resources import get_distribution - pkg = get_distribution(package) - if pkg.has_metadata('top_level.txt'): - module_name = pkg.get_metadata('top_level.txt').split('\n')[0] - return module_name - else: - raise ValueError(f'can not infer the module name of {package}') - - -def call_command(cmd: list) -> None: - try: - subprocess.check_call(cmd) - except Exception as e: - raise e # type: ignore - - -def install_package(package: str): - if not is_installed(package): - call_command(['python', '-m', 'pip', 'install', package]) diff --git a/mmengine/utils/path.py b/mmengine/utils/path.py deleted file mode 100644 index 307d053f2f..0000000000 --- a/mmengine/utils/path.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -from pathlib import Path - -from .misc import is_str - - -def is_filepath(x): - return is_str(x) or isinstance(x, Path) - - -def fopen(filepath, *args, **kwargs): - if is_str(filepath): - return open(filepath, *args, **kwargs) - elif isinstance(filepath, Path): - return filepath.open(*args, **kwargs) - raise ValueError('`filepath` should be a string or a Path') - - -def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): - if not osp.isfile(filename): - raise FileNotFoundError(msg_tmpl.format(filename)) - - -def mkdir_or_exist(dir_name, mode=0o777): - if dir_name == '': - return - dir_name = osp.expanduser(dir_name) - os.makedirs(dir_name, mode=mode, exist_ok=True) - - -def symlink(src, dst, overwrite=True, **kwargs): - if os.path.lexists(dst) and overwrite: - os.remove(dst) - os.symlink(src, dst, **kwargs) - - -def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): - """Scan a directory to find the interested files. - - Args: - dir_path (str | :obj:`Path`): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Defaults to None. - recursive (bool, optional): If set to True, recursively scan the - directory. Defaults to False. - case_sensitive (bool, optional) : If set to False, ignore the case of - suffix. Defaults to True. - - Returns: - A generator for all the interested files with relative paths. - """ - if isinstance(dir_path, (str, Path)): - dir_path = str(dir_path) - else: - raise TypeError('"dir_path" must be a string or Path object') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - if suffix is not None and not case_sensitive: - suffix = suffix.lower() if isinstance(suffix, str) else tuple( - item.lower() for item in suffix) - - root = dir_path - - def _scandir(dir_path, suffix, recursive, case_sensitive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - _rel_path = rel_path if case_sensitive else rel_path.lower() - if suffix is None or _rel_path.endswith(suffix): - yield rel_path - elif recursive and os.path.isdir(entry.path): - # scan recursively if entry.path is a directory - yield from _scandir(entry.path, suffix, recursive, - case_sensitive) - - return _scandir(dir_path, suffix, recursive, case_sensitive) - - -def find_vcs_root(path, markers=('.git', )): - """Finds the root directory (including itself) of specified markers. - - Args: - path (str): Path of directory or file. - markers (list[str], optional): List of file or directory names. - - Returns: - The directory contained one of the markers or None if not found. - """ - if osp.isfile(path): - path = osp.dirname(path) - - prev, cur = None, osp.abspath(osp.expanduser(path)) - while cur != prev: - if any(osp.exists(osp.join(cur, marker)) for marker in markers): - return cur - prev, cur = cur, osp.split(cur)[0] - return None - - -def is_abs(path: str) -> bool: - """Check if path is an absolute path in different backends. - - Args: - path (str): path of directory or file. - - Returns: - bool: whether path is an absolute path. - """ - if osp.isabs(path) or path.startswith(('http://', 'https://', 's3://')): - return True - else: - return False diff --git a/mmengine/utils/progressbar.py b/mmengine/utils/progressbar.py deleted file mode 100644 index 47e710603b..0000000000 --- a/mmengine/utils/progressbar.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import sys -from collections.abc import Iterable -from multiprocessing import Pool -from shutil import get_terminal_size -from typing import Callable, Optional, Sequence - -from .timer import Timer - - -class ProgressBar: - """A progress bar which can print the progress. - - Args: - task_num (int): Number of total steps. Defaults to 0. - bar_width (int): Width of the progress bar. Defaults to 50. - start (bool): Whether to start the progress bar in the constructor. - Defaults to True. - file (callable): Progress bar output mode. Defaults to "sys.stdout". - - Examples: - >>> import mmengine - >>> import time - >>> bar = mmengine.ProgressBar(10) - >>> for i in range(10): - >>> bar.update() - >>> time.sleep(1) - """ - - def __init__(self, - task_num: int = 0, - bar_width: int = 50, - start: bool = True, - file=sys.stdout): - self.task_num = task_num - self.bar_width = bar_width - self.completed = 0 - self.file = file - if start: - self.start() - - @property - def terminal_width(self): - width, _ = get_terminal_size() - return width - - def start(self): - if self.task_num > 0: - self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' - 'elapsed: 0s, ETA:') - else: - self.file.write('completed: 0, elapsed: 0s') - self.file.flush() - self.timer = Timer() - - def update(self, num_tasks: int = 1): - """Update progressbar. - - Args: - num_tasks (int): Update step size. - """ - assert num_tasks > 0 - self.completed += num_tasks - elapsed = self.timer.since_start() - if elapsed > 0: - fps = self.completed / elapsed - else: - fps = float('inf') - if self.task_num > 0: - percentage = self.completed / float(self.task_num) - eta = int(elapsed * (1 - percentage) / percentage + 0.5) - msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ - f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ - f'ETA: {eta:5}s' - - bar_width = min(self.bar_width, - int(self.terminal_width - len(msg)) + 2, - int(self.terminal_width * 0.6)) - bar_width = max(2, bar_width) - mark_width = int(bar_width * percentage) - bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) - self.file.write(msg.format(bar_chars)) - else: - self.file.write( - f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' - f' {fps:.1f} tasks/s') - self.file.flush() - - -def track_progress(func: Callable, - tasks: Sequence, - bar_width: int = 50, - file=sys.stdout, - **kwargs): - """Track the progress of tasks execution with a progress bar. - - Tasks are done with a simple for-loop. - - Args: - func (callable): The function to be applied to each task. - tasks (Sequence): If tasks is a tuple, it must contain two elements, - the first being the tasks to be completed and the other being the - number of tasks. If it is not a tuple, it represents the tasks to - be completed. - bar_width (int): Width of progress bar. - - Returns: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] # type: ignore - elif isinstance(tasks, Sequence): - task_num = len(tasks) - else: - raise TypeError( - '"tasks" must be a tuple object or a sequence object, but got ' - f'{type(tasks)}') - prog_bar = ProgressBar(task_num, bar_width, file=file) - results = [] - for task in tasks: - results.append(func(task, **kwargs)) - prog_bar.update() - prog_bar.file.write('\n') - return results - - -def init_pool(process_num, initializer=None, initargs=None): - if initializer is None: - return Pool(process_num) - elif initargs is None: - return Pool(process_num, initializer) - else: - if not isinstance(initargs, tuple): - raise TypeError('"initargs" must be a tuple') - return Pool(process_num, initializer, initargs) - - -def track_parallel_progress(func: Callable, - tasks: Sequence, - nproc: int, - initializer: Optional[Callable] = None, - initargs: Optional[tuple] = None, - bar_width: int = 50, - chunksize: int = 1, - skip_first: bool = False, - keep_order: bool = True, - file=sys.stdout): - """Track the progress of parallel task execution with a progress bar. - - The built-in :mod:`multiprocessing` module is used for process pools and - tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. - - Args: - func (callable): The function to be applied to each task. - tasks (Sequence): If tasks is a tuple, it must contain two elements, - the first being the tasks to be completed and the other being the - number of tasks. If it is not a tuple, it represents the tasks to - be completed. - nproc (int): Process (worker) number. - initializer (None or callable): Refer to :class:`multiprocessing.Pool` - for details. - initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for - details. - chunksize (int): Refer to :class:`multiprocessing.Pool` for details. - bar_width (int): Width of progress bar. - skip_first (bool): Whether to skip the first sample for each worker - when estimating fps, since the initialization step may takes - longer. - keep_order (bool): If True, :func:`Pool.imap` is used, otherwise - :func:`Pool.imap_unordered` is used. - - Returns: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] # type: ignore - elif isinstance(tasks, Sequence): - task_num = len(tasks) - else: - raise TypeError( - '"tasks" must be a tuple object or a sequence object, but got ' - f'{type(tasks)}') - pool = init_pool(nproc, initializer, initargs) - start = not skip_first - task_num -= nproc * chunksize * int(skip_first) - prog_bar = ProgressBar(task_num, bar_width, start, file=file) - results = [] - if keep_order: - gen = pool.imap(func, tasks, chunksize) - else: - gen = pool.imap_unordered(func, tasks, chunksize) - for result in gen: - results.append(result) - if skip_first: - if len(results) < nproc * chunksize: - continue - elif len(results) == nproc * chunksize: - prog_bar.start() - continue - prog_bar.update() - prog_bar.file.write('\n') - pool.close() - pool.join() - return results - - -def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout): - """Track the progress of tasks iteration or enumeration with a progress - bar. - - Tasks are yielded with a simple for-loop. - - Args: - tasks (Sequence): If tasks is a tuple, it must contain two elements, - the first being the tasks to be completed and the other being the - number of tasks. If it is not a tuple, it represents the tasks to - be completed. - bar_width (int): Width of progress bar. - - Yields: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] # type: ignore - elif isinstance(tasks, Sequence): - task_num = len(tasks) - else: - raise TypeError( - '"tasks" must be a tuple object or a sequence object, but got ' - f'{type(tasks)}') - prog_bar = ProgressBar(task_num, bar_width, file=file) - for task in tasks: - yield task - prog_bar.update() - prog_bar.file.write('\n') diff --git a/mmengine/utils/progressbar_rich.py b/mmengine/utils/progressbar_rich.py deleted file mode 100644 index f8e04d8041..0000000000 --- a/mmengine/utils/progressbar_rich.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from multiprocessing import Pool -from typing import Callable, Iterable, Optional, Sized - -from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, - TaskProgressColumn, TextColumn, TimeRemainingColumn) -from rich.text import Text - - -class _Worker: - """Function wrapper for ``track_progress_rich``""" - - def __init__(self, func) -> None: - self.func = func - - def __call__(self, inputs): - inputs, idx = inputs - if not isinstance(inputs, (tuple, list)): - inputs = (inputs, ) - - return self.func(*inputs), idx - - -class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): - """Skip calculating remaining time for the first few times. - - Args: - skip_times (int): The number of times to skip. Defaults to 0. - """ - - def __init__(self, *args, skip_times=0, **kwargs): - super().__init__(*args, **kwargs) - self.skip_times = skip_times - - def render(self, task: Task) -> Text: - """Show time remaining.""" - if task.completed <= self.skip_times: - return Text('-:--:--', style='progress.remaining') - return super().render(task) - - -def _tasks_with_index(tasks): - """Add index to tasks.""" - for idx, task in enumerate(tasks): - yield task, idx - - -def track_progress_rich(func: Callable, - tasks: Iterable = tuple(), - task_num: Optional[int] = None, - nproc: int = 1, - chunksize: int = 1, - description: str = 'Processing', - color: str = 'blue') -> list: - """Track the progress of parallel task execution with a progress bar. The - built-in :mod:`multiprocessing` module is used for process pools and tasks - are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. - - Args: - func (callable): The function to be applied to each task. - tasks (Iterable or Sized): A tuple of tasks. There are several cases - for different format tasks: - - When ``func`` accepts no arguments: tasks should be an empty - tuple, and ``task_num`` must be specified. - - When ``func`` accepts only one argument: tasks should be a tuple - containing the argument. - - When ``func`` accepts multiple arguments: tasks should be a - tuple, with each element representing a set of arguments. - If an element is a ``dict``, it will be parsed as a set of - keyword-only arguments. - Defaults to an empty tuple. - task_num (int, optional): If ``tasks`` is an iterator which does not - have length, the number of tasks can be provided by ``task_num``. - Defaults to None. - nproc (int): Process (worker) number, if nuproc is 1, - use single process. Defaults to 1. - chunksize (int): Refer to :class:`multiprocessing.Pool` for details. - Defaults to 1. - description (str): The description of progress bar. - Defaults to "Process". - color (str): The color of progress bar. Defaults to "blue". - - Examples: - >>> import time - - >>> def func(x): - ... time.sleep(1) - ... return x**2 - >>> track_progress_rich(func, range(10), nproc=2) - - Returns: - list: The task results. - """ - if not callable(func): - raise TypeError('func must be a callable object') - if not isinstance(tasks, Iterable): - raise TypeError( - f'tasks must be an iterable object, but got {type(tasks)}') - if isinstance(tasks, Sized): - if len(tasks) == 0: - if task_num is None: - raise ValueError('If tasks is an empty iterable, ' - 'task_num must be set') - else: - tasks = tuple(tuple() for _ in range(task_num)) - else: - if task_num is not None and task_num != len(tasks): - raise ValueError('task_num does not match the length of tasks') - task_num = len(tasks) - - if nproc <= 0: - raise ValueError('nproc must be a positive number') - - skip_times = nproc * chunksize if nproc > 1 else 0 - prog_bar = Progress( - TextColumn('{task.description}'), - BarColumn(), - _SkipFirstTimeRemainingColumn(skip_times=skip_times), - MofNCompleteColumn(), - TaskProgressColumn(show_speed=True), - ) - - worker = _Worker(func) - task_id = prog_bar.add_task( - total=task_num, color=color, description=description) - tasks = _tasks_with_index(tasks) - - # Use single process when nproc is 1, else use multiprocess. - with prog_bar: - if nproc == 1: - results = [] - for task in tasks: - results.append(worker(task)[0]) - prog_bar.update(task_id, advance=1, refresh=True) - else: - with Pool(nproc) as pool: - results = [] - unordered_results = [] - gen = pool.imap_unordered(worker, tasks, chunksize) - try: - for result in gen: - result, idx = result - unordered_results.append((result, idx)) - results.append(None) - prog_bar.update(task_id, advance=1, refresh=True) - except Exception as e: - prog_bar.stop() - raise e - for result, idx in unordered_results: - results[idx] = result - return results diff --git a/mmengine/utils/timer.py b/mmengine/utils/timer.py deleted file mode 100644 index 087a969cfa..0000000000 --- a/mmengine/utils/timer.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from time import time - - -class TimerError(Exception): - - def __init__(self, message): - self.message = message - super().__init__(message) - - -class Timer: - """A flexible Timer class. - - Examples: - >>> import time - >>> import mmcv - >>> with mmcv.Timer(): - >>> # simulate a code block that will run for 1s - >>> time.sleep(1) - 1.000 - >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): - >>> # simulate a code block that will run for 1s - >>> time.sleep(1) - it takes 1.0 seconds - >>> timer = mmcv.Timer() - >>> time.sleep(0.5) - >>> print(timer.since_start()) - 0.500 - >>> time.sleep(0.5) - >>> print(timer.since_last_check()) - 0.500 - >>> print(timer.since_start()) - 1.000 - """ - - def __init__(self, start=True, print_tmpl=None): - self._is_running = False - self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' - if start: - self.start() - - @property - def is_running(self): - """bool: indicate whether the timer is running""" - return self._is_running - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - print(self.print_tmpl.format(self.since_last_check())) - self._is_running = False - - def start(self): - """Start the timer.""" - if not self._is_running: - self._t_start = time() - self._is_running = True - self._t_last = time() - - def since_start(self): - """Total time since the timer is started. - - Returns: - float: Time in seconds. - """ - if not self._is_running: - raise TimerError('timer is not running') - self._t_last = time() - return self._t_last - self._t_start - - def since_last_check(self): - """Time since the last checking. - - Either :func:`since_start` or :func:`since_last_check` is a checking - operation. - - Returns: - float: Time in seconds. - """ - if not self._is_running: - raise TimerError('timer is not running') - dur = time() - self._t_last - self._t_last = time() - return dur - - -_g_timers = {} # global timers - - -def check_time(timer_id): - """Add check points in a single line. - - This method is suitable for running a task on a list of items. A timer will - be registered when the method is called for the first time. - - Examples: - >>> import time - >>> import mmcv - >>> for i in range(1, 6): - >>> # simulate a code block - >>> time.sleep(i) - >>> mmcv.check_time('task1') - 2.000 - 3.000 - 4.000 - 5.000 - - Args: - str: Timer identifier. - """ - if timer_id not in _g_timers: - _g_timers[timer_id] = Timer() - return 0 - else: - return _g_timers[timer_id].since_last_check() diff --git a/mmengine/utils/version_utils.py b/mmengine/utils/version_utils.py deleted file mode 100644 index 620180547a..0000000000 --- a/mmengine/utils/version_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import subprocess -import warnings - -from packaging.version import parse - - -def digit_version(version_str: str, length: int = 4): - """Convert a version string into a tuple of integers. - - This method is usually used for comparing two versions. For pre-release - versions: alpha < beta < rc. - - Args: - version_str (str): The version string. - length (int): The maximum number of version levels. Defaults to 4. - - Returns: - tuple[int]: The version info in digits (integers). - """ - assert 'parrots' not in version_str - version = parse(version_str) - assert version.release, f'failed to parse version {version_str}' - release = list(version.release) - release = release[:length] - if len(release) < length: - release = release + [0] * (length - len(release)) - if version.is_prerelease: - mapping = {'a': -3, 'b': -2, 'rc': -1} - val = -4 - # version.pre can be None - if version.pre: - if version.pre[0] not in mapping: - warnings.warn(f'unknown prerelease version {version.pre[0]}, ' - 'version checking may go wrong') - else: - val = mapping[version.pre[0]] - release.extend([val, version.pre[-1]]) - else: - release.extend([val, 0]) - - elif version.is_postrelease: - release.extend([1, version.post]) # type: ignore - else: - release.extend([0, 0]) - return tuple(release) - - -def _minimal_ext_cmd(cmd): - # construct minimal environment - env = {} - for k in ['SYSTEMROOT', 'PATH', 'HOME']: - v = os.environ.get(k) - if v is not None: - env[k] = v - # LANGUAGE is used on win32 - env['LANGUAGE'] = 'C' - env['LANG'] = 'C' - env['LC_ALL'] = 'C' - out, err = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env=env).communicate() - return out - - -def get_git_hash(fallback='unknown', digits=None): - """Get the git hash of the current repo. - - Args: - fallback (str, optional): The fallback string when git hash is - unavailable. Defaults to 'unknown'. - digits (int, optional): kept digits of the hash. Defaults to None, - meaning all digits are kept. - - Returns: - str: Git commit hash. - """ - - if digits is not None and not isinstance(digits, int): - raise TypeError('digits must be None or an integer') - - try: - out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) - sha = out.strip().decode('ascii') - if digits is not None: - sha = sha[:digits] - except OSError: - sha = fallback - - return sha diff --git a/mmengine/version.py b/mmengine/version.py deleted file mode 100644 index dbf60e04c3..0000000000 --- a/mmengine/version.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -__version__ = '0.10.7' - - -def parse_version_info(version_str): - """Parse the version information. - - Args: - version_str (str): version string like '0.1.0'. - - Returns: - tuple: version information contains major, minor, micro version. - """ - version_info = [] - for x in version_str.split('.'): - if x.isdigit(): - version_info.append(int(x)) - elif x.find('rc') != -1: - patch_version = x.split('rc') - version_info.append(int(patch_version[0])) - version_info.append(f'rc{patch_version[1]}') - return tuple(version_info) - - -version_info = parse_version_info(__version__) diff --git a/mmengine/visualization/__init__.py b/mmengine/visualization/__init__.py deleted file mode 100644 index 8f59452c54..0000000000 --- a/mmengine/visualization/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .vis_backend import (AimVisBackend, BaseVisBackend, ClearMLVisBackend, - DVCLiveVisBackend, LocalVisBackend, MLflowVisBackend, - NeptuneVisBackend, TensorboardVisBackend, - WandbVisBackend) -from .visualizer import Visualizer - -__all__ = [ - 'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend', - 'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend', - 'NeptuneVisBackend', 'DVCLiveVisBackend', 'AimVisBackend' -] diff --git a/mmengine/visualization/utils.py b/mmengine/visualization/utils.py deleted file mode 100644 index 3e6b7d8ba9..0000000000 --- a/mmengine/visualization/utils.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union - -import cv2 -import numpy as np -import torch - -if TYPE_CHECKING: - from matplotlib.backends.backend_agg import FigureCanvasAgg - - -def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: - """If the type of value is torch.Tensor, convert the value to np.ndarray. - - Args: - value (np.ndarray, torch.Tensor): value. - - Returns: - Any: value. - """ - if isinstance(value, torch.Tensor): - value = value.detach().cpu().numpy() - return value - - -def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], - expand_dim: int) -> List[Any]: - """If the type of ``value`` is ``valid_type``, convert the value to list - and expand to ``expand_dim``. - - Args: - value (Any): value. - valid_type (Union[Type, Tuple[Type, ...]): valid type. - expand_dim (int): expand dim. - - Returns: - List[Any]: value. - """ - if isinstance(value, valid_type): - value = [value] * expand_dim - return value - - -def check_type(name: str, value: Any, - valid_type: Union[Type, Tuple[Type, ...]]) -> None: - """Check whether the type of value is in ``valid_type``. - - Args: - name (str): value name. - value (Any): value. - valid_type (Type, Tuple[Type, ...]): expected type. - """ - if not isinstance(value, valid_type): - raise TypeError(f'`{name}` should be {valid_type} ' - f' but got {type(value)}') - - -def check_length(name: str, value: Any, valid_length: int) -> None: - """If type of the ``value`` is list, check whether its length is equal with - or greater than ``valid_length``. - - Args: - name (str): value name. - value (Any): value. - valid_length (int): expected length. - """ - if isinstance(value, list): - if len(value) < valid_length: - raise AssertionError( - f'The length of {name} must equal with or ' - f'greater than {valid_length}, but got {len(value)}') - - -def check_type_and_length(name: str, value: Any, - valid_type: Union[Type, Tuple[Type, ...]], - valid_length: int) -> None: - """Check whether the type of value is in ``valid_type``. If type of the - ``value`` is list, check whether its length is equal with or greater than - ``valid_length``. - - Args: - value (Any): value. - legal_type (Type, Tuple[Type, ...]): legal type. - valid_length (int): expected length. - - Returns: - List[Any]: value. - """ - check_type(name, value, valid_type) - check_length(name, value, valid_length) - - -def color_val_matplotlib( - colors: Union[str, tuple, List[Union[str, tuple]]] -) -> Union[str, tuple, List[Union[str, tuple]]]: - """Convert various input in RGB order to normalized RGB matplotlib color - tuples, - Args: - colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs - Returns: - Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized - floats indicating RGB channels. - """ - if isinstance(colors, str): - return colors - elif isinstance(colors, tuple): - assert len(colors) == 3 - for channel in colors: - assert 0 <= channel <= 255 - colors = [channel / 255 for channel in colors] - return tuple(colors) - elif isinstance(colors, list): - colors = [ - color_val_matplotlib(color) # type:ignore - for color in colors - ] - return colors - else: - raise TypeError(f'Invalid type for color: {type(colors)}') - - -def color_str2rgb(color: str) -> tuple: - """Convert Matplotlib str color to an RGB color which range is 0 to 255, - silently dropping the alpha channel. - - Args: - color (str): Matplotlib color. - - Returns: - tuple: RGB color. - """ - import matplotlib - rgb_color: tuple = matplotlib.colors.to_rgb(color) - rgb_color = tuple(int(c * 255) for c in rgb_color) - return rgb_color - - -def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], - img: Optional[np.ndarray] = None, - alpha: float = 0.5) -> np.ndarray: - """Convert feat_map to heatmap and overlay on image, if image is not None. - - Args: - feat_map (np.ndarray, torch.Tensor): The feat_map to convert - with of shape (H, W), where H is the image height and W is - the image width. - img (np.ndarray, optional): The origin image. The format - should be RGB. Defaults to None. - alpha (float): The transparency of featmap. Defaults to 0.5. - - Returns: - np.ndarray: heatmap - """ - assert feat_map.ndim == 2 or (feat_map.ndim == 3 - and feat_map.shape[0] in [1, 3]) - if isinstance(feat_map, torch.Tensor): - feat_map = feat_map.detach().cpu().numpy() - - if feat_map.ndim == 3: - feat_map = feat_map.transpose(1, 2, 0) - - norm_img = np.zeros(feat_map.shape) - norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) - norm_img = np.asarray(norm_img, dtype=np.uint8) - heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) - heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) - if img is not None: - heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) - return heat_img - - -def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: - """Show the image and wait for the user's input. - - This implementation refers to - https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py - - Args: - timeout (float): If positive, continue after ``timeout`` seconds. - Defaults to 0. - continue_key (str): The key for users to continue. Defaults to - the space key. - - Returns: - int: If zero, means time out or the user pressed ``continue_key``, - and if one, means the user closed the show figure. - """ # noqa: E501 - import matplotlib.pyplot as plt - from matplotlib.backend_bases import CloseEvent - is_inline = 'inline' in plt.get_backend() - if is_inline: - # If use inline backend, interactive input and timeout is no use. - return 0 - - if figure.canvas.manager: # type: ignore - # Ensure that the figure is shown - figure.show() # type: ignore - - while True: - - # Connect the events to the handler function call. - event = None - - def handler(ev): - # Set external event variable - nonlocal event - # Qt backend may fire two events at the same time, - # use a condition to avoid missing close event. - event = ev if not isinstance(event, CloseEvent) else event - figure.canvas.stop_event_loop() - - cids = [ - figure.canvas.mpl_connect(name, handler) # type: ignore - for name in ('key_press_event', 'close_event') - ] - - try: - figure.canvas.start_event_loop(timeout) # type: ignore - finally: # Run even on exception like ctrl-c. - # Disconnect the callbacks. - for cid in cids: - figure.canvas.mpl_disconnect(cid) # type: ignore - - if isinstance(event, CloseEvent): - return 1 # Quit for close. - elif event is None or event.key == continue_key: - return 0 # Quit for continue. - - -def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: - """Get RGB image from ``FigureCanvasAgg``. - - Args: - canvas (FigureCanvasAgg): The canvas to get image. - - Returns: - np.ndarray: the output of image in RGB. - """ # noqa: E501 - s, (width, height) = canvas.print_to_buffer() - buffer = np.frombuffer(s, dtype='uint8') - img_rgba = buffer.reshape(height, width, 4) - rgb, alpha = np.split(img_rgba, [3], axis=2) - return rgb.astype('uint8') diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py deleted file mode 100644 index b752ec85a7..0000000000 --- a/mmengine/visualization/vis_backend.py +++ /dev/null @@ -1,1448 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import functools -import logging -import os -import os.path as osp -import platform -import warnings -from abc import ABCMeta, abstractmethod -from collections.abc import MutableMapping -from typing import Any, Callable, List, Optional, Sequence, Union - -import cv2 -import numpy as np -import torch - -from mmengine.config import Config, ConfigDict -from mmengine.fileio import dump -from mmengine.hooks.logger_hook import SUFFIX_TYPE -from mmengine.logging import MMLogger, print_log -from mmengine.registry import VISBACKENDS -from mmengine.utils import digit_version, scandir -from mmengine.utils.dl_utils import TORCH_VERSION - - -def force_init_env(old_func: Callable) -> Any: - """Those methods decorated by ``force_init_env`` will be forced to call - ``_init_env`` if the instance has not been fully initiated. This function - will decorated all the `add_xxx` method and `experiment` method, because - `VisBackend` is initialized only when used its API. - - Args: - old_func (Callable): Decorated function, make sure the first arg is an - instance with ``_init_env`` method. - - Returns: - Any: Depends on old_func. - """ - - @functools.wraps(old_func) - def wrapper(obj: object, *args, **kwargs): - # The instance must have `_init_env` method. - if not hasattr(obj, '_init_env'): - raise AttributeError(f'{type(obj)} does not have _init_env ' - 'method.') - # If instance does not have `_env_initialized` attribute or - # `_env_initialized` is False, call `_init_env` and set - # `_env_initialized` to True - if not getattr(obj, '_env_initialized', False): - print_log( - 'Attribute `_env_initialized` is not defined in ' - f'{type(obj)} or `{type(obj)}._env_initialized is ' - 'False, `_init_env` will be called and ' - f'{type(obj)}._env_initialized will be set to True', - logger='current', - level=logging.DEBUG) - obj._init_env() # type: ignore - obj._env_initialized = True # type: ignore - - return old_func(obj, *args, **kwargs) - - return wrapper - - -class BaseVisBackend(metaclass=ABCMeta): - """Base class for visualization backend. - - All backends must inherit ``BaseVisBackend`` and implement - the required functions. - - Args: - save_dir (str, optional): The root directory to save - the files produced by the backend. - """ - - def __init__(self, save_dir: str): - self._save_dir = save_dir - self._env_initialized = False - - @property - @abstractmethod - def experiment(self) -> Any: - """Return the experiment object associated with this visualization - backend. - - The experiment attribute can get the visualization backend, such as - wandb, tensorboard. If you want to write other data, such as writing a - table, you can directly get the visualization backend through - experiment. - """ - pass - - @abstractmethod - def _init_env(self) -> Any: - """Setup env for VisBackend.""" - pass - - def add_config(self, config: Config, **kwargs) -> None: - """Record the config. - - Args: - config (Config): The Config object - """ - pass - - def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], - **kwargs) -> None: - """Record the model graph. - - Args: - model (torch.nn.Module): Model to draw. - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - pass - - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - pass - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record the scalar. - - Args: - name (str): The scalar identifier. - value (int, float): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - pass - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Defaults to None. - """ - pass - - def close(self) -> None: - """Close an opened object.""" - pass - - -@VISBACKENDS.register_module() -class LocalVisBackend(BaseVisBackend): - """Local visualization backend class. - - It can write image, config, scalars, etc. - to the local hard disk. You can get the drawing backend - through the experiment property for custom drawing. - - Examples: - >>> from mmengine.visualization import LocalVisBackend - >>> import numpy as np - >>> local_vis_backend = LocalVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> local_vis_backend.add_image('img', img) - >>> local_vis_backend.add_scalar('mAP', 0.6) - >>> local_vis_backend.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> local_vis_backend.add_config(cfg) - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. If it is none, it means no data - is stored. - img_save_dir (str): The directory to save images. - Defaults to 'vis_image'. - config_save_file (str): The file name to save config. - Defaults to 'config.py'. - scalar_save_file (str): The file name to save scalar values. - Defaults to 'scalars.json'. - """ - - def __init__(self, - save_dir: str, - img_save_dir: str = 'vis_image', - config_save_file: str = 'config.py', - scalar_save_file: str = 'scalars.json'): - assert config_save_file.split('.')[-1] == 'py' - assert scalar_save_file.split('.')[-1] == 'json' - super().__init__(save_dir) - self._img_save_dir = img_save_dir - self._config_save_file = config_save_file - self._scalar_save_file = scalar_save_file - - def _init_env(self): - """Init save dir.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) - self._img_save_dir = osp.join( - self._save_dir, # type: ignore - self._img_save_dir) - self._config_save_file = osp.join( - self._save_dir, # type: ignore - self._config_save_file) - self._scalar_save_file = osp.join( - self._save_dir, # type: ignore - self._scalar_save_file) - - @property # type: ignore - @force_init_env - def experiment(self) -> 'LocalVisBackend': - """Return the experiment object associated with this visualization - backend.""" - return self - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to disk. - - Args: - config (Config): The Config object - """ - assert isinstance(config, Config) - config.dump(self._config_save_file) - - @force_init_env - def add_image(self, - name: str, - image: np.array, - step: int = 0, - **kwargs) -> None: - """Record the image to disk. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - assert image.dtype == np.uint8 - drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - os.makedirs(self._img_save_dir, exist_ok=True) - save_file_name = f'{name}_{step}.png' - cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to disk. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - if isinstance(value, torch.Tensor): - value = value.item() - self._dump({name: value, 'step': step}, self._scalar_save_file, 'json') - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalars to disk. - - The scalar dict will be written to the default and - specified files if ``file_path`` is specified. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. The value must be dumped - into json format. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the ``file_path`` file at the same time - if the ``file_path`` parameter is specified. - Defaults to None. - """ - assert isinstance(scalar_dict, dict) - scalar_dict = copy.deepcopy(scalar_dict) - scalar_dict.setdefault('step', step) - - if file_path is not None: - assert file_path.split('.')[-1] == 'json' - new_save_file_path = osp.join( - self._save_dir, # type: ignore - file_path) - assert new_save_file_path != self._scalar_save_file, \ - '``file_path`` and ``scalar_save_file`` have the ' \ - 'same name, please set ``file_path`` to another value' - self._dump(scalar_dict, new_save_file_path, 'json') - self._dump(scalar_dict, self._scalar_save_file, 'json') - - def _dump(self, value_dict: dict, file_path: str, - file_format: str) -> None: - """Dump dict to file. - - Args: - value_dict (dict) : The dict data to saved. - file_path (str): The file path to save data. - file_format (str): The file format to save data. - """ - with open(file_path, 'a+') as f: - dump(value_dict, f, file_format=file_format) - f.write('\n') - - -@VISBACKENDS.register_module() -class WandbVisBackend(BaseVisBackend): - """Wandb visualization backend class. - - Examples: - >>> from mmengine.visualization import WandbVisBackend - >>> import numpy as np - >>> wandb_vis_backend = WandbVisBackend() - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> wandb_vis_backend.add_image('img', img) - >>> wandb_vis_backend.add_scaler('mAP', 0.6) - >>> wandb_vis_backend.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> wandb_vis_backend.add_config(cfg) - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. - init_kwargs (dict, optional): wandb initialization - input parameters. - See `wandb.init `_ for - details. Defaults to None. - define_metric_cfg (dict or list[dict], optional): - When a dict is set, it is a dict of metrics and summary for - ``wandb.define_metric``. - The key is metric and the value is summary. - When a list is set, each dict should be a valid argument of - the ``define_metric``. - For example, ``define_metric_cfg={'coco/bbox_mAP': 'max'}``, - means the maximum value of ``coco/bbox_mAP`` is logged on wandb UI. - When ``define_metric_cfg=[dict(name='loss', - step_metric='epoch')]``, - the "loss" will be plotted against the epoch. - See `wandb define_metric `_ for details. - Defaults to None. - commit (bool, optional) Save the metrics dict to the wandb server - and increment the step. If false `wandb.log` just updates the - current metrics dict with the row argument and metrics won't be - saved until `wandb.log` is called with `commit=True`. - Defaults to True. - log_code_name (str, optional) The name of code artifact. - By default, the artifact will be named - source-$PROJECT_ID-$ENTRYPOINT_RELPATH. See - `wandb log_code `_ - for details. Defaults to None. - `New in version 0.3.0.` - watch_kwargs (optional, dict): Agurments for ``wandb.watch``. - `New in version 0.4.0.` - """ - - def __init__(self, - save_dir: str, - init_kwargs: Optional[dict] = None, - define_metric_cfg: Union[dict, list, None] = None, - commit: Optional[bool] = True, - log_code_name: Optional[str] = None, - watch_kwargs: Optional[dict] = None): - super().__init__(save_dir) - self._init_kwargs = init_kwargs - self._define_metric_cfg = define_metric_cfg - self._commit = commit - self._log_code_name = log_code_name - self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {} - - def _init_env(self): - """Setup env for wandb.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - if self._init_kwargs is None: - self._init_kwargs = {'dir': self._save_dir} - else: - self._init_kwargs.setdefault('dir', self._save_dir) - try: - import wandb - except ImportError: - raise ImportError( - 'Please run "pip install wandb" to install wandb') - - wandb.init(**self._init_kwargs) - if self._define_metric_cfg is not None: - if isinstance(self._define_metric_cfg, dict): - for metric, summary in self._define_metric_cfg.items(): - wandb.define_metric(metric, summary=summary) - elif isinstance(self._define_metric_cfg, list): - for metric_cfg in self._define_metric_cfg: - wandb.define_metric(**metric_cfg) - else: - raise ValueError('define_metric_cfg should be dict or list') - self._wandb = wandb - - @property # type: ignore - @force_init_env - def experiment(self): - """Return wandb object. - - The experiment attribute can get the wandb backend, If you want to - write other data, such as writing a table, you can directly get the - wandb backend through experiment. - """ - return self._wandb - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to wandb. - - Args: - config (Config): The Config object - """ - assert isinstance(self._init_kwargs, dict) - allow_val_change = self._init_kwargs.get('allow_val_change', False) - self._wandb.config.update( - config.to_dict(), allow_val_change=allow_val_change) - self._wandb.run.log_code(name=self._log_code_name) - - @force_init_env - def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], - **kwargs) -> None: - """Record the model graph. - - Args: - model (torch.nn.Module): Model to draw. - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self._wandb.watch(model, **self._watch_kwargs) - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image to wandb. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Useless parameter. Wandb does not - need this parameter. Defaults to 0. - """ - image = self._wandb.Image(image) - self._wandb.log({name: image}, commit=self._commit) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to wandb. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Useless parameter. Wandb does not - need this parameter. Defaults to 0. - """ - self._wandb.log({name: value}, commit=self._commit) - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalar's data to wandb. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Useless parameter. Wandb does not - need this parameter. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - self._wandb.log(scalar_dict, commit=self._commit) - - def close(self) -> None: - """Close an opened wandb object.""" - if hasattr(self, '_wandb'): - self._wandb.join() - - -@VISBACKENDS.register_module() -class TensorboardVisBackend(BaseVisBackend): - """Tensorboard visualization backend class. - - It can write images, config, scalars, etc. to a - tensorboard file. - - Examples: - >>> from mmengine.visualization import TensorboardVisBackend - >>> import numpy as np - >>> vis_backend = TensorboardVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> vis_backend.add_image('img', img) - >>> vis_backend.add_scaler('mAP', 0.6) - >>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis_backend.add_config(cfg) - - Args: - save_dir (str): The root directory to save the files - produced by the backend. - """ - - def __init__(self, save_dir: str): - super().__init__(save_dir) - - def _init_env(self): - """Setup env for Tensorboard.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - if TORCH_VERSION == 'parrots': - try: - from tensorboardX import SummaryWriter - except ImportError: - raise ImportError('Please install tensorboardX to use ' - 'TensorboardLoggerHook.') - else: - try: - from torch.utils.tensorboard import SummaryWriter - except ImportError: - raise ImportError( - 'Please run "pip install future tensorboard" to install ' - 'the dependencies to use torch.utils.tensorboard ' - '(applicable to PyTorch 1.1 or higher)') - self._tensorboard = SummaryWriter(self._save_dir) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return Tensorboard object.""" - return self._tensorboard - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to tensorboard. - - Args: - config (Config): The Config object - """ - self._tensorboard.add_text('config', config.pretty_text) - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image to tensorboard. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Global step value to record. Defaults to 0. - """ - self._tensorboard.add_image(name, image, step, dataformats='HWC') - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to tensorboard. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - if isinstance(value, - (int, float, torch.Tensor, np.ndarray, np.number)): - self._tensorboard.add_scalar(name, value, step) - else: - warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, ' - f'int or float are expected. skip it!') - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalar's data to tensorboard. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - assert isinstance(scalar_dict, dict) - assert 'step' not in scalar_dict, 'Please set it directly ' \ - 'through the step parameter' - for key, value in scalar_dict.items(): - self.add_scalar(key, value, step) - - def close(self): - """Close an opened tensorboard object.""" - if hasattr(self, '_tensorboard'): - self._tensorboard.close() - - -@VISBACKENDS.register_module() -class MLflowVisBackend(BaseVisBackend): - """MLflow visualization backend class. - - It can write images, config, scalars, etc. to a - mlflow file. - - Examples: - >>> from mmengine.visualization import MLflowVisBackend - >>> from mmengine import Config - >>> import numpy as np - >>> vis_backend = MLflowVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> vis_backend.add_image('img.png', img) - >>> vis_backend.add_scalar('mAP', 0.6) - >>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis_backend.add_config(cfg) - - Args: - save_dir (str): The root directory to save the files - produced by the backend. - exp_name (str, optional): The experiment name. Defaults to None. - run_name (str, optional): The run name. Defaults to None. - tags (dict, optional): The tags to be added to the experiment. - Defaults to None. - params (dict, optional): The params to be added to the experiment. - Defaults to None. - tracking_uri (str, optional): The tracking uri. Defaults to None. - artifact_suffix (Tuple[str] or str, optional): The artifact suffix. - Defaults to ('.json', '.log', '.py', 'yaml'). - tracked_config_keys (dict, optional): The top level keys of config that - will be added to the experiment. If it is None, which means all - the config will be added. Defaults to None. - `New in version 0.7.4.` - artifact_location (str, optional): The location to store run artifacts. - If None, the server picks an appropriate default. - Defaults to None. - `New in version 0.10.4.` - """ - - def __init__(self, - save_dir: str, - exp_name: Optional[str] = None, - run_name: Optional[str] = None, - tags: Optional[dict] = None, - params: Optional[dict] = None, - tracking_uri: Optional[str] = None, - artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', - 'yaml'), - tracked_config_keys: Optional[dict] = None, - artifact_location: Optional[str] = None): - super().__init__(save_dir) - self._exp_name = exp_name - self._run_name = run_name - self._tags = tags - self._params = params - self._tracking_uri = tracking_uri - self._artifact_suffix = artifact_suffix - self._tracked_config_keys = tracked_config_keys - self._artifact_location = artifact_location - - def _init_env(self): - """Setup env for MLflow.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - - try: - import mlflow - except ImportError: - raise ImportError( - 'Please run "pip install mlflow" to install mlflow' - ) # type: ignore - self._mlflow = mlflow - - # when mlflow is imported, a default logger is created. - # at this time, the default logger's stream is None - # so the stream is reopened only when the stream is None - # or the stream is closed - logger = MMLogger.get_current_instance() - for handler in logger.handlers: - if handler.stream is None or handler.stream.closed: - handler.stream = open(handler.baseFilename, 'a') - - if self._tracking_uri is not None: - logger.warning( - 'Please make sure that the mlflow server is running.') - self._mlflow.set_tracking_uri(self._tracking_uri) - else: - if os.name == 'nt': - file_url = f'file:\\{os.path.abspath(self._save_dir)}' - else: - file_url = f'file://{os.path.abspath(self._save_dir)}' - self._mlflow.set_tracking_uri(file_url) - - self._exp_name = self._exp_name or 'Default' - - if self._mlflow.get_experiment_by_name(self._exp_name) is None: - self._mlflow.create_experiment( - self._exp_name, artifact_location=self._artifact_location) - - self._mlflow.set_experiment(self._exp_name) - - if self._run_name is not None: - self._mlflow.set_tag('mlflow.runName', self._run_name) - if self._tags is not None: - self._mlflow.set_tags(self._tags) - if self._params is not None: - self._mlflow.log_params(self._params) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return MLflow object.""" - return self._mlflow - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to mlflow. - - Args: - config (Config): The Config object - """ - self.cfg = config - if self._tracked_config_keys is None: - self._mlflow.log_params(self._flatten(self.cfg.to_dict())) - else: - tracked_cfg = dict() - for k in self._tracked_config_keys: - tracked_cfg[k] = self.cfg[k] - self._mlflow.log_params(self._flatten(tracked_cfg)) - self._mlflow.log_text(self.cfg.pretty_text, 'config.py') - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image to mlflow. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Global step value to record. Default to 0. - """ - self._mlflow.log_image(image, name) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to mlflow. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._mlflow.log_metric(name, value, step) - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalar's data to mlflow. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - assert isinstance(scalar_dict, dict) - assert 'step' not in scalar_dict, 'Please set it directly ' \ - 'through the step parameter' - self._mlflow.log_metrics(scalar_dict, step) - - def close(self) -> None: - """Close the mlflow.""" - if not hasattr(self, '_mlflow'): - return - - file_paths = dict() - for filename in scandir(self.cfg.work_dir, self._artifact_suffix, - True): - file_path = osp.join(self.cfg.work_dir, filename) - relative_path = os.path.relpath(file_path, self.cfg.work_dir) - dir_path = os.path.dirname(relative_path) - file_paths[file_path] = dir_path - - for file_path, dir_path in file_paths.items(): - self._mlflow.log_artifact(file_path, dir_path) - - self._mlflow.end_run() - - def _flatten(self, d, parent_key='', sep='.') -> dict: - """Flatten the dict.""" - items = dict() - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, MutableMapping): - items.update(self._flatten(v, new_key, sep=sep)) - elif isinstance(v, list): - if any(isinstance(x, dict) for x in v): - for i, x in enumerate(v): - items.update( - self._flatten(x, new_key + sep + str(i), sep=sep)) - else: - items[new_key] = v - else: - items[new_key] = v - return items - - -@VISBACKENDS.register_module() -class ClearMLVisBackend(BaseVisBackend): - """Clearml visualization backend class. It requires `clearml`_ to be - installed. - - Examples: - >>> from mmengine.visualization import ClearMLVisBackend - >>> from mmengine import Config - >>> import numpy as np - >>> vis_backend = ClearMLVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> vis_backend.add_image('img.png', img) - >>> vis_backend.add_scalar('mAP', 0.6) - >>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis_backend.add_config(cfg) - - Args: - save_dir (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - init_kwargs (dict, optional): A dict contains the arguments of - ``clearml.Task.init`` . See `taskinit`_ for more details. - Defaults to None - artifact_suffix (Tuple[str] or str): The artifact suffix. - Defaults to ('.py', 'pth'). - - .. _clearml: - https://clear.ml/docs/latest/docs/ - - .. _taskinit: - https://clear.ml/docs/latest/docs/references/sdk/task/#taskinit - """ - - def __init__(self, - save_dir: Optional[str] = None, - init_kwargs: Optional[dict] = None, - artifact_suffix: SUFFIX_TYPE = ('.py', '.pth')): - super().__init__(save_dir) # type: ignore - self._init_kwargs = init_kwargs - self._artifact_suffix = artifact_suffix - - def _init_env(self) -> None: - try: - import clearml - except ImportError: - raise ImportError( - 'Please run "pip install clearml" to install clearml') - - task_kwargs = self._init_kwargs or {} - self._clearml = clearml - self._task = self._clearml.Task.init(**task_kwargs) - self._logger = self._task.get_logger() - - @property # type: ignore - @force_init_env - def experiment(self): - """Return clearml object.""" - return self._clearml - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to clearml. - - Args: - config (Config): The Config object - """ - self.cfg = config - self._task.connect_configuration(config.to_dict()) - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image to clearml. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Global step value to record. Defaults to 0. - """ - self._logger.report_image( - title=name, series=name, iteration=step, image=image) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to clearml. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - self._logger.report_scalar( - title=name, series=name, value=value, iteration=step) - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalar's data to clearml. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - assert 'step' not in scalar_dict, 'Please set it directly ' \ - 'through the step parameter' - for key, value in scalar_dict.items(): - self._logger.report_scalar( - title=key, series=key, value=value, iteration=step) - - def close(self) -> None: - """Close the clearml.""" - if not hasattr(self, '_clearml'): - return - - file_paths: List[str] = list() - if (hasattr(self, 'cfg') - and osp.isdir(getattr(self.cfg, 'work_dir', ''))): - for filename in scandir(self.cfg.work_dir, self._artifact_suffix, - False): - file_path = osp.join(self.cfg.work_dir, filename) - file_paths.append(file_path) - - for file_path in file_paths: - self._task.upload_artifact(os.path.basename(file_path), file_path) - self._task.close() - - -@VISBACKENDS.register_module() -class NeptuneVisBackend(BaseVisBackend): - """Neptune visualization backend class. - - Examples: - >>> from mmengine.visualization import NeptuneVisBackend - >>> from mmengine import Config - >>> import numpy as np - >>> init_kwargs = {'project': 'your_project_name'} - >>> neptune_vis_backend = NeptuneVisBackend(init_kwargs=init_kwargs) - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> neptune_vis_backend.add_image('img', img) - >>> neptune_vis_backend.add_scalar('mAP', 0.6) - >>> neptune_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> neptune_vis_backend.add_config(cfg) - - Note: - `New in version 0.9.0.` - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. NeptuneVisBackend does - not require this argument. Defaults to None. - init_kwargs (dict, optional): Neptune initialization parameters. - Defaults to None. - - - project (str): Name of a project in a form of - `namespace/project_name`. If `project` is not specified, - the value of `NEPTUNE_PROJECT` environment variable - will be taken. - - api_token (str): User's API token. If api_token is not api_token, - the value of `NEPTUNE_API_TOKEN` environment variable will - be taken. Note: It is strongly recommended to use - `NEPTUNE_API_TOKEN` environment variable rather than - placing your API token here. - - If 'project' and 'api_token are not specified in `init_kwargs`, - the 'mode' will be set to 'offline'. - See `neptune.init_run - `_ for - details. - """ - - def __init__(self, - save_dir: Optional[str] = None, - init_kwargs: Optional[dict] = None): - super().__init__(save_dir) # type:ignore - self._init_kwargs = init_kwargs - - def _init_env(self): - """Setup env for neptune.""" - try: - import neptune - except ImportError: - raise ImportError( - 'Please run "pip install -U neptune" to install neptune') - if self._init_kwargs is None: - self._init_kwargs = {'mode': 'offline'} - - self._neptune = neptune.init_run(**self._init_kwargs) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return Neptune object.""" - return self._neptune - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to neptune. - - Args: - config (Config): The Config object - """ - from neptune.types import File - self._neptune['config'].upload(File.from_content(config.pretty_text)) - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - from neptune.types import File - - # values in the array need to be in the [0, 1] range - img = image.astype(np.float32) / 255.0 - self._neptune['images'].append( - File.as_image(img), name=name, step=step) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record the scalar. - - Args: - name (str): The scalar identifier. - value (int, float): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - self._neptune[name].append(value, step=step) - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Defaults to None. - """ - assert isinstance(scalar_dict, dict) - assert 'step' not in scalar_dict, 'Please set it directly ' \ - 'through the step parameter' - - for k, v in scalar_dict.items(): - self._neptune[k].append(v, step=step) - - def close(self) -> None: - """Close an opened object.""" - if hasattr(self, '_neptune'): - self._neptune.stop() - - -@VISBACKENDS.register_module() -class DVCLiveVisBackend(BaseVisBackend): - """DVCLive visualization backend class. - - Examples: - >>> from mmengine.visualization import DVCLiveVisBackend - >>> import numpy as np - >>> dvclive_vis_backend = DVCLiveVisBackend(save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> dvclive_vis_backend.add_image('img', img) - >>> dvclive_vis_backend.add_scalar('mAP', 0.6) - >>> dvclive_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> dvclive_vis_backend.add_config(cfg) - - Note: - `New in version 0.9.0.` - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. - artifact_suffix (Tuple[str] or str, optional): The artifact suffix. - Defaults to ('.json', '.py', 'yaml'). - init_kwargs (dict, optional): DVCLive initialization parameters. - See `DVCLive `_ for details. - Defaults to None. - """ - - def __init__(self, - save_dir: str, - artifact_suffix: SUFFIX_TYPE = ('.json', '.py', 'yaml'), - init_kwargs: Optional[dict] = None): - super().__init__(save_dir) - self._artifact_suffix = artifact_suffix - self._init_kwargs = init_kwargs - - def _init_env(self): - """Setup env for dvclive.""" - if digit_version(platform.python_version()) < digit_version('3.8'): - raise RuntimeError('Please use Python 3.8 or higher version ' - 'to use DVCLiveVisBackend.') - - try: - import pygit2 - from dvclive import Live - except ImportError: - raise ImportError( - 'Please run "pip install dvclive" to install dvclive') - # if no git info, init dvc without git to avoid SCMError - try: - path = pygit2.discover_repository(os.fspath(os.curdir), True, '') - pygit2.Repository(path).default_signature - except KeyError: - os.system('dvc init -f --no-scm') - - if self._init_kwargs is None: - self._init_kwargs = {} - self._init_kwargs.setdefault('dir', self._save_dir) - self._init_kwargs.setdefault('save_dvc_exp', True) - self._init_kwargs.setdefault('cache_images', True) - - self._dvclive = Live(**self._init_kwargs) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return dvclive object. - - The experiment attribute can get the dvclive backend, If you want to - write other data, such as writing a table, you can directly get the - dvclive backend through experiment. - """ - return self._dvclive - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to dvclive. - - Args: - config (Config): The Config object - """ - assert isinstance(config, Config) - self.cfg = config - self._dvclive.log_params(self._to_dvc_paramlike(self.cfg.to_dict())) - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image to dvclive. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Useless parameter. Dvclive does not - need this parameter. Defaults to 0. - """ - assert image.dtype == np.uint8 - save_file_name = f'{name}.png' - - self._dvclive.log_image(save_file_name, image) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to dvclive. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - if isinstance(value, torch.Tensor): - value = value.numpy() - self._dvclive.step = step - self._dvclive.log_metric(name, value) - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalar's data to dvclive. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - for key, value in scalar_dict.items(): - self.add_scalar(key, value, step, **kwargs) - - def close(self) -> None: - """Close an opened dvclive object.""" - if not hasattr(self, '_dvclive'): - return - - file_paths = dict() - for filename in scandir(self._save_dir, self._artifact_suffix, True): - file_path = osp.join(self._save_dir, filename) - relative_path = os.path.relpath(file_path, self._save_dir) - dir_path = os.path.dirname(relative_path) - file_paths[file_path] = dir_path - - for file_path, dir_path in file_paths.items(): - self._dvclive.log_artifact(file_path, dir_path) - - self._dvclive.end() - - def _to_dvc_paramlike(self, - value: Union[int, float, dict, list, tuple, Config, - ConfigDict, torch.Tensor, np.ndarray]): - """Convert the input value to a DVC `ParamLike` recursively. - - Or the `log_params` method of dvclive will raise an error. - """ - - if isinstance(value, (dict, Config, ConfigDict)): - return {k: self._to_dvc_paramlike(v) for k, v in value.items()} - elif isinstance(value, (tuple, list)): - return [self._to_dvc_paramlike(item) for item in value] - elif isinstance(value, (torch.Tensor, np.ndarray)): - return value.tolist() - elif isinstance(value, np.generic): - return value.item() - else: - return value - - -@VISBACKENDS.register_module() -class AimVisBackend(BaseVisBackend): - """Aim visualization backend class. - - Examples: - >>> from mmengine.visualization import AimVisBackend - >>> import numpy as np - >>> aim_vis_backend = AimVisBackend() - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> aim_vis_backend.add_image('img', img) - >>> aim_vis_backend.add_scalar('mAP', 0.6) - >>> aim_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> aim_vis_backend.add_config(cfg) - - Note: - 1. `New in version 0.9.0.` - 2. Refer to - `Github issue `_ , - Aim is not unable to be install on Windows for now. - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. - init_kwargs (dict, optional): Aim initialization parameters. See - `Aim `_ - for details. Defaults to None. - """ - - def __init__(self, - save_dir: Optional[str] = None, - init_kwargs: Optional[dict] = None): - super().__init__(save_dir) # type:ignore - self._init_kwargs = init_kwargs - - def _init_env(self): - """Setup env for Aim.""" - try: - from aim import Run - except ImportError: - raise ImportError('Please run "pip install aim" to install aim') - - from datetime import datetime - - if self._save_dir is not None: - path_list = os.path.normpath(self._save_dir).split(os.sep) - exp_name = f'{path_list[-2]}_{path_list[-1]}' - else: - exp_name = datetime.now().strftime('%Y%m%d_%H%M%S') - - if self._init_kwargs is None: - self._init_kwargs = {} - self._init_kwargs.setdefault('experiment', exp_name) - self._aim_run = Run(**self._init_kwargs) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return Aim object.""" - return self._aim_run - - @force_init_env - def add_config(self, config, **kwargs) -> None: - """Record the config to Aim. - - Args: - config (Config): The Config object - """ - if isinstance(config, Config): - config = config.to_dict() - self._aim_run['hparams'] = config - - @force_init_env - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - from aim import Image - self._aim_run.track(name=name, value=Image(image), step=step) - - @force_init_env - def add_scalar(self, - name: str, - value: Union[int, float, torch.Tensor, np.ndarray], - step: int = 0, - **kwargs) -> None: - """Record the scalar data to Aim. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._aim_run.track(name=name, value=value, step=step) - - @force_init_env - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalar's data to wandb. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - for key, value in scalar_dict.items(): - self._aim_run.track(name=key, value=value, step=step) - - def close(self) -> None: - """Close the Aim.""" - if not hasattr(self, '_aim_run'): - return - - self._aim_run.close() diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py deleted file mode 100644 index 6979395aca..0000000000 --- a/mmengine/visualization/visualizer.py +++ /dev/null @@ -1,1186 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import os.path as osp -import warnings -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union - -if TYPE_CHECKING: - from matplotlib.font_manager import FontProperties - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F - -from mmengine.config import Config -from mmengine.dist import master_only -from mmengine.registry import VISBACKENDS, VISUALIZERS -from mmengine.structures import BaseDataElement -from mmengine.utils import ManagerMixin, is_seq_of -from mmengine.visualization.utils import (check_type, check_type_and_length, - color_str2rgb, color_val_matplotlib, - convert_overlay_heatmap, - img_from_canvas, tensor2ndarray, - value2list, wait_continue) -from mmengine.visualization.vis_backend import BaseVisBackend - -VisBackendsType = Union[List[Union[List, BaseDataElement]], BaseDataElement, - dict, None] - - -@VISUALIZERS.register_module() -class Visualizer(ManagerMixin): - """MMEngine provides a Visualizer class that uses the ``Matplotlib`` - library as the backend. It has the following functions: - - - Basic drawing methods - - - draw_bboxes: draw single or multiple bounding boxes - - draw_texts: draw single or multiple text boxes - - draw_points: draw single or multiple points - - draw_lines: draw single or multiple line segments - - draw_circles: draw single or multiple circles - - draw_polygons: draw single or multiple polygons - - draw_binary_masks: draw single or multiple binary masks - - draw_featmap: draw feature map - - - Basic visualizer backend methods - - - add_configs: write config to all vis storage backends - - add_graph: write model graph to all vis storage backends - - add_image: write image to all vis storage backends - - add_scalar: write scalar to all vis storage backends - - add_scalars: write scalars to all vis storage backends - - add_datasample: write datasample to all vis storage \ - backends. The abstract drawing interface used by the user - - - Basic info methods - - - set_image: sets the original image data - - get_image: get the image data in Numpy format after drawing - - show: visualization - - close: close all resources that have been opened - - get_backend: get the specified vis backend - - - All the basic drawing methods support chain calls, which is convenient for - overlaydrawing and display. Each downstream algorithm library can inherit - ``Visualizer`` and implement the add_datasample logic. For example, - ``DetLocalVisualizer`` in MMDetection inherits from ``Visualizer`` - and implements functions, such as visual detection boxes, instance masks, - and semantic segmentation maps in the add_datasample interface. - - Args: - name (str): Name of the instance. Defaults to 'visualizer'. - image (np.ndarray, optional): the origin image to draw. The format - should be RGB. Defaults to None. - vis_backends (list, optional): Visual backend config list. - Defaults to None. - save_dir (str, optional): Save file dir for all storage backends. - If it is None, the backend storage will not save any data. - fig_save_cfg (dict): Keyword parameters of figure for saving. - Defaults to empty dict. - fig_show_cfg (dict): Keyword parameters of figure for showing. - Defaults to empty dict. - - Examples: - >>> # Basic info methods - >>> vis = Visualizer() - >>> vis.set_image(image) - >>> vis.get_image() - >>> vis.show() - - >>> # Basic drawing methods - >>> vis = Visualizer(image=image) - >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g') - >>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]), - >>> edge_colors=['g', 'r']) - >>> vis.draw_lines(x_datas=np.array([1, 3]), - >>> y_datas=np.array([1, 3]), - >>> colors='r', line_widths=1) - >>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]), - >>> y_datas=np.array([[1, 3], [2, 4]]), - >>> colors=['r', 'r'], line_widths=[1, 2]) - >>> vis.draw_texts(text='MMEngine', - >>> position=np.array([2, 2]), - >>> colors='b') - >>> vis.draw_texts(text=['MMEngine','OpenMMLab'], - >>> position=np.array([[2, 2], [5, 5]]), - >>> colors=['b', 'b']) - >>> vis.draw_circles(circle_coord=np.array([2, 2]), radius=np.array[1]) - >>> vis.draw_circles(circle_coord=np.array([[2, 2], [3, 5]), - >>> radius=np.array[1, 2], colors=['g', 'r']) - >>> square = np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) - >>> vis.draw_polygons(polygons=square, edge_colors='g') - >>> squares = [np.array([[0, 0], [100, 0], [100, 100], [0, 100]]), - >>> np.array([[0, 0], [50, 0], [50, 50], [0, 50]])] - >>> vis.draw_polygons(polygons=squares, edge_colors=['g', 'r']) - >>> vis.draw_binary_masks(binary_mask, alpha=0.6) - >>> heatmap = vis.draw_featmap(featmap, img, - >>> channel_reduction='select_max') - >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, - >>> topk=8, arrangement=(4, 2)) - >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, - >>> topk=-1) - - >>> # chain calls - >>> vis.draw_bboxes().draw_texts().draw_circle().draw_binary_masks() - - >>> # Backend related methods - >>> vis = Visualizer(vis_backends=[dict(type='LocalVisBackend')], - >>> save_dir='temp_dir') - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis.add_config(cfg) - >>> image=np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - >>> vis.add_image('image',image) - >>> vis.add_scaler('mAP', 0.6) - >>> vis.add_scalars({'loss': 0.1,'acc':0.8}) - - >>> # inherit - >>> class DetLocalVisualizer(Visualizer): - >>> def add_datasample(self, - >>> name, - >>> image: np.ndarray, - >>> gt_sample: - >>> Optional['BaseDataElement'] = None, - >>> pred_sample: - >>> Optional['BaseDataElement'] = None, - >>> draw_gt: bool = True, - >>> draw_pred: bool = True, - >>> show: bool = False, - >>> wait_time: int = 0, - >>> step: int = 0) -> None: - >>> pass - """ - - def __init__( - self, - name='visualizer', - image: Optional[np.ndarray] = None, - vis_backends: VisBackendsType = None, - save_dir: Optional[str] = None, - fig_save_cfg=dict(frameon=False), - fig_show_cfg=dict(frameon=False) - ) -> None: - super().__init__(name) - self._dataset_meta: Optional[dict] = None - self._vis_backends: Dict[str, BaseVisBackend] = {} - - if vis_backends is None: - vis_backends = [] - - if isinstance(vis_backends, (dict, BaseVisBackend)): - vis_backends = [vis_backends] # type: ignore - - if not is_seq_of(vis_backends, (dict, BaseVisBackend)): - raise TypeError('vis_backends must be a list of dicts or a list ' - 'of BaseBackend instances') - if save_dir is not None: - save_dir = osp.join(save_dir, 'vis_data') - - for vis_backend in vis_backends: # type: ignore - name = None - if isinstance(vis_backend, dict): - name = vis_backend.pop('name', None) - vis_backend.setdefault('save_dir', save_dir) - vis_backend = VISBACKENDS.build(vis_backend) - - # If vis_backend requires `save_dir` (with no default value) - # but is initialized with None, then don't add this - # vis_backend to the visualizer. - save_dir_arg = inspect.signature( - vis_backend.__class__.__init__).parameters.get('save_dir') - if (save_dir_arg is not None - and save_dir_arg.default is save_dir_arg.empty - and getattr(vis_backend, '_save_dir') is None): - warnings.warn(f'Failed to add {vis_backend.__class__}, ' - 'please provide the `save_dir` argument.') - continue - - type_name = vis_backend.__class__.__name__ - name = name or type_name - - if name in self._vis_backends: - raise RuntimeError(f'vis_backend name {name} already exists') - self._vis_backends[name] = vis_backend # type: ignore - - self.fig_save = None - self.fig_save_cfg = fig_save_cfg - self.fig_show_cfg = fig_show_cfg - - (self.fig_save_canvas, self.fig_save, - self.ax_save) = self._initialize_fig(fig_save_cfg) - self.dpi = self.fig_save.get_dpi() - - if image is not None: - self.set_image(image) - - @property # type: ignore - @master_only - def dataset_meta(self) -> Optional[dict]: - """Optional[dict]: Meta info of the dataset.""" - return self._dataset_meta - - @dataset_meta.setter # type: ignore - @master_only - def dataset_meta(self, dataset_meta: dict) -> None: - """Set the dataset meta info to the Visualizer.""" - self._dataset_meta = dataset_meta - - @master_only - def show(self, - drawn_img: Optional[np.ndarray] = None, - win_name: str = 'image', - wait_time: float = 0., - continue_key: str = ' ', - backend: str = 'matplotlib') -> None: - """Show the drawn image. - - Args: - drawn_img (np.ndarray, optional): The image to show. If drawn_img - is None, it will show the image got by Visualizer. Defaults - to None. - win_name (str): The image title. Defaults to 'image'. - wait_time (float): Delay in seconds. 0 is the special - value that means "forever". Defaults to 0. - continue_key (str): The key for users to continue. Defaults to - the space key. - backend (str): The backend to show the image. Defaults to - 'matplotlib'. `New in version 0.7.3.` - """ - if backend == 'matplotlib': - import matplotlib.pyplot as plt - is_inline = 'inline' in plt.get_backend() - img = self.get_image() if drawn_img is None else drawn_img - self._init_manager(win_name) - fig = self.manager.canvas.figure - # remove white edges by set subplot margin - fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - fig.clear() - ax = fig.add_subplot() - ax.axis(False) - ax.imshow(img) - self.manager.canvas.draw() - - # Find a better way for inline to show the image - if is_inline: - return fig - wait_continue(fig, timeout=wait_time, continue_key=continue_key) - elif backend == 'cv2': - # Keep images are shown in the same window, and the title of window - # will be updated with `win_name`. - cv2.namedWindow(winname=f'{id(self)}') - cv2.setWindowTitle(f'{id(self)}', win_name) - cv2.imshow( - str(id(self)), - self.get_image() if drawn_img is None else drawn_img) - cv2.waitKey(int(np.ceil(wait_time * 1000))) - else: - raise ValueError('backend should be "matplotlib" or "cv2", ' - f'but got {backend} instead') - - @master_only - def set_image(self, image: np.ndarray) -> None: - """Set the image to draw. - - Args: - image (np.ndarray): The image to draw. - """ - assert image is not None - image = image.astype('uint8') - self._image = image - self.width, self.height = image.shape[1], image.shape[0] - self._default_font_size = max( - np.sqrt(self.height * self.width) // 90, 10) - - # add a small 1e-2 to avoid precision lost due to matplotlib's - # truncation (https://github.com/matplotlib/matplotlib/issues/15363) - self.fig_save.set_size_inches( # type: ignore - (self.width + 1e-2) / self.dpi, (self.height + 1e-2) / self.dpi) - # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) - self.ax_save.cla() - self.ax_save.axis(False) - self.ax_save.imshow( - image, - extent=(0, self.width, self.height, 0), - interpolation='none') - - @master_only - def get_image(self) -> np.ndarray: - """Get the drawn image. The format is RGB. - - Returns: - np.ndarray: the drawn image which channel is RGB. - """ - assert self._image is not None, 'Please set image using `set_image`' - return img_from_canvas(self.fig_save_canvas) # type: ignore - - def _initialize_fig(self, fig_cfg) -> tuple: - """Build figure according to fig_cfg. - - Args: - fig_cfg (dict): The config to build figure. - - Returns: - tuple: build canvas figure and axes. - """ - from matplotlib.backends.backend_agg import FigureCanvasAgg - from matplotlib.figure import Figure - fig = Figure(**fig_cfg) - ax = fig.add_subplot() - ax.axis(False) - - # remove white edges by set subplot margin - fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - canvas = FigureCanvasAgg(fig) - return canvas, fig, ax - - def _init_manager(self, win_name: str) -> None: - """Initialize the matplot manager. - - Args: - win_name (str): The window name. - """ - from matplotlib.figure import Figure - from matplotlib.pyplot import new_figure_manager - if getattr(self, 'manager', None) is None: - self.manager = new_figure_manager( - num=1, FigureClass=Figure, **self.fig_show_cfg) - - try: - self.manager.set_window_title(win_name) - except Exception: - self.manager = new_figure_manager( - num=1, FigureClass=Figure, **self.fig_show_cfg) - self.manager.set_window_title(win_name) - - @master_only - def get_backend(self, name) -> 'BaseVisBackend': - """Get vis backend by name. - - Args: - name (str): The name of vis backend - - Returns: - BaseVisBackend: The vis backend. - """ - return self._vis_backends.get(name) # type: ignore - - def _is_posion_valid(self, position: np.ndarray) -> bool: - """Judge whether the position is in image. - - Args: - position (np.ndarray): The position to judge which last dim must - be two and the format is [x, y]. - - Returns: - bool: Whether the position is in image. - """ - flag = (position[..., 0] < self.width).all() and \ - (position[..., 0] >= 0).all() and \ - (position[..., 1] < self.height).all() and \ - (position[..., 1] >= 0).all() - return flag - - @master_only - def draw_points(self, - positions: Union[np.ndarray, torch.Tensor], - colors: Union[str, tuple, List[str], List[tuple]] = 'g', - marker: Optional[str] = None, - sizes: Optional[Union[np.ndarray, torch.Tensor]] = None): - """Draw single or multiple points. - - Args: - positions (Union[np.ndarray, torch.Tensor]): Positions to draw. - colors (Union[str, tuple, List[str], List[tuple]]): The colors - of points. ``colors`` can have the same length with points or - just single value. If ``colors`` is single value, all the - points will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g. - marker (str, optional): The marker style. - See :mod:`matplotlib.markers` for more information about - marker styles. Defaults to None. - sizes (Optional[Union[np.ndarray, torch.Tensor]]): The marker size. - Defaults to None. - """ - check_type('positions', positions, (np.ndarray, torch.Tensor)) - positions = tensor2ndarray(positions) - - if len(positions.shape) == 1: - positions = positions[None] - assert positions.shape[-1] == 2, ( - 'The shape of `positions` should be (N, 2), ' - f'but got {positions.shape}') - colors = color_val_matplotlib(colors) # type: ignore - self.ax_save.scatter( - positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) - return self - - @master_only - def draw_texts( - self, - texts: Union[str, List[str]], - positions: Union[np.ndarray, torch.Tensor], - font_sizes: Optional[Union[int, List[int]]] = None, - colors: Union[str, tuple, List[str], List[tuple]] = 'g', - vertical_alignments: Union[str, List[str]] = 'top', - horizontal_alignments: Union[str, List[str]] = 'left', - font_families: Union[str, List[str]] = 'sans-serif', - bboxes: Optional[Union[dict, List[dict]]] = None, - font_properties: Optional[Union['FontProperties', - List['FontProperties']]] = None - ) -> 'Visualizer': - """Draw single or multiple text boxes. - - Args: - texts (Union[str, List[str]]): Texts to draw. - positions (Union[np.ndarray, torch.Tensor]): The position to draw - the texts, which should have the same length with texts and - each dim contain x and y. - font_sizes (Union[int, List[int]], optional): The font size of - texts. ``font_sizes`` can have the same length with texts or - just single value. If ``font_sizes`` is single value, all the - texts will have the same font size. Defaults to None. - colors (Union[str, tuple, List[str], List[tuple]]): The colors - of texts. ``colors`` can have the same length with texts or - just single value. If ``colors`` is single value, all the - texts will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g. - vertical_alignments (Union[str, List[str]]): The verticalalignment - of texts. verticalalignment controls whether the y positional - argument for the text indicates the bottom, center or top side - of the text bounding box. - ``vertical_alignments`` can have the same length with - texts or just single value. If ``vertical_alignments`` is - single value, all the texts will have the same - verticalalignment. verticalalignment can be 'center' or - 'top', 'bottom' or 'baseline'. Defaults to 'top'. - horizontal_alignments (Union[str, List[str]]): The - horizontalalignment of texts. Horizontalalignment controls - whether the x positional argument for the text indicates the - left, center or right side of the text bounding box. - ``horizontal_alignments`` can have - the same length with texts or just single value. - If ``horizontal_alignments`` is single value, all the texts - will have the same horizontalalignment. Horizontalalignment - can be 'center','right' or 'left'. Defaults to 'left'. - font_families (Union[str, List[str]]): The font family of - texts. ``font_families`` can have the same length with texts or - just single value. If ``font_families`` is single value, all - the texts will have the same font family. - font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy' - or 'monospace'. Defaults to 'sans-serif'. - bboxes (Union[dict, List[dict]], optional): The bounding box of the - texts. If bboxes is None, there are no bounding box around - texts. ``bboxes`` can have the same length with texts or - just single value. If ``bboxes`` is single value, all - the texts will have the same bbox. Reference to - https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch - for more details. Defaults to None. - font_properties (Union[FontProperties, List[FontProperties]], optional): - The font properties of texts. FontProperties is - a ``font_manager.FontProperties()`` object. - If you want to draw Chinese texts, you need to prepare - a font file that can show Chinese characters properly. - For example: `simhei.ttf`, `simsun.ttc`, `simkai.ttf` and so on. - Then set ``font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file')`` - ``font_properties`` can have the same length with texts or - just single value. If ``font_properties`` is single value, - all the texts will have the same font properties. - Defaults to None. - `New in version 0.6.0.` - """ # noqa: E501 - from matplotlib.font_manager import FontProperties - check_type('texts', texts, (str, list)) - if isinstance(texts, str): - texts = [texts] - num_text = len(texts) - check_type('positions', positions, (np.ndarray, torch.Tensor)) - positions = tensor2ndarray(positions) - if len(positions.shape) == 1: - positions = positions[None] - assert positions.shape == (num_text, 2), ( - '`positions` should have the shape of ' - f'({num_text}, 2), but got {positions.shape}') - if not self._is_posion_valid(positions): - warnings.warn( - 'Warning: The text is out of bounds,' - ' the drawn text may not be in the image', UserWarning) - positions = positions.tolist() - - if font_sizes is None: - font_sizes = self._default_font_size - check_type_and_length('font_sizes', font_sizes, (int, float, list), - num_text) - font_sizes = value2list(font_sizes, (int, float), num_text) - - check_type_and_length('colors', colors, (str, tuple, list), num_text) - colors = value2list(colors, (str, tuple), num_text) - colors = color_val_matplotlib(colors) # type: ignore - - check_type_and_length('vertical_alignments', vertical_alignments, - (str, list), num_text) - vertical_alignments = value2list(vertical_alignments, str, num_text) - - check_type_and_length('horizontal_alignments', horizontal_alignments, - (str, list), num_text) - horizontal_alignments = value2list(horizontal_alignments, str, - num_text) - - check_type_and_length('font_families', font_families, (str, list), - num_text) - font_families = value2list(font_families, str, num_text) - - if font_properties is None: - font_properties = [None for _ in range(num_text)] # type: ignore - else: - check_type_and_length('font_properties', font_properties, - (FontProperties, list), num_text) - font_properties = value2list(font_properties, FontProperties, - num_text) - - if bboxes is None: - bboxes = [None for _ in range(num_text)] # type: ignore - else: - check_type_and_length('bboxes', bboxes, (dict, list), num_text) - bboxes = value2list(bboxes, dict, num_text) - - for i in range(num_text): - self.ax_save.text( - positions[i][0], - positions[i][1], - texts[i], - size=font_sizes[i], # type: ignore - bbox=bboxes[i], # type: ignore - verticalalignment=vertical_alignments[i], - horizontalalignment=horizontal_alignments[i], - family=font_families[i], - fontproperties=font_properties[i], - color=colors[i]) - return self - - @master_only - def draw_lines( - self, - x_datas: Union[np.ndarray, torch.Tensor], - y_datas: Union[np.ndarray, torch.Tensor], - colors: Union[str, tuple, List[str], List[tuple]] = 'g', - line_styles: Union[str, List[str]] = '-', - line_widths: Union[Union[int, float], List[Union[int, float]]] = 2 - ) -> 'Visualizer': - """Draw single or multiple line segments. - - Args: - x_datas (Union[np.ndarray, torch.Tensor]): The x coordinate of - each line' start and end points. - y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of - each line' start and end points. - colors (Union[str, tuple, List[str], List[tuple]]): The colors of - lines. ``colors`` can have the same length with lines or just - single value. If ``colors`` is single value, all the lines - will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g'. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - """ - from matplotlib.collections import LineCollection - check_type('x_datas', x_datas, (np.ndarray, torch.Tensor)) - x_datas = tensor2ndarray(x_datas) - check_type('y_datas', y_datas, (np.ndarray, torch.Tensor)) - y_datas = tensor2ndarray(y_datas) - assert x_datas.shape == y_datas.shape, ( - '`x_datas` and `y_datas` should have the same shape') - assert x_datas.shape[-1] == 2, ( - f'The shape of `x_datas` should be (N, 2), but got {x_datas.shape}' - ) - if len(x_datas.shape) == 1: - x_datas = x_datas[None] - y_datas = y_datas[None] - colors = color_val_matplotlib(colors) # type: ignore - lines = np.concatenate( - (x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1) - if not self._is_posion_valid(lines): - warnings.warn( - 'Warning: The line is out of bounds,' - ' the drawn line may not be in the image', UserWarning) - line_collect = LineCollection( - lines.tolist(), - colors=colors, - linestyles=line_styles, - linewidths=line_widths) - self.ax_save.add_collection(line_collect) - return self - - @master_only - def draw_circles( - self, - center: Union[np.ndarray, torch.Tensor], - radius: Union[np.ndarray, torch.Tensor], - edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', - line_styles: Union[str, List[str]] = '-', - line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, - face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', - alpha: Union[float, int] = 0.8, - ) -> 'Visualizer': - """Draw single or multiple circles. - - Args: - center (Union[np.ndarray, torch.Tensor]): The x coordinate of - each line' start and end points. - radius (Union[np.ndarray, torch.Tensor]): The y coordinate of - each line' start and end points. - edge_colors (Union[str, tuple, List[str], List[tuple]]): The - colors of circles. ``colors`` can have the same length with - lines or just single value. If ``colors`` is single value, - all the lines will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - face_colors (Union[str, tuple, List[str], List[tuple]]): - The face colors. Defaults to None. - alpha (Union[int, float]): The transparency of circles. - Defaults to 0.8. - """ - from matplotlib.collections import PatchCollection - from matplotlib.patches import Circle - check_type('center', center, (np.ndarray, torch.Tensor)) - center = tensor2ndarray(center) - check_type('radius', radius, (np.ndarray, torch.Tensor)) - radius = tensor2ndarray(radius) - if len(center.shape) == 1: - center = center[None] - assert center.shape == (radius.shape[0], 2), ( - 'The shape of `center` should be (radius.shape, 2), ' - f'but got {center.shape}') - if not (self._is_posion_valid(center - - np.tile(radius.reshape((-1, 1)), (1, 2))) - and self._is_posion_valid( - center + np.tile(radius.reshape((-1, 1)), (1, 2)))): - warnings.warn( - 'Warning: The circle is out of bounds,' - ' the drawn circle may not be in the image', UserWarning) - - center = center.tolist() - radius = radius.tolist() - edge_colors = color_val_matplotlib(edge_colors) # type: ignore - face_colors = color_val_matplotlib(face_colors) # type: ignore - circles = [] - for i in range(len(center)): - circles.append(Circle(tuple(center[i]), radius[i])) - - if isinstance(line_widths, (int, float)): - line_widths = [line_widths] * len(circles) - line_widths = [ - min(max(linewidth, 1), self._default_font_size / 4) - for linewidth in line_widths - ] - p = PatchCollection( - circles, - alpha=alpha, - facecolors=face_colors, - edgecolors=edge_colors, - linewidths=line_widths, - linestyles=line_styles) - self.ax_save.add_collection(p) - return self - - @master_only - def draw_bboxes( - self, - bboxes: Union[np.ndarray, torch.Tensor], - edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', - line_styles: Union[str, List[str]] = '-', - line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, - face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', - alpha: Union[int, float] = 0.8, - ) -> 'Visualizer': - """Draw single or multiple bboxes. - - Args: - bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with - the format of(x1,y1,x2,y2). - edge_colors (Union[str, tuple, List[str], List[tuple]]): The - colors of bboxes. ``colors`` can have the same length with - lines or just single value. If ``colors`` is single value, all - the lines will have the same colors. Refer to `matplotlib. - colors` for full list of formats that are accepted. - Defaults to 'g'. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - face_colors (Union[str, tuple, List[str], List[tuple]]): - The face colors. Defaults to None. - alpha (Union[int, float]): The transparency of bboxes. - Defaults to 0.8. - """ - check_type('bboxes', bboxes, (np.ndarray, torch.Tensor)) - bboxes = tensor2ndarray(bboxes) - - if len(bboxes.shape) == 1: - bboxes = bboxes[None] - assert bboxes.shape[-1] == 4, ( - f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') - - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= - bboxes[:, 3]).all() - if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): - warnings.warn( - 'Warning: The bbox is out of bounds,' - ' the drawn bbox may not be in the image', UserWarning) - poly = np.stack( - (bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 1], - bboxes[:, 2], bboxes[:, 3], bboxes[:, 0], bboxes[:, 3]), - axis=-1).reshape(-1, 4, 2) - poly = [p for p in poly] - return self.draw_polygons( - poly, - alpha=alpha, - edge_colors=edge_colors, - line_styles=line_styles, - line_widths=line_widths, - face_colors=face_colors) - - @master_only - def draw_polygons( - self, - polygons: Union[Union[np.ndarray, torch.Tensor], - List[Union[np.ndarray, torch.Tensor]]], - edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', - line_styles: Union[str, List[str]] = '-', - line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, - face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', - alpha: Union[int, float] = 0.8, - ) -> 'Visualizer': - """Draw single or multiple bboxes. - - Args: - polygons (Union[Union[np.ndarray, torch.Tensor],\ - List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw - with the format of (x1,y1,x2,y2,...,xn,yn). - edge_colors (Union[str, tuple, List[str], List[tuple]]): The - colors of polygons. ``colors`` can have the same length with - lines or just single value. If ``colors`` is single value, - all the lines will have the same colors. Refer to - `matplotlib.colors` for full list of formats that are accepted. - Defaults to 'g. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - face_colors (Union[str, tuple, List[str], List[tuple]]): - The face colors. Defaults to None. - alpha (Union[int, float]): The transparency of polygons. - Defaults to 0.8. - """ - from matplotlib.collections import PolyCollection - check_type('polygons', polygons, (list, np.ndarray, torch.Tensor)) - edge_colors = color_val_matplotlib(edge_colors) # type: ignore - face_colors = color_val_matplotlib(face_colors) # type: ignore - - if isinstance(polygons, (np.ndarray, torch.Tensor)): - polygons = [polygons] - if isinstance(polygons, list): - for polygon in polygons: - assert polygon.shape[1] == 2, ( - 'The shape of each polygon in `polygons` should be (M, 2),' - f' but got {polygon.shape}') - polygons = [tensor2ndarray(polygon) for polygon in polygons] - for polygon in polygons: - if not self._is_posion_valid(polygon): - warnings.warn( - 'Warning: The polygon is out of bounds,' - ' the drawn polygon may not be in the image', UserWarning) - if isinstance(line_widths, (int, float)): - line_widths = [line_widths] * len(polygons) - line_widths = [ - min(max(linewidth, 1), self._default_font_size / 4) - for linewidth in line_widths - ] - polygon_collection = PolyCollection( - polygons, - alpha=alpha, - facecolor=face_colors, - linestyles=line_styles, - edgecolors=edge_colors, - linewidths=line_widths) - - self.ax_save.add_collection(polygon_collection) - return self - - @master_only - def draw_binary_masks( - self, - binary_masks: Union[np.ndarray, torch.Tensor], - colors: Union[str, tuple, List[str], List[tuple]] = 'g', - alphas: Union[float, List[float]] = 0.8) -> 'Visualizer': - """Draw single or multiple binary masks. - - Args: - binary_masks (np.ndarray, torch.Tensor): The binary_masks to draw - with of shape (N, H, W), where H is the image height and W is - the image width. Each value in the array is either a 0 or 1 - value of uint8 type. - colors (np.ndarray): The colors which binary_masks will convert to. - ``colors`` can have the same length with binary_masks or just - single value. If ``colors`` is single value, all the - binary_masks will convert to the same colors. The colors format - is RGB. Defaults to np.array([0, 255, 0]). - alphas (Union[int, List[int]]): The transparency of masks. - Defaults to 0.8. - """ - check_type('binary_masks', binary_masks, (np.ndarray, torch.Tensor)) - binary_masks = tensor2ndarray(binary_masks) - assert binary_masks.dtype == np.bool_, ( - 'The dtype of binary_masks should be np.bool_, ' - f'but got {binary_masks.dtype}') - binary_masks = binary_masks.astype('uint8') * 255 - img = self.get_image() - if binary_masks.ndim == 2: - binary_masks = binary_masks[None] - assert img.shape[:2] == binary_masks.shape[ - 1:], '`binary_masks` must have ' \ - 'the same shape with image' - binary_mask_len = binary_masks.shape[0] - - check_type_and_length('colors', colors, (str, tuple, list), - binary_mask_len) - colors = value2list(colors, (str, tuple), binary_mask_len) - colors = [ - color_str2rgb(color) if isinstance(color, str) else color - for color in colors - ] - for color in colors: - assert len(color) == 3 - for channel in color: - assert 0 <= channel <= 255 # type: ignore - - if isinstance(alphas, float): - alphas = [alphas] * binary_mask_len - - for binary_mask, color, alpha in zip(binary_masks, colors, alphas): - binary_mask_complement = cv2.bitwise_not(binary_mask) - rgb = np.zeros_like(img) - rgb[...] = color - rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask) - img_complement = cv2.bitwise_and( - img, img, mask=binary_mask_complement) - rgb = rgb + img_complement - img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) - self.ax_save.imshow( - img, - extent=(0, self.width, self.height, 0), - interpolation='nearest') - return self - - @staticmethod - @master_only - def draw_featmap(featmap: torch.Tensor, - overlaid_image: Optional[np.ndarray] = None, - channel_reduction: Optional[str] = 'squeeze_mean', - topk: int = 20, - arrangement: Tuple[int, int] = (4, 5), - resize_shape: Optional[tuple] = None, - alpha: float = 0.5) -> np.ndarray: - """Draw featmap. - - - If `overlaid_image` is not None, the final output image will be the - weighted sum of img and featmap. - - - If `resize_shape` is specified, `featmap` and `overlaid_image` - are interpolated. - - - If `resize_shape` is None and `overlaid_image` is not None, - the feature map will be interpolated to the spatial size of the image - in the case where the spatial dimensions of `overlaid_image` and - `featmap` are different. - - - If `channel_reduction` is "squeeze_mean" and "select_max", - it will compress featmap to single channel image and weighted - sum to `overlaid_image`. - - - If `channel_reduction` is None - - - If topk <= 0, featmap is assert to be one or three - channel and treated as image and will be weighted sum - to ``overlaid_image``. - - If topk > 0, it will select topk channel to show by the sum of - each channel. At the same time, you can specify the `arrangement` - to set the window layout. - - Args: - featmap (torch.Tensor): The featmap to draw which format is - (C, H, W). - overlaid_image (np.ndarray, optional): The overlaid image. - Defaults to None. - channel_reduction (str, optional): Reduce multiple channels to a - single channel. The optional value is 'squeeze_mean' - or 'select_max'. Defaults to 'squeeze_mean'. - topk (int): If channel_reduction is not None and topk > 0, - it will select topk channel to show by the sum of each channel. - if topk <= 0, tensor_chw is assert to be one or three. - Defaults to 20. - arrangement (Tuple[int, int]): The arrangement of featmap when - channel_reduction is None and topk > 0. Defaults to (4, 5). - resize_shape (tuple, optional): The shape to scale the feature map. - Defaults to None. - alpha (Union[int, List[int]]): The transparency of featmap. - Defaults to 0.5. - - Returns: - np.ndarray: RGB image. - """ - import matplotlib.pyplot as plt - assert isinstance(featmap, - torch.Tensor), (f'`featmap` should be torch.Tensor,' - f' but got {type(featmap)}') - assert featmap.ndim == 3, f'Input dimension must be 3, ' \ - f'but got {featmap.ndim}' - featmap = featmap.detach().cpu() - - if overlaid_image is not None: - if overlaid_image.ndim == 2: - overlaid_image = cv2.cvtColor(overlaid_image, - cv2.COLOR_GRAY2RGB) - - if overlaid_image.shape[:2] != featmap.shape[1:]: - warnings.warn( - f'Since the spatial dimensions of ' - f'overlaid_image: {overlaid_image.shape[:2]} and ' - f'featmap: {featmap.shape[1:]} are not same, ' - f'the feature map will be interpolated. ' - f'This may cause mismatch problems !') - if resize_shape is None: - featmap = F.interpolate( - featmap[None], - overlaid_image.shape[:2], - mode='bilinear', - align_corners=False)[0] - - if resize_shape is not None: - featmap = F.interpolate( - featmap[None], - resize_shape, - mode='bilinear', - align_corners=False)[0] - if overlaid_image is not None: - overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1]) - - if channel_reduction is not None: - assert channel_reduction in [ - 'squeeze_mean', 'select_max'], \ - f'Mode only support "squeeze_mean", "select_max", ' \ - f'but got {channel_reduction}' - if channel_reduction == 'select_max': - sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) - _, indices = torch.topk(sum_channel_featmap, 1) - feat_map = featmap[indices] - else: - feat_map = torch.mean(featmap, dim=0) - return convert_overlay_heatmap(feat_map, overlaid_image, alpha) - elif topk <= 0: - featmap_channel = featmap.shape[0] - assert featmap_channel in [ - 1, 3 - ], ('The input tensor channel dimension must be 1 or 3 ' - 'when topk is less than 1, but the channel ' - f'dimension you input is {featmap_channel}, you can use the' - ' channel_reduction parameter or set topk greater than ' - '0 to solve the error') - return convert_overlay_heatmap(featmap, overlaid_image, alpha) - else: - row, col = arrangement - channel, height, width = featmap.shape - assert row * col >= topk, 'The product of row and col in ' \ - 'the `arrangement` is less than ' \ - 'topk, please set the ' \ - '`arrangement` correctly' - - # Extract the feature map of topk - topk = min(channel, topk) - sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) - _, indices = torch.topk(sum_channel_featmap, topk) - topk_featmap = featmap[indices] - - fig = plt.figure(frameon=False) - # Set the window layout - fig.subplots_adjust( - left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) - dpi = fig.get_dpi() - fig.set_size_inches((width * col + 1e-2) / dpi, - (height * row + 1e-2) / dpi) - for i in range(topk): - axes = fig.add_subplot(row, col, i + 1) - axes.axis('off') - axes.text(2, 15, f'channel: {indices[i]}', fontsize=10) - axes.imshow( - convert_overlay_heatmap(topk_featmap[i], overlaid_image, - alpha)) - image = img_from_canvas(fig.canvas) - plt.close(fig) - return image - - @master_only - def add_config(self, config: Config, **kwargs): - """Record the config. - - Args: - config (Config): The Config object. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_config(config, **kwargs) - - @master_only - def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], - **kwargs) -> None: - """Record the model graph. - - Args: - model (torch.nn.Module): Model to draw. - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_graph(model, data_batch, **kwargs) - - @master_only - def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_image(name, image, step) # type: ignore - - @master_only - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record the scalar data. - - Args: - name (str): The scalar identifier. - value (float, int): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore - - @master_only - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record the scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Defaults to None. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_scalars(scalar_dict, step, file_path, **kwargs) - - @master_only - def add_datasample(self, - name, - image: np.ndarray, - data_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - show: bool = False, - wait_time: int = 0, - step: int = 0) -> None: - """Draw datasample.""" - pass - - def close(self) -> None: - """Close an opened object.""" - for vis_backend in self._vis_backends.values(): - vis_backend.close() - - @classmethod - def get_instance(cls, name: str, **kwargs) -> 'Visualizer': - """Make subclass can get latest created instance by - ``Visualizer.get_current_instance()``. - - Downstream codebase may need to get the latest created instance - without knowing the specific Visualizer type. For example, mmdetection - builds visualizer in runner and some component which cannot access - runner wants to get latest created visualizer. In this case, - the component does not know which type of visualizer has been built - and cannot get target instance. Therefore, :class:`Visualizer` - overrides the :meth:`get_instance` and its subclass will register - the created instance to :attr:`_instance_dict` additionally. - :meth:`get_current_instance` will return the latest created subclass - instance. - - Examples: - >>> class DetLocalVisualizer(Visualizer): - >>> def __init__(self, name): - >>> super().__init__(name) - >>> - >>> visualizer1 = DetLocalVisualizer.get_instance('name1') - >>> visualizer2 = Visualizer.get_current_instance() - >>> visualizer3 = DetLocalVisualizer.get_current_instance() - >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3) - - Args: - name (str): Name of instance. - - Returns: - object: Corresponding name instance. - """ - instance = super().get_instance(name, **kwargs) - Visualizer._instance_dict[name] = instance - return instance diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..b2a7737fae --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,101 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mmengine" +version = "0.10.7" +description = "Engine of OpenMMLab projects" +readme = "README.md" +requires-python = ">=3.10" +license-files = ["LICENSE"] +authors = [ + { name = "MMEngine Authors", email = "openmmlab@gmail.com" }, +] +urls = { "Homepage" = "https://github.com/open-mmlab/mmengine" } + +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Utilities" +] + +dependencies = [ + "addict", + "matplotlib", + "numpy<2.0", + "pyyaml", + "regex;sys_platform=='win32'", + "rich", + "termcolor" +] + +[project.optional-dependencies] +all = [ + "aim<=3.17.5;sys_platform!='win32'", + "bitsandbytes", + "clearml", + "coverage", + "dadaptation", + "dvclive", + "lion-pytorch", + "lmdb", + "mlflow", + "parameterized", + "pydantic==1.10.9", + "pytest", + "transformers" +] +tests = [ + "pytest" +] + +[tool.setuptools] +include-package-data = true + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.ruff] +line-length = 120 +target-version = "py312" +fix = true + +[tool.setuptools.packages.find] +where = ["src"] +include = ["mmengine*"] + +[tool.ruff.lint] +ignore = ["C408", "C901", "E501", "E741", "F402", "F823", "SIM1", "SIM300", "SIM212", "SIM905", "UP009", "UP015", "UP031", "UP028", "UP004", "UP045", "UP007", "UP035"] +select = ["C", "E", "F", "I", "W", "RUF013", "PERF102", "PLC1802", "PLC0208", "SIM", "UP"] +extend-safe-fixes = ["UP006"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 + +[tool.coverage.run] +source = ["transformers"] +omit = [ + "*/convert_*", + "*/__main__.py" +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "raise", + "except", + "register_parameter" +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" \ No newline at end of file diff --git a/setup.py b/setup.py index 4165915e46..e7b427b947 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,24 @@ import os import re -from setuptools import find_packages, setup # type: ignore from pkg_resources import DistributionNotFound, get_distribution +from setuptools import find_packages, setup # type: ignore def readme(): - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: content = f.read() return content -version_file = 'mmengine/version.py' +# version_file = 'mmengine/version.py' def choose_requirement(primary, secondary): """If some version of primary requirement installed, return primary, else return secondary.""" try: - name = re.split(r'[!<>=]', primary)[0] + name = re.split(r"[!<>=]", primary)[0] get_distribution(name) except DistributionNotFound: return secondary @@ -26,13 +26,7 @@ def choose_requirement(primary, secondary): return str(primary) -def get_version(): - with open(version_file) as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] - - -def parse_requirements(fname='requirements/runtime.txt', with_version=True): +def parse_requirements(fname="requirements/runtime.txt", with_version=True): """Parse the package dependencies listed in a requirements file but strips specific versioning information. @@ -49,109 +43,106 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True): import re import sys from os.path import exists + require_fpath = fname def parse_line(line): """Parse information from a line in a requirements text file.""" - if line.startswith('-r '): + if line.startswith("-r "): # Allow specifying requirements in other files - target = line.split(' ')[1] + target = line.split(" ")[1] for info in parse_require_file(target): yield info else: - info = {'line': line} - if line.startswith('-e '): - info['package'] = line.split('#egg=')[1] + info = {"line": line} + if line.startswith("-e "): + info["package"] = line.split("#egg=")[1] else: # Remove versioning from the package - pat = '(' + '|'.join(['>=', '==', '>']) + ')' + pat = "(" + "|".join([">=", "==", ">"]) + ")" parts = re.split(pat, line, maxsplit=1) parts = [p.strip() for p in parts] - info['package'] = parts[0] + info["package"] = parts[0] if len(parts) > 1: op, rest = parts[1:] - if ';' in rest: + if ";" in rest: # Handle platform specific dependencies # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, - rest.split(';')) - info['platform_deps'] = platform_deps + version, platform_deps = map(str.strip, rest.split(";")) + info["platform_deps"] = platform_deps else: version = rest # NOQA - info['version'] = (op, version) + info["version"] = (op, version) yield info def parse_require_file(fpath): with open(fpath) as f: for line in f.readlines(): line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): yield from parse_line(line) def gen_packages_items(): if exists(require_fpath): for info in parse_require_file(require_fpath): - parts = [info['package']] - if with_version and 'version' in info: - parts.extend(info['version']) - if not sys.version.startswith('3.4'): + parts = [info["package"]] + if with_version and "version" in info: + parts.extend(info["version"]) + if not sys.version.startswith("3.4"): # apparently package_deps are broken in 3.4 - platform_deps = info.get('platform_deps') + platform_deps = info.get("platform_deps") if platform_deps is not None: - parts.append(';' + platform_deps) - item = ''.join(parts) + parts.append(";" + platform_deps) + item = "".join(parts) yield item packages = list(gen_packages_items()) return packages -if int(os.getenv('MMENGINE_LITE', '0')) == 1: - install_requires = parse_requirements('requirements/runtime_lite.txt') +if int(os.getenv("MMENGINE_LITE", "0")) == 1: + install_requires = parse_requirements("requirements/runtime_lite.txt") else: install_requires = parse_requirements() try: # OpenCV installed via conda. import cv2 # NOQA: F401 - major, minor, *rest = cv2.__version__.split('.') + + major, minor, *rest = cv2.__version__.split(".") if int(major) < 3: - raise RuntimeError( - f'OpenCV >=3 is required but {cv2.__version__} is installed') + raise RuntimeError(f"OpenCV >=3 is required but {cv2.__version__} is installed") except ImportError: # If first not installed install second package - CHOOSE_INSTALL_REQUIRES = [('opencv-python-headless>=3', - 'opencv-python>=3')] + CHOOSE_INSTALL_REQUIRES = [("opencv-python-headless>=3", "opencv-python>=3")] for main, secondary in CHOOSE_INSTALL_REQUIRES: install_requires.append(choose_requirement(main, secondary)) setup( - name='mmengine' - if os.getenv('MMENGINE_LITE', '0') == '0' else 'mmengine-lite', - version=get_version(), - description='Engine of OpenMMLab projects', + name="mmengine" if os.getenv("MMENGINE_LITE", "0") == "0" else "mmengine-lite", + description="Engine of OpenMMLab projects", long_description=readme(), - long_description_content_type='text/markdown', - url='https://github.com/open-mmlab/mmengine', - author='MMEngine Authors', - author_email='openmmlab@gmail.com', + long_description_content_type="text/markdown", + url="https://github.com/open-mmlab/mmengine", + author="MMEngine Authors", + author_email="openmmlab@gmail.com", packages=find_packages(), include_package_data=True, classifiers=[ - 'Development Status :: 4 - Beta', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Utilities', + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Utilities", ], - python_requires='>=3.7', + python_requires=">=3.7", install_requires=install_requires, extras_require={ - 'all': parse_requirements('requirements.txt'), - 'tests': parse_requirements('requirements/tests.txt'), + "all": parse_requirements("requirements.txt"), + "tests": parse_requirements("requirements/tests.txt"), }, ) diff --git a/tests/data/config/lazy_module_config/_base_/base_model.py b/tests/data/config/lazy_module_config/_base_/base_model.py index 8e3a9dab7a..fb9ca4e11e 100644 --- a/tests/data/config/lazy_module_config/_base_/base_model.py +++ b/tests/data/config/lazy_module_config/_base_/base_model.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.testing.runner_test_case import ToyModel + model = dict(type=ToyModel) diff --git a/tests/data/config/lazy_module_config/_base_/default_runtime.py b/tests/data/config/lazy_module_config/_base_/default_runtime.py index d8ab215548..8d52038330 100644 --- a/tests/data/config/lazy_module_config/_base_/default_runtime.py +++ b/tests/data/config/lazy_module_config/_base_/default_runtime.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, - LoggerHook, ParamSchedulerHook) +from mmengine.hooks import CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook -default_scope = 'test_config' + +default_scope = "test_config" # configure default hooks default_hooks = dict( @@ -18,13 +18,13 @@ # whether to enable cudnn benchmark cudnn_benchmark=False, # set multi process parameters - mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0), # set distributed parameters - dist_cfg=dict(backend='nccl'), + dist_cfg=dict(backend="nccl"), ) # set log level -log_level = 'INFO' +log_level = "INFO" # load from which checkpoint load_from = None diff --git a/tests/data/config/lazy_module_config/_base_/scheduler.py b/tests/data/config/lazy_module_config/_base_/scheduler.py index a9a4c15af8..e2858f554c 100644 --- a/tests/data/config/lazy_module_config/_base_/scheduler.py +++ b/tests/data/config/lazy_module_config/_base_/scheduler.py @@ -3,12 +3,11 @@ from mmengine.optim.scheduler import MultiStepLR + # optimizer -optim_wrapper = dict( - optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) +optim_wrapper = dict(optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) # learning policy -param_scheduler = dict( - type=MultiStepLR, by_epoch=True, milestones=[1, 2], gamma=0.1) +param_scheduler = dict(type=MultiStepLR, by_epoch=True, milestones=[1, 2], gamma=0.1) # train, val, test setting train_cfg = dict(by_epoch=True, max_epochs=5, val_interval=1) diff --git a/tests/data/config/lazy_module_config/error_mix_using1.py b/tests/data/config/lazy_module_config/error_mix_using1.py index b7017ef0f2..6f8bb43c8c 100644 --- a/tests/data/config/lazy_module_config/error_mix_using1.py +++ b/tests/data/config/lazy_module_config/error_mix_using1.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './error_mix_using3.py' +_base_ = "./error_mix_using3.py" diff --git a/tests/data/config/lazy_module_config/error_mix_using2.py b/tests/data/config/lazy_module_config/error_mix_using2.py index c08d1b6716..4e4714ab9e 100644 --- a/tests/data/config/lazy_module_config/error_mix_using2.py +++ b/tests/data/config/lazy_module_config/error_mix_using2.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.config import read_base + with read_base(): from ...config.py_config.test_base_variables import * diff --git a/tests/data/config/lazy_module_config/error_mix_using3.py b/tests/data/config/lazy_module_config/error_mix_using3.py index 70418146ec..ef101fec61 100644 --- a/tests/data/config/lazy_module_config/error_mix_using3.py +++ b/tests/data/config/lazy_module_config/error_mix_using3.py @@ -1,2 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. -import numpy as np diff --git a/tests/data/config/lazy_module_config/load_mmdet_config.py b/tests/data/config/lazy_module_config/load_mmdet_config.py index c9eae6ba7b..2bf147dd9c 100644 --- a/tests/data/config/lazy_module_config/load_mmdet_config.py +++ b/tests/data/config/lazy_module_config/load_mmdet_config.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.config import read_base + with read_base(): from mmdet.configs.retinanet.retinanet_r50_caffe_fpn_1x_coco import * - from mmdet.configs.retinanet.retinanet_r101_caffe_fpn_1x_coco import \ - model as r101 + from mmdet.configs.retinanet.retinanet_r101_caffe_fpn_1x_coco import model as r101 model = r101 diff --git a/tests/data/config/lazy_module_config/test_ast_transform.py b/tests/data/config/lazy_module_config/test_ast_transform.py index a8803dde24..990177cf2c 100644 --- a/tests/data/config/lazy_module_config/test_ast_transform.py +++ b/tests/data/config/lazy_module_config/test_ast_transform.py @@ -1,15 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os -from importlib.util import find_spec as find_module - -import numpy -import numpy.compat -import numpy.linalg as linalg - -from mmengine.config import Config -from mmengine.fileio import LocalBackend as local -from mmengine.fileio import PetrelBackend -from ._base_.default_runtime import default_scope as scope -from ._base_.scheduler import val_cfg + + from rich.progress import Progress + + start = Progress.start diff --git a/tests/data/config/lazy_module_config/test_mix_builtin.py b/tests/data/config/lazy_module_config/test_mix_builtin.py index e36da58a3b..33701915ed 100644 --- a/tests/data/config/lazy_module_config/test_mix_builtin.py +++ b/tests/data/config/lazy_module_config/test_mix_builtin.py @@ -2,15 +2,12 @@ import os.path as osp from functools import partial from itertools import chain -from os.path import basename +from os.path import basename, splitext from os.path import exists as ex -from os.path import splitext -import numpy as np -path = osp.join('a', 'b') -name, suffix = splitext('a/b.py') +path = osp.join("a", "b") +name, suffix = splitext("a/b.py") chained = list(chain([1, 2], [3, 4])) existed = ex(__file__) cfgname = partial(basename, __file__)() - diff --git a/tests/data/config/lazy_module_config/toy_model.py b/tests/data/config/lazy_module_config/toy_model.py index a9d2a3f64a..d440098078 100644 --- a/tests/data/config/lazy_module_config/toy_model.py +++ b/tests/data/config/lazy_module_config/toy_model.py @@ -6,6 +6,7 @@ from mmengine.runner import FlexibleRunner from mmengine.testing.runner_test_case import ToyDataset, ToyMetric + with read_base(): from ._base_.base_model import * from ._base_.default_runtime import * @@ -15,24 +16,18 @@ train_dataloader = dict( - dataset=dict(type=ToyDataset), - sampler=dict(type=DefaultSampler, shuffle=True), - batch_size=3, - num_workers=0) + dataset=dict(type=ToyDataset), sampler=dict(type=DefaultSampler, shuffle=True), batch_size=3, num_workers=0 +) val_dataloader = dict( - dataset=dict(type=ToyDataset), - sampler=dict(type=DefaultSampler, shuffle=False), - batch_size=3, - num_workers=0) + dataset=dict(type=ToyDataset), sampler=dict(type=DefaultSampler, shuffle=False), batch_size=3, num_workers=0 +) val_evaluator = [dict(type=ToyMetric)] test_dataloader = dict( - dataset=dict(type=ToyDataset), - sampler=dict(type=DefaultSampler, shuffle=False), - batch_size=3, - num_workers=0) + dataset=dict(type=ToyDataset), sampler=dict(type=DefaultSampler, shuffle=False), batch_size=3, num_workers=0 +) test_evaluator = [dict(type=ToyMetric)] @@ -43,7 +38,8 @@ momentum=0.0002, update_buffers=True, strict_load=False, - priority=49) + priority=49, + ) ] runner_type = FlexibleRunner diff --git a/tests/data/config/py_config/base.py b/tests/data/config/py_config/base.py index 2364e1d10b..979522aae3 100644 --- a/tests/data/config/py_config/base.py +++ b/tests/data/config/py_config/base.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. item1 = [1, 2] -item2 = {'a': 0} +item2 = {"a": 0} item3 = True -item4 = 'test' +item4 = "test" diff --git a/tests/data/config/py_config/config.py b/tests/data/config/py_config/config.py index 65c03bf884..1318ce9822 100644 --- a/tests/data/config/py_config/config.py +++ b/tests/data/config/py_config/config.py @@ -2,4 +2,4 @@ test_int = 1 test_list = [1, 2, 3] # include type, optimizer can be initiated by build_from_cfg -optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) +optimizer = dict(type="SGD", lr=0.1, momentum=0.9, weight_decay=0.0001) diff --git a/tests/data/config/py_config/simple.config.py b/tests/data/config/py_config/simple.config.py index 2364e1d10b..979522aae3 100644 --- a/tests/data/config/py_config/simple.config.py +++ b/tests/data/config/py_config/simple.config.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. item1 = [1, 2] -item2 = {'a': 0} +item2 = {"a": 0} item3 = True -item4 = 'test' +item4 = "test" diff --git a/tests/data/config/py_config/simple_config.py b/tests/data/config/py_config/simple_config.py index 2364e1d10b..979522aae3 100644 --- a/tests/data/config/py_config/simple_config.py +++ b/tests/data/config/py_config/simple_config.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. item1 = [1, 2] -item2 = {'a': 0} +item2 = {"a": 0} item3 = True -item4 = 'test' +item4 = "test" diff --git a/tests/data/config/py_config/test_base_variables.py b/tests/data/config/py_config/test_base_variables.py index 4d20d7f025..6fdebbefab 100644 --- a/tests/data/config/py_config/test_base_variables.py +++ b/tests/data/config/py_config/test_base_variables.py @@ -1,11 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = [ - './base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json', - './base4.py' -] +_base_ = ["./base1.py", "../yaml_config/base2.yaml", "../json_config/base3.json", "./base4.py"] item3 = False -item4 = 'test' -item8 = '{{fileBasename}}' +item4 = "test" +item8 = "{{fileBasename}}" item9 = {{_base_.item2}} item10 = {{_base_.item7.b.c}} diff --git a/tests/data/config/py_config/test_base_variables_nested.py b/tests/data/config/py_config/test_base_variables_nested.py index ea9a6004a8..bc4743039d 100644 --- a/tests/data/config/py_config/test_base_variables_nested.py +++ b/tests/data/config/py_config/test_base_variables_nested.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = ['./test_base_variables.py'] -base = '_base_.item8' +_base_ = ["./test_base_variables.py"] +base = "_base_.item8" item11 = {{_base_.item8}} item12 = {{_base_.item9}} item13 = {{_base_.item10}} @@ -10,4 +10,5 @@ b=[{{_base_.item3}}], c=[{{_base_.item4}}], d=[[dict(e={{_base_.item5.a}})], {{_base_.item6}}], - e={{_base_.item1}}) + e={{_base_.item1}}, +) diff --git a/tests/data/config/py_config/test_code_in_config.py b/tests/data/config/py_config/test_code_in_config.py index d39c7cffe2..46ade7a7e3 100644 --- a/tests/data/config/py_config/test_code_in_config.py +++ b/tests/data/config/py_config/test_code_in_config.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine import Config # isort:skip -cfg = Config.fromfile('tests/data/config/py_config/simple_config.py') +cfg = Config.fromfile("tests/data/config/py_config/simple_config.py") item5 = cfg.item1[0] + cfg.item2.a diff --git a/tests/data/config/py_config/test_custom_class.py b/tests/data/config/py_config/test_custom_class.py index ad706b087e..5728ef1a87 100644 --- a/tests/data/config/py_config/test_custom_class.py +++ b/tests/data/config/py_config/test_custom_class.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -class A: - ... +class A: ... + item_a = dict(a=A) diff --git a/tests/data/config/py_config/test_custom_import.py b/tests/data/config/py_config/test_custom_import.py index d485a19005..4db94457ee 100644 --- a/tests/data/config/py_config/test_custom_import.py +++ b/tests/data/config/py_config/test_custom_import.py @@ -1,3 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -custom_imports = dict( - imports=['test_custom_import_module'], allow_failed_imports=False) +custom_imports = dict(imports=["test_custom_import_module"], allow_failed_imports=False) diff --git a/tests/data/config/py_config/test_custom_import_module.py b/tests/data/config/py_config/test_custom_import_module.py index 853b31ae6b..7b87375aef 100644 --- a/tests/data/config/py_config/test_custom_import_module.py +++ b/tests/data/config/py_config/test_custom_import_module.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -os.environ['TEST_VALUE'] = 'test' + +os.environ["TEST_VALUE"] = "test" diff --git a/tests/data/config/py_config/test_deprecated.py b/tests/data/config/py_config/test_deprecated.py index 7c82380428..41e4eb1349 100644 --- a/tests/data/config/py_config/test_deprecated.py +++ b/tests/data/config/py_config/test_deprecated.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './base.py' +_base_ = "./base.py" -_deprecation_ = dict( - expected='tests/data/config/py_config/base.py', reference='') +_deprecation_ = dict(expected="tests/data/config/py_config/base.py", reference="") diff --git a/tests/data/config/py_config/test_deprecated_base.py b/tests/data/config/py_config/test_deprecated_base.py index 828a384c86..95f0df9040 100644 --- a/tests/data/config/py_config/test_deprecated_base.py +++ b/tests/data/config/py_config/test_deprecated_base.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './test_deprecated.py' +_base_ = "./test_deprecated.py" diff --git a/tests/data/config/py_config/test_dump_pickle_support.py b/tests/data/config/py_config/test_dump_pickle_support.py index 6050ce10b1..61760e3d37 100644 --- a/tests/data/config/py_config/test_dump_pickle_support.py +++ b/tests/data/config/py_config/test_dump_pickle_support.py @@ -4,25 +4,17 @@ def func(): - return 'string with \tescape\\ characters\n' + return "string with \tescape\\ characters\n" test_item1 = [1, 2] bool_item2 = True -str_item3 = 'test' +str_item3 = "test" dict_item4 = dict( - a={ - 'c/d': 'path/d', - 'f': 's3//f', - 6: '2333', - '2333': 'number' - }, - b={'8': 543}, - c={9: 678}, - d={'a': 0}, - f=dict(a='69')) -dict_item5 = {'x/x': {'a.0': 233}} -dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]} + a={"c/d": "path/d", "f": "s3//f", 6: "2333", "2333": "number"}, b={"8": 543}, c={9: 678}, d={"a": 0}, f=dict(a="69") +) +dict_item5 = {"x/x": {"a.0": 233}} +dict_list_item6 = {"x/x": [{"a.0": 1.0, "b.0": 2.0}, {"c/3": 3.0}]} # Test windows path and escape. -str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in +str_item_7 = osp.join(osp.expanduser("~"), "folder") # with backslash in str_item_8 = func() diff --git a/tests/data/config/py_config/test_environment_var.py b/tests/data/config/py_config/test_environment_var.py index 89508fb133..40d4ac89aa 100644 --- a/tests/data/config/py_config/test_environment_var.py +++ b/tests/data/config/py_config/test_environment_var.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -item1 = '{{ $ITEM1: }}' -item2 = '{{ $ITEM2:default_value }}' -item3 = {{' $ITEM3:80 '}} +item1 = "{{ $ITEM1: }}" +item2 = "{{ $ITEM2:default_value }}" +item3 = {{" $ITEM3:80 "}} diff --git a/tests/data/config/py_config/test_get_external_cfg.py b/tests/data/config/py_config/test_get_external_cfg.py index 7598ce0cb6..a35ff974ac 100644 --- a/tests/data/config/py_config/test_get_external_cfg.py +++ b/tests/data/config/py_config/test_get_external_cfg.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. _base_ = [ - 'mmdet::_base_/models/faster-rcnn_r50_fpn.py', - 'mmdet::_base_/datasets/coco_detection.py', - 'mmdet::_base_/schedules/schedule_1x.py', - 'mmdet::_base_/default_runtime.py' + "mmdet::_base_/models/faster-rcnn_r50_fpn.py", + "mmdet::_base_/datasets/coco_detection.py", + "mmdet::_base_/schedules/schedule_1x.py", + "mmdet::_base_/default_runtime.py", ] diff --git a/tests/data/config/py_config/test_get_external_cfg2.py b/tests/data/config/py_config/test_get_external_cfg2.py index 7e72bdbf27..c53f98efe8 100644 --- a/tests/data/config/py_config/test_get_external_cfg2.py +++ b/tests/data/config/py_config/test_get_external_cfg2.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = 'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py' +_base_ = "mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py" diff --git a/tests/data/config/py_config/test_get_external_cfg3.py b/tests/data/config/py_config/test_get_external_cfg3.py index 5ae261350a..724e5fda42 100644 --- a/tests/data/config/py_config/test_get_external_cfg3.py +++ b/tests/data/config/py_config/test_get_external_cfg3.py @@ -1,18 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. _base_ = [ - 'mmdet::_base_/models/faster-rcnn_r50_fpn.py', - 'mmdet::_base_/datasets/coco_detection.py', - 'mmdet::_base_/schedules/schedule_1x.py', - 'mmdet::_base_/default_runtime.py', - './test_get_external_cfg_base.py' + "mmdet::_base_/models/faster-rcnn_r50_fpn.py", + "mmdet::_base_/datasets/coco_detection.py", + "mmdet::_base_/schedules/schedule_1x.py", + "mmdet::_base_/default_runtime.py", + "./test_get_external_cfg_base.py", ] -custom_hooks = [dict(type='mmdet.DetVisualizationHook')] +custom_hooks = [dict(type="mmdet.DetVisualizationHook")] -model = dict( - roi_head=dict( - bbox_head=dict( - loss_cls=dict(_delete_=True, type='test.ToyLoss') - ) - ) -) +model = dict(roi_head=dict(bbox_head=dict(loss_cls=dict(_delete_=True, type="test.ToyLoss")))) diff --git a/tests/data/config/py_config/test_get_external_cfg_base.py b/tests/data/config/py_config/test_get_external_cfg_base.py index d680ef0a6b..b3750b5b16 100644 --- a/tests/data/config/py_config/test_get_external_cfg_base.py +++ b/tests/data/config/py_config/test_get_external_cfg_base.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -toy_model = dict(type='ToyModel') +toy_model = dict(type="ToyModel") diff --git a/tests/data/config/py_config/test_merge_delete.py b/tests/data/config/py_config/test_merge_delete.py index f8a1eaf64c..8895ea3c3f 100644 --- a/tests/data/config/py_config/test_merge_delete.py +++ b/tests/data/config/py_config/test_merge_delete.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './base.py' -item1 = {'a': 0, '_delete_': True} -item2 = {'b': 0} +_base_ = "./base.py" +item1 = {"a": 0, "_delete_": True} +item2 = {"b": 0} diff --git a/tests/data/config/py_config/test_merge_from_base_error.py b/tests/data/config/py_config/test_merge_from_base_error.py index 1340e4bd27..615bcd3836 100644 --- a/tests/data/config/py_config/test_merge_from_base_error.py +++ b/tests/data/config/py_config/test_merge_from_base_error.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './base.py' -item3 = {'a': 1} +_base_ = "./base.py" +item3 = {"a": 1} diff --git a/tests/data/config/py_config/test_merge_from_base_single.py b/tests/data/config/py_config/test_merge_from_base_single.py index 19edcf82d0..1e516f17f7 100644 --- a/tests/data/config/py_config/test_merge_from_base_single.py +++ b/tests/data/config/py_config/test_merge_from_base_single.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './base.py' +_base_ = "./base.py" item1 = [2, 3] -item2 = {'a': 1} +item2 = {"a": 1} item3 = False -item4 = 'test_base' +item4 = "test_base" diff --git a/tests/data/config/py_config/test_merge_from_dict.py b/tests/data/config/py_config/test_merge_from_dict.py index cca07539c8..7c07ee3933 100644 --- a/tests/data/config/py_config/test_merge_from_dict.py +++ b/tests/data/config/py_config/test_merge_from_dict.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -item = [{'a': 0}, {'b': 0, 'c': 0}] +item = [{"a": 0}, {"b": 0, "c": 0}] diff --git a/tests/data/config/py_config/test_merge_from_multiple_bases.py b/tests/data/config/py_config/test_merge_from_multiple_bases.py index da575c39bc..06bc5bc55a 100644 --- a/tests/data/config/py_config/test_merge_from_multiple_bases.py +++ b/tests/data/config/py_config/test_merge_from_multiple_bases.py @@ -1,9 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = [ - './base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json', - './base4.py' -] +_base_ = ["./base1.py", "../yaml_config/base2.yaml", "../json_config/base3.json", "./base4.py"] item3 = False -item4 = 'test' +item4 = "test" item_bool = True item_float = 1.0 diff --git a/tests/data/config/py_config/test_merge_from_multiple_error.py b/tests/data/config/py_config/test_merge_from_multiple_error.py index b38596d953..fe4fe4e54f 100644 --- a/tests/data/config/py_config/test_merge_from_multiple_error.py +++ b/tests/data/config/py_config/test_merge_from_multiple_error.py @@ -1,7 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = [ - './base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json', - 'simple_config.py' -] +_base_ = ["./base1.py", "../yaml_config/base2.yaml", "../json_config/base3.json", "simple_config.py"] item3 = False -item4 = 'test' +item4 = "test" diff --git a/tests/data/config/py_config/test_merge_intermediate_variable_base.py b/tests/data/config/py_config/test_merge_intermediate_variable_base.py index f31a46a15d..32234c4287 100644 --- a/tests/data/config/py_config/test_merge_intermediate_variable_base.py +++ b/tests/data/config/py_config/test_merge_intermediate_variable_base.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. item1 = [1, 2] -item2 = {'a': 0} +item2 = {"a": 0} item3 = True -item4 = 'test' -item_cfg = {'b': 1} -item5 = {'cfg': item_cfg} -item6 = {'cfg': item_cfg} +item4 = "test" +item_cfg = {"b": 1} +item5 = {"cfg": item_cfg} +item6 = {"cfg": item_cfg} diff --git a/tests/data/config/py_config/test_merge_intermediate_variable_child.py b/tests/data/config/py_config/test_merge_intermediate_variable_child.py index 17325b13bc..7162eeb629 100644 --- a/tests/data/config/py_config/test_merge_intermediate_variable_child.py +++ b/tests/data/config/py_config/test_merge_intermediate_variable_child.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './test_merge_intermediate_variable_base.py' -item_cfg = {'b': 2} -item6 = {'cfg': item_cfg} +_base_ = "./test_merge_intermediate_variable_base.py" +item_cfg = {"b": 2} +item6 = {"cfg": item_cfg} diff --git a/tests/data/config/py_config/test_merge_recursive_bases.py b/tests/data/config/py_config/test_merge_recursive_bases.py index 6d2218bab5..50f9e8c853 100644 --- a/tests/data/config/py_config/test_merge_recursive_bases.py +++ b/tests/data/config/py_config/test_merge_recursive_bases.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = './test_merge_from_base_single.py' -item4 = 'test_recursive_bases' +_base_ = "./test_merge_from_base_single.py" +item4 = "test_recursive_bases" diff --git a/tests/data/config/py_config/test_pre_substitute_base_vars.py b/tests/data/config/py_config/test_pre_substitute_base_vars.py index 72d67ab404..87978246cb 100644 --- a/tests/data/config/py_config/test_pre_substitute_base_vars.py +++ b/tests/data/config/py_config/test_pre_substitute_base_vars.py @@ -1,11 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = ['./test_base_variables_nested.py'] +_base_ = ["./test_base_variables_nested.py"] item21 = {{_base_.item11}} item22 = item21 item23 = {{_base_.item10}} item24 = item23 -item25 = dict( - a=dict(b=item24), - b=[item24], - c=[[dict(e=item22)], {{_base_.item6}}], - e=item21) +item25 = dict(a=dict(b=item24), b=[item24], c=[[dict(e=item22)], {{_base_.item6}}], e=item21) diff --git a/tests/data/config/py_config/test_predefined_var.py b/tests/data/config/py_config/test_predefined_var.py index 82594590cf..65fd288f00 100644 --- a/tests/data/config/py_config/test_predefined_var.py +++ b/tests/data/config/py_config/test_predefined_var.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -item1 = '{{fileBasename}}' -item2 = '{{ fileDirname}}' -item3 = 'abc_{{ fileBasenameNoExtension }}' +item1 = "{{fileBasename}}" +item2 = "{{ fileDirname}}" +item3 = "abc_{{ fileBasenameNoExtension }}" diff --git a/tests/data/config/py_config/test_py_base.py b/tests/data/config/py_config/test_py_base.py index 8073705726..c7815d701f 100644 --- a/tests/data/config/py_config/test_py_base.py +++ b/tests/data/config/py_config/test_py_base.py @@ -1,11 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = [ - './base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json', - './base4.py' -] +_base_ = ["./base1.py", "../yaml_config/base2.yaml", "../json_config/base3.json", "./base4.py"] item2 = dict(b=[5, 6]) item3 = False -item4 = 'test' +item4 = "test" _base_.item6[0] = dict(c=0) -item8 = '{{fileBasename}}' -item9, item10, item11 = _base_.item7['b']['c'] +item8 = "{{fileBasename}}" +item9, item10, item11 = _base_.item7["b"]["c"] diff --git a/tests/data/config/py_config/test_py_modify_key.py b/tests/data/config/py_config/test_py_modify_key.py index f2dbbf03b1..2c32923231 100644 --- a/tests/data/config/py_config/test_py_modify_key.py +++ b/tests/data/config/py_config/test_py_modify_key.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. # Support modify value in config. item1 = dict() -item1['a'] = 1 +item1["a"] = 1 diff --git a/tests/data/config/py_config/test_py_nested_path.py b/tests/data/config/py_config/test_py_nested_path.py index b233616bd4..449977cb68 100644 --- a/tests/data/config/py_config/test_py_nested_path.py +++ b/tests/data/config/py_config/test_py_nested_path.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -_base_ = ['./test_py_base.py'] +_base_ = ["./test_py_base.py"] item12 = _base_.item8 item13 = _base_.item9 item14 = _base_.item1 @@ -7,5 +7,6 @@ a=dict(b=_base_.item2), b=[_base_.item3], c=[_base_.item4], - d=[[dict(e=_base_.item5['a'])], _base_.item6], - e=_base_.item1) + d=[[dict(e=_base_.item5["a"])], _base_.item6], + e=_base_.item1, +) diff --git a/tests/data/config/py_config/test_reserved_key.py b/tests/data/config/py_config/test_reserved_key.py index 34d4ebe2f8..5ff9f06d55 100644 --- a/tests/data/config/py_config/test_reserved_key.py +++ b/tests/data/config/py_config/test_reserved_key.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -filename = 'reserved.py' +filename = "reserved.py" diff --git a/tests/data/scripts/hello.py b/tests/data/scripts/hello.py index 2ed1a1e319..86510ffa53 100644 --- a/tests/data/scripts/hello.py +++ b/tests/data/scripts/hello.py @@ -6,8 +6,8 @@ def parse_args(): - parser = argparse.ArgumentParser(description='Say hello.') - parser.add_argument('name', help='To whom.') + parser = argparse.ArgumentParser(description="Say hello.") + parser.add_argument("name", help="To whom.") args = parser.parse_args() @@ -16,10 +16,10 @@ def parse_args(): def main(): args = parse_args() - print(f'hello {args.name}!') - if args.name == 'agent': - warnings.warn('I have a secret!') + print(f"hello {args.name}!") + if args.name == "agent": + warnings.warn("I have a secret!", stacklevel=2) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/test_analysis/test_activation_count.py b/tests/test_analysis/test_activation_count.py index eb41da23aa..4382d5982e 100644 --- a/tests/test_analysis/test_activation_count.py +++ b/tests/test_analysis/test_activation_count.py @@ -6,7 +6,7 @@ import typing import unittest from collections import Counter, defaultdict -from typing import Any, Dict, List, Tuple +from typing import Any import torch import torch.nn as nn @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv3(x) return x - def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]: + def get_gt_activation(self, x: torch.Tensor) -> tuple[int, int, int]: x = self.conv1(x) count1 = prod(list(x.size())) x = self.conv2(x) @@ -55,13 +55,13 @@ def setUp(self) -> None: # we are testing the right thing. lin = nn.Linear(10, 10) lin_x: torch.Tensor = torch.randn(10, 10) - trace = torch.jit.trace(lin, (lin_x, )) + trace = torch.jit.trace(lin, (lin_x,)) node_kinds = [node.kind() for node in trace.graph.nodes()] - assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds - if 'aten::addmm' in node_kinds: - self.lin_op = 'addmm' + assert "aten::addmm" in node_kinds or "aten::linear" in node_kinds + if "aten::addmm" in node_kinds: + self.lin_op = "addmm" else: - self.lin_op = 'linear' + self.lin_op = "linear" def test_conv2d(self) -> None: """Test the activation count for convolutions.""" @@ -70,15 +70,15 @@ def test_conv2d(self) -> None: spatial_dim = 32 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) conv_net = SmallConvNet(input_dim) - ac_dict, _ = activation_count(conv_net, (x, )) + ac_dict, _ = activation_count(conv_net, (x,)) gt_count = sum(conv_net.get_gt_activation(x)) gt_dict = defaultdict(float) - gt_dict['conv'] = gt_count / 1e6 + gt_dict["conv"] = gt_count / 1e6 self.assertDictEqual( gt_dict, ac_dict, - 'conv_net with 3 layers failed to pass the activation count test.', + "conv_net with 3 layers failed to pass the activation count test.", ) def test_linear(self) -> None: @@ -88,34 +88,32 @@ def test_linear(self) -> None: output_dim = 20 linear = nn.Linear(input_dim, output_dim) x = torch.randn(batch_size, input_dim) - ac_dict, _ = activation_count(linear, (x, )) + ac_dict, _ = activation_count(linear, (x,)) gt_count = batch_size * output_dim gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_count / 1e6 - self.assertEqual(gt_dict, ac_dict, - 'FC layer failed to pass the activation count test.') + self.assertEqual(gt_dict, ac_dict, "FC layer failed to pass the activation count test.") def test_supported_ops(self) -> None: """Test the activation count for user provided handles.""" - def dummy_handle(inputs: List[Any], - outputs: List[Any]) -> typing.Counter[str]: - return Counter({'conv': 100}) + def dummy_handle(inputs: list[Any], outputs: list[Any]) -> typing.Counter[str]: + return Counter({"conv": 100}) batch_size = 1 input_dim = 3 spatial_dim = 32 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) conv_net = SmallConvNet(input_dim) - sp_ops: Dict[str, Handle] = {'aten::_convolution': dummy_handle} - ac_dict, _ = activation_count(conv_net, (x, ), sp_ops) + sp_ops: dict[str, Handle] = {"aten::_convolution": dummy_handle} + ac_dict, _ = activation_count(conv_net, (x,), sp_ops) gt_dict = defaultdict(float) conv_layers = 3 - gt_dict['conv'] = 100 * conv_layers / 1e6 + gt_dict["conv"] = 100 * conv_layers / 1e6 self.assertDictEqual( gt_dict, ac_dict, - 'conv_net with 3 layers failed to pass the activation count test.', + "conv_net with 3 layers failed to pass the activation count test.", ) def test_activation_count_class(self) -> None: @@ -126,10 +124,12 @@ def test_activation_count_class(self) -> None: netLinear = nn.Linear(input_dim, output_dim) x = torch.randn(batch_size, input_dim) gt_count = batch_size * output_dim - gt_dict = Counter({ - '': gt_count, - }) - acts_counter = ActivationAnalyzer(netLinear, (x, )) + gt_dict = Counter( + { + "": gt_count, + } + ) + acts_counter = ActivationAnalyzer(netLinear, (x,)) self.assertEqual(acts_counter.by_module(), gt_dict) batch_size = 1 @@ -137,13 +137,15 @@ def test_activation_count_class(self) -> None: spatial_dim = 32 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) conv_net = SmallConvNet(input_dim) - acts_counter = ActivationAnalyzer(conv_net, (x, )) + acts_counter = ActivationAnalyzer(conv_net, (x,)) gt_counts = conv_net.get_gt_activation(x) - gt_dict = Counter({ - '': sum(gt_counts), - 'conv1': gt_counts[0], - 'conv2': gt_counts[1], - 'conv3': gt_counts[2], - }) + gt_dict = Counter( + { + "": sum(gt_counts), + "conv1": gt_counts[0], + "conv2": gt_counts[1], + "conv3": gt_counts[2], + } + ) self.assertDictEqual(gt_dict, acts_counter.by_module()) diff --git a/tests/test_analysis/test_flop_count.py b/tests/test_analysis/test_flop_count.py index 20749a0bab..ab8d33b948 100644 --- a/tests/test_analysis/test_flop_count.py +++ b/tests/test_analysis/test_flop_count.py @@ -7,7 +7,7 @@ import typing import unittest from collections import Counter, defaultdict -from typing import Any, Dict, Tuple +from typing import Any import torch import torch.nn as nn @@ -20,7 +20,6 @@ class _CustomOp(Function): - @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: return input @@ -74,14 +73,12 @@ def __init__( ) -> None: super().__init__() if transpose: - conv_layers = [ - nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d - ] - kwargs = {'output_padding': output_padding} + conv_layers = [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] + kwargs = {"output_padding": output_padding} else: conv_layers = [nn.Conv1d, nn.Conv2d, nn.Conv3d] - assert (output_padding == 0), 'output_padding is not supported for' - ' un-transposed convolutions.' + assert output_padding == 0, "output_padding is not supported for" + " un-transposed convolutions." kwargs = {} ConvLayer = conv_layers[conv_dim - 1] @@ -180,13 +177,13 @@ def setUp(self) -> None: # we are testing the right thing. lin = nn.Linear(10, 10) lin_x: torch.Tensor = torch.randn(10, 10) - trace = torch.jit.trace(lin, (lin_x, )) + trace = torch.jit.trace(lin, (lin_x,)) node_kinds = [node.kind() for node in trace.graph.nodes()] - assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds - if 'aten::addmm' in node_kinds: - self.lin_op = 'addmm' + assert "aten::addmm" in node_kinds or "aten::linear" in node_kinds + if "aten::addmm" in node_kinds: + self.lin_op = "addmm" else: - self.lin_op = 'linear' + self.lin_op = "linear" def test_customized_ops(self) -> None: """Test the use of customized operation handles. @@ -198,39 +195,33 @@ def test_customized_ops(self) -> None: """ # New handle for a new operation. - def dummy_sigmoid_flop_jit( - inputs: typing.List[Any], - outputs: typing.List[Any]) -> typing.Counter[str]: + def dummy_sigmoid_flop_jit(inputs: list[Any], outputs: list[Any]) -> typing.Counter[str]: """A dummy handle function for sigmoid. Note the handle here does not compute actual flop count. This is used for test only. """ flop_dict = Counter() # type: Counter - flop_dict['sigmoid'] = 10000 + flop_dict["sigmoid"] = 10000 return flop_dict batch_size = 10 input_dim = 5 output_dim = 4 custom_net = CustomNet(input_dim, output_dim) - custom_ops: Dict[str, Handle] = { - 'aten::sigmoid': dummy_sigmoid_flop_jit - } + custom_ops: dict[str, Handle] = {"aten::sigmoid": dummy_sigmoid_flop_jit} x = torch.rand(batch_size, input_dim) - flop_dict1, _ = flop_count(custom_net, (x, ), supported_ops=custom_ops) + flop_dict1, _ = flop_count(custom_net, (x,), supported_ops=custom_ops) flop_sigmoid = 10000 / 1e9 self.assertEqual( - flop_dict1['sigmoid'], + flop_dict1["sigmoid"], flop_sigmoid, - 'Customized operation handle failed to pass the flop count test.', + "Customized operation handle failed to pass the flop count test.", ) # New handle that overwrites a default handle addmm. So now the new # handle counts flops for the fully connected layer. - def addmm_dummy_flop_jit( - inputs: typing.List[object], - outputs: typing.List[object]) -> typing.Counter[str]: + def addmm_dummy_flop_jit(inputs: list[object], outputs: list[object]) -> typing.Counter[str]: """A dummy handle function for fully connected layers. This overwrites the default handle. Note the handle here does not @@ -240,16 +231,13 @@ def addmm_dummy_flop_jit( flop_dict[self.lin_op] = 400000 return flop_dict - custom_ops2: Dict[str, Handle] = { - f'aten::{self.lin_op}': addmm_dummy_flop_jit - } - flop_dict2, _ = flop_count( - custom_net, (x, ), supported_ops=custom_ops2) + custom_ops2: dict[str, Handle] = {f"aten::{self.lin_op}": addmm_dummy_flop_jit} + flop_dict2, _ = flop_count(custom_net, (x,), supported_ops=custom_ops2) flop = 400000 / 1e9 self.assertEqual( flop_dict2[self.lin_op], flop, - 'Customized operation handle failed to pass the flop count test.', + "Customized operation handle failed to pass the flop count test.", ) def test_nn(self) -> None: @@ -259,12 +247,11 @@ def test_nn(self) -> None: input_dim = 8 output_dim = 4 x = torch.randn(batch_size, input_dim) - flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x, )) + flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x,)) gt_flop = batch_size * input_dim * output_dim / 1e9 gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_flop - self.assertDictEqual(flop_dict, gt_dict, - 'nn.Linear failed to pass the flop count test.') + self.assertDictEqual(flop_dict, gt_dict, "nn.Linear failed to pass the flop count test.") def test_skip_ops(self) -> None: """Test the return of skipped operations.""" @@ -273,12 +260,10 @@ def test_skip_ops(self) -> None: output_dim = 4 custom_net = CustomNet(input_dim, output_dim) x = torch.rand(batch_size, input_dim) - _, skip_dict = flop_count(custom_net, (x, )) + _, skip_dict = flop_count(custom_net, (x,)) gt_dict = Counter() # type: Counter - gt_dict['aten::sigmoid'] = 1 - self.assertDictEqual( - skip_dict, gt_dict, - 'Skipped operations failed to pass the flop count test.') + gt_dict["aten::sigmoid"] = 1 + self.assertDictEqual(skip_dict, gt_dict, "Skipped operations failed to pass the flop count test.") def test_linear(self) -> None: """Test a network with a single fully connected layer.""" @@ -287,32 +272,32 @@ def test_linear(self) -> None: output_dim = 20 linear_net = LinearNet(input_dim, output_dim) x = torch.randn(batch_size, input_dim) - flop_dict, _ = flop_count(linear_net, (x, )) + flop_dict, _ = flop_count(linear_net, (x,)) gt_flop = batch_size * input_dim * output_dim / 1e9 gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_flop self.assertDictEqual( flop_dict, gt_dict, - 'Fully connected layer failed to pass the flop count test.', + "Fully connected layer failed to pass the flop count test.", ) # Test with #input_dims>2 - if self.lin_op != 'linear': + if self.lin_op != "linear": # Skip this test if nn.Linear doesn't use aten::linear # TODO: Stop skipping when multidimensional aten::matmul # flop counting is implemented return extra_dim = 5 x = torch.randn(batch_size, extra_dim, input_dim) - flop_dict, _ = flop_count(linear_net, (x, )) + flop_dict, _ = flop_count(linear_net, (x,)) gt_flop = batch_size * input_dim * extra_dim * output_dim / 1e9 gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_flop self.assertDictEqual( flop_dict, gt_dict, - 'Fully connected layer failed to pass the flop count test.', + "Fully connected layer failed to pass the flop count test.", ) def test_conv(self) -> None: @@ -347,33 +332,34 @@ def _test_conv( transpose, output_padding, ) - assert conv_dim in [ - 1, 2, 3 - ], 'Convolution dimension needs to be 1, 2, or 3' + assert conv_dim in [1, 2, 3], "Convolution dimension needs to be 1, 2, or 3" if conv_dim == 1: x = torch.randn(batch_size, input_dim, spatial_dim) elif conv_dim == 2: - x = torch.randn(batch_size, input_dim, spatial_dim, - spatial_dim) + x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) else: - x = torch.randn(batch_size, input_dim, spatial_dim, - spatial_dim, spatial_dim) + x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim, spatial_dim) - flop_dict, _ = flop_count(convNet, (x, )) + flop_dict, _ = flop_count(convNet, (x,)) if transpose: spatial_size = spatial_dim else: - spatial_size = ( - (spatial_dim + 2 * padding) - kernel_size) // stride + 1 + spatial_size = ((spatial_dim + 2 * padding) - kernel_size) // stride + 1 gt_flop = ( - batch_size * input_dim * output_dim * (kernel_size**conv_dim) * - (spatial_size**conv_dim) / group_size / 1e9) + batch_size + * input_dim + * output_dim + * (kernel_size**conv_dim) + * (spatial_size**conv_dim) + / group_size + / 1e9 + ) gt_dict = defaultdict(float) - gt_dict['conv'] = gt_flop + gt_dict["conv"] = gt_flop self.assertDictEqual( flop_dict, gt_dict, - 'Convolution layer failed to pass the flop count test.', + "Convolution layer failed to pass the flop count test.", ) # Test flop count for 2d convolution. @@ -590,17 +576,13 @@ def _test_conv( flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = m * n * p / 1e9 gt_dict = defaultdict(float) - gt_dict['matmul'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'Matmul operation failed to pass the flop count test.') + gt_dict["matmul"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "Matmul operation failed to pass the flop count test.") # Test with single dimension y y = torch.randn(n) - gt_dict['matmul'] = m * n * 1 / 1e9 + gt_dict["matmul"] = m * n * 1 / 1e9 flop_dict, _ = flop_count(m_net, (x, y)) - self.assertDictEqual( - flop_dict, gt_dict, - 'Matmul operation failed to pass the flop count test.') + self.assertDictEqual(flop_dict, gt_dict, "Matmul operation failed to pass the flop count test.") def test_matmul_broadcast(self) -> None: """Test flop count for operation matmul.""" @@ -613,40 +595,32 @@ def test_matmul_broadcast(self) -> None: flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = m * n * p / 1e9 gt_dict = defaultdict(float) - gt_dict['matmul'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'Matmul operation failed to pass the flop count test.') + gt_dict["matmul"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "Matmul operation failed to pass the flop count test.") x = torch.randn(2, 2, m, n) y = torch.randn(2, 2, n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = 4 * m * n * p / 1e9 gt_dict = defaultdict(float) - gt_dict['matmul'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'Matmul operation failed to pass the flop count test.') + gt_dict["matmul"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "Matmul operation failed to pass the flop count test.") x = torch.randn(1, m, n) y = torch.randn(n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = m * n * p / 1e9 gt_dict = defaultdict(float) - gt_dict['matmul'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'Matmul operation failed to pass the flop count test.') + gt_dict["matmul"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "Matmul operation failed to pass the flop count test.") x = torch.randn(2, m, n) y = torch.randn(n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = 2 * m * n * p / 1e9 gt_dict = defaultdict(float) - gt_dict['matmul'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'Matmul operation failed to pass the flop count test.') + gt_dict["matmul"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "Matmul operation failed to pass the flop count test.") def test_bmm(self) -> None: """Test flop count for operation torch.bmm. @@ -663,11 +637,11 @@ def test_bmm(self) -> None: flop_dict, _ = flop_count(e_net, (x, y)) gt_flop = n * t * p * c / 1e9 gt_dict = defaultdict(float) - gt_dict['bmm'] = gt_flop + gt_dict["bmm"] = gt_flop self.assertDictEqual( flop_dict, gt_dict, - 'bmm operation nct,ncp->ntp failed to pass the flop count test.', + "bmm operation nct,ncp->ntp failed to pass the flop count test.", ) def test_einsum(self) -> None: @@ -676,7 +650,7 @@ def test_einsum(self) -> None: The first case checks torch.einsum with equation nct,ncp->ntp. The second case checks torch.einsum with equation "ntg,ncg->nct". """ - equation = 'nct,ncp->ntp' + equation = "nct,ncp->ntp" n = 1 c = 5 t = 2 @@ -687,14 +661,14 @@ def test_einsum(self) -> None: flop_dict, _ = flop_count(e_net, (x, y)) gt_flop = n * t * p * c / 1e9 gt_dict = defaultdict(float) - gt_dict['einsum'] = gt_flop + gt_dict["einsum"] = gt_flop self.assertDictEqual( flop_dict, gt_dict, - 'Einsum operation nct,ncp->ntp failed to pass flop count test.', + "Einsum operation nct,ncp->ntp failed to pass flop count test.", ) - equation = 'ntg,ncg->nct' + equation = "ntg,ncg->nct" g = 6 e_net = EinsumNet(equation) x = torch.randn(n, t, g) @@ -702,11 +676,11 @@ def test_einsum(self) -> None: flop_dict, _ = flop_count(e_net, (x, y)) gt_flop = n * t * g * c / 1e9 gt_dict = defaultdict(float) - gt_dict['einsum'] = gt_flop + gt_dict["einsum"] = gt_flop self.assertDictEqual( flop_dict, gt_dict, - 'Einsum operation ntg,ncg->nct failed to pass flop count test.', + "Einsum operation ntg,ncg->nct failed to pass flop count test.", ) def test_batchnorm(self) -> None: @@ -719,13 +693,11 @@ def test_batchnorm(self) -> None: input_dim = 10 batch_1d = nn.BatchNorm1d(input_dim, affine=False).eval() x = torch.randn(batch_size, input_dim) - flop_dict, _ = flop_count(batch_1d, (x, )) + flop_dict, _ = flop_count(batch_1d, (x,)) gt_flop = batch_size * input_dim / 1e9 gt_dict = defaultdict(float) - gt_dict['batch_norm'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'BatchNorm1d failed to pass the flop count test.') + gt_dict["batch_norm"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "BatchNorm1d failed to pass the flop count test.") # Test for BatchNorm2d. batch_size = 10 @@ -734,14 +706,11 @@ def test_batchnorm(self) -> None: spatial_dim_y = 5 batch_2d = nn.BatchNorm2d(input_dim, affine=False) x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y) - flop_dict, _ = flop_count(batch_2d, (x, )) - gt_flop = 4 * batch_size * input_dim * spatial_dim_x * \ - spatial_dim_y / 1e9 + flop_dict, _ = flop_count(batch_2d, (x,)) + gt_flop = 4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y / 1e9 gt_dict = defaultdict(float) - gt_dict['batch_norm'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'BatchNorm2d failed to pass the flop count test.') + gt_dict["batch_norm"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "BatchNorm2d failed to pass the flop count test.") # Test for BatchNorm3d. batch_size = 10 @@ -750,16 +719,12 @@ def test_batchnorm(self) -> None: spatial_dim_y = 5 spatial_dim_z = 5 batch_3d = nn.BatchNorm3d(input_dim, affine=False) - x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y, - spatial_dim_z) - flop_dict, _ = flop_count(batch_3d, (x, )) - gt_flop = (4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y * - spatial_dim_z / 1e9) + x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y, spatial_dim_z) + flop_dict, _ = flop_count(batch_3d, (x,)) + gt_flop = 4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y * spatial_dim_z / 1e9 gt_dict = defaultdict(float) - gt_dict['batch_norm'] = gt_flop - self.assertDictEqual( - flop_dict, gt_dict, - 'BatchNorm3d failed to pass the flop count test.') + gt_dict["batch_norm"] = gt_flop + self.assertDictEqual(flop_dict, gt_dict, "BatchNorm3d failed to pass the flop count test.") def test_threeNet(self) -> None: """Test a network with more than one layer. @@ -774,20 +739,19 @@ def test_threeNet(self) -> None: linear_dim = 3 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) three_net = ThreeNet(input_dim, conv_dim, linear_dim) - flop1 = batch_size * conv_dim * input_dim * spatial_dim * \ - spatial_dim / 1e9 + flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim / 1e9 flop_linear1 = batch_size * conv_dim * linear_dim / 1e9 flop_linear2 = batch_size * linear_dim * 1 / 1e9 flop2 = flop_linear1 + flop_linear2 - flop_dict, _ = flop_count(three_net, (x, )) + flop_dict, _ = flop_count(three_net, (x,)) gt_dict = defaultdict(float) - gt_dict['conv'] = flop1 + gt_dict["conv"] = flop1 gt_dict[self.lin_op] = flop2 - gt_dict['adaptive_avg_pool2d'] = 2e-6 + gt_dict["adaptive_avg_pool2d"] = 2e-6 self.assertDictEqual( flop_dict, gt_dict, - 'The three-layer network failed to pass the flop count test.', + "The three-layer network failed to pass the flop count test.", ) def test_flop_counter_class(self) -> None: @@ -802,26 +766,28 @@ def test_flop_counter_class(self) -> None: flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim flop_linear1 = batch_size * conv_dim * linear_dim flop_linear2 = batch_size * linear_dim * 1 - flop_counter = FlopAnalyzer(three_net, (x, )) - gt_dict = Counter({ - 'conv': flop1, - 'linear1': flop_linear1, - 'linear2': flop_linear2, - 'pool': flop1 // input_dim, - }) - gt_dict[''] = sum(gt_dict.values()) + flop_counter = FlopAnalyzer(three_net, (x,)) + gt_dict = Counter( + { + "conv": flop1, + "linear1": flop_linear1, + "linear2": flop_linear2, + "pool": flop1 // input_dim, + } + ) + gt_dict[""] = sum(gt_dict.values()) self.assertEqual(flop_counter.by_module(), gt_dict) def test_autograd_function(self): # test support on custom autograd function class Mod(nn.Module): - def forward(self, x): return _CustomOp.apply(x) - flop = FlopAnalyzer(Mod(), (torch.rand(4, 5), )).set_op_handle( - 'prim::PythonOp._CustomOp', lambda *args, **kwargs: 42) + flop = FlopAnalyzer(Mod(), (torch.rand(4, 5),)).set_op_handle( + "prim::PythonOp._CustomOp", lambda *args, **kwargs: 42 + ) self.assertEqual(flop.total(), 42) def test_scripted_function(self): @@ -831,39 +797,34 @@ def func(x): return x @ x class Mod(nn.Module): - def forward(self, x): f = torch.jit.script(func) return f(x * x) - flop = FlopAnalyzer(Mod(), (torch.rand(5, 5), )) + flop = FlopAnalyzer(Mod(), (torch.rand(5, 5),)) _ = flop.total() - self.assertIn('prim::CallFunction', flop.unsupported_ops()) + self.assertIn("prim::CallFunction", flop.unsupported_ops()) class TestFlopCountHandles(unittest.TestCase): - - def _count_function(self, func, inputs, name) -> Tuple[Any, Any]: + def _count_function(self, func, inputs, name) -> tuple[Any, Any]: tensor_inputs = [x for x in inputs if isinstance(x, torch.Tensor)] def f(*args): return func(*inputs) - graph = torch.jit.trace( - f, tuple(tensor_inputs), check_trace=False).graph + graph = torch.jit.trace(f, tuple(tensor_inputs), check_trace=False).graph nodes = [k for k in graph.nodes() if k.kind() == name] self.assertEqual(len(nodes), 1) node = nodes[0] return list(node.inputs()), list(node.outputs()) def test_batch_norm(self): - op_name = 'aten::batch_norm' + op_name = "aten::batch_norm" counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] vec = torch.rand(2) - nodes = self._count_function( - F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec), - op_name) + nodes = self._count_function(F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec), op_name) self.assertEqual(counter(*nodes), 32) nodes = self._count_function( @@ -882,43 +843,36 @@ def test_batch_norm(self): self.assertEqual(counter(*nodes), 80) def test_group_norm(self): - op_name = 'aten::group_norm' + op_name = "aten::group_norm" counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] vec = torch.rand(2) - nodes = self._count_function(F.group_norm, - (torch.rand(2, 2, 2, 2), 2, vec, vec), - op_name) + nodes = self._count_function(F.group_norm, (torch.rand(2, 2, 2, 2), 2, vec, vec), op_name) self.assertEqual(counter(*nodes), 80) - nodes = self._count_function(F.group_norm, - (torch.rand(2, 2, 2, 2), 2, None, None), - op_name) + nodes = self._count_function(F.group_norm, (torch.rand(2, 2, 2, 2), 2, None, None), op_name) self.assertEqual(counter(*nodes), 64) def test_upsample(self): - op_name = 'aten::upsample_bilinear2d' + op_name = "aten::upsample_bilinear2d" counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] - nodes = self._count_function( - F.interpolate, - (torch.rand(2, 2, 2, 2), None, 2, 'bilinear', False), op_name) + nodes = self._count_function(F.interpolate, (torch.rand(2, 2, 2, 2), None, 2, "bilinear", False), op_name) self.assertEqual(counter(*nodes), 2**4 * 4 * 4) def test_complicated_einsum(self): - op_name = 'aten::einsum' + op_name = "aten::einsum" counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] nodes = self._count_function( torch.einsum, - ('nc,nchw->hw', torch.rand(3, 4), torch.rand(3, 4, 2, 3)), + ("nc,nchw->hw", torch.rand(3, 4), torch.rand(3, 4, 2, 3)), op_name, ) self.assertEqual(counter(*nodes), 72.0) def test_torch_mm(self): - for op_name, func in zip(['aten::mm', 'aten::matmul'], - [torch.mm, torch.matmul]): + for op_name, func in zip(["aten::mm", "aten::matmul"], [torch.mm, torch.matmul], strict=False): counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] nodes = self._count_function( diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py index be10309d0f..ce62e19a51 100644 --- a/tests/test_analysis/test_jit_analysis.py +++ b/tests/test_analysis/test_jit_analysis.py @@ -9,7 +9,7 @@ import unittest import warnings from collections import Counter -from typing import Any, Dict, List +from typing import Any import torch import torch.nn as nn @@ -17,14 +17,13 @@ from mmengine import MMLogger from mmengine.analysis import FlopAnalyzer from mmengine.analysis.jit_analysis import JitModelAnalysis -from mmengine.analysis.jit_handles import (Handle, addmm_flop_jit, - conv_flop_jit, linear_flop_jit) +from mmengine.analysis.jit_handles import Handle, addmm_flop_jit, conv_flop_jit, linear_flop_jit class NestedNetInnerModule(nn.Module): """A submodule for the nested net test module below.""" - def __init__(self, lin_op: str = 'addmm') -> None: + def __init__(self, lin_op: str = "addmm") -> None: super().__init__() conv_input_size = (2, 5) conv_in = 2 @@ -44,21 +43,20 @@ def __init__(self, lin_op: str = 'addmm') -> None: fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) - spatial_pos = (conv_input_size[1] + 2 * padding) - 2 * ( - kernel_size // 2) + spatial_pos = (conv_input_size[1] + 2 * padding) - 2 * (kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out - conv_flops = Counter({'conv': conv_flops_}) + conv_flops = Counter({"conv": conv_flops_}) model_flops = conv_flops + fc_flops - self.flops: 'Dict[str, typing.Counter[str]]' = { - '': model_flops, - 'fc': fc_flops, - 'conv': conv_flops, + self.flops: dict[str, typing.Counter[str]] = { + "": model_flops, + "fc": fc_flops, + "conv": conv_flops, } - self.name_to_module: 'Dict[str, nn.Module]' = { - '': self, - 'fc': self.fc, - 'conv': self.conv, + self.name_to_module: dict[str, nn.Module] = { + "": self, + "fc": self.fc, + "conv": self.conv, } def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -73,7 +71,7 @@ class NestedNet(nn.Module): """A network with nested submodules for testing the ability to correctly capture scope information.""" - def __init__(self, lin_op: str = 'addmm') -> None: + def __init__(self, lin_op: str = "addmm") -> None: super().__init__() self.input_size = (4, 5) @@ -95,35 +93,34 @@ def __init__(self, lin_op: str = 'addmm') -> None: fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) - spatial_pos = (self.input_size[1] + 2 * padding) - 2 * ( - kernel_size // 2) + spatial_pos = (self.input_size[1] + 2 * padding) - 2 * (kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out - conv_flops = Counter({'conv': conv_flops_}) - - model_flops = conv_flops + fc_flops + self.submod.flops[''] - self.flops: 'Dict[str, typing.Counter[str]]' = { - '': model_flops, - 'fc': fc_flops, - 'conv': conv_flops, - 'submod': self.submod.flops[''], - 'submod.fc': self.submod.flops['fc'], - 'submod.conv': self.submod.flops['conv'], + conv_flops = Counter({"conv": conv_flops_}) + + model_flops = conv_flops + fc_flops + self.submod.flops[""] + self.flops: dict[str, typing.Counter[str]] = { + "": model_flops, + "fc": fc_flops, + "conv": conv_flops, + "submod": self.submod.flops[""], + "submod.fc": self.submod.flops["fc"], + "submod.conv": self.submod.flops["conv"], } - self.name_to_module: 'Dict[str, nn.Module]' = { - '': self, - 'fc': self.fc, - 'conv': self.conv, - 'submod': self.submod, - 'submod.fc': self.submod.name_to_module['fc'], - 'submod.conv': self.submod.name_to_module['conv'], + self.name_to_module: dict[str, nn.Module] = { + "": self, + "fc": self.fc, + "conv": self.conv, + "submod": self.submod, + "submod.fc": self.submod.name_to_module["fc"], + "submod.conv": self.submod.name_to_module["conv"], } def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = torch.flatten(x, 1) x = self.fc(x) - x = self.submod(x)**2 + x = self.submod(x) ** 2 return x @@ -132,7 +129,7 @@ class UnusedNet(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc1_in, fc1_out = 10, 10 fc2_in, fc2_out = 10, 1 unused_in, unused_out = 20, 20 @@ -140,7 +137,7 @@ def __init__(self) -> None: self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out) self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out) self.unused = nn.Linear(in_features=unused_in, out_features=unused_out) - self.act: 'nn.Module' = nn.ReLU() + self.act: nn.Module = nn.ReLU() self.fc1_flops: int = fc1_in * fc1_out self.fc2_flops: int = fc2_in * fc2_out @@ -155,7 +152,7 @@ class RepeatedNet(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc1_in, fc1_out = 10, 10 fc2_in, fc2_out = 10, 10 self.fc1_num = 3 @@ -180,7 +177,7 @@ class NonForwardInnerModule(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc_in, fc_out = 10, 1 self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) @@ -199,7 +196,7 @@ class NonForwardNet(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc_in, fc_out = 10, 10 self.submod = NonForwardInnerModule() @@ -230,7 +227,7 @@ class SharedModuleNet(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc1_in, fc1_out = 10, 10 fc2_in, fc2_out = 10, 1 @@ -238,8 +235,8 @@ def __init__(self) -> None: self.submod1 = SharedInnerModule(inner) self.submod2 = SharedInnerModule(inner) multiname = nn.Linear(in_features=fc2_in, out_features=fc2_out) - self.multiname1: 'nn.Module' = multiname - self.multiname2: 'nn.Module' = multiname + self.multiname1: nn.Module = multiname + self.multiname2: nn.Module = multiname self.multiname_flops: int = fc2_in * fc2_out self.shared_flops: int = fc1_in * fc1_out @@ -255,7 +252,7 @@ class RecursiveScopeNet(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc_in, fc_out = 10, 1 self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) @@ -277,7 +274,7 @@ class TraceWarningNet(nn.Module): def __init__(self) -> None: super().__init__() - self.input_size = (10, ) + self.input_size = (10,) fc1_in, fc1_out = 10, 1 fc2_in, fc2_out = 10, 10 @@ -289,7 +286,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.fc1(x).item() - warnings.warn('Dummy RuntimeWarning.', RuntimeWarning) + warnings.warn("Dummy RuntimeWarning.", RuntimeWarning, stacklevel=2) if y < 0.0: x = self.fc2(x) return x + 2 @@ -307,20 +304,20 @@ def setUp(self) -> None: # we are testing the right thing. lin = nn.Linear(10, 10) lin_x: torch.Tensor = torch.randn(10, 10) - trace = torch.jit.trace(lin, (lin_x, )) + trace = torch.jit.trace(lin, (lin_x,)) node_kinds = [node.kind() for node in trace.graph.nodes()] - assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds - if 'aten::addmm' in node_kinds: - self.lin_op = 'addmm' + assert "aten::addmm" in node_kinds or "aten::linear" in node_kinds + if "aten::addmm" in node_kinds: + self.lin_op = "addmm" else: - self.lin_op = 'linear' + self.lin_op = "linear" def test_total(self) -> None: """Tests that JitModelAnalysis.total(module) returns the correct counts for string and module inputs.""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) @@ -336,15 +333,12 @@ def test_by_module(self) -> None: in the correctly structured dictionary.""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) - flops = { - name: sum(counts.values()) - for name, counts in model.flops.items() - } + flops = {name: sum(counts.values()) for name, counts in model.flops.items()} self.assertEqual(analyzer.by_module(), flops) @@ -353,7 +347,7 @@ def test_by_operator(self) -> None: counts for string and module inputs.""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) @@ -368,7 +362,7 @@ def test_by_module_and_operator(self) -> None: correct counts in the correct structure.""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) @@ -384,26 +378,26 @@ def test_unused_module(self) -> None: """ model = UnusedNet() - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) unused_count = 0 unused_per_operator = Counter() # type: Counter model_count = model.fc1_flops + model.fc2_flops - self.assertEqual(analyzer.total('unused'), unused_count) - self.assertEqual(analyzer.by_operator('unused'), unused_per_operator) - self.assertEqual(analyzer.total(''), model_count) + self.assertEqual(analyzer.total("unused"), unused_count) + self.assertEqual(analyzer.by_operator("unused"), unused_per_operator) + self.assertEqual(analyzer.total(""), model_count) # The unused mod is recognized as never called - self.assertEqual(analyzer.uncalled_modules(), {'unused'}) + self.assertEqual(analyzer.uncalled_modules(), {"unused"}) def test_repeated_module(self) -> None: """Tests that repeated calls to the same submodule correct aggregates results to that submodule.""" model = RepeatedNet() - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) fc1_count = model.fc1_num * model.fc1_flops @@ -411,10 +405,10 @@ def test_repeated_module(self) -> None: total_count = fc1_count + fc2_count fc1_per_operator = Counter({self.lin_op: fc1_count}) - self.assertEqual(analyzer.total('fc1'), fc1_count) - self.assertEqual(analyzer.total('fc2'), fc2_count) - self.assertEqual(analyzer.total(''), total_count) - self.assertEqual(analyzer.by_operator('fc1'), fc1_per_operator) + self.assertEqual(analyzer.total("fc1"), fc1_count) + self.assertEqual(analyzer.total("fc2"), fc2_count) + self.assertEqual(analyzer.total(""), total_count) + self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set()) @@ -427,25 +421,23 @@ def test_non_forward_func_call(self) -> None: """ model = NonForwardNet() - inputs = (torch.randn((1, 10)), ) - analyzer = FlopAnalyzer( - model=model, inputs=inputs).ancestor_mode('caller') + inputs = (torch.randn((1, 10)),) + analyzer = FlopAnalyzer(model=model, inputs=inputs).ancestor_mode("caller") inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count - self.assertEqual(analyzer.total('submod'), 0) - self.assertEqual(analyzer.total('submod.fc'), inner_fc_count) - self.assertEqual(analyzer.total(''), total_count) + self.assertEqual(analyzer.total("submod"), 0) + self.assertEqual(analyzer.total("submod.fc"), inner_fc_count) + self.assertEqual(analyzer.total(""), total_count) # The mod not directly called is registered as such - self.assertEqual(analyzer.uncalled_modules(), {'submod'}) + self.assertEqual(analyzer.uncalled_modules(), {"submod"}) - analyzer = FlopAnalyzer( - model=model, inputs=inputs).ancestor_mode('owner') - self.assertEqual(analyzer.total('submod'), inner_fc_count) - self.assertEqual(analyzer.total('submod.fc'), inner_fc_count) - self.assertEqual(analyzer.total(''), total_count) + analyzer = FlopAnalyzer(model=model, inputs=inputs).ancestor_mode("owner") + self.assertEqual(analyzer.total("submod"), inner_fc_count) + self.assertEqual(analyzer.total("submod.fc"), inner_fc_count) + self.assertEqual(analyzer.total(""), total_count) self.assertEqual(analyzer.uncalled_modules(), set()) def test_shared_module(self) -> None: @@ -453,11 +445,11 @@ def test_shared_module(self) -> None: names.""" model = SharedModuleNet() - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = ( - FlopAnalyzer(model=model, inputs=inputs).unsupported_ops_warnings( - enabled=False).ancestor_mode('caller')) + FlopAnalyzer(model=model, inputs=inputs).unsupported_ops_warnings(enabled=False).ancestor_mode("caller") + ) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. @@ -467,34 +459,30 @@ def test_shared_module(self) -> None: shared_flops = 2 * model.shared_flops # Shared under 2 submodules total_flops = multiname_flops + shared_flops flops = { - '': total_flops, - 'submod1': model.shared_flops, - 'submod1.submod': shared_flops, - 'submod2': model.shared_flops, - 'multiname1': multiname_flops, + "": total_flops, + "submod1": model.shared_flops, + "submod1.submod": shared_flops, + "submod2": model.shared_flops, + "multiname1": multiname_flops, } self.assertEqual(analyzer.by_module(), flops) # Test access by alternative name self.assertEqual( - analyzer.total('submod2.submod'), - flops['submod1.submod'], + analyzer.total("submod2.submod"), + flops["submod1.submod"], ) self.assertEqual( - analyzer.total('multiname2'), - flops['multiname1'], + analyzer.total("multiname2"), + flops["multiname1"], ) # Test getting canonical name - self.assertEqual( - analyzer.canonical_module_name('multiname2'), 'multiname1') - self.assertEqual( - analyzer.canonical_module_name('multiname1'), 'multiname1') - self.assertEqual( - analyzer.canonical_module_name('submod2.submod'), 'submod1.submod') - self.assertEqual( - analyzer.canonical_module_name('submod1.submod'), 'submod1.submod') + self.assertEqual(analyzer.canonical_module_name("multiname2"), "multiname1") + self.assertEqual(analyzer.canonical_module_name("multiname1"), "multiname1") + self.assertEqual(analyzer.canonical_module_name("submod2.submod"), "submod1.submod") + self.assertEqual(analyzer.canonical_module_name("submod1.submod"), "submod1.submod") # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) @@ -503,12 +491,12 @@ def test_recursive_scope(self) -> None: """Tests that an op is only counted once per module, even if it is in the scope of that module multiple times.""" model = RecursiveScopeNet() - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model, inputs) self.assertEqual(analyzer.total(), model.flops) - self.assertEqual(analyzer.total('fc'), model.flops) + self.assertEqual(analyzer.total("fc"), model.flops) # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) @@ -517,19 +505,13 @@ def test_data_parallel(self) -> None: """Tests that a model wrapped in DataParallel still returns results labeled by the correct scopes.""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) # Find flops for wrapper - flops = { - 'module' + ('.' if name else '') + name: flop - for name, flop in model.flops.items() - } - flops[''] = model.flops[''] - name_to_module = { - 'module' + ('.' if name else '') + name: mod - for name, mod in model.name_to_module.items() - } - name_to_module[''] = model.name_to_module[''] + flops = {"module" + ("." if name else "") + name: flop for name, flop in model.flops.items()} + flops[""] = model.flops[""] + name_to_module = {"module" + ("." if name else "") + name: mod for name, mod in model.name_to_module.items()} + name_to_module[""] = model.name_to_module[""] model = torch.nn.DataParallel(model).cpu() analyzer = FlopAnalyzer(model=model, inputs=inputs) @@ -550,8 +532,8 @@ def test_data_parallel(self) -> None: def test_data_parallel_root_scope(self) -> None: # A test case discussed in D32227000 model = nn.DataParallel(nn.Linear(10, 10)).cpu() - for mode in ['caller', 'owner']: - flop = FlopAnalyzer(model, (torch.randn(10, 10), )) + for mode in ["caller", "owner"]: + flop = FlopAnalyzer(model, (torch.randn(10, 10),)) flop.ancestor_mode(mode) self.assertEqual(flop.total(), 1000) @@ -559,37 +541,36 @@ def test_unsupported_ops(self) -> None: """Tests per-module recording of unsupported operations.""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) - analyzer = JitModelAnalysis( - model=model, inputs=inputs).set_op_handle( - 'aten::addmm', - addmm_flop_jit, - 'aten::linear', - linear_flop_jit, - ) + analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle( + "aten::addmm", + addmm_flop_jit, + "aten::linear", + linear_flop_jit, + ) analyzer.total() - skipped_inner_conv = Counter({'aten::_convolution': 1}) + skipped_inner_conv = Counter({"aten::_convolution": 1}) skipped_inner_fc = Counter() # type: Counter - skipped_inner = Counter({'aten::add': 1, 'aten::mul': 1}) + skipped_inner = Counter({"aten::add": 1, "aten::mul": 1}) skipped_inner += skipped_inner_fc skipped_inner += skipped_inner_conv - skipped_outer_conv = Counter({'aten::_convolution': 1}) + skipped_outer_conv = Counter({"aten::_convolution": 1}) skipped_outer_fc = Counter() # type: Counter - skipped_outer = Counter({'aten::pow': 1}) + skipped_outer = Counter({"aten::pow": 1}) skipped_outer += skipped_outer_conv skipped_outer += skipped_outer_fc skipped_outer += skipped_inner skipped = { - '': skipped_outer, - 'conv': skipped_outer_conv, - 'fc': skipped_outer_fc, - 'submod': skipped_inner, - 'submod.conv': skipped_inner_conv, - 'submod.fc': skipped_inner_fc, + "": skipped_outer, + "conv": skipped_outer_conv, + "fc": skipped_outer_fc, + "submod": skipped_inner, + "submod.conv": skipped_inner_conv, + "submod.fc": skipped_inner_fc, } # Access by string @@ -600,47 +581,41 @@ def test_unsupported_ops(self) -> None: def test_changing_handles(self) -> None: """Tests .set_op_handle(), .clear_op_handles()""" model = NestedNet(lin_op=self.lin_op) - inputs = (torch.randn((1, *model.input_size)), ) - op_handles: 'Dict[str, Handle]' = { - 'aten::addmm': addmm_flop_jit, - 'aten::linear': linear_flop_jit, + inputs = (torch.randn((1, *model.input_size)),) + op_handles: dict[str, Handle] = { + "aten::addmm": addmm_flop_jit, + "aten::linear": linear_flop_jit, } - analyzer = JitModelAnalysis( - model=model, inputs=inputs).set_op_handle(**op_handles) + analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle(**op_handles) analyzer.unsupported_ops_warnings(enabled=False) # Request a result once to cache flop counts - _ = analyzer.total('') + _ = analyzer.total("") # Add an op handle - analyzer.set_op_handle('aten::_convolution', conv_flop_jit) + analyzer.set_op_handle("aten::_convolution", conv_flop_jit) self.assertEqual(analyzer.by_module_and_operator(), model.flops) # Overwrite an op handle def make_dummy_op(name: str, output: int) -> Handle: - - def dummy_ops_handle(inputs: List[Any], - outputs: List[Any]) -> typing.Counter[str]: + def dummy_ops_handle(inputs: list[Any], outputs: list[Any]) -> typing.Counter[str]: return Counter({name: output}) return dummy_ops_handle - dummy_name = 'dummy_op' + dummy_name = "dummy_op" dummy_out = 1000 - analyzer.set_op_handle(f'aten::{self.lin_op}', - make_dummy_op(dummy_name, dummy_out)) + analyzer.set_op_handle(f"aten::{self.lin_op}", make_dummy_op(dummy_name, dummy_out)) dummy_flops = {} for name, counts in model.flops.items(): - dummy_flops[name] = Counter( - {op: flop - for op, flop in counts.items() if op != self.lin_op}) - dummy_flops[''][dummy_name] = 2 * dummy_out - dummy_flops['fc'][dummy_name] = dummy_out - dummy_flops['submod'][dummy_name] = dummy_out - dummy_flops['submod.fc'][dummy_name] = dummy_out + dummy_flops[name] = Counter({op: flop for op, flop in counts.items() if op != self.lin_op}) + dummy_flops[""][dummy_name] = 2 * dummy_out + dummy_flops["fc"][dummy_name] = dummy_out + dummy_flops["submod"][dummy_name] = dummy_out + dummy_flops["submod.fc"][dummy_name] = dummy_out self.assertEqual(analyzer.by_module_and_operator(), dummy_flops) @@ -655,16 +630,19 @@ def test_copy(self) -> None: """Tests .copy(...)""" model = RepeatedNet() - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = ( - JitModelAnalysis(model=model, inputs=inputs).set_op_handle( - 'aten::addmm', + JitModelAnalysis(model=model, inputs=inputs) + .set_op_handle( + "aten::addmm", addmm_flop_jit, - 'aten::linear', + "aten::linear", linear_flop_jit, - ).unsupported_ops_warnings(enabled=False).tracer_warnings( - mode='none')) + ) + .unsupported_ops_warnings(enabled=False) + .tracer_warnings(mode="none") + ) repeated_net_flops = model.fc1_num * model.fc1_flops repeated_net_flops += model.fc2_num * model.fc2_flops @@ -698,9 +676,8 @@ def test_copy(self) -> None: # Copy with new model and inputs new_model = NonForwardNet() bs = 5 - new_inputs = (torch.randn((bs, *new_model.input_size)), ) - analyzer_new = analyzer.copy( - new_model=new_model, new_inputs=new_inputs) + new_inputs = (torch.randn((bs, *new_model.input_size)),) + analyzer_new = analyzer.copy(new_model=new_model, new_inputs=new_inputs) non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops @@ -720,32 +697,32 @@ def test_copy(self) -> None: def test_disable_warnings(self) -> None: """Tests .unsupported_ops_warnings(...) and .tracer_warnings(...)""" model = TraceWarningNet() - inputs = (torch.randn((1, *model.input_size)), ) + inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopAnalyzer(model=model, inputs=inputs) # Tracer warnings - analyzer.tracer_warnings(mode='all') + analyzer.tracer_warnings(mode="all") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(torch.jit.TracerWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) - analyzer.tracer_warnings(mode='none') + analyzer.tracer_warnings(mode="none") analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse(torch.jit.TracerWarning in warning_types) self.assertFalse(RuntimeWarning in warning_types) - analyzer.tracer_warnings(mode='no_tracer_warning') + analyzer.tracer_warnings(mode="no_tracer_warning") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] @@ -754,15 +731,15 @@ def test_disable_warnings(self) -> None: # Unsupported ops and uncalled modules warnings logger = MMLogger.get_current_instance() - skipeed_msg = 'Unsupported operator aten::add encountered 1 time(s)' - uncalled_msg = 'never called' - uncalled_modules = 'fc1' # fc2 is called by chance + skipeed_msg = "Unsupported operator aten::add encountered 1 time(s)" + uncalled_msg = "never called" + uncalled_modules = "fc1" # fc2 is called by chance analyzer.uncalled_modules_warnings(enabled=False) analyzer.unsupported_ops_warnings(enabled=False) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: - logger.warning('Dummy warning.') + logger.warning("Dummy warning.") _ = analyzer.total() self.assertFalse(any(skipeed_msg in s for s in cm.output)) self.assertFalse(any(uncalled_msg in s for s in cm.output)) @@ -782,7 +759,6 @@ def test_skip_uncalled_containers_warnings(self) -> None: # uncalled containers should not warn class A(nn.Module): - def forward(self, x): return self.submod[0](x) + 1 @@ -793,7 +769,7 @@ def forward(self, x): logger = MMLogger.get_current_instance() with self.assertLogs(logger, logging.WARN) as cm: - logger.warning('Dummy warning.') + logger.warning("Dummy warning.") _ = analyzer.total() - uncalled_string = 'Module never called: submod' + uncalled_string = "Module never called: submod" self.assertFalse(any(uncalled_string in s for s in cm.output)) diff --git a/tests/test_analysis/test_param_count.py b/tests/test_analysis/test_param_count.py index 59db4df0c4..3e6f74dd20 100644 --- a/tests/test_analysis/test_param_count.py +++ b/tests/test_analysis/test_param_count.py @@ -7,12 +7,10 @@ from torch import nn -from mmengine.analysis.complexity_analysis import (parameter_count, - parameter_count_table) +from mmengine.analysis.complexity_analysis import parameter_count, parameter_count_table class NetWithReuse(nn.Module): - def __init__(self, reuse: bool = False) -> None: super().__init__() self.conv1 = nn.Conv2d(100, 100, 3) @@ -22,7 +20,6 @@ def __init__(self, reuse: bool = False) -> None: class NetWithDupPrefix(nn.Module): - def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(100, 100, 3) @@ -30,22 +27,20 @@ def __init__(self) -> None: class TestParamCount(unittest.TestCase): - def test_param(self) -> None: net = NetWithReuse() count = parameter_count(net) - self.assertTrue(count[''], 180200) - self.assertTrue(count['conv2'], 90100) + self.assertTrue(count[""], 180200) + self.assertTrue(count["conv2"], 90100) def test_param_with_reuse(self) -> None: net = NetWithReuse(reuse=True) count = parameter_count(net) - self.assertTrue(count[''], 90200) - self.assertTrue(count['conv2'], 100) + self.assertTrue(count[""], 90200) + self.assertTrue(count["conv2"], 100) def test_param_with_same_prefix(self) -> None: net = NetWithDupPrefix() table = parameter_count_table(net) - c = ['conv111.weight' in line for line in table.split('\n')] - self.assertEqual( - sum(c), 1) # it only appears once, despite being a prefix of conv1 + c = ["conv111.weight" in line for line in table.split("\n")] + self.assertEqual(sum(c), 1) # it only appears once, despite being a prefix of conv1 diff --git a/tests/test_analysis/test_print_helper.py b/tests/test_analysis/test_print_helper.py index 14366583d5..861c1df12f 100644 --- a/tests/test_analysis/test_print_helper.py +++ b/tests/test_analysis/test_print_helper.py @@ -11,7 +11,6 @@ class NetAcceptOneTensor(nn.Module): - def __init__(self) -> None: super().__init__() self.l1 = nn.Linear(in_features=5, out_features=6) @@ -22,7 +21,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class NetAcceptTwoTensors(nn.Module): - def __init__(self) -> None: super().__init__() self.l1 = nn.Linear(in_features=5, out_features=6) @@ -34,7 +32,6 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: class NetAcceptOneTensorAndOneScalar(nn.Module): - def __init__(self) -> None: super().__init__() self.l1 = nn.Linear(in_features=5, out_features=6) @@ -56,53 +53,45 @@ def test_get_model_complexity_info(): model = NetAcceptOneTensor() complexity_info = get_model_complexity_info(model=model, inputs=input1) flops = FlopAnalyzer(model=model, inputs=input1).total() - params = parameter_count(model=model)[''] - assert complexity_info['flops'] == flops - assert complexity_info['params'] == params + params = parameter_count(model=model)[""] + assert complexity_info["flops"] == flops + assert complexity_info["params"] == params - complexity_info = get_model_complexity_info( - model=model, input_shape=input_shape1) - flops = FlopAnalyzer( - model=model, inputs=(torch.randn(1, *input_shape1), )).total() - assert complexity_info['flops'] == flops + complexity_info = get_model_complexity_info(model=model, input_shape=input_shape1) + flops = FlopAnalyzer(model=model, inputs=(torch.randn(1, *input_shape1),)).total() + assert complexity_info["flops"] == flops # test a network that accepts two tensors as input model = NetAcceptTwoTensors() - complexity_info = get_model_complexity_info( - model=model, inputs=(input1, input2)) + complexity_info = get_model_complexity_info(model=model, inputs=(input1, input2)) flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total() - params = parameter_count(model=model)[''] - assert complexity_info['flops'] == flops - assert complexity_info['params'] == params + params = parameter_count(model=model)[""] + assert complexity_info["flops"] == flops + assert complexity_info["params"] == params - complexity_info = get_model_complexity_info( - model=model, input_shape=(input_shape1, input_shape2)) + complexity_info = get_model_complexity_info(model=model, input_shape=(input_shape1, input_shape2)) inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2)) flops = FlopAnalyzer(model=model, inputs=inputs).total() - assert complexity_info['flops'] == flops + assert complexity_info["flops"] == flops # test a network that accepts one tensor and one scalar as input model = NetAcceptOneTensorAndOneScalar() # For pytorch<1.9, a scalar input is not acceptable for torch.jit, # wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501 - scalar = torch.tensor([ - scalar - ]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar - complexity_info = get_model_complexity_info( - model=model, inputs=(input1, scalar)) + scalar = torch.tensor([scalar]) if digit_version(TORCH_VERSION) < digit_version("1.9.0") else scalar + complexity_info = get_model_complexity_info(model=model, inputs=(input1, scalar)) flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total() - params = parameter_count(model=model)[''] - assert complexity_info['flops'] == flops - assert complexity_info['params'] == params + params = parameter_count(model=model)[""] + assert complexity_info["flops"] == flops + assert complexity_info["params"] == params # `get_model_complexity_info()` should throw `ValueError` # when neithor `inputs` nor `input_shape` is specified - with pytest.raises(ValueError, match='should be set'): + with pytest.raises(ValueError, match="should be set"): get_model_complexity_info(model) # `get_model_complexity_info()` should throw `ValueError` # when both `inputs` and `input_shape` are specified model = NetAcceptOneTensor() - with pytest.raises(ValueError, match='cannot be both set'): - get_model_complexity_info( - model, inputs=input1, input_shape=input_shape1) + with pytest.raises(ValueError, match="cannot be both set"): + get_model_complexity_info(model, inputs=input1, input_shape=input_shape1) diff --git a/tests/test_config/test_collect_meta.py b/tests/test_config/test_collect_meta.py index 1e058f9f7b..8b273da9c3 100644 --- a/tests/test_config/test_collect_meta.py +++ b/tests/test_config/test_collect_meta.py @@ -3,41 +3,40 @@ import pytest -from mmengine.config.utils import (_get_external_cfg_base_path, - _get_package_and_cfg_path) +from mmengine.config.utils import _get_external_cfg_base_path, _get_package_and_cfg_path def test_get_external_cfg_base_path(tmp_path): package_path = tmp_path - rel_cfg_path = os.path.join('cfg_dir', 'cfg_file') + rel_cfg_path = os.path.join("cfg_dir", "cfg_file") with pytest.raises(FileNotFoundError): _get_external_cfg_base_path(str(package_path), rel_cfg_path) - cfg_dir = tmp_path / '.mim' / 'configs' / 'cfg_dir' + cfg_dir = tmp_path / ".mim" / "configs" / "cfg_dir" cfg_dir.mkdir(parents=True, exist_ok=True) - f = open(cfg_dir / 'cfg_file', 'w') + f = open(cfg_dir / "cfg_file", "w") f.close() cfg_path = _get_external_cfg_base_path(str(package_path), rel_cfg_path) - assert cfg_path == f'{os.path.join(str(cfg_dir), "cfg_file")}' + assert cfg_path == f"{os.path.join(str(cfg_dir), 'cfg_file')}" def test_get_external_cfg_path(): - external_cfg_path = 'mmdet::path/cfg' + external_cfg_path = "mmdet::path/cfg" package, rel_cfg_path = _get_package_and_cfg_path(external_cfg_path) - assert package == 'mmdet' - assert rel_cfg_path == 'path/cfg' + assert package == "mmdet" + assert rel_cfg_path == "path/cfg" # external config must contain `::`. - external_cfg_path = 'path/cfg' + external_cfg_path = "path/cfg" with pytest.raises(ValueError): _get_package_and_cfg_path(external_cfg_path) # Use `:::` as operator will raise an error. - external_cfg_path = 'mmdet:::path/cfg' + external_cfg_path = "mmdet:::path/cfg" with pytest.raises(ValueError): _get_package_and_cfg_path(external_cfg_path) # Use `:` as operator will raise an error. - external_cfg_path = 'mmdet:path/cfg' + external_cfg_path = "mmdet:path/cfg" with pytest.raises(ValueError): _get_package_and_cfg_path(external_cfg_path) # Too much `::` - external_cfg_path = 'mmdet::path/cfg::error' + external_cfg_path = "mmdet::path/cfg::error" with pytest.raises(ValueError): _get_package_and_cfg_path(external_cfg_path) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index e783431441..36b218cbac 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -23,65 +23,56 @@ class TestConfig: - data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data/') + data_path = osp.join(osp.dirname(osp.dirname(__file__)), "data/") - @pytest.mark.parametrize('file_format', ['py', 'json', 'yaml']) + @pytest.mark.parametrize("file_format", ["py", "json", "yaml"]) def test_init(self, file_format): # test init Config by __init__ cfg = Config() assert cfg.filename is None - assert cfg.text == '' + assert cfg.text == "" assert len(cfg) == 0 assert cfg._cfg_dict == {} # test `cfg_dict` parameter # `cfg_dict` is either dict or None - with pytest.raises(TypeError, match='cfg_dict must be a dict'): + with pytest.raises(TypeError, match="cfg_dict must be a dict"): Config([0, 1]) # test `filename` parameter - cfg_dict = dict( - item1=[1, 2], item2=dict(a=0), item3=True, item4='test') - cfg_file = osp.join( - self.data_path, - f'config/{file_format}_config/simple_config.{file_format}') + cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4="test") + cfg_file = osp.join(self.data_path, f"config/{file_format}_config/simple_config.{file_format}") cfg = Config(cfg_dict, filename=cfg_file) assert isinstance(cfg, Config) assert cfg.filename == cfg_file assert cfg.text == open(cfg_file).read() - cfg_file = osp.join( - self.data_path, - f'config/{file_format}_config/test_reserved_key.{file_format}') + cfg_file = osp.join(self.data_path, f"config/{file_format}_config/test_reserved_key.{file_format}") # reserved keys cannot be set in config - with pytest.raises( - KeyError, match='filename is reserved for config ' - 'file'): + with pytest.raises(KeyError, match="filename is reserved for config file"): Config.fromfile(cfg_file) def test_fromfile(self): # test whether import `custom_imports` from cfg_file. - cfg_file = osp.join(self.data_path, 'config', - 'py_config/test_custom_import.py') - sys.path.append(osp.join(self.data_path, 'config/py_config')) + cfg_file = osp.join(self.data_path, "config", "py_config/test_custom_import.py") + sys.path.append(osp.join(self.data_path, "config/py_config")) cfg = Config.fromfile(cfg_file, import_custom_modules=True) assert isinstance(cfg, Config) # If import successfully, os.environ[''TEST_VALUE''] will be # set to 'test' - assert os.environ.pop('TEST_VALUE') == 'test' + assert os.environ.pop("TEST_VALUE") == "test" sys.path.pop() Config.fromfile(cfg_file, import_custom_modules=False) - assert 'TEST_VALUE' not in os.environ - sys.modules.pop('test_custom_import_module') - with pytest.raises( - ImportError, match='Failed to import custom modules from'): + assert "TEST_VALUE" not in os.environ + sys.modules.pop("test_custom_import_module") + with pytest.raises(ImportError, match="Failed to import custom modules from"): Config.fromfile(cfg_file, import_custom_modules=True) - @pytest.mark.parametrize('file_format', ['py', 'json', 'yaml']) + @pytest.mark.parametrize("file_format", ["py", "json", "yaml"]) def test_fromstring(self, file_format): - filename = f'{file_format}_config/simple_config.{file_format}' - cfg_file = osp.join(self.data_path, 'config', filename) + filename = f"{file_format}_config/simple_config.{file_format}" + cfg_file = osp.join(self.data_path, "config", filename) file_format = osp.splitext(filename)[-1] in_cfg = Config.fromfile(cfg_file) @@ -91,19 +82,18 @@ def test_fromstring(self, file_format): # test pretty_text only supports py file format # in_cfg.pretty_text is .py format, cannot be parsed to .json - if file_format != '.py': + if file_format != ".py": with pytest.raises(Exception): Config.fromstring(in_cfg.pretty_text, file_format) # error format with pytest.raises(IOError): - Config.fromstring(cfg_str, '.xml') + Config.fromstring(cfg_str, ".xml") def test_magic_methods(self): - cfg_dict = dict( - item1=[1, 2], item2=dict(a=0), item3=True, item4='test') - filename = 'py_config/simple_config.py' - cfg_file = osp.join(self.data_path, 'config', filename) + cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4="test") + filename = "py_config/simple_config.py" + cfg_file = osp.join(self.data_path, "config", filename) cfg = Config.fromfile(cfg_file) # len(cfg) assert len(cfg) == 4 @@ -118,26 +108,26 @@ def test_magic_methods(self): assert name in cfg_dict assert value in cfg_dict.values() # cfg.field - assert cfg.item1 == cfg_dict['item1'] - assert cfg.item2 == cfg_dict['item2'] + assert cfg.item1 == cfg_dict["item1"] + assert cfg.item2 == cfg_dict["item2"] assert cfg.item2.a == 0 - assert cfg.item3 == cfg_dict['item3'] - assert cfg.item4 == cfg_dict['item4'] + assert cfg.item3 == cfg_dict["item3"] + assert cfg.item4 == cfg_dict["item4"] # accessing keys that do not exist will cause error with pytest.raises(AttributeError): cfg.not_exist # field in cfg, cfg[field], cfg.get() - for name in ['item1', 'item2', 'item3', 'item4']: + for name in ["item1", "item2", "item3", "item4"]: assert name in cfg assert cfg[name] == cfg_dict[name] assert cfg.get(name) == cfg_dict[name] - assert cfg.get('not_exist') is None - assert cfg.get('not_exist', 0) == 0 + assert cfg.get("not_exist") is None + assert cfg.get("not_exist", 0) == 0 # accessing keys that do not exist will cause error with pytest.raises(KeyError): - cfg['not_exist'] - assert 'item1' in cfg - assert 'not_exist' not in cfg + cfg["not_exist"] + assert "item1" in cfg + assert "not_exist" not in cfg # cfg.update() cfg.update(dict(item1=0)) assert cfg.item1 == 0 @@ -146,40 +136,38 @@ def test_magic_methods(self): # test __setattr__ cfg = Config() cfg.item1 = [1, 2] - cfg.item2 = {'a': 0} - cfg['item5'] = {'a': {'b': None}} - assert cfg._cfg_dict['item1'] == [1, 2] + cfg.item2 = {"a": 0} + cfg["item5"] = {"a": {"b": None}} + assert cfg._cfg_dict["item1"] == [1, 2] assert cfg.item1 == [1, 2] - assert cfg._cfg_dict['item2'] == {'a': 0} + assert cfg._cfg_dict["item2"] == {"a": 0} assert cfg.item2.a == 0 - assert cfg._cfg_dict['item5'] == {'a': {'b': None}} + assert cfg._cfg_dict["item5"] == {"a": {"b": None}} assert cfg.item5.a.b is None def test_merge_from_dict(self): - cfg_file = osp.join(self.data_path, - 'config/py_config/simple_config.py') + cfg_file = osp.join(self.data_path, "config/py_config/simple_config.py") cfg = Config.fromfile(cfg_file) - input_options = {'item2.a': 1, 'item2.b': 0.1, 'item3': False} + input_options = {"item2.a": 1, "item2.b": 0.1, "item3": False} cfg.merge_from_dict(input_options) assert cfg.item2 == dict(a=1, b=0.1) assert cfg.item3 is False - cfg_file = osp.join(self.data_path, - 'config/py_config/test_merge_from_dict.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_from_dict.py") cfg = Config.fromfile(cfg_file) # Allow list keys - input_options = {'item.0.a': 1, 'item.1.b': 1} + input_options = {"item.0.a": 1, "item.1.b": 1} cfg.merge_from_dict(input_options, allow_list_keys=True) - assert cfg.item == [{'a': 1}, {'b': 1, 'c': 0}] + assert cfg.item == [{"a": 1}, {"b": 1, "c": 0}] # allow_list_keys is False - input_options = {'item.0.a': 1, 'item.1.b': 1} + input_options = {"item.0.a": 1, "item.1.b": 1} with pytest.raises(TypeError): cfg.merge_from_dict(input_options, allow_list_keys=False) # Overflowed index number - input_options = {'item.2.a': 1} + input_options = {"item.2.a": 1} with pytest.raises(KeyError): cfg.merge_from_dict(input_options, allow_list_keys=True) @@ -187,15 +175,14 @@ def test_diff(self): cfg1 = Config(dict(a=1, b=2)) cfg2 = Config(dict(a=1, b=3)) - diff_str = \ - '--- \n\n+++ \n\n@@ -1,3 +1,3 @@\n\n a = 1\n-b = 2\n+b = 3\n \n\n' + diff_str = "--- \n\n+++ \n\n@@ -1,3 +1,3 @@\n\n a = 1\n-b = 2\n+b = 3\n \n\n" assert Config.diff(cfg1, cfg2) == diff_str - cfg1_file = osp.join(self.data_path, 'config/py_config/test_diff_1.py') + cfg1_file = osp.join(self.data_path, "config/py_config/test_diff_1.py") cfg1 = Config.fromfile(cfg1_file) - cfg2_file = osp.join(self.data_path, 'config/py_config/test_diff_2.py') + cfg2_file = osp.join(self.data_path, "config/py_config/test_diff_2.py") cfg2 = Config.fromfile(cfg2_file) assert Config.diff(cfg1, cfg2) == diff_str @@ -204,9 +191,7 @@ def test_auto_argparser(self): # Temporarily make sys.argv only has one argument and keep backups tmp = sys.argv[1:] sys.argv = sys.argv[:2] - sys.argv[1] = osp.join( - self.data_path, - 'config/py_config/test_merge_from_multiple_bases.py') + sys.argv[1] = osp.join(self.data_path, "config/py_config/test_merge_from_multiple_bases.py") parser, cfg = Config.auto_argparser() args = parser.parse_args() assert args.config == sys.argv[1] @@ -218,8 +203,7 @@ def test_auto_argparser(self): sys.argv.extend(tmp) def test_dict_to_config_dict(self): - cfg_dict = dict( - a=1, b=dict(c=dict()), d=[dict(e=dict(f=(dict(g=1), [])))]) + cfg_dict = dict(a=1, b=dict(c=dict()), d=[dict(e=dict(f=(dict(g=1), [])))]) cfg_dict = Config._dict_to_config_dict(cfg_dict) assert isinstance(cfg_dict, ConfigDict) assert isinstance(cfg_dict.a, int) @@ -233,42 +217,41 @@ def test_dict_to_config_dict(self): assert isinstance(cfg_dict.d[0].e.f[1], list) def test_dump(self, tmp_path): - file_path = 'config/py_config/test_merge_from_multiple_bases.py' + file_path = "config/py_config/test_merge_from_multiple_bases.py" cfg_file = osp.join(self.data_path, file_path) cfg = Config.fromfile(cfg_file) - dump_py = tmp_path / 'simple_config.py' + dump_py = tmp_path / "simple_config.py" cfg.dump(dump_py) assert cfg.dump() == cfg.pretty_text assert open(dump_py).read() == cfg.pretty_text # test dump json/yaml. - file_path = 'config/json_config/simple.config.json' + file_path = "config/json_config/simple.config.json" cfg_file = osp.join(self.data_path, file_path) cfg = Config.fromfile(cfg_file) - dump_json = tmp_path / 'simple_config.json' + dump_json = tmp_path / "simple_config.json" cfg.dump(dump_json) with open(dump_json) as f: assert f.read() == cfg.dump() # test pickle - file_path = 'config/py_config/test_dump_pickle_support.py' + file_path = "config/py_config/test_dump_pickle_support.py" cfg_file = osp.join(self.data_path, file_path) cfg = Config.fromfile(cfg_file) - text_cfg_filename = tmp_path / '_text_config.py' + text_cfg_filename = tmp_path / "_text_config.py" cfg.dump(text_cfg_filename) text_cfg = Config.fromfile(text_cfg_filename) - assert text_cfg.str_item_7 == osp.join(osp.expanduser('~'), 'folder') - assert text_cfg.str_item_8 == 'string with \tescape\\ characters\n' + assert text_cfg.str_item_7 == osp.join(osp.expanduser("~"), "folder") + assert text_cfg.str_item_8 == "string with \tescape\\ characters\n" assert text_cfg._cfg_dict == cfg._cfg_dict - cfg_file = osp.join(self.data_path, - 'config/py_config/test_dump_pickle_support.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_dump_pickle_support.py") cfg = Config.fromfile(cfg_file) - pkl_cfg_filename = tmp_path / '_pickle.pkl' + pkl_cfg_filename = tmp_path / "_pickle.pkl" dump(cfg, pkl_cfg_filename) pkl_cfg = load(pkl_cfg_filename) assert pkl_cfg._cfg_dict == cfg._cfg_dict @@ -277,126 +260,100 @@ def test_dump(self, tmp_path): cfg = Config(cfg_dict) assert cfg.pretty_text == cfg.dump() # Test dump python format config. - dump_file = tmp_path / 'dump_from_dict.py' + dump_file = tmp_path / "dump_from_dict.py" cfg.dump(dump_file) with open(dump_file) as f: - assert f.read() == 'a = 1\nb = 2\n' + assert f.read() == "a = 1\nb = 2\n" # Test dump json format config. - dump_file = tmp_path / 'dump_from_dict.json' + dump_file = tmp_path / "dump_from_dict.json" cfg.dump(dump_file) with open(dump_file) as f: assert f.read() == '{"a": 1, "b": 2}' # Test dump yaml format config. - dump_file = tmp_path / 'dump_from_dict.yaml' + dump_file = tmp_path / "dump_from_dict.yaml" cfg.dump(dump_file) with open(dump_file) as f: - assert f.read() == 'a: 1\nb: 2\n' + assert f.read() == "a: 1\nb: 2\n" def test_pretty_text(self, tmp_path): - cfg_file = osp.join( - self.data_path, - 'config/py_config/test_merge_from_multiple_bases.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_from_multiple_bases.py") cfg = Config.fromfile(cfg_file) - text_cfg_filename = tmp_path / '_text_config.py' - with open(text_cfg_filename, 'w') as f: + text_cfg_filename = tmp_path / "_text_config.py" + with open(text_cfg_filename, "w") as f: f.write(cfg.pretty_text) text_cfg = Config.fromfile(text_cfg_filename) assert text_cfg._cfg_dict == cfg._cfg_dict def test_repr(self, tmp_path): - cfg_file = osp.join(self.data_path, - 'config/py_config/simple_config.py') + cfg_file = osp.join(self.data_path, "config/py_config/simple_config.py") cfg = Config.fromfile(cfg_file) - tmp_txt = tmp_path / 'tmp.txt' - with open(tmp_txt, 'w') as f: + tmp_txt = tmp_path / "tmp.txt" + with open(tmp_txt, "w") as f: print(cfg, file=f) with open(tmp_txt) as f: - assert f.read().strip() == f'Config (path: {cfg.filename}): ' \ - f'{cfg._cfg_dict.__repr__()}' + assert f.read().strip() == f"Config (path: {cfg.filename}): {cfg._cfg_dict.__repr__()}" def test_dict_action(self): - parser = argparse.ArgumentParser(description='Train a detector') - parser.add_argument( - '--options', nargs='+', action=DictAction, help='custom options') + parser = argparse.ArgumentParser(description="Train a detector") + parser.add_argument("--options", nargs="+", action=DictAction, help="custom options") # Nested brackets - args = parser.parse_args( - ['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]']) - out_dict = { - 'item2.a': ['a', 'b'], - 'item2.b': [('a', 'b'), [1, 2], False] - } + args = parser.parse_args(["--options", "item2.a=a,b", "item2.b=[(a,b), [1,2], false]"]) + out_dict = {"item2.a": ["a", "b"], "item2.b": [("a", "b"), [1, 2], False]} assert args.options == out_dict # Single Nested brackets - args = parser.parse_args(['--options', 'item2.a=[[1]]']) - out_dict = {'item2.a': [[1]]} + args = parser.parse_args(["--options", "item2.a=[[1]]"]) + out_dict = {"item2.a": [[1]]} assert args.options == out_dict # Imbalance bracket will cause error with pytest.raises(AssertionError): - parser.parse_args(['--options', 'item2.a=[(a,b), [1,2], false']) + parser.parse_args(["--options", "item2.a=[(a,b), [1,2], false"]) # Normal values - args = parser.parse_args([ - '--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false' - ]) - out_dict = { - 'item2.a': 1, - 'item2.b': 0.1, - 'item2.c': 'x', - 'item3': False - } + args = parser.parse_args(["--options", "item2.a=1", "item2.b=0.1", "item2.c=x", "item3=false"]) + out_dict = {"item2.a": 1, "item2.b": 0.1, "item2.c": "x", "item3": False} assert args.options == out_dict - cfg_file = osp.join(self.data_path, - 'config/py_config/simple_config.py') + cfg_file = osp.join(self.data_path, "config/py_config/simple_config.py") cfg = Config.fromfile(cfg_file) cfg.merge_from_dict(args.options) - assert cfg.item2 == dict(a=1, b=0.1, c='x') + assert cfg.item2 == dict(a=1, b=0.1, c="x") assert cfg.item3 is False # test multiple options - args = parser.parse_args([ - '--options', 'item1.a=1', 'item2.a=2', '--options', 'item2.a=1', - 'item3=false' - ]) - out_dict = {'item1.a': 1, 'item2.a': 1, 'item3': False} + args = parser.parse_args(["--options", "item1.a=1", "item2.a=2", "--options", "item2.a=1", "item3=false"]) + out_dict = {"item1.a": 1, "item2.a": 1, "item3": False} assert args.options == out_dict def test_validate_py_syntax(self, tmp_path): - tmp_cfg = tmp_path / 'tmp_config.py' - with open(tmp_cfg, 'w') as f: - f.write('dict(a=1,b=2.c=3)') + tmp_cfg = tmp_path / "tmp_config.py" + with open(tmp_cfg, "w") as f: + f.write("dict(a=1,b=2.c=3)") # Incorrect point in dict will cause error with pytest.raises(SyntaxError): Config._validate_py_syntax(tmp_cfg) - with open(tmp_cfg, 'w') as f: - f.write('[dict(a=1, b=2, c=(1, 2)]') + with open(tmp_cfg, "w") as f: + f.write("[dict(a=1, b=2, c=(1, 2)]") # Imbalance bracket will cause error with pytest.raises(SyntaxError): Config._validate_py_syntax(tmp_cfg) - with open(tmp_cfg, 'w') as f: - f.write('dict(a=1,b=2\nc=3)') + with open(tmp_cfg, "w") as f: + f.write("dict(a=1,b=2\nc=3)") # Incorrect feed line in dict will cause error with pytest.raises(SyntaxError): Config._validate_py_syntax(tmp_cfg) def test_substitute_predefined_vars(self, tmp_path): - cfg_text = 'a={{fileDirname}}\n' \ - 'b={{fileBasename}}\n' \ - 'c={{fileBasenameNoExtension}}\n' \ - 'd={{fileExtname}}\n' + cfg_text = "a={{fileDirname}}\nb={{fileBasename}}\nc={{fileBasenameNoExtension}}\nd={{fileExtname}}\n" - cfg = tmp_path / 'tmp_cfg1.py' - substituted_cfg = tmp_path / 'tmp_cfg2.py' + cfg = tmp_path / "tmp_cfg1.py" + substituted_cfg = tmp_path / "tmp_cfg2.py" file_dirname = osp.dirname(cfg) file_basename = osp.basename(cfg) file_basename_no_extension = osp.splitext(file_basename)[0] file_extname = osp.splitext(cfg)[1] - expected_text = f'a={file_dirname}\n' \ - f'b={file_basename}\n' \ - f'c={file_basename_no_extension}\n' \ - f'd={file_extname}\n' - expected_text = expected_text.replace('\\', '/') - with open(cfg, 'w') as f: + expected_text = f"a={file_dirname}\nb={file_basename}\nc={file_basename_no_extension}\nd={file_extname}\n" + expected_text = expected_text.replace("\\", "/") + with open(cfg, "w") as f: f.write(cfg_text) Config._substitute_predefined_vars(cfg, substituted_cfg) @@ -404,98 +361,87 @@ def test_substitute_predefined_vars(self, tmp_path): assert f.read() == expected_text def test_substitute_environment_vars(self, tmp_path): - cfg = tmp_path / 'tmp_cfg1.py' - substituted_cfg = tmp_path / 'tmp_cfg2.py' + cfg = tmp_path / "tmp_cfg1.py" + substituted_cfg = tmp_path / "tmp_cfg2.py" - cfg_text = 'a={{$A:}}\n' - with open(cfg, 'w') as f: + cfg_text = "a={{$A:}}\n" + with open(cfg, "w") as f: f.write(cfg_text) with pytest.raises(KeyError): Config._substitute_env_variables(cfg, substituted_cfg) - os.environ['A'] = 'text_A' + os.environ["A"] = "text_A" Config._substitute_env_variables(cfg, substituted_cfg) with open(substituted_cfg) as f: - assert f.read() == 'a=text_A\n' - os.environ.pop('A') + assert f.read() == "a=text_A\n" + os.environ.pop("A") - cfg_text = 'b={{$B:80}}\n' - with open(cfg, 'w') as f: + cfg_text = "b={{$B:80}}\n" + with open(cfg, "w") as f: f.write(cfg_text) Config._substitute_env_variables(cfg, substituted_cfg) with open(substituted_cfg) as f: - assert f.read() == 'b=80\n' + assert f.read() == "b=80\n" - os.environ['B'] = '100' + os.environ["B"] = "100" Config._substitute_env_variables(cfg, substituted_cfg) with open(substituted_cfg) as f: - assert f.read() == 'b=100\n' - os.environ.pop('B') + assert f.read() == "b=100\n" + os.environ.pop("B") cfg_text = 'c={{"$C:80"}}\n' - with open(cfg, 'w') as f: + with open(cfg, "w") as f: f.write(cfg_text) Config._substitute_env_variables(cfg, substituted_cfg) with open(substituted_cfg) as f: - assert f.read() == 'c=80\n' + assert f.read() == "c=80\n" def test_pre_substitute_base_vars(self, tmp_path): - cfg_path = osp.join(self.data_path, 'config', - 'py_config/test_pre_substitute_base_vars.py') - tmp_cfg = tmp_path / 'tmp_cfg.py' + cfg_path = osp.join(self.data_path, "config", "py_config/test_pre_substitute_base_vars.py") + tmp_cfg = tmp_path / "tmp_cfg.py" base_var_dict = Config._pre_substitute_base_vars(cfg_path, tmp_cfg) - assert 'item6' in base_var_dict.values() - assert 'item10' in base_var_dict.values() - assert 'item11' in base_var_dict.values() + assert "item6" in base_var_dict.values() + assert "item10" in base_var_dict.values() + assert "item11" in base_var_dict.values() sys.path.append(str(tmp_path)) - cfg_module_dict = import_module(tmp_cfg.name.strip('.py')).__dict__ - assert cfg_module_dict['item22'].startswith('_item11') - assert cfg_module_dict['item23'].startswith('_item10') - assert cfg_module_dict['item25']['c'][1].startswith('_item6') + cfg_module_dict = import_module(tmp_cfg.name.strip(".py")).__dict__ + assert cfg_module_dict["item22"].startswith("_item11") + assert cfg_module_dict["item23"].startswith("_item10") + assert cfg_module_dict["item25"]["c"][1].startswith("_item6") sys.path.pop() - cfg_path = osp.join(self.data_path, 'config', - 'json_config/test_base.json') - tmp_cfg = tmp_path / 'tmp_cfg.json' + cfg_path = osp.join(self.data_path, "config", "json_config/test_base.json") + tmp_cfg = tmp_path / "tmp_cfg.json" Config._pre_substitute_base_vars(cfg_path, tmp_cfg) cfg_module_dict = load(tmp_cfg) - assert cfg_module_dict['item9'].startswith('_item2') - assert cfg_module_dict['item10'].startswith('_item7') + assert cfg_module_dict["item9"].startswith("_item2") + assert cfg_module_dict["item10"].startswith("_item7") - cfg_path = osp.join(self.data_path, 'config', - 'yaml_config/test_base.yaml') - tmp_cfg = tmp_path / 'tmp_cfg.yaml' + cfg_path = osp.join(self.data_path, "config", "yaml_config/test_base.yaml") + tmp_cfg = tmp_path / "tmp_cfg.yaml" Config._pre_substitute_base_vars(cfg_path, tmp_cfg) cfg_module_dict = load(tmp_cfg) - assert cfg_module_dict['item9'].startswith('_item2') - assert cfg_module_dict['item10'].startswith('_item7') + assert cfg_module_dict["item9"].startswith("_item2") + assert cfg_module_dict["item10"].startswith("_item7") def test_substitute_base_vars(self): - cfg = dict( - item4='_item1.12345', - item5=dict(item3='1', item2='_item2_.fswf'), - item0=('_item0_.12ed21wq', 1)) + cfg = dict(item4="_item1.12345", item5=dict(item3="1", item2="_item2_.fswf"), item0=("_item0_.12ed21wq", 1)) cfg_base = dict(item1=0, item2=[1, 2, 3], item0=(1, 2, 3)) - base_var_dict = { - '_item1.12345': 'item1', - '_item2_.fswf': 'item2', - '_item0_.12ed21wq': 'item0' - } + base_var_dict = {"_item1.12345": "item1", "_item2_.fswf": "item2", "_item0_.12ed21wq": "item0"} cfg = Config._substitute_base_vars(cfg, base_var_dict, cfg_base) - assert cfg['item4'] == cfg_base['item1'] - assert cfg['item5']['item2'] == cfg_base['item2'] + assert cfg["item4"] == cfg_base["item1"] + assert cfg["item5"]["item2"] == cfg_base["item2"] def test_file2dict(self, tmp_path): - # test error format config - tmp_cfg = tmp_path / 'tmp_cfg.xml' - tmp_cfg.write_text('exist') + tmp_cfg = tmp_path / "tmp_cfg.xml" + tmp_cfg.write_text("exist") # invalid config format with pytest.raises(IOError): Config.fromfile(tmp_cfg) # invalid config file path with pytest.raises(FileNotFoundError): - Config.fromfile('no_such_file.py') + Config.fromfile("no_such_file.py") self._simple_load() self._predefined_vars() @@ -510,237 +456,194 @@ def test_file2dict(self, tmp_path): self._deprecation() def test_get_cfg_path_local(self): - filename = 'py_config/simple_config.py' - filename = osp.join(self.data_path, 'config', filename) - cfg_name = './base.py' + filename = "py_config/simple_config.py" + filename = osp.join(self.data_path, "config", filename) + cfg_name = "./base.py" cfg_path, scope = Config._get_cfg_path(cfg_name, filename) assert scope is None osp.isfile(cfg_path) @pytest.mark.skipif( - not is_installed('mmdet') or not is_installed('mmcls'), - reason='mmdet and mmcls should be installed') + not is_installed("mmdet") or not is_installed("mmpretrain"), reason="mmdet and mmcls should be installed" + ) def test_get_cfg_path_external(self): - filename = 'py_config/simple_config.py' - filename = osp.join(self.data_path, 'config', filename) + filename = "py_config/simple_config.py" + filename = osp.join(self.data_path, "config", filename) - cfg_name = 'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py' + cfg_name = "mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py" cfg_path, scope = Config._get_cfg_path(cfg_name, filename) - assert scope == 'mmdet' + assert scope == "mmdet" osp.isfile(cfg_path) - cfg_name = 'mmcls::cspnet/cspresnet50_8xb32_in1k.py' + cfg_name = "mmpretrain::cspnet/cspresnet50_8xb32_in1k.py" cfg_path, scope = Config._get_cfg_path(cfg_name, filename) - assert scope == 'mmcls' + assert scope == "mmpretrain" osp.isfile(cfg_path) def _simple_load(self): # test load simple config - for file_format in ['py', 'json', 'yaml']: - for name in ['simple.config', 'simple_config']: - filename = f'{file_format}_config/{name}.{file_format}' + for file_format in ["py", "json", "yaml"]: + for name in ["simple.config", "simple_config"]: + filename = f"{file_format}_config/{name}.{file_format}" - cfg_file = osp.join(self.data_path, 'config', filename) + cfg_file = osp.join(self.data_path, "config", filename) cfg_dict, cfg_text, env_variables = Config._file2dict(cfg_file) assert isinstance(cfg_text, str) assert isinstance(cfg_dict, dict) assert isinstance(env_variables, dict) def _get_file_path(self, file_path): - if platform.system() == 'Windows': - return file_path.replace('\\', '/') + if platform.system() == "Windows": + return file_path.replace("\\", "/") else: return file_path def _predefined_vars(self): # test parse predefined_var in config - cfg_file = osp.join(self.data_path, - 'config/py_config/test_predefined_var.py') - path = osp.join(self.data_path, 'config/py_config') + cfg_file = osp.join(self.data_path, "config/py_config/test_predefined_var.py") + path = osp.join(self.data_path, "config/py_config") path = Path(path).as_posix() - cfg_dict_dst = dict( - item1='test_predefined_var.py', - item2=path, - item3='abc_test_predefined_var') + cfg_dict_dst = dict(item1="test_predefined_var.py", item2=path, item3="abc_test_predefined_var") - assert Config._file2dict(cfg_file)[0]['item1'] == cfg_dict_dst['item1'] - assert Config._file2dict(cfg_file)[0]['item2'] == cfg_dict_dst['item2'] - assert Config._file2dict(cfg_file)[0]['item3'] == cfg_dict_dst['item3'] + assert Config._file2dict(cfg_file)[0]["item1"] == cfg_dict_dst["item1"] + assert Config._file2dict(cfg_file)[0]["item2"] == cfg_dict_dst["item2"] + assert Config._file2dict(cfg_file)[0]["item3"] == cfg_dict_dst["item3"] # test `use_predefined_variable=False` cfg_dict_ori = dict( - item1='{{fileBasename}}', - item2='{{ fileDirname}}', - item3='abc_{{ fileBasenameNoExtension }}') + item1="{{fileBasename}}", item2="{{ fileDirname}}", item3="abc_{{ fileBasenameNoExtension }}" + ) - assert Config._file2dict(cfg_file, - False)[0]['item1'] == cfg_dict_ori['item1'] - assert Config._file2dict(cfg_file, - False)[0]['item2'] == cfg_dict_ori['item2'] - assert Config._file2dict(cfg_file, - False)[0]['item3'] == cfg_dict_ori['item3'] + assert Config._file2dict(cfg_file, False)[0]["item1"] == cfg_dict_ori["item1"] + assert Config._file2dict(cfg_file, False)[0]["item2"] == cfg_dict_ori["item2"] + assert Config._file2dict(cfg_file, False)[0]["item3"] == cfg_dict_ori["item3"] # test test_predefined_var.yaml - cfg_file = osp.join(self.data_path, - 'config/yaml_config/test_predefined_var.yaml') + cfg_file = osp.join(self.data_path, "config/yaml_config/test_predefined_var.yaml") # test `use_predefined_variable=False` - assert Config._file2dict(cfg_file, - False)[0]['item1'] == '{{ fileDirname }}' - assert Config._file2dict(cfg_file)[0]['item1'] == self._get_file_path( - osp.dirname(cfg_file)) + assert Config._file2dict(cfg_file, False)[0]["item1"] == "{{ fileDirname }}" + assert Config._file2dict(cfg_file)[0]["item1"] == self._get_file_path(osp.dirname(cfg_file)) # test test_predefined_var.json - cfg_file = osp.join(self.data_path, - 'config/json_config/test_predefined_var.json') + cfg_file = osp.join(self.data_path, "config/json_config/test_predefined_var.json") - assert Config.fromfile(cfg_file, False)['item1'] == '{{ fileDirname }}' - assert Config.fromfile(cfg_file)['item1'] == self._get_file_path( - osp.dirname(cfg_file)) + assert Config.fromfile(cfg_file, False)["item1"] == "{{ fileDirname }}" + assert Config.fromfile(cfg_file)["item1"] == self._get_file_path(osp.dirname(cfg_file)) def _environment_vars(self): # test parse predefined_var in config - cfg_file = osp.join(self.data_path, - 'config/py_config/test_environment_var.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_environment_var.py") with pytest.raises(KeyError): Config._file2dict(cfg_file) - os.environ['ITEM1'] = '60' - cfg_dict_dst = dict(item1='60', item2='default_value', item3=80) - assert Config._file2dict(cfg_file)[0]['item1'] == cfg_dict_dst['item1'] - assert Config._file2dict(cfg_file)[0]['item2'] == cfg_dict_dst['item2'] - assert Config._file2dict(cfg_file)[0]['item3'] == cfg_dict_dst['item3'] + os.environ["ITEM1"] = "60" + cfg_dict_dst = dict(item1="60", item2="default_value", item3=80) + assert Config._file2dict(cfg_file)[0]["item1"] == cfg_dict_dst["item1"] + assert Config._file2dict(cfg_file)[0]["item2"] == cfg_dict_dst["item2"] + assert Config._file2dict(cfg_file)[0]["item3"] == cfg_dict_dst["item3"] - os.environ['ITEM2'] = 'new_value' - os.environ['ITEM3'] = '50' - cfg_dict_dst = dict(item1='60', item2='new_value', item3=50) - assert Config._file2dict(cfg_file)[0]['item1'] == cfg_dict_dst['item1'] - assert Config._file2dict(cfg_file)[0]['item2'] == cfg_dict_dst['item2'] - assert Config._file2dict(cfg_file)[0]['item3'] == cfg_dict_dst['item3'] + os.environ["ITEM2"] = "new_value" + os.environ["ITEM3"] = "50" + cfg_dict_dst = dict(item1="60", item2="new_value", item3=50) + assert Config._file2dict(cfg_file)[0]["item1"] == cfg_dict_dst["item1"] + assert Config._file2dict(cfg_file)[0]["item2"] == cfg_dict_dst["item2"] + assert Config._file2dict(cfg_file)[0]["item3"] == cfg_dict_dst["item3"] - os.environ.pop('ITEM1') - os.environ.pop('ITEM2') - os.environ.pop('ITEM3') + os.environ.pop("ITEM1") + os.environ.pop("ITEM2") + os.environ.pop("ITEM3") def _merge_from_base(self): - cfg_file = osp.join(self.data_path, - 'config/py_config/test_merge_from_base_single.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_from_base_single.py") cfg_dict = Config._file2dict(cfg_file)[0] - assert cfg_dict['item1'] == [2, 3] - assert cfg_dict['item2']['a'] == 1 - assert cfg_dict['item3'] is False - assert cfg_dict['item4'] == 'test_base' + assert cfg_dict["item1"] == [2, 3] + assert cfg_dict["item2"]["a"] == 1 + assert cfg_dict["item3"] is False + assert cfg_dict["item4"] == "test_base" # item3 is a dict in the child config but a boolean in base config with pytest.raises(TypeError): - Config.fromfile( - osp.join(self.data_path, - 'config/py_config/test_merge_from_base_error.py')) + Config.fromfile(osp.join(self.data_path, "config/py_config/test_merge_from_base_error.py")) def _merge_from_multiple_bases(self): - cfg_file = osp.join( - self.data_path, - 'config/py_config/test_merge_from_multiple_bases.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_from_multiple_bases.py") cfg_dict = Config._file2dict(cfg_file)[0] # cfg.fcfg_dictd - assert cfg_dict['item1'] == [1, 2] - assert cfg_dict['item2']['a'] == 0 - assert cfg_dict['item3'] is False - assert cfg_dict['item4'] == 'test' - assert cfg_dict['item5'] == dict(a=0, b=1) - assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict["item1"] == [1, 2] + assert cfg_dict["item2"]["a"] == 0 + assert cfg_dict["item3"] is False + assert cfg_dict["item4"] == "test" + assert cfg_dict["item5"] == dict(a=0, b=1) + assert cfg_dict["item6"] == [dict(a=0), dict(b=1)] + assert cfg_dict["item7"] == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) # Redefine key with pytest.raises(KeyError): - Config.fromfile( - osp.join(self.data_path, - 'config/py_config/test_merge_from_multiple_error.py')) + Config.fromfile(osp.join(self.data_path, "config/py_config/test_merge_from_multiple_error.py")) def _base_variables(self): - for file in [ - 'py_config/test_base_variables.py', - 'json_config/test_base.json', 'yaml_config/test_base.yaml' - ]: - cfg_file = osp.join(self.data_path, 'config', file) + for file in ["py_config/test_base_variables.py", "json_config/test_base.json", "yaml_config/test_base.yaml"]: + cfg_file = osp.join(self.data_path, "config", file) cfg_dict = Config._file2dict(cfg_file)[0] - assert cfg_dict['item1'] == [1, 2] - assert cfg_dict['item2']['a'] == 0 - assert cfg_dict['item3'] is False - assert cfg_dict['item4'] == 'test' - assert cfg_dict['item5'] == dict(a=0, b=1) - assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - assert cfg_dict['item8'] == file.split('/')[-1] - assert cfg_dict['item9'] == dict(a=0) - assert cfg_dict['item10'] == [3.1, 4.2, 5.3] + assert cfg_dict["item1"] == [1, 2] + assert cfg_dict["item2"]["a"] == 0 + assert cfg_dict["item3"] is False + assert cfg_dict["item4"] == "test" + assert cfg_dict["item5"] == dict(a=0, b=1) + assert cfg_dict["item6"] == [dict(a=0), dict(b=1)] + assert cfg_dict["item7"] == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict["item8"] == file.split("/")[-1] + assert cfg_dict["item9"] == dict(a=0) + assert cfg_dict["item10"] == [3.1, 4.2, 5.3] # test nested base for file in [ - 'py_config/test_base_variables_nested.py', - 'json_config/test_base_variables_nested.json', - 'yaml_config/test_base_variables_nested.yaml' + "py_config/test_base_variables_nested.py", + "json_config/test_base_variables_nested.json", + "yaml_config/test_base_variables_nested.yaml", ]: - cfg_file = osp.join(self.data_path, 'config', file) + cfg_file = osp.join(self.data_path, "config", file) cfg_dict = Config._file2dict(cfg_file)[0] - assert cfg_dict['base'] == '_base_.item8' - assert cfg_dict['item1'] == [1, 2] - assert cfg_dict['item2']['a'] == 0 - assert cfg_dict['item3'] is False - assert cfg_dict['item4'] == 'test' - assert cfg_dict['item5'] == dict(a=0, b=1) - assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - assert cfg_dict['item8'] == 'test_base_variables.py' - assert cfg_dict['item9'] == dict(a=0) - assert cfg_dict['item10'] == [3.1, 4.2, 5.3] - assert cfg_dict['item11'] == 'test_base_variables.py' - assert cfg_dict['item12'] == dict(a=0) - assert cfg_dict['item13'] == [3.1, 4.2, 5.3] - assert cfg_dict['item14'] == [1, 2] - assert cfg_dict['item15'] == dict( - a=dict(b=dict(a=0)), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) + assert cfg_dict["base"] == "_base_.item8" + assert cfg_dict["item1"] == [1, 2] + assert cfg_dict["item2"]["a"] == 0 + assert cfg_dict["item3"] is False + assert cfg_dict["item4"] == "test" + assert cfg_dict["item5"] == dict(a=0, b=1) + assert cfg_dict["item6"] == [dict(a=0), dict(b=1)] + assert cfg_dict["item7"] == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict["item8"] == "test_base_variables.py" + assert cfg_dict["item9"] == dict(a=0) + assert cfg_dict["item10"] == [3.1, 4.2, 5.3] + assert cfg_dict["item11"] == "test_base_variables.py" + assert cfg_dict["item12"] == dict(a=0) + assert cfg_dict["item13"] == [3.1, 4.2, 5.3] + assert cfg_dict["item14"] == [1, 2] + assert cfg_dict["item15"] == dict( + a=dict(b=dict(a=0)), b=[False], c=["test"], d=[[{"e": 0}], [{"a": 0}, {"b": 1}]], e=[1, 2] + ) # test reference assignment for py - cfg_file = osp.join( - self.data_path, - 'config/py_config/test_pre_substitute_base_vars.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_pre_substitute_base_vars.py") cfg_dict = Config._file2dict(cfg_file)[0] - assert cfg_dict['item21'] == 'test_base_variables.py' - assert cfg_dict['item22'] == 'test_base_variables.py' - assert cfg_dict['item23'] == [3.1, 4.2, 5.3] - assert cfg_dict['item24'] == [3.1, 4.2, 5.3] - assert cfg_dict['item25'] == dict( + assert cfg_dict["item21"] == "test_base_variables.py" + assert cfg_dict["item22"] == "test_base_variables.py" + assert cfg_dict["item23"] == [3.1, 4.2, 5.3] + assert cfg_dict["item24"] == [3.1, 4.2, 5.3] + assert cfg_dict["item25"] == dict( a=dict(b=[3.1, 4.2, 5.3]), b=[[3.1, 4.2, 5.3]], - c=[[{ - 'e': 'test_base_variables.py' - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e='test_base_variables.py') - - cfg_file = osp.join(self.data_path, 'config/py_config/test_py_base.py') + c=[[{"e": "test_base_variables.py"}], [{"a": 0}, {"b": 1}]], + e="test_base_variables.py", + ) + + cfg_file = osp.join(self.data_path, "config/py_config/test_py_base.py") cfg = Config.fromfile(cfg_file) assert isinstance(cfg, Config) assert cfg.filename == cfg_file @@ -749,18 +652,17 @@ def _base_variables(self): assert cfg.item2.a == 0 assert cfg.item2.b == [5, 6] assert cfg.item3 is False - assert cfg.item4 == 'test' + assert cfg.item4 == "test" assert cfg.item5 == dict(a=0, b=1) assert cfg.item6 == [dict(c=0), dict(b=1)] assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - assert cfg.item8 == 'test_py_base.py' + assert cfg.item8 == "test_py_base.py" assert cfg.item9 == 3.1 assert cfg.item10 == 4.2 assert cfg.item11 == 5.3 # test nested base - cfg_file = osp.join(self.data_path, - 'config/py_config/test_py_nested_path.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_py_nested_path.py") cfg = Config.fromfile(cfg_file) assert isinstance(cfg, Config) assert cfg.filename == cfg_file @@ -769,41 +671,30 @@ def _base_variables(self): assert cfg.item2.a == 0 assert cfg.item2.b == [5, 6] assert cfg.item3 is False - assert cfg.item4 == 'test' + assert cfg.item4 == "test" assert cfg.item5 == dict(a=0, b=1) assert cfg.item6 == [dict(c=0), dict(b=1)] assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - assert cfg.item8 == 'test_py_base.py' + assert cfg.item8 == "test_py_base.py" assert cfg.item9 == 3.1 assert cfg.item10 == 4.2 assert cfg.item11 == 5.3 - assert cfg.item12 == 'test_py_base.py' + assert cfg.item12 == "test_py_base.py" assert cfg.item13 == 3.1 assert cfg.item14 == [1, 2] assert cfg.item15 == dict( - a=dict(b=dict(a=0, b=[5, 6])), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'c': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) + a=dict(b=dict(a=0, b=[5, 6])), b=[False], c=["test"], d=[[{"e": 0}], [{"c": 0}, {"b": 1}]], e=[1, 2] + ) # Test use global variable in config function - cfg_file = osp.join(self.data_path, - 'config/py_config/test_py_function_global_var.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_py_function_global_var.py") cfg = Config._file2dict(cfg_file)[0] - assert cfg['item1'] == 1 - assert cfg['item2'] == 2 + assert cfg["item1"] == 1 + assert cfg["item2"] == 2 # Test support modifying the value of dict without defining base # config. - cfg_file = osp.join(self.data_path, - 'config/py_config/test_py_modify_key.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_py_modify_key.py") cfg = Config._file2dict(cfg_file)[0] assert cfg == dict(item1=dict(a=1)) @@ -811,72 +702,62 @@ def _base_variables(self): # /tmp/test.axsgr12/. This patch is to check the issue # https://github.com/open-mmlab/mmengine/issues/788 has been solved. class PatchedTempDirectory(tempfile.TemporaryDirectory): - - def __init__(self, *args, prefix='test.', **kwargs): + def __init__(self, *args, prefix="test.", **kwargs): super().__init__(*args, prefix=prefix, **kwargs) - with patch('mmengine.config.config.tempfile.TemporaryDirectory', - PatchedTempDirectory): - cfg_file = osp.join(self.data_path, - 'config/py_config/test_py_modify_key.py') + with patch("mmengine.config.config.tempfile.TemporaryDirectory", PatchedTempDirectory): + cfg_file = osp.join(self.data_path, "config/py_config/test_py_modify_key.py") cfg = Config._file2dict(cfg_file)[0] assert cfg == dict(item1=dict(a=1)) def _merge_recursive_bases(self): - cfg_file = osp.join(self.data_path, - 'config/py_config/test_merge_recursive_bases.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_recursive_bases.py") cfg_dict = Config._file2dict(cfg_file)[0] - assert cfg_dict['item1'] == [2, 3] - assert cfg_dict['item2']['a'] == 1 - assert cfg_dict['item3'] is False - assert cfg_dict['item4'] == 'test_recursive_bases' + assert cfg_dict["item1"] == [2, 3] + assert cfg_dict["item2"]["a"] == 1 + assert cfg_dict["item3"] is False + assert cfg_dict["item4"] == "test_recursive_bases" def _merge_delete(self): - cfg_file = osp.join(self.data_path, - 'config/py_config/test_merge_delete.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_delete.py") cfg_dict = Config._file2dict(cfg_file)[0] # cfg.field - assert cfg_dict['item1'] == dict(a=0) - assert cfg_dict['item2'] == dict(a=0, b=0) - assert cfg_dict['item3'] is True - assert cfg_dict['item4'] == 'test' - assert '_delete_' not in cfg_dict['item1'] + assert cfg_dict["item1"] == dict(a=0) + assert cfg_dict["item2"] == dict(a=0, b=0) + assert cfg_dict["item3"] is True + assert cfg_dict["item4"] == "test" + assert "_delete_" not in cfg_dict["item1"] - assert type(cfg_dict['item1']) is ConfigDict - assert type(cfg_dict['item2']) is ConfigDict + assert type(cfg_dict["item1"]) is ConfigDict + assert type(cfg_dict["item2"]) is ConfigDict def _merge_intermediate_variable(self): - - cfg_file = osp.join( - self.data_path, - 'config/py_config/test_merge_intermediate_variable_child.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_merge_intermediate_variable_child.py") cfg_dict = Config._file2dict(cfg_file)[0] # cfg.field - assert cfg_dict['item1'] == [1, 2] - assert cfg_dict['item2'] == dict(a=0) - assert cfg_dict['item3'] is True - assert cfg_dict['item4'] == 'test' - assert cfg_dict['item_cfg'] == dict(b=2) - assert cfg_dict['item5'] == dict(cfg=dict(b=1)) - assert cfg_dict['item6'] == dict(cfg=dict(b=2)) + assert cfg_dict["item1"] == [1, 2] + assert cfg_dict["item2"] == dict(a=0) + assert cfg_dict["item3"] is True + assert cfg_dict["item4"] == "test" + assert cfg_dict["item_cfg"] == dict(b=2) + assert cfg_dict["item5"] == dict(cfg=dict(b=1)) + assert cfg_dict["item6"] == dict(cfg=dict(b=2)) def _code_in_config(self): - cfg_file = osp.join(self.data_path, - 'config/py_config/test_code_in_config.py') + cfg_file = osp.join(self.data_path, "config/py_config/test_code_in_config.py") cfg = Config.fromfile(cfg_file) # cfg.field assert cfg.cfg.item1 == [1, 2] assert cfg.cfg.item2 == dict(a=0) assert cfg.cfg.item3 is True - assert cfg.cfg.item4 == 'test' + assert cfg.cfg.item4 == "test" assert cfg.item5 == 1 def _deprecation(self): deprecated_cfg_files = [ - osp.join(self.data_path, 'config', 'py_config/test_deprecated.py'), - osp.join(self.data_path, 'config', - 'py_config/test_deprecated_base.py') + osp.join(self.data_path, "config", "py_config/test_deprecated.py"), + osp.join(self.data_path, "config", "py_config/test_deprecated_base.py"), ] for cfg_file in deprecated_cfg_files: @@ -885,8 +766,7 @@ def _deprecation(self): assert cfg.item1 == [1, 2] def test_deepcopy(self): - cfg_file = osp.join(self.data_path, 'config', - 'py_config/test_dump_pickle_support.py') + cfg_file = osp.join(self.data_path, "config", "py_config/test_dump_pickle_support.py") cfg = Config.fromfile(cfg_file) new_cfg = copy.deepcopy(cfg) @@ -897,8 +777,7 @@ def test_deepcopy(self): assert new_cfg._text == cfg._text def test_copy(self): - cfg_file = osp.join(self.data_path, 'config', - 'py_config/test_dump_pickle_support.py') + cfg_file = osp.join(self.data_path, "config", "py_config/test_dump_pickle_support.py") cfg = Config.fromfile(cfg_file) new_cfg = copy.copy(cfg) @@ -913,40 +792,34 @@ def test_copy(self): assert new_cfg._filename == cfg._filename assert new_cfg._text == cfg._text - @pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet should be installed') + @pytest.mark.skipif(not is_installed("mmdet"), reason="mmdet should be installed") def test_get_external_cfg(self): - ext_cfg_path = osp.join(self.data_path, - 'config/py_config/test_get_external_cfg.py') + ext_cfg_path = osp.join(self.data_path, "config/py_config/test_get_external_cfg.py") ext_cfg = Config.fromfile(ext_cfg_path) assert ext_cfg._cfg_dict.model.neck == dict( - type='FPN', + type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, ) - assert '_scope_' in ext_cfg._cfg_dict.model + assert ext_cfg._cfg_dict.model._scope_ == "mmdet" - @pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet should be installed') + @pytest.mark.skipif(not is_installed("mmdet"), reason="mmdet should be installed") def test_build_external_package(self): # Test load base config. - ext_cfg_path = osp.join(self.data_path, - 'config/py_config/test_get_external_cfg.py') + ext_cfg_path = osp.join(self.data_path, "config/py_config/test_get_external_cfg.py") ext_cfg = Config.fromfile(ext_cfg_path) - LOCAL_MODELS = Registry('local_model', parent=MODELS, scope='test') + LOCAL_MODELS = Registry("local_model", parent=MODELS, scope="test") LOCAL_MODELS.build(ext_cfg.model) # Test load non-base config - ext_cfg_path = osp.join(self.data_path, - 'config/py_config/test_get_external_cfg2.py') + ext_cfg_path = osp.join(self.data_path, "config/py_config/test_get_external_cfg2.py") ext_cfg = Config.fromfile(ext_cfg_path) LOCAL_MODELS.build(ext_cfg.model) # Test override base variable. - ext_cfg_path = osp.join(self.data_path, - 'config/py_config/test_get_external_cfg3.py') + ext_cfg_path = osp.join(self.data_path, "config/py_config/test_get_external_cfg3.py") ext_cfg = Config.fromfile(ext_cfg_path) @LOCAL_MODELS.register_module() @@ -957,46 +830,42 @@ class ToyLoss: class ToyModel: pass - DefaultScope.get_instance('test1', scope_name='test') - assert ext_cfg.model._scope_ == 'mmdet' + DefaultScope.get_instance("test1", scope_name="test") + assert ext_cfg.model._scope_ == "mmdet" model = LOCAL_MODELS.build(ext_cfg.model) # Local base config should not have scope. - assert '_scope_' not in ext_cfg.toy_model + assert "_scope_" not in ext_cfg.toy_model toy_model = LOCAL_MODELS.build(ext_cfg.toy_model) assert isinstance(toy_model, ToyModel) - assert model.backbone.style == 'pytorch' + assert model.backbone.style == "pytorch" assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss) - DefaultScope._instance_dict.pop('test1') + DefaultScope._instance_dict.pop("test1") def test_pickle(self): # Text style config - cfg_path = osp.join(self.data_path, 'config/py_config/test_py_base.py') + cfg_path = osp.join(self.data_path, "config/py_config/test_py_base.py") cfg = Config.fromfile(cfg_path) pickled = pickle.loads(pickle.dumps(cfg)) assert pickled.__dict__ == cfg.__dict__ - cfg_path = osp.join(self.data_path, - 'config/lazy_module_config/toy_model.py') + cfg_path = osp.join(self.data_path, "config/lazy_module_config/toy_model.py") cfg = Config.fromfile(cfg_path) pickled = pickle.loads(pickle.dumps(cfg)) assert pickled.__dict__ == cfg.__dict__ def test_lazy_import(self, tmp_path): - lazy_import_cfg_path = osp.join( - self.data_path, 'config/lazy_module_config/toy_model.py') + lazy_import_cfg_path = osp.join(self.data_path, "config/lazy_module_config/toy_model.py") cfg = Config.fromfile(lazy_import_cfg_path) cfg_dict = cfg.to_dict() - assert (cfg_dict['train_dataloader']['dataset']['type'] == - 'mmengine.testing.runner_test_case.ToyDataset') - assert ( - cfg_dict['custom_hooks'][0]['type'] == 'mmengine.hooks.EMAHook') + assert cfg_dict["train_dataloader"]["dataset"]["type"] == "mmengine.testing.runner_test_case.ToyDataset" + assert cfg_dict["custom_hooks"][0]["type"] == "mmengine.hooks.EMAHook" # Dumped config - dumped_cfg_path = tmp_path / 'test_dump_lazy.py' + dumped_cfg_path = tmp_path / "test_dump_lazy.py" cfg.dump(dumped_cfg_path) dumped_cfg = Config.fromfile(dumped_cfg_path) - copied_cfg_path = tmp_path / 'test_dump_copied_lazy.py' + copied_cfg_path = tmp_path / "test_dump_copied_lazy.py" cfg_copy = cfg.copy() cfg_copy.dump(copied_cfg_path) copied_cfg = Config.fromfile(copied_cfg_path) @@ -1008,7 +877,7 @@ def _compare_dict(a, b): _compare_dict(v, b[k]) elif isinstance(a, list): assert len(a) == len(b) - for item_a, item_b in zip(a, b): + for item_a, item_b in zip(a, b, strict=False): _compare_dict(item_a, item_b) else: assert str(a) == str(b) @@ -1029,81 +898,67 @@ def _compare_dict(a, b): # 'is not installed or mmdet version is too low') # catch import error correctly - error_obj = tmp_path / 'error_obj.py' + error_obj = tmp_path / "error_obj.py" error_obj.write_text("""from mmengine.fileio import error_obj""") # match pattern should be double escaped - match = str(error_obj).encode('unicode_escape').decode() + match = str(error_obj).encode("unicode_escape").decode() with pytest.raises(ImportError, match=match): cfg = Config.fromfile(str(error_obj)) cfg.error_obj - error_attr = tmp_path / 'error_attr.py' + error_attr = tmp_path / "error_attr.py" error_attr.write_text(""" import mmengine error_attr = mmengine.error_attr """) # noqa: E122 - match = str(error_attr).encode('unicode_escape').decode() + match = str(error_attr).encode("unicode_escape").decode() with pytest.raises(ImportError, match=match): cfg = Config.fromfile(str(error_attr)) cfg.error_attr - error_module = tmp_path / 'error_module.py' + error_module = tmp_path / "error_module.py" error_module.write_text("""import error_module""") - match = str(error_module).encode('unicode_escape').decode() + match = str(error_module).encode("unicode_escape").decode() with pytest.raises(ImportError, match=match): cfg = Config.fromfile(str(error_module)) cfg.error_module # lazy-import and non-lazy-import should not be used mixed. # current text config, base lazy-import config - with pytest.raises(RuntimeError, match='with read_base()'): - Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/error_mix_using1.py')) + with pytest.raises(RuntimeError, match="with read_base()"): + Config.fromfile(osp.join(self.data_path, "config/lazy_module_config/error_mix_using1.py")) # Force to import in non-lazy-import mode - Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/error_mix_using1.py'), - lazy_import=False) + Config.fromfile(osp.join(self.data_path, "config/lazy_module_config/error_mix_using1.py"), lazy_import=False) # current lazy-import config, base text config - with pytest.raises(RuntimeError, match='_base_ ='): - Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/error_mix_using2.py')) - - cfg = Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/test_mix_builtin.py')) - assert cfg.path == osp.join('a', 'b') - assert cfg.name == 'a/b' - assert cfg.suffix == '.py' + with pytest.raises(RuntimeError, match="_base_ ="): + Config.fromfile(osp.join(self.data_path, "config/lazy_module_config/error_mix_using2.py")) + + cfg = Config.fromfile(osp.join(self.data_path, "config/lazy_module_config/test_mix_builtin.py")) + assert cfg.path == osp.join("a", "b") + assert cfg.name == "a/b" + assert cfg.suffix == ".py" assert cfg.chained == [1, 2, 3, 4] assert cfg.existed - assert cfg.cfgname == 'test_mix_builtin.py' + assert cfg.cfgname == "test_mix_builtin.py" cfg_dict = cfg.to_dict() - dumped_cfg_path = tmp_path / 'test_dump_lazy.py' + dumped_cfg_path = tmp_path / "test_dump_lazy.py" cfg.dump(dumped_cfg_path) dumped_cfg = Config.fromfile(dumped_cfg_path) - assert set(dumped_cfg.keys()) == { - 'path', 'name', 'suffix', 'chained', 'existed', 'cfgname' - } + assert set(dumped_cfg.keys()) == {"path", "name", "suffix", "chained", "existed", "cfgname"} assert dumped_cfg.to_dict() == cfg.to_dict() class TestConfigDict(TestCase): - def test_keep_custom_dict(self): - - class CustomDict(dict): - ... + class CustomDict(dict): ... cfg_dict = ConfigDict(dict(a=CustomDict(b=1))) self.assertIsInstance(cfg_dict.a, CustomDict) - self.assertIsInstance(cfg_dict['a'], CustomDict) + self.assertIsInstance(cfg_dict["a"], CustomDict) self.assertIsInstance(cfg_dict.values()[0], CustomDict) self.assertIsInstance(cfg_dict.items()[0][1], CustomDict) @@ -1120,7 +975,7 @@ def test_build_lazy(self): # Part I # Keep key-value the same - raw = dict(a=1, b=dict(c=2, e=[dict(f=(2, ))])) + raw = dict(a=1, b=dict(c=2, e=[dict(f=(2,))])) cfg_dict = ConfigDict(raw) assert len(cfg_dict) == 2 @@ -1132,14 +987,8 @@ def test_build_lazy(self): # Check `items` and `values` will only return the build object raw = dict( - a=LazyObject('mmengine'), - b=dict( - c=2, - e=[ - dict( - f=dict(h=LazyObject('mmengine')), - g=LazyObject('mmengine')) - ])) + a=LazyObject("mmengine"), b=dict(c=2, e=[dict(f=dict(h=LazyObject("mmengine")), g=LazyObject("mmengine"))]) + ) cfg_dict = ConfigDict(raw) # check `items` and values self.assertDictEqual(cfg_dict._to_lazy_dict(), raw) @@ -1151,27 +1000,26 @@ def test_build_lazy(self): self.assertIs(cfg_dict.b.e[0].g, mmengine) # check get - self.assertIs(cfg_dict.get('a'), mmengine) - self.assertIs( - cfg_dict.get('b').get('e')[0].get('f').get('h'), mmengine) - self.assertIs(cfg_dict.get('b').get('e')[0].get('g'), mmengine) + self.assertIs(cfg_dict.get("a"), mmengine) + self.assertIs(cfg_dict.get("b").get("e")[0].get("f").get("h"), mmengine) + self.assertIs(cfg_dict.get("b").get("e")[0].get("g"), mmengine) # check pop - a = cfg_dict.pop('a') - b = cfg_dict.pop('b') - e = b.pop('e') - h = e[0].pop('f')['h'] - g = e[0].pop('g') + a = cfg_dict.pop("a") + b = cfg_dict.pop("b") + e = b.pop("e") + h = e[0].pop("f")["h"] + g = e[0].pop("g") self.assertIs(a, mmengine) self.assertIs(h, mmengine) self.assertIs(g, mmengine) self.assertEqual(cfg_dict, {}) - self.assertEqual(b, {'c': 2}) + self.assertEqual(b, {"c": 2}) # Part II # check update with dict and ConfigDict for dict_type in (dict, ConfigDict): - cfg_dict = ConfigDict(x=LazyObject('mmengine')) + cfg_dict = ConfigDict(x=LazyObject("mmengine")) cfg_dict.update(dict_type(raw)) self._check(cfg_dict) @@ -1180,32 +1028,25 @@ def test_build_lazy(self): self._check(new_dict) # Update the ConfigDict by __setitem__ and __setattr__ - new_dict['b']['h'] = LazyObject('mmengine') - new_dict['b']['k'] = dict(l=dict(n=LazyObject('mmengine'))) - new_dict.b.e[0].i = LazyObject('mmengine') - new_dict.b.e[0].j = dict(l=dict(n=LazyObject('mmengine'))) + new_dict["b"]["h"] = LazyObject("mmengine") + new_dict["b"]["k"] = dict(l=dict(n=LazyObject("mmengine"))) + new_dict.b.e[0].i = LazyObject("mmengine") + new_dict.b.e[0].j = dict(l=dict(n=LazyObject("mmengine"))) self._check(new_dict) def _check(self, cfg_dict): - self._recursive_check_lazy(cfg_dict, - lambda x: not isinstance(x, LazyObject)) - self._recursive_check_lazy(cfg_dict._to_lazy_dict(), - lambda x: x is not mmengine) + self._recursive_check_lazy(cfg_dict, lambda x: not isinstance(x, LazyObject)) + self._recursive_check_lazy(cfg_dict._to_lazy_dict(), lambda x: x is not mmengine) self._recursive_check_lazy( - cfg_dict._to_lazy_dict(), lambda x: not isinstance(x, ConfigDict) - if isinstance(x, dict) else True) - self._recursive_check_lazy( - cfg_dict, lambda x: isinstance(x, ConfigDict) - if isinstance(x, dict) else True) + cfg_dict._to_lazy_dict(), lambda x: not isinstance(x, ConfigDict) if isinstance(x, dict) else True + ) + self._recursive_check_lazy(cfg_dict, lambda x: isinstance(x, ConfigDict) if isinstance(x, dict) else True) def _recursive_check_lazy(self, cfg, expr): if isinstance(cfg, dict): - { - key: self._recursive_check_lazy(value, expr) - for key, value in cfg.items() - } + {key: self._recursive_check_lazy(value, expr) for key, value in cfg.items()} [self._recursive_check_lazy(value, expr) for value in cfg.values()] - elif isinstance(cfg, (tuple, list)): + elif isinstance(cfg, tuple | list): [self._recursive_check_lazy(value, expr) for value in cfg] else: self.assertTrue(expr(cfg)) diff --git a/tests/test_config/test_lazy.py b/tests/test_config/test_lazy.py index d69822814b..af1693cb4e 100644 --- a/tests/test_config/test_lazy.py +++ b/tests/test_config/test_lazy.py @@ -20,38 +20,33 @@ class TestImportTransformer(TestCase): - @classmethod def setUpClass(cls) -> None: cls.data_dir = osp.join( # type: ignore - osp.dirname(__file__), '..', 'data', 'config', - 'lazy_module_config') + osp.dirname(__file__), "..", "data", "config", "lazy_module_config" + ) super().setUpClass() def test_lazy_module(self): - cfg_path = osp.join(self.data_dir, 'test_ast_transform.py') + cfg_path = osp.join(self.data_dir, "test_ast_transform.py") with open(cfg_path) as f: codestr = f.read() codeobj = ast.parse(codestr) global_dict = { - 'LazyObject': LazyObject, + "LazyObject": LazyObject, } base_dict = { - '._base_.default_runtime': { - 'default_scope': 'test_config' - }, - '._base_.scheduler': { - 'val_cfg': {} - }, + "._base_.default_runtime": {"default_scope": "test_config"}, + "._base_.scheduler": {"val_cfg": {}}, } codeobj = ImportTransformer(global_dict, base_dict).visit(codeobj) codeobj, _ = _gather_abs_import_lazyobj(codeobj) codeobj = ast.fix_missing_locations(codeobj) - exec(compile(codeobj, cfg_path, mode='exec'), global_dict, global_dict) + exec(compile(codeobj, cfg_path, mode="exec"), global_dict, global_dict) # 1. absolute import # 1.1 import module as LazyObject - lazy_numpy = global_dict['numpy'] + lazy_numpy = global_dict["numpy"] self.assertIsInstance(lazy_numpy, LazyObject) # 1.2 getattr as LazyAttr @@ -70,61 +65,57 @@ def test_lazy_module(self): self.assertIs(imported_linalg, linalg) # 1.4.2 build class method from LazyAttr - start = global_dict['start'] - self.assertEqual(start.module, 'rich.progress.Progress') - self.assertEqual(str(start), 'start') + start = global_dict["start"] + self.assertEqual(start.module, "rich.progress.Progress") + self.assertEqual(str(start), "start") self.assertIs(start.build(), Progress.start) # 1.5 import ... as, and build module from LazyObject - lazy_linalg = global_dict['linalg'] + lazy_linalg = global_dict["linalg"] self.assertIsInstance(lazy_linalg, LazyObject) self.assertIs(lazy_linalg.build(), linalg) self.assertIsInstance(lazy_linalg.norm, LazyAttr) self.assertIs(lazy_linalg.norm.build(), linalg.norm) # 1.6 import built in module - imported_os = global_dict['os'] + imported_os = global_dict["os"] self.assertIs(imported_os, os) # 2. Relative import # 2.1 from ... import ... - lazy_local_backend = global_dict['local'] + lazy_local_backend = global_dict["local"] self.assertIsInstance(lazy_local_backend, LazyObject) self.assertIs(lazy_local_backend.build(), LocalBackend) # 2.2 from ... import ... as ... - lazy_petrel_backend = global_dict['PetrelBackend'] + lazy_petrel_backend = global_dict["PetrelBackend"] self.assertIsInstance(lazy_petrel_backend, LazyObject) self.assertIs(lazy_petrel_backend.build(), PetrelBackend) # 2.3 from ... import builtin module or obj from `mmengine.Config` - self.assertIs(global_dict['find_module'], find_spec) - self.assertIs(global_dict['Config'], Config) + self.assertIs(global_dict["find_module"], find_spec) + self.assertIs(global_dict["Config"], Config) # 3 test import base config # 3.1 simple from ... import and from ... import ... as - self.assertEqual(global_dict['scope'], 'test_config') - self.assertDictEqual(global_dict['val_cfg'], {}) + self.assertEqual(global_dict["scope"], "test_config") + self.assertDictEqual(global_dict["val_cfg"], {}) # 4. Error catching - cfg_path = osp.join(self.data_dir, - 'test_ast_transform_error_catching1.py') + cfg_path = osp.join(self.data_dir, "test_ast_transform_error_catching1.py") with open(cfg_path) as f: codestr = f.read() codeobj = ast.parse(codestr) - global_dict = {'LazyObject': LazyObject} - with self.assertRaisesRegex( - RuntimeError, - r'Illegal syntax in config! `from xxx import \*`'): + global_dict = {"LazyObject": LazyObject} + with self.assertRaisesRegex(RuntimeError, r"Illegal syntax in config! `from xxx import \*`"): codeobj = ImportTransformer(global_dict).visit(codeobj) class TestLazyObject(TestCase): - def test_init(self): - LazyObject('mmengine') - LazyObject('mmengine.fileio') - LazyObject('mmengine.fileio', 'LocalBackend') + LazyObject("mmengine") + LazyObject("mmengine.fileio") + LazyObject("mmengine.fileio", "LocalBackend") # module must be str with self.assertRaises(TypeError): @@ -132,17 +123,16 @@ def test_init(self): # imported must be a sequence of string or None with self.assertRaises(TypeError): - LazyObject('mmengine', ['error_type']) + LazyObject("mmengine", ["error_type"]) def test_build(self): - lazy_mmengine = LazyObject('mmengine') + lazy_mmengine = LazyObject("mmengine") self.assertIs(lazy_mmengine.build(), mmengine) - lazy_mmengine_fileio = LazyObject('mmengine.fileio') - self.assertIs(lazy_mmengine_fileio.build(), - import_module('mmengine.fileio')) + lazy_mmengine_fileio = LazyObject("mmengine.fileio") + self.assertIs(lazy_mmengine_fileio.build(), import_module("mmengine.fileio")) - lazy_local_backend = LazyObject('mmengine.fileio', 'LocalBackend') + lazy_local_backend = LazyObject("mmengine.fileio", "LocalBackend") self.assertIs(lazy_local_backend.build(), LocalBackend) # TODO: The commented test is required, we need to test the built @@ -166,14 +156,14 @@ def test_build(self): lazy_mmengine() with self.assertRaises(ImportError): - LazyObject('unknown').build() + LazyObject("unknown").build() class TestLazyAttr(TestCase): # Since LazyAttr should only be built from LazyObect, we only test # the build method here. def test_build(self): - lazy_mmengine = LazyObject('mmengine') + lazy_mmengine = LazyObject("mmengine") local_backend = lazy_mmengine.fileio.LocalBackend self.assertIs(local_backend.build(), LocalBackend) @@ -183,7 +173,5 @@ def test_build(self): with self.assertRaises(RuntimeError): local_backend() - with self.assertRaisesRegex( - ImportError, - 'Failed to import mmengine.fileio.LocalBackend.unknown'): + with self.assertRaisesRegex(ImportError, "Failed to import mmengine.fileio.LocalBackend.unknown"): local_backend.unknown.build() diff --git a/tests/test_data/test_data_utils.py b/tests/test_data/test_data_utils.py index 76e30e8642..e3f3553891 100644 --- a/tests/test_data/test_data_utils.py +++ b/tests/test_data/test_data_utils.py @@ -10,7 +10,6 @@ class TestDataUtils(TestCase): - def test_pseudo_collate(self): # Test with list of dict tensor inputs. input1 = torch.randn(1, 3, 5) @@ -18,15 +17,12 @@ def test_pseudo_collate(self): label1 = torch.randn(1) label2 = torch.randn(1) - data_batch = [ - dict(inputs=input1, data_sample=label1), - dict(inputs=input2, data_sample=label2) - ] + data_batch = [dict(inputs=input1, data_sample=label1), dict(inputs=input2, data_sample=label2)] data_batch = pseudo_collate(data_batch) - self.assertTrue(torch.allclose(input1, data_batch['inputs'][0])) - self.assertTrue(torch.allclose(input2, data_batch['inputs'][1])) - self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0])) - self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1])) + self.assertTrue(torch.allclose(input1, data_batch["inputs"][0])) + self.assertTrue(torch.allclose(input2, data_batch["inputs"][1])) + self.assertTrue(torch.allclose(label1, data_batch["data_sample"][0])) + self.assertTrue(torch.allclose(label2, data_batch["data_sample"][1])) # Test with list of dict, and each element contains `data_sample` # inputs @@ -37,8 +33,7 @@ def test_pseudo_collate(self): dict(inputs=input2, data_sample=data_sample2), ] data_batch = pseudo_collate(data) - batch_inputs, batch_data_sample = (data_batch['inputs'], - data_batch['data_sample']) + batch_inputs, batch_data_sample = (data_batch["inputs"], data_batch["data_sample"]) # check batch_inputs self.assertTrue(is_list_of(batch_inputs, torch.Tensor)) self.assertIs(input1, batch_inputs[0]) @@ -49,41 +44,27 @@ def test_pseudo_collate(self): self.assertIs(batch_data_sample[1], data_sample2) # Test with list of tuple, each tuple is a nested dict instance - data_batch = [(dict( - inputs=input1, - data_sample=data_sample1, - value=1, - name='1', - nested=dict(data_sample=data_sample1)), - dict( - inputs=input2, - data_sample=data_sample2, - value=2, - name='2', - nested=dict(data_sample=data_sample2))), - (dict( - inputs=input1, - data_sample=data_sample1, - value=1, - name='1', - nested=dict(data_sample=data_sample1)), - dict( - inputs=input2, - data_sample=data_sample2, - value=2, - name='2', - nested=dict(data_sample=data_sample2)))] + data_batch = [ + ( + dict(inputs=input1, data_sample=data_sample1, value=1, name="1", nested=dict(data_sample=data_sample1)), + dict(inputs=input2, data_sample=data_sample2, value=2, name="2", nested=dict(data_sample=data_sample2)), + ), + ( + dict(inputs=input1, data_sample=data_sample1, value=1, name="1", nested=dict(data_sample=data_sample1)), + dict(inputs=input2, data_sample=data_sample2, value=2, name="2", nested=dict(data_sample=data_sample2)), + ), + ] data_batch = pseudo_collate(data_batch) - batch_inputs_0 = data_batch[0]['inputs'] - batch_inputs_1 = data_batch[1]['inputs'] - batch_data_sample_0 = data_batch[0]['data_sample'] - batch_data_sample_1 = data_batch[1]['data_sample'] - batch_value_0 = data_batch[0]['value'] - batch_value_1 = data_batch[1]['value'] - batch_name_0 = data_batch[0]['name'] - batch_name_1 = data_batch[1]['name'] - batch_nested_0 = data_batch[0]['nested'] - batch_nested_1 = data_batch[1]['nested'] + batch_inputs_0 = data_batch[0]["inputs"] + batch_inputs_1 = data_batch[1]["inputs"] + batch_data_sample_0 = data_batch[0]["data_sample"] + batch_data_sample_1 = data_batch[1]["data_sample"] + batch_value_0 = data_batch[0]["value"] + batch_value_1 = data_batch[1]["value"] + batch_name_0 = data_batch[0]["name"] + batch_name_1 = data_batch[1]["name"] + batch_nested_0 = data_batch[0]["nested"] + batch_nested_1 = data_batch[1]["nested"] self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor)) self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor)) @@ -100,52 +81,48 @@ def test_pseudo_collate(self): self.assertEqual(batch_value_0, [1, 1]) self.assertEqual(batch_value_1, [2, 2]) - self.assertEqual(batch_name_0, ['1', '1']) - self.assertEqual(batch_name_1, ['2', '2']) + self.assertEqual(batch_name_0, ["1", "1"]) + self.assertEqual(batch_name_1, ["2", "2"]) - self.assertIs(batch_nested_0['data_sample'][0], data_sample1) - self.assertIs(batch_nested_0['data_sample'][1], data_sample1) - self.assertIs(batch_nested_1['data_sample'][0], data_sample2) - self.assertIs(batch_nested_1['data_sample'][1], data_sample2) + self.assertIs(batch_nested_0["data_sample"][0], data_sample1) + self.assertIs(batch_nested_0["data_sample"][1], data_sample1) + self.assertIs(batch_nested_1["data_sample"][0], data_sample2) + self.assertIs(batch_nested_1["data_sample"][1], data_sample2) def test_default_collate(self): # `default_collate` has comment logic with `pseudo_collate`, therefore # only test it cam stack batch tensor, convert int or float to tensor. input1 = torch.randn(1, 3, 5) input2 = torch.randn(1, 3, 5) - data_batch = [( - dict(inputs=input1, value=1, array=np.array(1)), - dict(inputs=input2, value=2, array=np.array(2)), - ), - ( - dict(inputs=input1, value=1, array=np.array(1)), - dict(inputs=input2, value=2, array=np.array(2)), - )] + data_batch = [ + ( + dict(inputs=input1, value=1, array=np.array(1)), + dict(inputs=input2, value=2, array=np.array(2)), + ), + ( + dict(inputs=input1, value=1, array=np.array(1)), + dict(inputs=input2, value=2, array=np.array(2)), + ), + ] data_batch = default_collate(data_batch) - batch_inputs_0 = data_batch[0]['inputs'] - batch_inputs_1 = data_batch[1]['inputs'] - batch_value_0 = data_batch[0]['value'] - batch_value_1 = data_batch[1]['value'] - batch_array_0 = data_batch[0]['array'] - batch_array_1 = data_batch[1]['array'] + batch_inputs_0 = data_batch[0]["inputs"] + batch_inputs_1 = data_batch[1]["inputs"] + batch_value_0 = data_batch[0]["value"] + batch_value_1 = data_batch[1]["value"] + batch_array_0 = data_batch[0]["array"] + batch_array_1 = data_batch[1]["array"] self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5)) self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5)) - self.assertTrue( - torch.allclose(batch_inputs_0, torch.stack([input1, input1]))) - self.assertTrue( - torch.allclose(batch_inputs_1, torch.stack([input2, input2]))) + self.assertTrue(torch.allclose(batch_inputs_0, torch.stack([input1, input1]))) + self.assertTrue(torch.allclose(batch_inputs_1, torch.stack([input2, input2]))) target1 = torch.stack([torch.tensor(1), torch.tensor(1)]) target2 = torch.stack([torch.tensor(2), torch.tensor(2)]) - self.assertTrue( - torch.allclose(batch_value_0.to(target1.dtype), target1)) - self.assertTrue( - torch.allclose(batch_value_1.to(target2.dtype), target2)) + self.assertTrue(torch.allclose(batch_value_0.to(target1.dtype), target1)) + self.assertTrue(torch.allclose(batch_value_1.to(target2.dtype), target2)) - self.assertTrue( - torch.allclose(batch_array_0.to(target1.dtype), target1)) - self.assertTrue( - torch.allclose(batch_array_1.to(target2.dtype), target2)) + self.assertTrue(torch.allclose(batch_array_0.to(target1.dtype), target1)) + self.assertTrue(torch.allclose(batch_array_1.to(target2.dtype), target2)) diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index f4ec815ec2..1bd190c9c2 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -7,8 +7,7 @@ import torch from mmengine.config import Config, ConfigDict -from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose, - ConcatDataset, RepeatDataset, force_full_init) +from mmengine.dataset import BaseDataset, ClassBalancedDataset, Compose, ConcatDataset, RepeatDataset, force_full_init from mmengine.registry import DATASETS, TRANSFORMS @@ -18,7 +17,6 @@ def function_pipeline(data_info): @TRANSFORMS.register_module() class CallableTransform: - def __call__(self, data_info): return data_info @@ -34,10 +32,8 @@ class CustomDataset(BaseDataset): class TestBaseDataset: - def setup_method(self): - self.data_info = dict( - filename='test_img.jpg', height=604, width=640, sample_idx=0) + self.data_info = dict(filename="test_img.jpg", height=604, width=640, sample_idx=0) self.imgs = torch.rand((2, 3, 32, 32)) self.ori_meta = BaseDataset.METAINFO self.ori_parse_data_info = BaseDataset.parse_data_info @@ -51,39 +47,43 @@ def teardown_method(self): def test_init(self): # test the instantiation of self.base_dataset dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') - assert hasattr(dataset, 'data_address') + assert hasattr(dataset, "data_list") + assert hasattr(dataset, "data_address") dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=""), + ann_file="annotations/dummy_annotation.json", + ) assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') - assert hasattr(dataset, 'data_address') + assert hasattr(dataset, "data_list") + assert hasattr(dataset, "data_address") # test the instantiation of self.base_dataset with # `serialize_data=False` dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - serialize_data=False) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + serialize_data=False, + ) assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') - assert not hasattr(dataset, 'data_address') + assert hasattr(dataset, "data_list") + assert not hasattr(dataset, "data_address") assert len(dataset) == 3 assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset with lazy init dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=True) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + lazy_init=True, + ) assert not dataset._fully_initialized assert not dataset.data_list @@ -91,46 +91,47 @@ def test_init(self): # existed. with pytest.raises(FileNotFoundError): BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/not_existed_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/not_existed_annotation.json", + ) # Use the default value of ann_file, i.e., '' with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs')) + BaseDataset(data_root=osp.join(osp.dirname(__file__), "../data/"), data_prefix=dict(img_path="imgs")) # test the instantiation of self.base_dataset when the ann_file is # wrong with pytest.raises(ValueError): BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/annotation_wrong_keys.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/annotation_wrong_keys.json", + ) with pytest.raises(TypeError): BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/annotation_wrong_format.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/annotation_wrong_format.json", + ) with pytest.raises(TypeError): BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=['img']), - ann_file='annotations/annotation_wrong_format.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=["img"]), + ann_file="annotations/annotation_wrong_format.json", + ) # test the instantiation of self.base_dataset when `parse_data_info` # return `list[dict]` - BaseDataset.parse_data_info = MagicMock( - return_value=[self.data_info, - self.data_info.copy()]) + BaseDataset.parse_data_info = MagicMock(return_value=[self.data_info, self.data_info.copy()]) dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) dataset.pipeline = self.pipeline assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') - assert hasattr(dataset, 'data_address') + assert hasattr(dataset, "data_list") + assert hasattr(dataset, "data_address") assert len(dataset) == 6 assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info @@ -138,26 +139,28 @@ def test_init(self): # test the instantiation of self.base_dataset when `parse_data_info` # return unsupported data. with pytest.raises(TypeError): - BaseDataset.parse_data_info = MagicMock(return_value='xxx') + BaseDataset.parse_data_info = MagicMock(return_value="xxx") dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) with pytest.raises(TypeError): - BaseDataset.parse_data_info = MagicMock( - return_value=[self.data_info, 'xxx']) + BaseDataset.parse_data_info = MagicMock(return_value=[self.data_info, "xxx"]) BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) # test the instantiation of self.base_dataset without `ann_file` BaseDataset.parse_data_info = self.ori_parse_data_info dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='', + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="", serialize_data=False, - lazy_init=True) + lazy_init=True, + ) assert not dataset.ann_file # Test `ann_file` and `data_root` could be None. @@ -167,155 +170,140 @@ def test_meta(self): # test dataset.metainfo with setting the metainfo from annotation file # as the metainfo of self.base_dataset. dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) - assert dataset.metainfo == dict( - dataset_type='test_dataset', task_name='test_task', empty_list=[]) + assert dataset.metainfo == dict(dataset_type="test_dataset", task_name="test_task", empty_list=[]) # test dataset.metainfo with setting METAINFO in self.base_dataset - dataset_type = 'new_dataset' - BaseDataset.METAINFO = dict( - dataset_type=dataset_type, classes=('dog', 'cat')) + dataset_type = "new_dataset" + BaseDataset.METAINFO = dict(dataset_type=dataset_type, classes=("dog", "cat")) dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='test_task', - classes=('dog', 'cat'), - empty_list=[]) + dataset_type=dataset_type, task_name="test_task", classes=("dog", "cat"), empty_list=[] + ) # test dataset.metainfo with passing metainfo into self.base_dataset - metainfo = dict(classes=('dog', ), task_name='new_task') + metainfo = dict(classes=("dog",), task_name="new_task") dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + metainfo=metainfo, + ) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, classes=("dog", "cat")) assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset_type=dataset_type, task_name="new_task", classes=("dog",), empty_list=[] + ) # test dataset.metainfo with passing metainfo as Config into # self.base_dataset - metainfo = Config(dict(classes=('dog', ), task_name='new_task')) + metainfo = Config(dict(classes=("dog",), task_name="new_task")) dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + metainfo=metainfo, + ) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, classes=("dog", "cat")) assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset_type=dataset_type, task_name="new_task", classes=("dog",), empty_list=[] + ) # test dataset.metainfo with passing metainfo as ConfigDict (Mapping) # into self.base_dataset - metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task')) + metainfo = ConfigDict(dict(classes=("dog",), task_name="new_task")) dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + metainfo=metainfo, + ) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, classes=("dog", "cat")) assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset_type=dataset_type, task_name="new_task", classes=("dog",), empty_list=[] + ) # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not # change - BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish') - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + BaseDataset.METAINFO["classes"] = ("dog", "cat", "fish") + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, classes=("dog", "cat", "fish")) assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset_type=dataset_type, task_name="new_task", classes=("dog",), empty_list=[] + ) # test dataset.metainfo with passing metainfo containing a file into # self.base_dataset - metainfo = dict( - classes=osp.join( - osp.dirname(__file__), '../data/meta/classes.txt')) + metainfo = dict(classes=osp.join(osp.dirname(__file__), "../data/meta/classes.txt")) dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + metainfo=metainfo, + ) assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='test_task', - classes=['dog'], - empty_list=[]) + dataset_type=dataset_type, task_name="test_task", classes=["dog"], empty_list=[] + ) # test dataset.metainfo with passing unsupported metainfo into # self.base_dataset with pytest.raises(TypeError): - metainfo = 'dog' + metainfo = "dog" dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + metainfo=metainfo, + ) # test dataset.metainfo with passing metainfo into self.base_dataset # and lazy_init is True - metainfo = dict(classes=('dog', )) + metainfo = dict(classes=("dog",)) dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", metainfo=metainfo, - lazy_init=True) + lazy_init=True, + ) # 'task_name' and 'empty_list' not in dataset.metainfo - assert dataset.metainfo == dict( - dataset_type=dataset_type, classes=('dog', )) + assert dataset.metainfo == dict(dataset_type=dataset_type, classes=("dog",)) # test whether self.base_dataset.METAINFO is changed when a customize # dataset inherit self.base_dataset # test reset METAINFO in ToyDataset. class ToyDataset(BaseDataset): - METAINFO = dict(xxx='xxx') + METAINFO = dict(xxx="xxx") - assert ToyDataset.METAINFO == dict(xxx='xxx') - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert ToyDataset.METAINFO == dict(xxx="xxx") + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, classes=("dog", "cat", "fish")) # test update METAINFO in ToyDataset. class ToyDataset(BaseDataset): METAINFO = copy.deepcopy(BaseDataset.METAINFO) - METAINFO['classes'] = ('bird', ) + METAINFO["classes"] = ("bird",) - assert ToyDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('bird', )) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert ToyDataset.METAINFO == dict(dataset_type=dataset_type, classes=("bird",)) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, classes=("dog", "cat", "fish")) - @pytest.mark.parametrize('lazy_init', [True, False]) + @pytest.mark.parametrize("lazy_init", [True, False]) def test_length(self, lazy_init): dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + lazy_init=lazy_init, + ) if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") assert len(dataset) == 3 else: # test `__len__()` when lazy_init is True @@ -324,31 +312,28 @@ def test_length(self, lazy_init): # call `full_init()` automatically assert len(dataset) == 3 assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") def test_compose(self): # test callable transform transforms = [function_pipeline] compose = Compose(transforms=transforms) - assert (self.imgs == compose(dict(img=self.imgs))['img']).all() + assert (self.imgs == compose(dict(img=self.imgs))["img"]).all() # test transform build from cfg_dict - transforms = [dict(type='CallableTransform')] + transforms = [dict(type="CallableTransform")] compose = Compose(transforms=transforms) - assert (self.imgs == compose(dict(img=self.imgs))['img']).all() + assert (self.imgs == compose(dict(img=self.imgs))["img"]).all() # test return None in advance none_func = MagicMock(return_value=None) transforms = [none_func, function_pipeline] compose = Compose(transforms=transforms) assert compose(dict(img=self.imgs)) is None # test repr - repr_str = f'Compose(\n' \ - f' {none_func}\n' \ - f' {function_pipeline}\n' \ - f')' + repr_str = f"Compose(\n {none_func}\n {function_pipeline}\n)" assert repr(compose) == repr_str # non-callable transform will raise error with pytest.raises(TypeError): - transforms = [dict(type='NotCallableTransform')] + transforms = [dict(type="NotCallableTransform")] Compose(transforms) # transform must be callable or dict @@ -357,22 +342,23 @@ def test_compose(self): # when the input transform is None, do nothing compose = Compose(None) - assert (compose(dict(img=self.imgs))['img'] == self.imgs).all() + assert (compose(dict(img=self.imgs))["img"] == self.imgs).all() compose = Compose([]) - assert (compose(dict(img=self.imgs))['img'] == self.imgs).all() + assert (compose(dict(img=self.imgs))["img"] == self.imgs).all() - @pytest.mark.parametrize('lazy_init', [True, False]) + @pytest.mark.parametrize("lazy_init", [True, False]) def test_getitem(self, lazy_init): dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + lazy_init=lazy_init, + ) dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") assert dataset[0] == dict(imgs=self.imgs) else: # Test `__getitem__()` when lazy_init is True @@ -381,7 +367,7 @@ def test_getitem(self, lazy_init): # Call `full_init()` automatically assert dataset[0] == dict(imgs=self.imgs) assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") # Test with test mode dataset.test_mode = False @@ -404,17 +390,18 @@ def fake_prepare_data(idx): with pytest.raises(Exception): dataset[0] - @pytest.mark.parametrize('lazy_init', [True, False]) + @pytest.mark.parametrize("lazy_init", [True, False]) def test_get_data_info(self, lazy_init): dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + lazy_init=lazy_init, + ) if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") assert dataset.get_data_info(0) == self.data_info else: # test `get_data_info()` when lazy_init is True @@ -423,23 +410,22 @@ def test_get_data_info(self, lazy_init): # call `full_init()` automatically assert dataset.get_data_info(0) == self.data_info assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") # Test parse_data_info with `data_prefix` BaseDataset.parse_data_info = self.ori_parse_data_info - data_root = osp.join(osp.dirname(__file__), '../data/') + data_root = osp.join(osp.dirname(__file__), "../data/") dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) data_info = dataset.get_data_info(0) - assert data_info['img_path'] == osp.join(data_root, 'imgs', - 'test_img.jpg') + assert data_info["img_path"] == osp.join(data_root, "imgs", "test_img.jpg") def test_force_full_init(self): with pytest.raises(AttributeError): class ClassWithoutFullInit: - @force_full_init def foo(self): pass @@ -449,10 +435,11 @@ def foo(self): def test_full_init(self): dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=True) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + lazy_init=True, + ) dataset.pipeline = self.pipeline # test `full_init()` when lazy_init is True assert not dataset._fully_initialized @@ -460,49 +447,51 @@ def test_full_init(self): # call `full_init()` manually dataset.full_init() assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") assert len(dataset) == 3 assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=False) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + lazy_init=False, + ) dataset.pipeline = self.pipeline assert dataset._fully_initialized - assert hasattr(dataset, 'data_list') + assert hasattr(dataset, "data_list") assert len(dataset) == 3 assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset when passing indices dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=""), + ann_file="annotations/dummy_annotation.json", + ) dataset_sliced = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - indices=1) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=""), + ann_file="annotations/dummy_annotation.json", + indices=1, + ) assert dataset_sliced[0] == dataset[0] assert len(dataset_sliced) == 1 - @pytest.mark.parametrize( - 'lazy_init, serialize_data', - ([True, False], [False, True], [True, True], [False, False])) + @pytest.mark.parametrize("lazy_init, serialize_data", ([True, False], [False, True], [True, True], [False, False])) def test_get_subset_(self, lazy_init, serialize_data): # Test positive int indices. indices = 2 dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=""), + ann_file="annotations/dummy_annotation.json", lazy_init=lazy_init, - serialize_data=serialize_data) + serialize_data=serialize_data, + ) dataset_copy = copy.deepcopy(dataset) dataset_copy.get_subset_(indices) @@ -518,7 +507,7 @@ def test_get_subset_(self, lazy_init, serialize_data): assert len(dataset_copy) == 2 for i in range(len(dataset_copy)): ori_data = dataset[i + 1] - ori_data['sample_idx'] = i + ori_data["sample_idx"] = i assert dataset_copy[i] == ori_data # If indices is 0, return empty dataset. @@ -530,7 +519,7 @@ def test_get_subset_(self, lazy_init, serialize_data): indices = [1] dataset_copy = copy.deepcopy(dataset) ori_data = dataset[1] - ori_data['sample_idx'] = 0 + ori_data["sample_idx"] = 0 dataset_copy.get_subset_(indices) assert len(dataset_copy) == 1 assert dataset_copy[0] == ori_data @@ -539,7 +528,7 @@ def test_get_subset_(self, lazy_init, serialize_data): indices = [-1] dataset_copy = copy.deepcopy(dataset) ori_data = dataset[2] - ori_data['sample_idx'] = 0 + ori_data["sample_idx"] = 0 dataset_copy.get_subset_(indices) assert len(dataset_copy) == 1 assert dataset_copy[0] == ori_data @@ -555,7 +544,7 @@ def test_get_subset_(self, lazy_init, serialize_data): dataset_copy.get_subset_(indices) for i in range(len(dataset_copy)): ori_data = dataset[i] - ori_data['sample_idx'] = i + ori_data["sample_idx"] = i assert dataset_copy[i] == ori_data # Test list with multiple negative indices. indices = [-1, -2, 0] @@ -563,24 +552,23 @@ def test_get_subset_(self, lazy_init, serialize_data): dataset_copy.get_subset_(indices) for i in range(len(dataset_copy)): ori_data = dataset[len(dataset) - i - 1] - ori_data['sample_idx'] = i + ori_data["sample_idx"] = i assert dataset_copy[i] == ori_data with pytest.raises(TypeError): dataset.get_subset_(dict()) - @pytest.mark.parametrize( - 'lazy_init, serialize_data', - ([True, False], [False, True], [True, True], [False, False])) + @pytest.mark.parametrize("lazy_init, serialize_data", ([True, False], [False, True], [True, True], [False, False])) def test_get_subset(self, lazy_init, serialize_data): # Test positive indices. indices = 2 dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=""), + ann_file="annotations/dummy_annotation.json", lazy_init=lazy_init, - serialize_data=serialize_data) + serialize_data=serialize_data, + ) dataset_sliced = dataset.get_subset(indices) assert len(dataset_sliced) == 2 assert dataset_sliced[0] == dataset[0] @@ -592,7 +580,7 @@ def test_get_subset(self, lazy_init, serialize_data): assert len(dataset_sliced) == 2 for i in range(len(dataset_sliced)): ori_data = dataset[i + 1] - ori_data['sample_idx'] = i + ori_data["sample_idx"] = i assert dataset_sliced[i] == ori_data # If indices is 0 or empty list, return empty dataset. assert len(dataset.get_subset(0)) == 0 @@ -601,7 +589,7 @@ def test_get_subset(self, lazy_init, serialize_data): indices = [1] dataset_sliced = dataset.get_subset(indices) ori_data = dataset[1] - ori_data['sample_idx'] = 0 + ori_data["sample_idx"] = 0 assert len(dataset_sliced) == 1 assert dataset_sliced[0] == ori_data # Test list with multiple positive index. @@ -609,77 +597,77 @@ def test_get_subset(self, lazy_init, serialize_data): dataset_sliced = dataset.get_subset(indices) for i in range(len(dataset_sliced)): ori_data = dataset[i] - ori_data['sample_idx'] = i + ori_data["sample_idx"] = i assert dataset_sliced[i] == ori_data # Test list with multiple negative index. indices = [-1, -2, 0] dataset_sliced = dataset.get_subset(indices) for i in range(len(dataset_sliced)): ori_data = dataset[len(dataset) - i - 1] - ori_data['sample_idx'] = i + ori_data["sample_idx"] = i assert dataset_sliced[i] == ori_data def test_rand_another(self): # test the instantiation of self.base_dataset when passing num_samples dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - indices=1) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path=""), + ann_file="annotations/dummy_annotation.json", + indices=1, + ) assert dataset._rand_another() >= 0 assert dataset._rand_another() < len(dataset) class TestConcatDataset: - def setup_method(self): dataset = BaseDataset # create dataset_a - data_info = dict(filename='test_img.jpg', height=604, width=640) + data_info = dict(filename="test_img.jpg", height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset_a = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) # create dataset_b - data_info = dict(filename='gray.jpg', height=288, width=512) + data_info = dict(filename="gray.jpg", height=288, width=512) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset_b = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init - self.cat_datasets = ConcatDataset( - datasets=[self.dataset_a, self.dataset_b]) + self.cat_datasets = ConcatDataset(datasets=[self.dataset_a, self.dataset_b]) def test_init(self): # Test build dataset from cfg. dataset_cfg_b = dict( type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b]) cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline assert len(cat_datasets) == len(self.cat_datasets) for i in range(len(cat_datasets)): - assert (cat_datasets.get_data_info(i) == - self.cat_datasets.get_data_info(i)) - assert (cat_datasets[i] == self.cat_datasets[i]) + assert cat_datasets.get_data_info(i) == self.cat_datasets.get_data_info(i) + assert cat_datasets[i] == self.cat_datasets[i] with pytest.raises(TypeError): ConcatDataset(datasets=[0]) with pytest.raises(TypeError): - ConcatDataset( - datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1) + ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1) def test_full_init(self): # test init with lazy_init=True @@ -697,10 +685,11 @@ def test_full_init(self): self.cat_datasets.get_subset(1) dataset_b = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=dict(classes=('cat'))) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + metainfo=dict(classes=("cat")), + ) # Regardless of order, different meta information without # `ignore_keys` will raise error. with pytest.raises(ValueError): @@ -710,11 +699,9 @@ def test_full_init(self): # `ignore_keys` does not contain different meta information keys will # raise error. with pytest.raises(ValueError): - ConcatDataset( - datasets=[self.dataset_a, dataset_b], ignore_keys=['a']) + ConcatDataset(datasets=[self.dataset_a, dataset_b], ignore_keys=["a"]) # Different meta information with `ignore_keys` will not raise error. - cat_datasets = ConcatDataset( - datasets=[self.dataset_a, dataset_b], ignore_keys='classes') + cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_b], ignore_keys="classes") cat_datasets.full_init() assert len(cat_datasets) == 6 cat_datasets.full_init() @@ -727,86 +714,71 @@ def test_metainfo(self): assert self.cat_datasets.metainfo == self.dataset_a.metainfo def test_length(self): - assert len(self.cat_datasets) == ( - len(self.dataset_a) + len(self.dataset_b)) + assert len(self.cat_datasets) == (len(self.dataset_a) + len(self.dataset_b)) def test_getitem(self): - assert ( - self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() - assert (self.cat_datasets[0]['imgs'] != - self.dataset_b[0]['imgs']).all() + assert (self.cat_datasets[0]["imgs"] == self.dataset_a[0]["imgs"]).all() + assert (self.cat_datasets[0]["imgs"] != self.dataset_b[0]["imgs"]).all() - assert ( - self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() - assert (self.cat_datasets[-1]['imgs'] != - self.dataset_a[-1]['imgs']).all() + assert (self.cat_datasets[-1]["imgs"] == self.dataset_b[-1]["imgs"]).all() + assert (self.cat_datasets[-1]["imgs"] != self.dataset_a[-1]["imgs"]).all() def test_get_data_info(self): - assert self.cat_datasets.get_data_info( - 0) == self.dataset_a.get_data_info(0) - assert self.cat_datasets.get_data_info( - 0) != self.dataset_b.get_data_info(0) + assert self.cat_datasets.get_data_info(0) == self.dataset_a.get_data_info(0) + assert self.cat_datasets.get_data_info(0) != self.dataset_b.get_data_info(0) - assert self.cat_datasets.get_data_info( - -1) == self.dataset_b.get_data_info(-1) - assert self.cat_datasets.get_data_info( - -1) != self.dataset_a.get_data_info(-1) + assert self.cat_datasets.get_data_info(-1) == self.dataset_b.get_data_info(-1) + assert self.cat_datasets.get_data_info(-1) != self.dataset_a.get_data_info(-1) def test_get_ori_dataset_idx(self): - assert self.cat_datasets._get_ori_dataset_idx(3) == ( - 1, 3 - len(self.dataset_a)) - assert self.cat_datasets._get_ori_dataset_idx(-1) == ( - 1, len(self.dataset_b) - 1) + assert self.cat_datasets._get_ori_dataset_idx(3) == (1, 3 - len(self.dataset_a)) + assert self.cat_datasets._get_ori_dataset_idx(-1) == (1, len(self.dataset_b) - 1) with pytest.raises(ValueError): assert self.cat_datasets._get_ori_dataset_idx(-10) class TestRepeatDataset: - def setup_method(self): dataset = BaseDataset - data_info = dict(filename='test_img.jpg', height=604, width=640) + data_info = dict(filename="test_img.jpg", height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_times = 5 # test init - self.repeat_datasets = RepeatDataset( - dataset=self.dataset, times=self.repeat_times) + self.repeat_datasets = RepeatDataset(dataset=self.dataset, times=self.repeat_times) def test_init(self): # Test build dataset from cfg. dataset_cfg = dict( type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - repeat_dataset = RepeatDataset( - dataset=dataset_cfg, times=self.repeat_times) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) + repeat_dataset = RepeatDataset(dataset=dataset_cfg, times=self.repeat_times) repeat_dataset.dataset.pipeline = self.dataset.pipeline assert len(repeat_dataset) == len(self.repeat_datasets) for i in range(len(repeat_dataset)): - assert (repeat_dataset.get_data_info(i) == - self.repeat_datasets.get_data_info(i)) - assert (repeat_dataset[i] == self.repeat_datasets[i]) + assert repeat_dataset.get_data_info(i) == self.repeat_datasets.get_data_info(i) + assert repeat_dataset[i] == self.repeat_datasets[i] with pytest.raises(TypeError): RepeatDataset(dataset=[0], times=5) def test_full_init(self): self.repeat_datasets.full_init() - assert len( - self.repeat_datasets) == self.repeat_times * len(self.dataset) + assert len(self.repeat_datasets) == self.repeat_times * len(self.dataset) self.repeat_datasets.full_init() self.repeat_datasets._fully_initialized = False self.repeat_datasets[1] - assert len(self.repeat_datasets) == \ - self.repeat_times * len(self.dataset) + assert len(self.repeat_datasets) == self.repeat_times * len(self.dataset) with pytest.raises(NotImplementedError): self.repeat_datasets.get_subset_(1) @@ -818,57 +790,51 @@ def test_metainfo(self): assert self.repeat_datasets.metainfo == self.dataset.metainfo def test_length(self): - assert len( - self.repeat_datasets) == len(self.dataset) * self.repeat_times + assert len(self.repeat_datasets) == len(self.dataset) * self.repeat_times def test_getitem(self): for i in range(self.repeat_times): - assert self.repeat_datasets[len(self.dataset) * - i] == self.dataset[0] + assert self.repeat_datasets[len(self.dataset) * i] == self.dataset[0] def test_get_data_info(self): for i in range(self.repeat_times): - assert self.repeat_datasets.get_data_info( - len(self.dataset) * i) == self.dataset.get_data_info(0) + assert self.repeat_datasets.get_data_info(len(self.dataset) * i) == self.dataset.get_data_info(0) class TestClassBalancedDataset: - def setup_method(self): dataset = BaseDataset - data_info = dict(filename='test_img.jpg', height=604, width=640) + data_info = dict(filename="test_img.jpg", height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) dataset.get_cat_ids = MagicMock(return_value=[0]) self.dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_indices = [0, 0, 1, 1, 1] # test init - self.cls_banlanced_datasets = ClassBalancedDataset( - dataset=self.dataset, oversample_thr=1e-3) + self.cls_banlanced_datasets = ClassBalancedDataset(dataset=self.dataset, oversample_thr=1e-3) self.cls_banlanced_datasets.repeat_indices = self.repeat_indices def test_init(self): # Test build dataset from cfg. dataset_cfg = dict( type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - cls_banlanced_datasets = ClassBalancedDataset( - dataset=dataset_cfg, oversample_thr=1e-3) + data_root=osp.join(osp.dirname(__file__), "../data/"), + data_prefix=dict(img_path="imgs"), + ann_file="annotations/dummy_annotation.json", + ) + cls_banlanced_datasets = ClassBalancedDataset(dataset=dataset_cfg, oversample_thr=1e-3) cls_banlanced_datasets.repeat_indices = self.repeat_indices cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets) for i in range(len(cls_banlanced_datasets)): - assert (cls_banlanced_datasets.get_data_info(i) == - self.cls_banlanced_datasets.get_data_info(i)) - assert ( - cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i]) + assert cls_banlanced_datasets.get_data_info(i) == self.cls_banlanced_datasets.get_data_info(i) + assert cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i] with pytest.raises(TypeError): ClassBalancedDataset(dataset=[0], times=5) @@ -896,15 +862,12 @@ def test_length(self): def test_getitem(self): for i in range(len(self.repeat_indices)): - assert self.cls_banlanced_datasets[i] == self.dataset[ - self.repeat_indices[i]] + assert self.cls_banlanced_datasets[i] == self.dataset[self.repeat_indices[i]] def test_get_data_info(self): for i in range(len(self.repeat_indices)): - assert self.cls_banlanced_datasets.get_data_info( - i) == self.dataset.get_data_info(self.repeat_indices[i]) + assert self.cls_banlanced_datasets.get_data_info(i) == self.dataset.get_data_info(self.repeat_indices[i]) def test_get_cat_ids(self): for i in range(len(self.repeat_indices)): - assert self.cls_banlanced_datasets.get_cat_ids( - i) == self.dataset.get_cat_ids(self.repeat_indices[i]) + assert self.cls_banlanced_datasets.get_cat_ids(i) == self.dataset.get_cat_ids(self.repeat_indices[i]) diff --git a/tests/test_dataset/test_sampler.py b/tests/test_dataset/test_sampler.py index 31582a8679..01fffaf3b1 100644 --- a/tests/test_dataset/test_sampler.py +++ b/tests/test_dataset/test_sampler.py @@ -10,12 +10,11 @@ class TestDefaultSampler(TestCase): - def setUp(self): self.data_length = 100 self.dataset = list(range(self.data_length)) - @patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1)) + @patch("mmengine.dataset.sampler.get_dist_info", return_value=(0, 1)) def test_non_dist(self, mock): sampler = DefaultSampler(self.dataset) self.assertEqual(sampler.world_size, 1) @@ -33,7 +32,7 @@ def test_non_dist(self, mock): self.assertEqual(sampler.num_samples, self.data_length) self.assertEqual(list(sampler), list(range(self.data_length))) - @patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3)) + @patch("mmengine.dataset.sampler.get_dist_info", return_value=(2, 3)) def test_dist(self, mock): sampler = DefaultSampler(self.dataset) self.assertEqual(sampler.world_size, 3) @@ -44,20 +43,17 @@ def test_dist(self, mock): self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3)) self.assertEqual(sampler.total_size, sampler.num_samples * 3) self.assertEqual(len(sampler), sampler.num_samples) - self.assertEqual( - list(sampler), - list(range(self.data_length))[2::3] + [1]) + self.assertEqual(list(sampler), list(range(self.data_length))[2::3] + [1]) # test round_up=False sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) - self.assertEqual(sampler.num_samples, - np.ceil((self.data_length - 2) / 3)) + self.assertEqual(sampler.num_samples, np.ceil((self.data_length - 2) / 3)) self.assertEqual(sampler.total_size, self.data_length) self.assertEqual(len(sampler), sampler.num_samples) self.assertEqual(list(sampler), list(range(self.data_length))[2::3]) - @patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1)) - @patch('mmengine.dataset.sampler.sync_random_seed', return_value=7) + @patch("mmengine.dataset.sampler.get_dist_info", return_value=(0, 1)) + @patch("mmengine.dataset.sampler.sync_random_seed", return_value=7) def test_shuffle(self, mock1, mock2): # test seed=None sampler = DefaultSampler(self.dataset, seed=None) @@ -68,26 +64,21 @@ def test_shuffle(self, mock1, mock2): sampler.set_epoch(10) g = torch.Generator() g.manual_seed(10) - self.assertEqual( - list(sampler), - torch.randperm(len(self.dataset), generator=g).tolist()) + self.assertEqual(list(sampler), torch.randperm(len(self.dataset), generator=g).tolist()) sampler = DefaultSampler(self.dataset, shuffle=True, seed=42) sampler.set_epoch(10) g = torch.Generator() g.manual_seed(42 + 10) - self.assertEqual( - list(sampler), - torch.randperm(len(self.dataset), generator=g).tolist()) + self.assertEqual(list(sampler), torch.randperm(len(self.dataset), generator=g).tolist()) class TestInfiniteSampler(TestCase): - def setUp(self): self.data_length = 100 self.dataset = list(range(self.data_length)) - @patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1)) + @patch("mmengine.dataset.sampler.get_dist_info", return_value=(0, 1)) def test_non_dist(self, mock): sampler = InfiniteSampler(self.dataset) self.assertEqual(sampler.world_size, 1) @@ -101,7 +92,7 @@ def test_non_dist(self, mock): items = [next(sampler_iter) for _ in range(self.data_length * 2)] self.assertEqual(items, list(range(self.data_length)) * 2) - @patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3)) + @patch("mmengine.dataset.sampler.get_dist_info", return_value=(2, 3)) def test_dist(self, mock): sampler = InfiniteSampler(self.dataset) self.assertEqual(sampler.world_size, 3) @@ -117,8 +108,8 @@ def test_dist(self, mock): print(samples) self.assertEqual(samples, targets) - @patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1)) - @patch('mmengine.dataset.sampler.sync_random_seed', return_value=7) + @patch("mmengine.dataset.sampler.get_dist_info", return_value=(0, 1)) + @patch("mmengine.dataset.sampler.sync_random_seed", return_value=7) def test_shuffle(self, mock1, mock2): # test seed=None sampler = InfiniteSampler(self.dataset, seed=None) @@ -132,9 +123,7 @@ def test_shuffle(self, mock1, mock2): g = torch.Generator() g.manual_seed(42) - self.assertEqual( - samples, - torch.randperm(self.data_length, generator=g).tolist()) + self.assertEqual(samples, torch.randperm(self.data_length, generator=g).tolist()) def test_set_epoch(self): sampler = InfiniteSampler(self.dataset) diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py index d2171afa58..926addde50 100644 --- a/tests/test_device/test_device.py +++ b/tests/test_device/test_device.py @@ -1,20 +1,25 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.device import (get_device, is_cuda_available, is_mlu_available, - is_mps_available, is_musa_available, - is_npu_available) +from mmengine.device import ( + get_device, + is_cuda_available, + is_mlu_available, + is_mps_available, + is_musa_available, + is_npu_available, +) def test_get_device(): device = get_device() if is_npu_available(): - assert device == 'npu' + assert device == "npu" elif is_cuda_available(): - assert device == 'cuda' + assert device == "cuda" elif is_mlu_available(): - assert device == 'mlu' + assert device == "mlu" elif is_mps_available(): - assert device == 'mps' + assert device == "mps" elif is_musa_available(): - assert device == 'musa' + assert device == "musa" else: - assert device == 'cpu' + assert device == "cpu" diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index a2ef07b713..c42e108707 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -45,89 +45,76 @@ def test_broadcast(self): dist.broadcast(data) self.assertTrue(torch.allclose(data, expected)) - @patch('numpy.random.randint', return_value=10) + @patch("numpy.random.randint", return_value=10) def test_sync_random_seed(self, mock): self.assertEqual(sync_random_seed(), 10) def test_broadcast_object_list(self): with self.assertRaises(AssertionError): # input should be list of object - dist.broadcast_object_list('foo') + dist.broadcast_object_list("foo") - data = ['foo', 12, {1: 2}] - expected = ['foo', 12, {1: 2}] + data = ["foo", 12, {1: 2}] + expected = ["foo", 12, {1: 2}] dist.broadcast_object_list(data) self.assertEqual(data, expected) def test_all_reduce_dict(self): with self.assertRaises(AssertionError): # input should be dict - dist.all_reduce_dict('foo') - - data = { - 'key1': torch.arange(2, dtype=torch.int64), - 'key2': torch.arange(3, dtype=torch.int64) - } - expected = { - 'key1': torch.arange(2, dtype=torch.int64), - 'key2': torch.arange(3, dtype=torch.int64) - } + dist.all_reduce_dict("foo") + + data = {"key1": torch.arange(2, dtype=torch.int64), "key2": torch.arange(3, dtype=torch.int64)} + expected = {"key1": torch.arange(2, dtype=torch.int64), "key2": torch.arange(3, dtype=torch.int64)} dist.all_reduce_dict(data) for key in data: self.assertTrue(torch.allclose(data[key], expected[key])) def test_all_gather_object(self): - data = 'foo' - expected = 'foo' + data = "foo" + expected = "foo" gather_objects = dist.all_gather_object(data) self.assertEqual(gather_objects[0], expected) def test_gather_object(self): - data = 'foo' - expected = 'foo' + data = "foo" + expected = "foo" gather_objects = dist.gather_object(data) self.assertEqual(gather_objects[0], expected) def test_collect_results(self): - data = ['foo', {1: 2}] + data = ["foo", {1: 2}] size = 2 - expected = ['foo', {1: 2}] + expected = ["foo", {1: 2}] # test `device=cpu` - output = dist.collect_results(data, size, device='cpu') + output = dist.collect_results(data, size, device="cpu") self.assertEqual(output, expected) # test `device=gpu` - output = dist.collect_results(data, size, device='gpu') + output = dist.collect_results(data, size, device="gpu") self.assertEqual(output, expected) def test_all_reduce_params(self): - for tensor_type, reduce_op in zip([torch.int64, torch.float32], - ['sum', 'mean']): - data = [ - torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) - ] + for tensor_type, reduce_op in zip([torch.int64, torch.float32], ["sum", "mean"], strict=False): + data = [torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)] data_gen = (item for item in data) - expected = [ - torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) - ] + expected = [torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)] dist.all_reduce_params(data_gen, op=reduce_op) - for item1, item2 in zip(data, expected): + for item1, item2 in zip(data, expected, strict=False): self.assertTrue(torch.allclose(item1, item2)) -@unittest.skipIf(is_musa_available(), reason='musa do not support gloo yet') +@unittest.skipIf(is_musa_available(), reason="musa do not support gloo yet") class TestDistWithGLOOBackend(MultiProcessTestCase): - def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29505' - os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29505" + os.environ["RANK"] = str(rank) + torch_dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) def setUp(self): super().setUp() @@ -136,14 +123,14 @@ def setUp(self): def test_all_reduce(self): self._init_dist_env(self.rank, self.world_size) tensor_types = [torch.int64, torch.float32, torch.int64] - reduce_ops = ['sum', 'mean', 'mean'] - for tensor_type, reduce_op in zip(tensor_types, reduce_ops): + reduce_ops = ["sum", "mean", "mean"] + for tensor_type, reduce_op in zip(tensor_types, reduce_ops, strict=False): if dist.get_rank() == 0: data = torch.tensor([1, 2], dtype=tensor_type) else: data = torch.tensor([3, 4], dtype=tensor_type) - if reduce_op == 'sum': + if reduce_op == "sum": expected = torch.tensor([4, 6], dtype=tensor_type) else: expected = torch.tensor([2, 3], dtype=tensor_type) @@ -161,8 +148,7 @@ def test_all_gather(self): expected = [torch.tensor([0, 1]), torch.tensor([1, 2])] output = dist.all_gather(data) - self.assertTrue( - torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])) + self.assertTrue(torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])) def test_gather(self): self._init_dist_env(self.rank, self.world_size) @@ -193,9 +179,7 @@ def test_broadcast_dist(self): def test_sync_random_seed(self): self._init_dist_env(self.rank, self.world_size) - with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024)) as mock_tensor: + with patch.object(torch, "tensor", return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() @@ -203,38 +187,37 @@ def test_sync_random_seed(self): def test_broadcast_object_list(self): self._init_dist_env(self.rank, self.world_size) if dist.get_rank() == 0: - data = ['foo', 12, {1: 2}] + data = ["foo", 12, {1: 2}] else: data = [None, None, None] - expected = ['foo', 12, {1: 2}] + expected = ["foo", 12, {1: 2}] dist.broadcast_object_list(data) self.assertEqual(data, expected) def test_all_reduce_dict(self): self._init_dist_env(self.rank, self.world_size) - for tensor_type, reduce_op in zip([torch.int64, torch.float32], - ['sum', 'mean']): + for tensor_type, reduce_op in zip([torch.int64, torch.float32], ["sum", "mean"], strict=False): if dist.get_rank() == 0: data = { - 'key1': torch.tensor([0, 1], dtype=tensor_type), - 'key2': torch.tensor([1, 2], dtype=tensor_type), + "key1": torch.tensor([0, 1], dtype=tensor_type), + "key2": torch.tensor([1, 2], dtype=tensor_type), } else: data = { - 'key1': torch.tensor([2, 3], dtype=tensor_type), - 'key2': torch.tensor([3, 4], dtype=tensor_type), + "key1": torch.tensor([2, 3], dtype=tensor_type), + "key2": torch.tensor([3, 4], dtype=tensor_type), } - if reduce_op == 'sum': + if reduce_op == "sum": expected = { - 'key1': torch.tensor([2, 4], dtype=tensor_type), - 'key2': torch.tensor([4, 6], dtype=tensor_type), + "key1": torch.tensor([2, 4], dtype=tensor_type), + "key2": torch.tensor([4, 6], dtype=tensor_type), } else: expected = { - 'key1': torch.tensor([1, 2], dtype=tensor_type), - 'key2': torch.tensor([2, 3], dtype=tensor_type), + "key1": torch.tensor([1, 2], dtype=tensor_type), + "key2": torch.tensor([2, 3], dtype=tensor_type), } dist.all_reduce_dict(data, reduce_op) @@ -244,24 +227,24 @@ def test_all_reduce_dict(self): # `torch.cat` in torch1.5 can not concatenate different types so we # fallback to convert them all to float type. - if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + if digit_version(TORCH_VERSION) == digit_version("1.5.0"): if dist.get_rank() == 0: data = { - 'key1': torch.tensor([0, 1], dtype=torch.float32), - 'key2': torch.tensor([1, 2], dtype=torch.int32) + "key1": torch.tensor([0, 1], dtype=torch.float32), + "key2": torch.tensor([1, 2], dtype=torch.int32), } else: data = { - 'key1': torch.tensor([2, 3], dtype=torch.float32), - 'key2': torch.tensor([3, 4], dtype=torch.int32), + "key1": torch.tensor([2, 3], dtype=torch.float32), + "key2": torch.tensor([3, 4], dtype=torch.int32), } expected = { - 'key1': torch.tensor([2, 4], dtype=torch.float32), - 'key2': torch.tensor([4, 6], dtype=torch.float32), + "key1": torch.tensor([2, 4], dtype=torch.float32), + "key2": torch.tensor([4, 6], dtype=torch.float32), } - dist.all_reduce_dict(data, 'sum') + dist.all_reduce_dict(data, "sum") for key in data: assert torch.allclose(data[key], expected[key]) @@ -271,22 +254,22 @@ def test_all_gather_object(self): # data is a pickable python object if dist.get_rank() == 0: - data = 'foo' + data = "foo" else: data = {1: 2} - expected = ['foo', {1: 2}] + expected = ["foo", {1: 2}] output = dist.all_gather_object(data) self.assertEqual(output, expected) # data is a list of pickable python object if dist.get_rank() == 0: - data = ['foo', {1: 2}] + data = ["foo", {1: 2}] else: data = {2: 3} - expected = [['foo', {1: 2}], {2: 3}] + expected = [["foo", {1: 2}], {2: 3}] output = dist.all_gather_object(data) self.assertEqual(output, expected) @@ -296,27 +279,27 @@ def test_gather_object(self): # data is a pickable python object if dist.get_rank() == 0: - data = 'foo' + data = "foo" else: data = {1: 2} output = dist.gather_object(data, dst=0) if dist.get_rank() == 0: - self.assertEqual(output, ['foo', {1: 2}]) + self.assertEqual(output, ["foo", {1: 2}]) else: self.assertIsNone(output) # data is a list of pickable python object if dist.get_rank() == 0: - data = ['foo', {1: 2}] + data = ["foo", {1: 2}] else: data = {2: 3} output = dist.gather_object(data, dst=0) if dist.get_rank() == 0: - self.assertEqual(output, [['foo', {1: 2}], {2: 3}]) + self.assertEqual(output, [["foo", {1: 2}], {2: 3}]) else: self.assertIsNone(output) @@ -324,50 +307,38 @@ def test_all_reduce_params(self): self._init_dist_env(self.rank, self.world_size) tensor_types = [torch.int64, torch.float32] - reduce_ops = ['sum', 'mean'] + reduce_ops = ["sum", "mean"] coalesces = [True, False] - for tensor_type, reduce_op, coalesce in zip(tensor_types, reduce_ops, - coalesces): + for tensor_type, reduce_op, coalesce in zip(tensor_types, reduce_ops, coalesces, strict=False): if dist.get_rank() == 0: - data = [ - torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) - ] + data = [torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)] else: - data = ( - torch.tensor([2, 3], dtype=tensor_type) - for _ in range(100)) + data = (torch.tensor([2, 3], dtype=tensor_type) for _ in range(100)) data_gen = (item for item in data) - if reduce_op == 'sum': - expected = ( - torch.tensor([2, 4], dtype=tensor_type) - for _ in range(100)) + if reduce_op == "sum": + expected = (torch.tensor([2, 4], dtype=tensor_type) for _ in range(100)) else: - expected = ( - torch.tensor([1, 2], dtype=tensor_type) - for _ in range(100)) + expected = (torch.tensor([1, 2], dtype=tensor_type) for _ in range(100)) dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) - for item1, item2 in zip(data, expected): + for item1, item2 in zip(data, expected, strict=False): self.assertTrue(torch.allclose(item1, item2)) -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +@unittest.skipIf(torch.cuda.device_count() < 2, reason="need 2 gpu to test nccl") class TestDistWithNCCLBackend(MultiProcessTestCase): - def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29505' - os.environ['RANK'] = str(rank) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29505" + os.environ["RANK"] = str(rank) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) def setUp(self): super().setUp() @@ -376,12 +347,11 @@ def setUp(self): def test_all_reduce(self): self._init_dist_env(self.rank, self.world_size) tensor_types = [torch.int64, torch.float32] - reduce_ops = ['sum', 'mean'] - device_types = ['cpu', 'cuda'] - for tensor_type, reduce_op, device_type in product( - tensor_types, reduce_ops, device_types): + reduce_ops = ["sum", "mean"] + device_types = ["cpu", "cuda"] + for tensor_type, reduce_op, device_type in product(tensor_types, reduce_ops, device_types): # 'mean' op does not support torch.int64 - if tensor_type == torch.int64 and reduce_op == 'mean': + if tensor_type == torch.int64 and reduce_op == "mean": continue if dist.get_rank() == 0: @@ -389,37 +359,30 @@ def test_all_reduce(self): else: data = torch.tensor([3, 4], dtype=tensor_type).to(device_type) - if reduce_op == 'sum': - expected = torch.tensor([4, 6], - dtype=tensor_type).to(device_type) + if reduce_op == "sum": + expected = torch.tensor([4, 6], dtype=tensor_type).to(device_type) else: - expected = torch.tensor([2, 3], - dtype=tensor_type).to(device_type) + expected = torch.tensor([2, 3], dtype=tensor_type).to(device_type) dist.all_reduce(data, reduce_op) self.assertTrue(torch.allclose(data, expected)) def test_all_gather(self): self._init_dist_env(self.rank, self.world_size) - for device_type in ('cpu', 'cuda'): + for device_type in ("cpu", "cuda"): if dist.get_rank() == 0: data = torch.tensor([0, 1]).to(device_type) else: data = torch.tensor([1, 2]).to(device_type) - expected = [ - torch.tensor([0, 1]).to(device_type), - torch.tensor([1, 2]).to(device_type) - ] + expected = [torch.tensor([0, 1]).to(device_type), torch.tensor([1, 2]).to(device_type)] output = dist.all_gather(data) - self.assertTrue( - torch.allclose(output[dist.get_rank()], - expected[dist.get_rank()])) + self.assertTrue(torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])) def test_broadcast_dist(self): self._init_dist_env(self.rank, self.world_size) - for device_type in ('cpu', 'cuda'): + for device_type in ("cpu", "cuda"): if dist.get_rank() == 0: data = torch.tensor([0, 1]).to(device_type) else: @@ -431,9 +394,7 @@ def test_broadcast_dist(self): def test_sync_random_seed(self): self._init_dist_env(self.rank, self.world_size) - with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024)) as mock_tensor: + with patch.object(torch, "tensor", return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() @@ -441,53 +402,44 @@ def test_sync_random_seed(self): def test_broadcast_object_list(self): self._init_dist_env(self.rank, self.world_size) if dist.get_rank() == 0: - data = ['foo', 12, {1: 2}] + data = ["foo", 12, {1: 2}] else: data = [None, None, None] - expected = ['foo', 12, {1: 2}] + expected = ["foo", 12, {1: 2}] dist.broadcast_object_list(data) self.assertEqual(data, expected) def test_all_reduce_dict(self): self._init_dist_env(self.rank, self.world_size) tensor_types = [torch.int64, torch.float32] - reduce_ops = ['sum', 'mean'] - device_types = ['cpu', 'cuda'] - for tensor_type, reduce_op, device_type in product( - tensor_types, reduce_ops, device_types): + reduce_ops = ["sum", "mean"] + device_types = ["cpu", "cuda"] + for tensor_type, reduce_op, device_type in product(tensor_types, reduce_ops, device_types): # 'mean' op does not support torch.int64 - if tensor_type == torch.int64 and reduce_op == 'mean': + if tensor_type == torch.int64 and reduce_op == "mean": continue if dist.get_rank() == 0: data = { - 'key1': - torch.tensor([0, 1], dtype=tensor_type).to(device_type), - 'key2': - torch.tensor([1, 2], dtype=tensor_type).to(device_type), + "key1": torch.tensor([0, 1], dtype=tensor_type).to(device_type), + "key2": torch.tensor([1, 2], dtype=tensor_type).to(device_type), } else: data = { - 'key1': - torch.tensor([2, 3], dtype=tensor_type).to(device_type), - 'key2': - torch.tensor([3, 4], dtype=tensor_type).to(device_type), + "key1": torch.tensor([2, 3], dtype=tensor_type).to(device_type), + "key2": torch.tensor([3, 4], dtype=tensor_type).to(device_type), } - if reduce_op == 'sum': + if reduce_op == "sum": expected = { - 'key1': - torch.tensor([2, 4], dtype=tensor_type).to(device_type), - 'key2': - torch.tensor([4, 6], dtype=tensor_type).to(device_type), + "key1": torch.tensor([2, 4], dtype=tensor_type).to(device_type), + "key2": torch.tensor([4, 6], dtype=tensor_type).to(device_type), } else: expected = { - 'key1': - torch.tensor([1, 2], dtype=tensor_type).to(device_type), - 'key2': - torch.tensor([2, 3], dtype=tensor_type).to(device_type), + "key1": torch.tensor([1, 2], dtype=tensor_type).to(device_type), + "key2": torch.tensor([2, 3], dtype=tensor_type).to(device_type), } dist.all_reduce_dict(data, reduce_op) @@ -497,35 +449,25 @@ def test_all_reduce_dict(self): # `torch.cat` in torch1.5 can not concatenate different types so we # fallback to convert them all to float type. - for device_type in ('cpu', 'cuda'): - if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + for device_type in ("cpu", "cuda"): + if digit_version(TORCH_VERSION) == digit_version("1.5.0"): if dist.get_rank() == 0: data = { - 'key1': - torch.tensor([0, 1], - dtype=torch.float32).to(device_type), - 'key2': - torch.tensor([1, 2], - dtype=torch.int32).to(device_type), + "key1": torch.tensor([0, 1], dtype=torch.float32).to(device_type), + "key2": torch.tensor([1, 2], dtype=torch.int32).to(device_type), } else: data = { - 'key1': - torch.tensor([2, 3], - dtype=torch.float32).to(device_type), - 'key2': - torch.tensor([3, 4], - dtype=torch.int32).to(device_type), + "key1": torch.tensor([2, 3], dtype=torch.float32).to(device_type), + "key2": torch.tensor([3, 4], dtype=torch.int32).to(device_type), } expected = { - 'key1': - torch.tensor([2, 4], dtype=torch.float32).to(device_type), - 'key2': - torch.tensor([4, 6], dtype=torch.float32).to(device_type), + "key1": torch.tensor([2, 4], dtype=torch.float32).to(device_type), + "key2": torch.tensor([4, 6], dtype=torch.float32).to(device_type), } - dist.all_reduce_dict(data, 'sum') + dist.all_reduce_dict(data, "sum") for key in data: assert torch.allclose(data[key], expected[key]) @@ -535,22 +477,22 @@ def test_all_gather_object(self): # data is a pickable python object if dist.get_rank() == 0: - data = 'foo' + data = "foo" else: data = {1: 2} - expected = ['foo', {1: 2}] + expected = ["foo", {1: 2}] output = dist.all_gather_object(data) self.assertEqual(output, expected) # data is a list of pickable python object if dist.get_rank() == 0: - data = ['foo', {1: 2}] + data = ["foo", {1: 2}] else: data = {2: 3} - expected = [['foo', {1: 2}], {2: 3}] + expected = [["foo", {1: 2}], {2: 3}] output = dist.all_gather_object(data) self.assertEqual(output, expected) @@ -560,16 +502,16 @@ def test_collect_results(self): # 1. test `device` and `tmpdir` parameters if dist.get_rank() == 0: - data = ['foo', {1: 2}] + data = ["foo", {1: 2}] else: - data = [24, {'a': 'b'}] + data = [24, {"a": "b"}] size = 4 - expected = ['foo', 24, {1: 2}, {'a': 'b'}] + expected = ["foo", 24, {1: 2}, {"a": "b"}] # 1.1 test `device=cpu` and `tmpdir` is None - output = dist.collect_results(data, size, device='cpu') + output = dist.collect_results(data, size, device="cpu") if dist.get_rank() == 0: self.assertEqual(output, expected) else: @@ -580,8 +522,7 @@ def test_collect_results(self): # broadcast tmpdir to all ranks to make it consistent object_list = [tmpdir] dist.broadcast_object_list(object_list) - output = dist.collect_results( - data, size, device='cpu', tmpdir=object_list[0]) + output = dist.collect_results(data, size, device="cpu", tmpdir=object_list[0]) if dist.get_rank() == 0: self.assertEqual(output, expected) else: @@ -592,7 +533,7 @@ def test_collect_results(self): self.assertFalse(osp.exists(object_list[0])) # 1.3 test `device=gpu` - output = dist.collect_results(data, size, device='gpu') + output = dist.collect_results(data, size, device="gpu") if dist.get_rank() == 0: self.assertEqual(output, expected) else: @@ -600,23 +541,23 @@ def test_collect_results(self): # 2. test `size` parameter if dist.get_rank() == 0: - data = ['foo', {1: 2}] + data = ["foo", {1: 2}] else: - data = [24, {'a': 'b'}] + data = [24, {"a": "b"}] size = 3 - expected = ['foo', 24, {1: 2}] + expected = ["foo", 24, {1: 2}] # 2.1 test `device=cpu` and `tmpdir` is None - output = dist.collect_results(data, size, device='cpu') + output = dist.collect_results(data, size, device="cpu") if dist.get_rank() == 0: self.assertEqual(output, expected) else: self.assertIsNone(output) # 2.2 test `device=gpu` - output = dist.collect_results(data, size, device='gpu') + output = dist.collect_results(data, size, device="gpu") if dist.get_rank() == 0: self.assertEqual(output, expected) else: @@ -626,33 +567,22 @@ def test_all_reduce_params(self): self._init_dist_env(self.rank, self.world_size) tensor_types = [torch.int64, torch.float32] - reduce_ops = ['sum', 'mean'] + reduce_ops = ["sum", "mean"] coalesces = [True, False] - device_types = ['cpu', 'cuda'] - for tensor_type, reduce_op, coalesce, device_type in zip( - tensor_types, reduce_ops, coalesces, device_types): + device_types = ["cpu", "cuda"] + for tensor_type, reduce_op, coalesce, device_type in zip(tensor_types, reduce_ops, coalesces, device_types, strict=False): if dist.get_rank() == 0: - data = [ - torch.tensor([0, 1], dtype=tensor_type).to(device_type) - for _ in range(100) - ] + data = [torch.tensor([0, 1], dtype=tensor_type).to(device_type) for _ in range(100)] else: - data = [ - torch.tensor([2, 3], dtype=tensor_type).to(device_type) - for _ in range(100) - ] + data = [torch.tensor([2, 3], dtype=tensor_type).to(device_type) for _ in range(100)] data_gen = (item for item in data) dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) - if reduce_op == 'sum': - expected = ( - torch.tensor([2, 4], dtype=tensor_type).to(device_type) - for _ in range(100)) + if reduce_op == "sum": + expected = (torch.tensor([2, 4], dtype=tensor_type).to(device_type) for _ in range(100)) else: - expected = ( - torch.tensor([1, 2], dtype=tensor_type).to(device_type) - for _ in range(100)) + expected = (torch.tensor([1, 2], dtype=tensor_type).to(device_type) for _ in range(100)) - for item1, item2 in zip(data_gen, expected): + for item1, item2 in zip(data_gen, expected, strict=False): self.assertTrue(torch.allclose(item1, item2)) diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py index d9af72f964..ad560189be 100644 --- a/tests/test_dist/test_utils.py +++ b/tests/test_dist/test_utils.py @@ -12,7 +12,6 @@ class TestUtils(TestCase): - def test_get_backend(self): self.assertIsNone(dist.get_backend()) @@ -35,7 +34,6 @@ def test_is_main_process(self): self.assertTrue(dist.is_main_process()) def test_master_only(self): - @dist.master_only def fun(): assert dist.get_rank() == 0 @@ -48,11 +46,11 @@ def test_barrier(self): def test_get_data_device(self): # data is a Tensor data = torch.tensor([0, 1]) - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a list of Tensor data = [torch.tensor([0, 1]), torch.tensor([2, 3])] - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a list but not all items are Tensor data = [torch.tensor([0, 1]), 123] @@ -60,12 +58,12 @@ def test_get_data_device(self): dist.get_data_device(data) # data is a list containing Tensor and a dict - data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}] - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + data = [torch.tensor([0, 1]), {"key": torch.tensor([2, 3])}] + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a list containing Tensor and a dict but the dict contains # invalid type - data = [torch.tensor([0, 1]), {'key': '123'}] + data = [torch.tensor([0, 1]), {"key": "123"}] with self.assertRaises(TypeError): dist.get_data_device(data) @@ -74,20 +72,20 @@ def test_get_data_device(self): dist.get_data_device([]) # data is a dict - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])} - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([0, 1])} + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a dict but not all values are Tensor - data = {'key1': torch.tensor([0, 1]), 'key2': 123} + data = {"key1": torch.tensor([0, 1]), "key2": 123} with self.assertRaises(TypeError): dist.get_data_device(data) # data is a dict and one of values is list of Tensor - data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]} - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + data = {"key1": torch.tensor([0, 1]), "key2": [torch.tensor([0, 1])]} + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a dict and one of values is an invalid type - data = {'key1': torch.tensor([0, 1]), 'key2': ['123']} + data = {"key1": torch.tensor([0, 1]), "key2": ["123"]} with self.assertRaises(TypeError): dist.get_data_device(data) @@ -96,15 +94,12 @@ def test_get_data_device(self): dist.get_data_device({}) # data is not a valid type - with self.assertRaisesRegex( - TypeError, - 'data should be a Tensor, sequence of tensor or dict'): - dist.get_data_device('123') + with self.assertRaisesRegex(TypeError, "data should be a Tensor, sequence of tensor or dict"): + dist.get_data_device("123") - @unittest.skipIf( - torch.cuda.device_count() == 0, reason='at lest need 1 gpu to test') + @unittest.skipIf(torch.cuda.device_count() == 0, reason="at lest need 1 gpu to test") def test_cast_data_device(self): - expected_device = torch.device('cuda', torch.cuda.current_device()) + expected_device = torch.device("cuda", torch.cuda.current_device()) # data is a Tensor data = torch.tensor([0, 1]) output = dist.cast_data_device(data, expected_device) @@ -126,100 +121,92 @@ def test_cast_data_device(self): data = [torch.tensor([0, 1]), torch.tensor([2, 3])] out = [torch.tensor([3, 4]), torch.tensor([5, 6])] output = dist.cast_data_device(data, expected_device, out=out) - for item1, item2 in zip(output, out): + for item1, item2 in zip(output, out, strict=False): self.assertEqual(item1.device, expected_device) self.assertTrue(torch.allclose(item1.cpu(), item2)) # data is a list containing a Tensor and a dict - data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}] + data = [torch.tensor([0, 1]), {"key": torch.tensor([2, 3])}] output = dist.cast_data_device(data, expected_device) self.assertEqual(output[0].device, expected_device) - self.assertEqual(output[1]['key'].device, expected_device) + self.assertEqual(output[1]["key"].device, expected_device) # data is a list containing a Tensor and a dict, so does out - data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}] - out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}] + data = [torch.tensor([0, 1]), {"key": torch.tensor([2, 3])}] + out = [torch.tensor([3, 4]), {"key": torch.tensor([5, 6])}] output = dist.cast_data_device(data, expected_device, out=out) self.assertEqual(output[0].device, expected_device) self.assertTrue(torch.allclose(output[0].cpu(), out[0])) - self.assertEqual(output[1]['key'].device, expected_device) - self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key'])) + self.assertEqual(output[1]["key"].device, expected_device) + self.assertTrue(torch.allclose(output[1]["key"].cpu(), out[1]["key"])) # data is an empty list - with self.assertRaisesRegex(ValueError, 'data should not be empty'): + with self.assertRaisesRegex(ValueError, "data should not be empty"): dist.cast_data_device([], expected_device) # data is a dict - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])} + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([2, 3])} output = dist.cast_data_device(data, expected_device) for k, v in output.items(): self.assertEqual(v.device, expected_device) # data is a dict, so does out - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])} - out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])} + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([2, 3])} + out = {"key1": torch.tensor([3, 4]), "key2": torch.tensor([5, 6])} output = dist.cast_data_device(data, expected_device, out=out) for k, v in output.items(): self.assertEqual(v.device, expected_device) self.assertTrue(torch.allclose(v.cpu(), out[k])) # the length of data and out should be same - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])} - out = {'key1': torch.tensor([3, 4])} - with self.assertRaisesRegex(ValueError, - 'length of data and out should be same'): + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([2, 3])} + out = {"key1": torch.tensor([3, 4])} + with self.assertRaisesRegex(ValueError, "length of data and out should be same"): dist.cast_data_device(data, expected_device, out=out) # data is an empty dict - with self.assertRaisesRegex(ValueError, 'data should not be empty'): + with self.assertRaisesRegex(ValueError, "data should not be empty"): dist.cast_data_device({}, expected_device) # data is a dict and one of values is list - data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]} - out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]} + data = {"key1": torch.tensor([0, 1]), "key2": [torch.tensor([2, 3])]} + out = {"key1": torch.tensor([3, 4]), "key2": [torch.tensor([5, 6])]} output = dist.cast_data_device(data, expected_device, out=out) - self.assertEqual(output['key1'].device, expected_device) - self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1'])) - self.assertEqual(output['key2'][0].device, expected_device) - self.assertTrue( - torch.allclose(output['key2'][0].cpu(), out['key2'][0])) + self.assertEqual(output["key1"].device, expected_device) + self.assertTrue(torch.allclose(output["key1"].cpu(), out["key1"])) + self.assertEqual(output["key2"][0].device, expected_device) + self.assertTrue(torch.allclose(output["key2"][0].cpu(), out["key2"][0])) # data is not a valid type - with self.assertRaisesRegex( - TypeError, 'data should be a Tensor, list of tensor or dict'): + with self.assertRaisesRegex(TypeError, "data should be a Tensor, list of tensor or dict"): dist.cast_data_device(123, expected_device) - with self.assertRaisesRegex( - TypeError, 'data should be a Tensor, list of tensor or dict'): - dist.cast_data_device('123', expected_device) + with self.assertRaisesRegex(TypeError, "data should be a Tensor, list of tensor or dict"): + dist.cast_data_device("123", expected_device) - with self.assertRaisesRegex( - TypeError, 'data should be a Tensor, list of tensor or dict'): + with self.assertRaisesRegex(TypeError, "data should be a Tensor, list of tensor or dict"): dist.cast_data_device(np.array([0, 1]), expected_device) # data and out are not the same type data = torch.tensor([0, 1]) - out = '123' - with self.assertRaisesRegex(TypeError, - 'out should be the same type with data'): + out = "123" + with self.assertRaisesRegex(TypeError, "out should be the same type with data"): dist.cast_data_device(data, expected_device, out=out) data = {0, 1} out = {2, 3} - with self.assertRaisesRegex(TypeError, 'out should not be a set'): + with self.assertRaisesRegex(TypeError, "out should not be a set"): dist.cast_data_device(data, expected_device, out=out) class TestUtilsWithGLOOBackend(MultiProcessTestCase): - def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29505' - os.environ['RANK'] = str(rank) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29505" + os.environ["RANK"] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) dist.init_local_group(0, world_size) def setUp(self): @@ -247,8 +234,7 @@ def test_local_size(self): def test_local_rank(self): self._init_dist_env(self.rank, self.world_size) - self.assertEqual( - torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + self.assertEqual(torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) def test_get_dist_info(self): self._init_dist_env(self.rank, self.world_size) @@ -278,11 +264,11 @@ def test_get_data_device(self): # data is a Tensor data = torch.tensor([0, 1]) - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a list of Tensor data = [torch.tensor([0, 1]), torch.tensor([2, 3])] - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a list but not all items are Tensor data = [torch.tensor([0, 1]), 123] @@ -290,12 +276,12 @@ def test_get_data_device(self): dist.get_data_device(data) # data is a list containing Tensor and a dict - data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}] - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + data = [torch.tensor([0, 1]), {"key": torch.tensor([2, 3])}] + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a list containing Tensor and a dict but the dict contains # invalid type - data = [torch.tensor([0, 1]), {'key': '123'}] + data = [torch.tensor([0, 1]), {"key": "123"}] with self.assertRaises(TypeError): dist.get_data_device(data) @@ -304,20 +290,20 @@ def test_get_data_device(self): dist.get_data_device([]) # data is a dict - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])} - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([0, 1])} + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a dict but not all values are Tensor - data = {'key1': torch.tensor([0, 1]), 'key2': 123} + data = {"key1": torch.tensor([0, 1]), "key2": 123} with self.assertRaises(TypeError): dist.get_data_device(data) # data is a dict and one of values is list of Tensor - data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]} - self.assertEqual(dist.get_data_device(data), torch.device('cpu')) + data = {"key1": torch.tensor([0, 1]), "key2": [torch.tensor([0, 1])]} + self.assertEqual(dist.get_data_device(data), torch.device("cpu")) # data is a dict and one of values is an invalid type - data = {'key1': torch.tensor([0, 1]), 'key2': ['123']} + data = {"key1": torch.tensor([0, 1]), "key2": ["123"]} with self.assertRaises(TypeError): dist.get_data_device(data) @@ -326,31 +312,26 @@ def test_get_data_device(self): dist.get_data_device({}) # data is not a valid type - with self.assertRaisesRegex( - TypeError, - 'data should be a Tensor, sequence of tensor or dict'): - dist.get_data_device('123') + with self.assertRaisesRegex(TypeError, "data should be a Tensor, sequence of tensor or dict"): + dist.get_data_device("123") def test_get_comm_device(self): self._init_dist_env(self.rank, self.world_size) group = dist.get_default_group() - assert dist.get_comm_device(group) == torch.device('cpu') + assert dist.get_comm_device(group) == torch.device("cpu") -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +@unittest.skipIf(torch.cuda.device_count() < 2, reason="need 2 gpu to test nccl") class TestUtilsWithNCCLBackend(MultiProcessTestCase): - def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29505' - os.environ['RANK'] = str(rank) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29505" + os.environ["RANK"] = str(rank) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) dist.init_local_group(0, world_size) def setUp(self): @@ -378,8 +359,7 @@ def test_local_size(self): def test_local_rank(self): self._init_dist_env(self.rank, self.world_size) - self.assertEqual( - torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + self.assertEqual(torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) def test_get_dist_info(self): self._init_dist_env(self.rank, self.world_size) @@ -407,17 +387,14 @@ def fun(): def test_get_data_device(self): self._init_dist_env(self.rank, self.world_size) - expected_device = torch.device('cuda', torch.cuda.current_device()) + expected_device = torch.device("cuda", torch.cuda.current_device()) # data is a Tensor data = torch.tensor([0, 1]).to(expected_device) self.assertEqual(dist.get_data_device(data), expected_device) # data is a list of Tensor - data = [ - torch.tensor([0, 1]).to(expected_device), - torch.tensor([2, 3]).to(expected_device) - ] + data = [torch.tensor([0, 1]).to(expected_device), torch.tensor([2, 3]).to(expected_device)] self.assertEqual(dist.get_data_device(data), expected_device) # data is a list but not all items are Tensor @@ -431,16 +408,12 @@ def test_get_data_device(self): dist.get_data_device(data) # data is a list containing Tensor and a dict - data = [ - torch.tensor([0, 1]).to(expected_device), { - 'key': torch.tensor([2, 3]).to(expected_device) - } - ] + data = [torch.tensor([0, 1]).to(expected_device), {"key": torch.tensor([2, 3]).to(expected_device)}] self.assertEqual(dist.get_data_device(data), expected_device) # data is a list containing Tensor and a dict but the dict contains # invalid type - data = [torch.tensor([0, 1]).to(expected_device), {'key': '123'}] + data = [torch.tensor([0, 1]).to(expected_device), {"key": "123"}] with self.assertRaises(TypeError): dist.get_data_device(data) @@ -449,37 +422,25 @@ def test_get_data_device(self): dist.get_data_device([]) # data is a dict - data = { - 'key1': torch.tensor([0, 1]).to(expected_device), - 'key2': torch.tensor([0, 1]).to(expected_device) - } + data = {"key1": torch.tensor([0, 1]).to(expected_device), "key2": torch.tensor([0, 1]).to(expected_device)} self.assertEqual(dist.get_data_device(data), expected_device) # data is a dict but not all values are Tensor - data = {'key1': torch.tensor([0, 1]).to(expected_device), 'key2': 123} + data = {"key1": torch.tensor([0, 1]).to(expected_device), "key2": 123} with self.assertRaises(TypeError): dist.get_data_device(data) # data is a dict but not all values have the same device type - data = { - 'key1': torch.tensor([0, 1]), - 'key2': torch.tensor([0, 1]).to(expected_device) - } + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([0, 1]).to(expected_device)} with self.assertRaises(ValueError): dist.get_data_device(data) # data is a dict and one of values is list of Tensor - data = { - 'key1': torch.tensor([0, 1]).to(expected_device), - 'key2': [torch.tensor([0, 1]).to(expected_device)] - } + data = {"key1": torch.tensor([0, 1]).to(expected_device), "key2": [torch.tensor([0, 1]).to(expected_device)]} self.assertEqual(dist.get_data_device(data), expected_device) # data is a dict and one of values is an invalid type - data = { - 'key1': torch.tensor([0, 1]).to(expected_device), - 'key2': ['123'] - } + data = {"key1": torch.tensor([0, 1]).to(expected_device), "key2": ["123"]} with self.assertRaises(TypeError): dist.get_data_device(data) @@ -488,21 +449,19 @@ def test_get_data_device(self): dist.get_data_device({}) # data is not a valid type - with self.assertRaisesRegex( - TypeError, - 'data should be a Tensor, sequence of tensor or dict'): - dist.get_data_device('123') + with self.assertRaisesRegex(TypeError, "data should be a Tensor, sequence of tensor or dict"): + dist.get_data_device("123") def test_get_comm_device(self): self._init_dist_env(self.rank, self.world_size) group = dist.get_default_group() - expected = torch.device('cuda', torch.cuda.current_device()) + expected = torch.device("cuda", torch.cuda.current_device()) self.assertEqual(dist.get_comm_device(group), expected) def test_cast_data_device(self): self._init_dist_env(self.rank, self.world_size) - expected_device = torch.device('cuda', torch.cuda.current_device()) + expected_device = torch.device("cuda", torch.cuda.current_device()) # data is a Tensor data = torch.tensor([0, 1]) output = dist.cast_data_device(data, expected_device) @@ -524,85 +483,79 @@ def test_cast_data_device(self): data = [torch.tensor([0, 1]), torch.tensor([2, 3])] out = [torch.tensor([3, 4]), torch.tensor([5, 6])] output = dist.cast_data_device(data, expected_device, out=out) - for item1, item2 in zip(output, out): + for item1, item2 in zip(output, out, strict=False): self.assertEqual(item1.device, expected_device) self.assertTrue(torch.allclose(item1.cpu(), item2)) # data is a list containing a Tensor and a dict - data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}] + data = [torch.tensor([0, 1]), {"key": torch.tensor([2, 3])}] output = dist.cast_data_device(data, expected_device) self.assertEqual(output[0].device, expected_device) - self.assertEqual(output[1]['key'].device, expected_device) + self.assertEqual(output[1]["key"].device, expected_device) # data is a list containing a Tensor and a dict, so does out - data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}] - out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}] + data = [torch.tensor([0, 1]), {"key": torch.tensor([2, 3])}] + out = [torch.tensor([3, 4]), {"key": torch.tensor([5, 6])}] output = dist.cast_data_device(data, expected_device, out=out) self.assertEqual(output[0].device, expected_device) self.assertTrue(torch.allclose(output[0].cpu(), out[0])) - self.assertEqual(output[1]['key'].device, expected_device) - self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key'])) + self.assertEqual(output[1]["key"].device, expected_device) + self.assertTrue(torch.allclose(output[1]["key"].cpu(), out[1]["key"])) # data is an empty list - with self.assertRaisesRegex(ValueError, 'data should not be empty'): + with self.assertRaisesRegex(ValueError, "data should not be empty"): dist.cast_data_device([], expected_device) # data is a dict - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])} + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([2, 3])} output = dist.cast_data_device(data, expected_device) - for k, v in output.items(): + for v in output.values(): self.assertEqual(v.device, expected_device) # data is a dict, so does out - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])} - out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])} + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([2, 3])} + out = {"key1": torch.tensor([3, 4]), "key2": torch.tensor([5, 6])} output = dist.cast_data_device(data, expected_device, out=out) for k, v in output.items(): self.assertEqual(v.device, expected_device) self.assertTrue(torch.allclose(v.cpu(), out[k])) # the length of data and out should be same - data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])} - out = {'key1': torch.tensor([3, 4])} - with self.assertRaisesRegex(ValueError, - 'length of data and out should be same'): + data = {"key1": torch.tensor([0, 1]), "key2": torch.tensor([2, 3])} + out = {"key1": torch.tensor([3, 4])} + with self.assertRaisesRegex(ValueError, "length of data and out should be same"): dist.cast_data_device(data, expected_device, out=out) # data is an empty dict - with self.assertRaisesRegex(ValueError, 'data should not be empty'): + with self.assertRaisesRegex(ValueError, "data should not be empty"): dist.cast_data_device({}, expected_device) # data is a dict and one of values is list - data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]} - out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]} + data = {"key1": torch.tensor([0, 1]), "key2": [torch.tensor([2, 3])]} + out = {"key1": torch.tensor([3, 4]), "key2": [torch.tensor([5, 6])]} output = dist.cast_data_device(data, expected_device, out=out) - self.assertEqual(output['key1'].device, expected_device) - self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1'])) - self.assertEqual(output['key2'][0].device, expected_device) - self.assertTrue( - torch.allclose(output['key2'][0].cpu(), out['key2'][0])) + self.assertEqual(output["key1"].device, expected_device) + self.assertTrue(torch.allclose(output["key1"].cpu(), out["key1"])) + self.assertEqual(output["key2"][0].device, expected_device) + self.assertTrue(torch.allclose(output["key2"][0].cpu(), out["key2"][0])) # data is not a valid type - with self.assertRaisesRegex( - TypeError, 'data should be a Tensor, list of tensor or dict'): + with self.assertRaisesRegex(TypeError, "data should be a Tensor, list of tensor or dict"): dist.cast_data_device(123, expected_device) - with self.assertRaisesRegex( - TypeError, 'data should be a Tensor, list of tensor or dict'): - dist.cast_data_device('123', expected_device) + with self.assertRaisesRegex(TypeError, "data should be a Tensor, list of tensor or dict"): + dist.cast_data_device("123", expected_device) - with self.assertRaisesRegex( - TypeError, 'data should be a Tensor, list of tensor or dict'): + with self.assertRaisesRegex(TypeError, "data should be a Tensor, list of tensor or dict"): dist.cast_data_device(np.array([0, 1]), expected_device) # data and out are not the same type data = torch.tensor([0, 1]) - out = '123' - with self.assertRaisesRegex(TypeError, - 'out should be the same type with data'): + out = "123" + with self.assertRaisesRegex(TypeError, "out should be the same type with data"): dist.cast_data_device(data, expected_device, out=out) data = {0, 1} out = {2, 3} - with self.assertRaisesRegex(TypeError, 'out should not be a set'): + with self.assertRaisesRegex(TypeError, "out should not be a set"): dist.cast_data_device(data, expected_device, out=out) diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 58b7e1e6fe..d521220e1f 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import math import unittest -from typing import Dict, List, Optional, Sequence +from typing import Optional, Sequence from unittest import TestCase import numpy as np @@ -31,34 +31,28 @@ class ToyMetric(BaseMetric): returned as the metrics and override `accuracy` and `size`. """ - default_prefix = 'Toy' + default_prefix = "Toy" - def __init__(self, - collect_device: str = 'cpu', - prefix: Optional[str] = None, - dummy_metrics: Optional[Dict] = None): + def __init__(self, collect_device: str = "cpu", prefix: Optional[str] = None, dummy_metrics: Optional[dict] = None): super().__init__(collect_device=collect_device, prefix=prefix) self.dummy_metrics = dummy_metrics def process(self, data_batch, predictions): - results = [{ - 'pred': prediction['label'], - 'label': prediction['label'] - } for prediction in predictions] + results = [{"pred": prediction["label"], "label": prediction["label"]} for prediction in predictions] self.results.extend(results) - def compute_metrics(self, results: List): + def compute_metrics(self, results: list): if self.dummy_metrics is not None: assert isinstance(self.dummy_metrics, dict) return self.dummy_metrics.copy() - pred = np.array([result['pred'] for result in results]) - label = np.array([result['label'] for result in results]) + pred = np.array([result["pred"] for result in results]) + label = np.array([result["label"] for result in results]) acc = (pred == label).sum() / pred.size metrics = { - 'accuracy': acc, - 'size': pred.size, # To check the number of testing samples + "accuracy": acc, + "size": pred.size, # To check the number of testing samples } return metrics @@ -82,88 +76,70 @@ def generate_test_results(size, batch_size, pred, label): for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size data_batch = { - 'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)], - 'data_sample': [BaseDataElement(label=label) for _ in range(bs)] + "inputs": [np.zeros((3, 10, 10)) for _ in range(bs)], + "data_sample": [BaseDataElement(label=label) for _ in range(bs)], } - predictions = [ - BaseDataElement(pred=pred, label=label) for _ in range(bs) - ] + predictions = [BaseDataElement(pred=pred, label=label) for _ in range(bs)] yield (data_batch, predictions) class TestEvaluator(TestCase): - def test_single_metric(self): - cfg = dict(type='ToyMetric') + cfg = dict(type="ToyMetric") evaluator = Evaluator(cfg) size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, batch_size, pred=1, label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) - self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0) - self.assertEqual(metrics['Toy/size'], size) + self.assertAlmostEqual(metrics["Toy/accuracy"], 1.0) + self.assertEqual(metrics["Toy/size"], size) # Test empty results - cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0)) + cfg = dict(type="ToyMetric", dummy_metrics=dict(accuracy=1.0)) evaluator = Evaluator(cfg) # Warning should be raised if the results are empty - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): evaluator.evaluate(0) def test_composed_metrics(self): - cfg = [ - dict(type='ToyMetric'), - dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)) - ] + cfg = [dict(type="ToyMetric"), dict(type="ToyMetric", dummy_metrics=dict(mAP=0.0))] evaluator = Evaluator(cfg) size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, batch_size, pred=1, label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) - self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0) - self.assertAlmostEqual(metrics['Toy/mAP'], 0.0) - self.assertEqual(metrics['Toy/size'], size) + self.assertAlmostEqual(metrics["Toy/accuracy"], 1.0) + self.assertAlmostEqual(metrics["Toy/mAP"], 0.0) + self.assertEqual(metrics["Toy/size"], size) def test_ambiguous_metric(self): - cfg = [ - dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)), - dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)) - ] + cfg = [dict(type="ToyMetric", dummy_metrics=dict(mAP=0.0)), dict(type="ToyMetric", dummy_metrics=dict(mAP=0.0))] evaluator = Evaluator(cfg) size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, batch_size, pred=1, label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) - with self.assertRaisesRegex( - ValueError, - 'There are multiple evaluation results with the same metric ' - 'name'): + with self.assertRaisesRegex(ValueError, "There are multiple evaluation results with the same metric name"): _ = evaluator.evaluate(size=size) def test_dataset_meta(self): - dataset_meta = dict(classes=('cat', 'dog')) + dataset_meta = dict(classes=("cat", "dog")) - cfg = [ - dict(type='ToyMetric'), - dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)) - ] + cfg = [dict(type="ToyMetric"), dict(type="ToyMetric", dummy_metrics=dict(mAP=0.0))] evaluator = Evaluator(cfg) evaluator.dataset_meta = dataset_meta @@ -174,73 +150,67 @@ def test_dataset_meta(self): def test_collect_device(self): cfg = [ - dict(type='ToyMetric', collect_device='cpu'), - dict( - type='ToyMetric', - collect_device='gpu', - dummy_metrics=dict(mAP=0.0)) + dict(type="ToyMetric", collect_device="cpu"), + dict(type="ToyMetric", collect_device="gpu", dummy_metrics=dict(mAP=0.0)), ] evaluator = Evaluator(cfg) - self.assertEqual(evaluator.metrics[0].collect_device, 'cpu') - self.assertEqual(evaluator.metrics[1].collect_device, 'gpu') + self.assertEqual(evaluator.metrics[0].collect_device, "cpu") + self.assertEqual(evaluator.metrics[1].collect_device, "gpu") def test_prefix(self): - cfg = dict(type='NonPrefixedMetric') + cfg = dict(type="NonPrefixedMetric") logger = MMLogger.get_current_instance() # Warning should be raised if prefix is not set. - with self.assertLogs(logger, 'WARNING'): + with self.assertLogs(logger, "WARNING"): Evaluator(cfg) def test_get_metric_value(self): - metrics = { - 'prefix_0/metric_0': 0, - 'prefix_1/metric_0': 1, - 'prefix_1/metric_1': 2, - 'nonprefixed': 3, + "prefix_0/metric_0": 0, + "prefix_1/metric_0": 1, + "prefix_1/metric_1": 2, + "nonprefixed": 3, } # Test indicator with prefix - indicator = 'prefix_0/metric_0' # correct indicator + indicator = "prefix_0/metric_0" # correct indicator self.assertEqual(get_metric_value(indicator, metrics), 0) - indicator = 'prefix_1/metric_0' # correct indicator + indicator = "prefix_1/metric_0" # correct indicator self.assertEqual(get_metric_value(indicator, metrics), 1) - indicator = 'prefix_0/metric_1' # unmatched indicator (wrong metric) - with self.assertRaisesRegex(ValueError, 'can not match any metric'): + indicator = "prefix_0/metric_1" # unmatched indicator (wrong metric) + with self.assertRaisesRegex(ValueError, "can not match any metric"): _ = get_metric_value(indicator, metrics) - indicator = 'prefix_2/metric' # unmatched indicator (wrong prefix) - with self.assertRaisesRegex(ValueError, 'can not match any metric'): + indicator = "prefix_2/metric" # unmatched indicator (wrong prefix) + with self.assertRaisesRegex(ValueError, "can not match any metric"): _ = get_metric_value(indicator, metrics) # Test indicator without prefix - indicator = 'metric_1' # correct indicator (prefixed metric) + indicator = "metric_1" # correct indicator (prefixed metric) self.assertEqual(get_metric_value(indicator, metrics), 2) - indicator = 'nonprefixed' # correct indicator (non-prefixed metric) + indicator = "nonprefixed" # correct indicator (non-prefixed metric) self.assertEqual(get_metric_value(indicator, metrics), 3) - indicator = 'metric_0' # ambiguous indicator - with self.assertRaisesRegex(ValueError, 'matches multiple metrics'): + indicator = "metric_0" # ambiguous indicator + with self.assertRaisesRegex(ValueError, "matches multiple metrics"): _ = get_metric_value(indicator, metrics) - indicator = 'metric_2' # unmatched indicator - with self.assertRaisesRegex(ValueError, 'can not match any metric'): + indicator = "metric_2" # unmatched indicator + with self.assertRaisesRegex(ValueError, "can not match any metric"): _ = get_metric_value(indicator, metrics) def test_offline_evaluate(self): - cfg = dict(type='ToyMetric') + cfg = dict(type="ToyMetric") evaluator = Evaluator(cfg) size = 10 all_data = [dict() for _ in range(10)] - all_predictions = [ - BaseDataElement(pred=0, label=1) for _ in range(size) - ] + all_predictions = [BaseDataElement(pred=0, label=1) for _ in range(size)] evaluator.offline_evaluate(all_predictions, all_data) # Test with None data @@ -249,37 +219,34 @@ def test_offline_evaluate(self): # Different length of data and predictions will raise an error. all_data = [dict() for _ in range(9)] - with self.assertRaisesRegex( - AssertionError, - 'data_samples and data should have the same length'): + with self.assertRaisesRegex(AssertionError, "data_samples and data should have the same length"): evaluator.offline_evaluate(all_predictions, all_data) - @unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu') + @unittest.skipUnless(torch.cuda.is_available(), "can only run with gpu") def test_evaluate_cast_cpu(self): - cfg = dict(type='ToyMetric') + cfg = dict(type="ToyMetric") evaluator = Evaluator(cfg) size = 10 all_data = [ dict( - inputs=torch.zeros((3, 10, 10), device='cuda'), - data_sample=BaseDataElement( - label=torch.ones((1, ), device='cuda'))) + inputs=torch.zeros((3, 10, 10), device="cuda"), + data_sample=BaseDataElement(label=torch.ones((1,), device="cuda")), + ) for _ in range(size) ] all_predictions = [ - BaseDataElement( - pred=torch.zeros((1, ), device='cuda'), - label=torch.ones((1, ), device='cuda')) for _ in range(size) + BaseDataElement(pred=torch.zeros((1,), device="cuda"), label=torch.ones((1,), device="cuda")) + for _ in range(size) ] - for data, pred in zip(all_data, all_predictions): + for data, pred in zip(all_data, all_predictions, strict=False): evaluator.process([pred], [data]) - def test_results_device(results: List): + def test_results_device(results: list): for result in results: - self.assertEqual(result['pred'].device, torch.device('cpu')) - self.assertEqual(result['label'].device, torch.device('cpu')) + self.assertEqual(result["pred"].device, torch.device("cpu")) + self.assertEqual(result["label"].device, torch.device("cpu")) return {} # replace the `compute_metrics` to the test function diff --git a/tests/test_evaluator/test_metric.py b/tests/test_evaluator/test_metric.py index 055bd73ca1..60bc58239d 100644 --- a/tests/test_evaluator/test_metric.py +++ b/tests/test_evaluator/test_metric.py @@ -11,30 +11,24 @@ class TestDumpResults(TestCase): - def test_init(self): - with self.assertRaisesRegex(ValueError, - 'The output file must be a pkl file.'): - DumpResults(out_file_path='./results.json') + with self.assertRaisesRegex(ValueError, "The output file must be a pkl file."): + DumpResults(out_file_path="./results.json") # collect_dir could only be configured when collect_device='cpu' with self.assertRaises(ValueError): - DumpResults( - out_file_path='./results.json', - collect_device='gpu', - collect_dir='./tmp') + DumpResults(out_file_path="./results.json", collect_device="gpu", collect_dir="./tmp") def test_process(self): - metric = DumpResults(out_file_path='./results.pkl') + metric = DumpResults(out_file_path="./results.pkl") data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] metric.process(None, data_samples) self.assertEqual(len(metric.results), 1) - self.assertEqual(metric.results[0]['data'][0].device, - torch.device('cpu')) + self.assertEqual(metric.results[0]["data"][0].device, torch.device("cpu")) def test_compute_metrics(self): temp_dir = tempfile.TemporaryDirectory() - path = osp.join(temp_dir.name, 'results.pkl') + path = osp.join(temp_dir.name, "results.pkl") metric = DumpResults(out_file_path=path) data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] metric.process(None, data_samples) @@ -43,6 +37,6 @@ def test_compute_metrics(self): results = load(path) self.assertEqual(len(results), 1) - self.assertEqual(results[0]['data'][0].device, torch.device('cpu')) + self.assertEqual(results[0]["data"][0].device, torch.device("cpu")) temp_dir.cleanup() diff --git a/tests/test_fileio/test_backends/test_backend_utils.py b/tests/test_fileio/test_backends/test_backend_utils.py index 7903f5574e..44ea70e075 100644 --- a/tests/test_fileio/test_backends/test_backend_utils.py +++ b/tests/test_fileio/test_backends/test_backend_utils.py @@ -1,78 +1,72 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest -from mmengine.fileio.backends import (BaseStorageBackend, backends, - prefix_to_backends, register_backend) +from mmengine.fileio.backends import BaseStorageBackend, backends, prefix_to_backends, register_backend def test_register_backend(): # 1. two ways to register backend # 1.1 use it as a decorator - @register_backend('example') + @register_backend("example") class ExampleBackend(BaseStorageBackend): - def get(self, filepath): return filepath def get_text(self, filepath): return filepath - assert 'example' in backends + assert "example" in backends # 1.2 use it as a normal function class ExampleBackend1(BaseStorageBackend): - def get(self, filepath): return filepath def get_text(self, filepath): return filepath - register_backend('example1', ExampleBackend1) - assert 'example1' in backends + register_backend("example1", ExampleBackend1) + assert "example1" in backends # 2. test `name` parameter # 2. name should a string - with pytest.raises(TypeError, match='name should be a string'): + with pytest.raises(TypeError, match="name should be a string"): register_backend(1, ExampleBackend) - register_backend('example2', ExampleBackend) - assert 'example2' in backends + register_backend("example2", ExampleBackend) + assert "example2" in backends # 3. test `backend` parameter # If backend is not None, it should be a class and a subclass of # BaseStorageBackend. - with pytest.raises(TypeError, match='backend should be a class'): + with pytest.raises(TypeError, match="backend should be a class"): def test_backend(): pass - register_backend('example3', test_backend) + register_backend("example3", test_backend) class ExampleBackend2: - def get(self, filepath): return filepath def get_text(self, filepath): return filepath - with pytest.raises( - TypeError, match='not a subclass of BaseStorageBackend'): - register_backend('example3', ExampleBackend2) + with pytest.raises(TypeError, match="not a subclass of BaseStorageBackend"): + register_backend("example3", ExampleBackend2) # 4. test `force` parameter # 4.1 force=False - with pytest.raises(ValueError, match='example is already registered'): - register_backend('example', ExampleBackend) + with pytest.raises(ValueError, match="example is already registered"): + register_backend("example", ExampleBackend) # 4.2 force=True - register_backend('example', ExampleBackend, force=True) - assert 'example' in backends + register_backend("example", ExampleBackend, force=True) + assert "example" in backends # 5. test `prefixes` parameter class ExampleBackend3(BaseStorageBackend): - def get(self, filepath): return filepath @@ -80,35 +74,32 @@ def get_text(self, filepath): return filepath # 5.1 prefixes is a string - register_backend('example3', ExampleBackend3, prefixes='prefix1') - assert 'example3' in backends - assert 'prefix1' in prefix_to_backends + register_backend("example3", ExampleBackend3, prefixes="prefix1") + assert "example3" in backends + assert "prefix1" in prefix_to_backends # 5.2 prefixes is a list (tuple) of strings - register_backend( - 'example4', ExampleBackend3, prefixes=['prefix2', 'prefix3']) - assert 'example4' in backends - assert 'prefix2' in prefix_to_backends - assert 'prefix3' in prefix_to_backends - assert prefix_to_backends['prefix2'] == prefix_to_backends['prefix3'] + register_backend("example4", ExampleBackend3, prefixes=["prefix2", "prefix3"]) + assert "example4" in backends + assert "prefix2" in prefix_to_backends + assert "prefix3" in prefix_to_backends + assert prefix_to_backends["prefix2"] == prefix_to_backends["prefix3"] # 5.3 prefixes is an invalid type with pytest.raises(AssertionError): - register_backend('example5', ExampleBackend3, prefixes=1) + register_backend("example5", ExampleBackend3, prefixes=1) # 5.4 prefixes is already registered - with pytest.raises(ValueError, match='prefix2 is already registered'): - register_backend('example6', ExampleBackend3, prefixes='prefix2') + with pytest.raises(ValueError, match="prefix2 is already registered"): + register_backend("example6", ExampleBackend3, prefixes="prefix2") class ExampleBackend4(BaseStorageBackend): - def get(self, filepath): return filepath def get_text(self, filepath): return filepath - register_backend( - 'example6', ExampleBackend4, prefixes='prefix2', force=True) - assert 'example6' in backends - assert 'prefix2' in prefix_to_backends + register_backend("example6", ExampleBackend4, prefixes="prefix2", force=True) + assert "example6" in backends + assert "prefix2" in prefix_to_backends diff --git a/tests/test_fileio/test_backends/test_base_storage_backend.py b/tests/test_fileio/test_backends/test_base_storage_backend.py index 6aa608851d..295006dd3d 100644 --- a/tests/test_fileio/test_backends/test_base_storage_backend.py +++ b/tests/test_fileio/test_backends/test_base_storage_backend.py @@ -9,13 +9,10 @@ def test_base_storage_backend(): class ExampleBackend(BaseStorageBackend): pass - with pytest.raises( - TypeError, - match="Can't instantiate abstract class ExampleBackend"): + with pytest.raises(TypeError, match="Can't instantiate abstract class ExampleBackend"): ExampleBackend() class ExampleBackend(BaseStorageBackend): - def get(self, filepath): return filepath @@ -23,5 +20,5 @@ def get_text(self, filepath): return filepath backend = ExampleBackend() - assert backend.get('test') == 'test' - assert backend.get_text('test') == 'test' + assert backend.get("test") == "test" + assert backend.get_text("test") == "test" diff --git a/tests/test_fileio/test_backends/test_http_backend.py b/tests/test_fileio/test_backends/test_http_backend.py index c69394d147..0342b938df 100644 --- a/tests/test_fileio/test_backends/test_http_backend.py +++ b/tests/test_fileio/test_backends/test_http_backend.py @@ -15,23 +15,20 @@ def imfrombytes(content): def imread(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: content = f.read() img = imfrombytes(content) return img class TestHTTPBackend(TestCase): - @classmethod def setUpClass(cls): - cls.img_url = ( - 'https://download.openmmlab.com/mmengine/test-data/color.jpg') + cls.img_url = "https://download.openmmlab.com/mmengine/test-data/color.jpg" cls.img_shape = (300, 400, 3) - cls.text_url = ( - 'https://download.openmmlab.com/mmengine/test-data/filelist.txt') - cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' - cls.text_path = cls.test_data_dir / 'filelist.txt' + cls.text_url = "https://download.openmmlab.com/mmengine/test-data/filelist.txt" + cls.test_data_dir = Path(__file__).parent.parent.parent / "data" + cls.text_path = cls.test_data_dir / "filelist.txt" def test_get(self): backend = HTTPBackend() @@ -42,7 +39,7 @@ def test_get(self): def test_get_text(self): backend = HTTPBackend() text = backend.get_text(self.text_url) - self.assertEqual(self.text_path.open('r').read(), text) + self.assertEqual(self.text_path.open("r").read(), text) def test_get_local_path(self): backend = HTTPBackend() diff --git a/tests/test_fileio/test_backends/test_lmdb_backend.py b/tests/test_fileio/test_backends/test_lmdb_backend.py index dc2c7ded2b..c3fe944bc4 100644 --- a/tests/test_fileio/test_backends/test_lmdb_backend.py +++ b/tests/test_fileio/test_backends/test_lmdb_backend.py @@ -16,20 +16,19 @@ def imfrombytes(content): class TestLmdbBackend(TestCase): - @classmethod def setUpClass(cls): - cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' - cls.lmdb_path = cls.test_data_dir / 'demo.lmdb' + cls.test_data_dir = Path(__file__).parent.parent.parent / "data" + cls.lmdb_path = cls.test_data_dir / "demo.lmdb" @parameterized.expand([[Path], [str]]) def test_get(self, path_type): backend = LmdbBackend(path_type(self.lmdb_path)) - img_bytes = backend.get('baboon') + img_bytes = backend.get("baboon") img = imfrombytes(img_bytes) self.assertEqual(img.shape, (120, 125, 3)) def test_get_text(self): backend = LmdbBackend(self.lmdb_path) with self.assertRaises(NotImplementedError): - backend.get_text('filepath') + backend.get_text("filepath") diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py index 427ebf789a..cfe0f34c38 100644 --- a/tests/test_fileio/test_backends/test_local_backend.py +++ b/tests/test_fileio/test_backends/test_local_backend.py @@ -38,43 +38,42 @@ def build_temporary_directory(): | -- text2.txt \n """ with tempfile.TemporaryDirectory() as tmp_dir: - text1 = Path(tmp_dir) / 'text1.txt' - text1.open('w').write('text1') - text2 = Path(tmp_dir) / 'text2.txt' - text2.open('w').write('text2') - dir1 = Path(tmp_dir) / 'dir1' + text1 = Path(tmp_dir) / "text1.txt" + text1.open("w").write("text1") + text2 = Path(tmp_dir) / "text2.txt" + text2.open("w").write("text2") + dir1 = Path(tmp_dir) / "dir1" dir1.mkdir() - text3 = dir1 / 'text3.txt' - text3.open('w').write('text3') - dir2 = Path(tmp_dir) / 'dir2' + text3 = dir1 / "text3.txt" + text3.open("w").write("text3") + dir2 = Path(tmp_dir) / "dir2" dir2.mkdir() - jpg1 = dir2 / 'img.jpg' - jpg1.open('wb').write(b'img') - dir3 = dir2 / 'dir3' + jpg1 = dir2 / "img.jpg" + jpg1.open("wb").write(b"img") + dir3 = dir2 / "dir3" dir3.mkdir() - text4 = dir3 / 'text4.txt' - text4.open('w').write('text4') + text4 = dir3 / "text4.txt" + text4.open("w").write("text4") yield tmp_dir class TestLocalBackend(TestCase): - @classmethod def setUpClass(cls): - cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' - cls.img_path = cls.test_data_dir / 'color.jpg' + cls.test_data_dir = Path(__file__).parent.parent.parent / "data" + cls.img_path = cls.test_data_dir / "color.jpg" cls.img_shape = (300, 400, 3) - cls.text_path = cls.test_data_dir / 'filelist.txt' + cls.text_path = cls.test_data_dir / "filelist.txt" def test_name(self): backend = LocalBackend() - self.assertEqual(backend.name, 'LocalBackend') + self.assertEqual(backend.name, "LocalBackend") @parameterized.expand([[Path], [str]]) def test_get(self, path_type): backend = LocalBackend() img_bytes = backend.get(path_type(self.img_path)) - self.assertEqual(self.img_path.open('rb').read(), img_bytes) + self.assertEqual(self.img_path.open("rb").read(), img_bytes) img = imfrombytes(img_bytes) self.assertEqual(img.shape, self.img_shape) @@ -82,46 +81,46 @@ def test_get(self, path_type): def test_get_text(self, path_type): backend = LocalBackend() text = backend.get_text(path_type(self.text_path)) - self.assertEqual(self.text_path.open('r').read(), text) + self.assertEqual(self.text_path.open("r").read(), text) @parameterized.expand([[Path], [str]]) def test_put(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: - filepath = Path(tmp_dir) / 'test.jpg' - backend.put(b'disk', path_type(filepath)) - self.assertEqual(backend.get(filepath), b'disk') + filepath = Path(tmp_dir) / "test.jpg" + backend.put(b"disk", path_type(filepath)) + self.assertEqual(backend.get(filepath), b"disk") # If the directory does not exist, put will create a # directory first - filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' - backend.put(b'disk', path_type(filepath)) - self.assertEqual(backend.get(filepath), b'disk') + filepath = Path(tmp_dir) / "not_existed_dir" / "test.jpg" + backend.put(b"disk", path_type(filepath)) + self.assertEqual(backend.get(filepath), b"disk") @parameterized.expand([[Path], [str]]) def test_put_text(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: - filepath = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', path_type(filepath)) - self.assertEqual(backend.get_text(filepath), 'disk') + filepath = Path(tmp_dir) / "test.txt" + backend.put_text("disk", path_type(filepath)) + self.assertEqual(backend.get_text(filepath), "disk") # If the directory does not exist, put_text will create a # directory first - filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.txt' - backend.put_text('disk', path_type(filepath)) - self.assertEqual(backend.get_text(filepath), 'disk') + filepath = Path(tmp_dir) / "not_existed_dir" / "test.txt" + backend.put_text("disk", path_type(filepath)) + self.assertEqual(backend.get_text(filepath), "disk") @parameterized.expand([[Path], [str]]) def test_exists(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: self.assertTrue(backend.exists(path_type(tmp_dir))) - filepath = Path(tmp_dir) / 'test.txt' + filepath = Path(tmp_dir) / "test.txt" self.assertFalse(backend.exists(path_type(filepath))) - backend.put_text('disk', filepath) + backend.put_text("disk", filepath) self.assertTrue(backend.exists(path_type(filepath))) backend.remove(filepath) @@ -130,8 +129,8 @@ def test_isdir(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: self.assertTrue(backend.isdir(path_type(tmp_dir))) - filepath = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', filepath) + filepath = Path(tmp_dir) / "test.txt" + backend.put_text("disk", filepath) self.assertFalse(backend.isdir(path_type(filepath))) @parameterized.expand([[Path], [str]]) @@ -139,22 +138,19 @@ def test_isfile(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: self.assertFalse(backend.isfile(path_type(tmp_dir))) - filepath = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', filepath) + filepath = Path(tmp_dir) / "test.txt" + backend.put_text("disk", filepath) self.assertTrue(backend.isfile(path_type(filepath))) @parameterized.expand([[Path], [str]]) def test_join_path(self, path_type): backend = LocalBackend() - filepath = backend.join_path( - path_type(self.test_data_dir), path_type('file')) - expected = osp.join(path_type(self.test_data_dir), path_type('file')) + filepath = backend.join_path(path_type(self.test_data_dir), path_type("file")) + expected = osp.join(path_type(self.test_data_dir), path_type("file")) self.assertEqual(filepath, expected) - filepath = backend.join_path( - path_type(self.test_data_dir), path_type('dir'), path_type('file')) - expected = osp.join( - path_type(self.test_data_dir), path_type('dir'), path_type('file')) + filepath = backend.join_path(path_type(self.test_data_dir), path_type("dir"), path_type("file")) + expected = osp.join(path_type(self.test_data_dir), path_type("dir"), path_type("file")) self.assertEqual(filepath, expected) @parameterized.expand([[Path], [str]]) @@ -167,22 +163,19 @@ def test_get_local_path(self, path_type): def test_copyfile(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: - src = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', src) - dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) - self.assertEqual(backend.get_text(dst), 'disk') + src = Path(tmp_dir) / "test.txt" + backend.put_text("disk", src) + dst = Path(tmp_dir) / "test.txt.bak" + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), path_type(dst)) + self.assertEqual(backend.get_text(dst), "disk") # dst is a directory - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) - self.assertEqual( - backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + backend.copyfile(path_type(src), path_type(dst)), backend.join_path(path_type(dst), "test.txt") + ) + self.assertEqual(backend.get_text(backend.join_path(dst, "test.txt")), "disk") # src and src should not be same file with self.assertRaises(SameFileError): @@ -193,39 +186,33 @@ def test_copytree(self, path_type): backend = LocalBackend() with build_temporary_directory() as tmp_dir: # src and dst are Path objects - src = Path(tmp_dir) / 'dir1' - dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + src = Path(tmp_dir) / "dir1" + dst = Path(tmp_dir) / "dir100" + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), path_type(dst)) self.assertTrue(backend.isdir(dst)) - self.assertTrue(backend.isfile(dst / 'text3.txt')) - self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + self.assertTrue(backend.isfile(dst / "text3.txt")) + self.assertEqual(backend.get_text(dst / "text3.txt"), "text3") # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), path_type(Path(tmp_dir) / "dir2")) @parameterized.expand([[Path], [str]]) def test_copyfile_from_local(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: - src = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', src) - dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) - self.assertEqual(backend.get_text(dst), 'disk') + src = Path(tmp_dir) / "test.txt" + backend.put_text("disk", src) + dst = Path(tmp_dir) / "test.txt.bak" + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), path_type(dst)) + self.assertEqual(backend.get_text(dst), "disk") - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) - self.assertEqual( - backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + backend.copyfile(path_type(src), path_type(dst)), backend.join_path(path_type(dst), "test.txt") + ) + self.assertEqual(backend.get_text(backend.join_path(dst, "test.txt")), "disk") # src and src should not be same file with self.assertRaises(SameFileError): @@ -236,39 +223,33 @@ def test_copytree_from_local(self, path_type): backend = LocalBackend() with build_temporary_directory() as tmp_dir: # src and dst are Path objects - src = Path(tmp_dir) / 'dir1' - dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + src = Path(tmp_dir) / "dir1" + dst = Path(tmp_dir) / "dir100" + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), path_type(dst)) self.assertTrue(backend.isdir(dst)) - self.assertTrue(backend.isfile(dst / 'text3.txt')) - self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + self.assertTrue(backend.isfile(dst / "text3.txt")) + self.assertEqual(backend.get_text(dst / "text3.txt"), "text3") # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), path_type(Path(tmp_dir) / "dir2")) @parameterized.expand([[Path], [str]]) def test_copyfile_to_local(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: - src = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', src) - dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) - self.assertEqual(backend.get_text(dst), 'disk') + src = Path(tmp_dir) / "test.txt" + backend.put_text("disk", src) + dst = Path(tmp_dir) / "test.txt.bak" + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), path_type(dst)) + self.assertEqual(backend.get_text(dst), "disk") - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) - self.assertEqual( - backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + backend.copyfile(path_type(src), path_type(dst)), backend.join_path(path_type(dst), "test.txt") + ) + self.assertEqual(backend.get_text(backend.join_path(dst, "test.txt")), "disk") # src and src should not be same file with self.assertRaises(SameFileError): @@ -279,38 +260,35 @@ def test_copytree_to_local(self, path_type): backend = LocalBackend() with build_temporary_directory() as tmp_dir: # src and dst are Path objects - src = Path(tmp_dir) / 'dir1' - dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + src = Path(tmp_dir) / "dir1" + dst = Path(tmp_dir) / "dir100" + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), path_type(dst)) self.assertTrue(backend.isdir(dst)) - self.assertTrue(backend.isfile(dst / 'text3.txt')) - self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + self.assertTrue(backend.isfile(dst / "text3.txt")) + self.assertEqual(backend.get_text(dst / "text3.txt"), "text3") # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), path_type(Path(tmp_dir) / "dir2")) @parameterized.expand([[Path], [str]]) def test_remove(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: # filepath is a Path object - filepath = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', filepath) + filepath = Path(tmp_dir) / "test.txt" + backend.put_text("disk", filepath) self.assertTrue(backend.exists(filepath)) backend.remove(path_type(filepath)) self.assertFalse(backend.exists(filepath)) # raise error if file does not exist with self.assertRaises(FileNotFoundError): - filepath = Path(tmp_dir) / 'test1.txt' + filepath = Path(tmp_dir) / "test1.txt" backend.remove(path_type(filepath)) # can not remove directory - filepath = Path(tmp_dir) / 'dir' + filepath = Path(tmp_dir) / "dir" filepath.mkdir() with self.assertRaises(IsADirectoryError): backend.remove(path_type(filepath)) @@ -320,12 +298,12 @@ def test_rmtree(self, path_type): backend = LocalBackend() with build_temporary_directory() as tmp_dir: # src and dst are Path objects - dir_path = Path(tmp_dir) / 'dir1' + dir_path = Path(tmp_dir) / "dir1" self.assertTrue(backend.exists(dir_path)) backend.rmtree(path_type(dir_path)) self.assertFalse(backend.exists(dir_path)) - dir_path = Path(tmp_dir) / 'dir2' + dir_path = Path(tmp_dir) / "dir2" self.assertTrue(backend.exists(dir_path)) backend.rmtree(path_type(dir_path)) self.assertFalse(backend.exists(dir_path)) @@ -335,21 +313,21 @@ def test_copy_if_symlink_fails(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: # create a symlink for a file - src = Path(tmp_dir) / 'test.txt' - backend.put_text('disk', src) - dst = Path(tmp_dir) / 'test_link.txt' + src = Path(tmp_dir) / "test.txt" + backend.put_text("disk", src) + dst = Path(tmp_dir) / "test_link.txt" res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) - if platform.system() == 'Linux': + if platform.system() == "Linux": self.assertTrue(res) self.assertTrue(osp.islink(dst)) - self.assertEqual(backend.get_text(dst), 'disk') + self.assertEqual(backend.get_text(dst), "disk") # create a symlink for a directory - src = Path(tmp_dir) / 'dir' + src = Path(tmp_dir) / "dir" src.mkdir() - dst = Path(tmp_dir) / 'dir_link' + dst = Path(tmp_dir) / "dir_link" res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) - if platform.system() == 'Linux': + if platform.system() == "Linux": self.assertTrue(res) self.assertTrue(osp.islink(dst)) self.assertTrue(backend.exists(dst)) @@ -358,21 +336,19 @@ def symlink(src, dst): raise Exception # copy files if symblink fails - with patch.object(os, 'symlink', side_effect=symlink): - src = Path(tmp_dir) / 'test.txt' - dst = Path(tmp_dir) / 'test_link1.txt' - res = backend.copy_if_symlink_fails( - path_type(src), path_type(dst)) + with patch.object(os, "symlink", side_effect=symlink): + src = Path(tmp_dir) / "test.txt" + dst = Path(tmp_dir) / "test_link1.txt" + res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) self.assertFalse(res) self.assertFalse(osp.islink(dst)) self.assertTrue(backend.exists(dst)) # copy directory if symblink fails - with patch.object(os, 'symlink', side_effect=symlink): - src = Path(tmp_dir) / 'dir' - dst = Path(tmp_dir) / 'dir_link1' - res = backend.copy_if_symlink_fails( - path_type(src), path_type(dst)) + with patch.object(os, "symlink", side_effect=symlink): + src = Path(tmp_dir) / "dir" + dst = Path(tmp_dir) / "dir_link1" + res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) self.assertFalse(res) self.assertFalse(osp.islink(dst)) self.assertTrue(backend.exists(dst)) @@ -383,104 +359,84 @@ def test_list_dir_or_file(self, path_type): with build_temporary_directory() as tmp_dir: # list directories and files self.assertEqual( - set(backend.list_dir_or_file(path_type(tmp_dir))), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + set(backend.list_dir_or_file(path_type(tmp_dir))), {"dir1", "dir2", "text1.txt", "text2.txt"} + ) # list directories and files recursively self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), recursive=True)), + set(backend.list_dir_or_file(path_type(tmp_dir), recursive=True)), { - 'dir1', - osp.join('dir1', 'text3.txt'), 'dir2', - osp.join('dir2', 'dir3'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - }) + "dir1", + osp.join("dir1", "text3.txt"), + "dir2", + osp.join("dir2", "dir3"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + }, + ) # only list directories - self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False)), - {'dir1', 'dir2'}) + self.assertEqual(set(backend.list_dir_or_file(path_type(tmp_dir), list_file=False)), {"dir1", "dir2"}) - with self.assertRaisesRegex( - TypeError, - '`suffix` should be None when `list_dir` is True'): - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False, suffix='.txt') + with self.assertRaisesRegex(TypeError, "`suffix` should be None when `list_dir` is True"): + backend.list_dir_or_file(path_type(tmp_dir), list_file=False, suffix=".txt") # only list directories recursively self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False, recursive=True)), - {'dir1', 'dir2', osp.join('dir2', 'dir3')}) + set(backend.list_dir_or_file(path_type(tmp_dir), list_file=False, recursive=True)), + {"dir1", "dir2", osp.join("dir2", "dir3")}, + ) # only list files self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False)), - {'text1.txt', 'text2.txt'}) + set(backend.list_dir_or_file(path_type(tmp_dir), list_dir=False)), {"text1.txt", "text2.txt"} + ) # only list files recursively self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False, recursive=True)), + set(backend.list_dir_or_file(path_type(tmp_dir), list_dir=False, recursive=True)), { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - }) + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + }, + ) # only list files ending with suffix self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False, suffix='.txt')), - {'text1.txt', 'text2.txt'}) + set(backend.list_dir_or_file(path_type(tmp_dir), list_dir=False, suffix=".txt")), + {"text1.txt", "text2.txt"}, + ) self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) - - with self.assertRaisesRegex( - TypeError, - '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=['.txt', '.jpg']) + set(backend.list_dir_or_file(path_type(tmp_dir), list_dir=False, suffix=(".txt", ".jpg"))), + {"text1.txt", "text2.txt"}, + ) + + with self.assertRaisesRegex(TypeError, "`suffix` must be a string or tuple of strings"): + backend.list_dir_or_file(path_type(tmp_dir), list_dir=False, suffix=[".txt", ".jpg"]) # only list files ending with suffix recursively self.assertEqual( - set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix='.txt', - recursive=True)), { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - }) + set(backend.list_dir_or_file(path_type(tmp_dir), list_dir=False, suffix=".txt", recursive=True)), + {osp.join("dir1", "text3.txt"), osp.join("dir2", "dir3", "text4.txt"), "text1.txt", "text2.txt"}, + ) # only list files ending with suffix self.assertEqual( set( backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + path_type(tmp_dir), list_dir=False, suffix=(".txt", ".jpg"), recursive=True + ) + ), { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - }) + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + }, + ) diff --git a/tests/test_fileio/test_backends/test_memcached_backend.py b/tests/test_fileio/test_backends/test_memcached_backend.py index d320fcb16b..f3426bc318 100644 --- a/tests/test_fileio/test_backends/test_memcached_backend.py +++ b/tests/test_fileio/test_backends/test_memcached_backend.py @@ -17,43 +17,41 @@ def imfrombytes(content): return img -sys.modules['mc'] = MagicMock() +sys.modules["mc"] = MagicMock() class MockMemcachedClient: - def __init__(self, server_list_cfg, client_cfg): pass def Get(self, filepath, buffer): - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: buffer.content = f.read() class TestMemcachedBackend(TestCase): - @classmethod def setUpClass(cls): - cls.mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None) - cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' - cls.img_path = cls.test_data_dir / 'color.jpg' + cls.mc_cfg = dict(server_list_cfg="", client_cfg="", sys_path=None) + cls.test_data_dir = Path(__file__).parent.parent.parent / "data" + cls.img_path = cls.test_data_dir / "color.jpg" cls.img_shape = (300, 400, 3) @parameterized.expand([[Path], [str]]) - @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) - @patch('mc.pyvector', MagicMock) - @patch('mc.ConvertBuffer', lambda x: x.content) + @patch("mc.MemcachedClient.GetInstance", MockMemcachedClient) + @patch("mc.pyvector", MagicMock) + @patch("mc.ConvertBuffer", lambda x: x.content) def test_get(self, path_type): backend = MemcachedBackend(**self.mc_cfg) img_bytes = backend.get(path_type(self.img_path)) - self.assertEqual(self.img_path.open('rb').read(), img_bytes) + self.assertEqual(self.img_path.open("rb").read(), img_bytes) img = imfrombytes(img_bytes) self.assertEqual(img.shape, self.img_shape) - @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) - @patch('mc.pyvector', MagicMock) - @patch('mc.ConvertBuffer', lambda x: x.content) + @patch("mc.MemcachedClient.GetInstance", MockMemcachedClient) + @patch("mc.pyvector", MagicMock) + @patch("mc.ConvertBuffer", lambda x: x.content) def test_get_text(self): backend = MemcachedBackend(**self.mc_cfg) with self.assertRaises(NotImplementedError): - backend.get_text('filepath') + backend.get_text("filepath") diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index 6f379c3f23..20c2004d55 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -30,49 +30,45 @@ def build_temporary_directory(): | -- text2.txt \n """ with tempfile.TemporaryDirectory() as tmp_dir: - text1 = Path(tmp_dir) / 'text1.txt' - text1.open('w').write('text1') - text2 = Path(tmp_dir) / 'text2.txt' - text2.open('w').write('text2') - dir1 = Path(tmp_dir) / 'dir1' + text1 = Path(tmp_dir) / "text1.txt" + text1.open("w").write("text1") + text2 = Path(tmp_dir) / "text2.txt" + text2.open("w").write("text2") + dir1 = Path(tmp_dir) / "dir1" dir1.mkdir() - text3 = dir1 / 'text3.txt' - text3.open('w').write('text3') - dir2 = Path(tmp_dir) / 'dir2' + text3 = dir1 / "text3.txt" + text3.open("w").write("text3") + dir2 = Path(tmp_dir) / "dir2" dir2.mkdir() - jpg1 = dir2 / 'img.jpg' - jpg1.open('wb').write(b'img') - dir3 = dir2 / 'dir3' + jpg1 = dir2 / "img.jpg" + jpg1.open("wb").write(b"img") + dir3 = dir2 / "dir3" dir3.mkdir() - text4 = dir3 / 'text4.txt' - text4.open('w').write('text4') + text4 = dir3 / "text4.txt" + text4.open("w").write("text4") yield tmp_dir try: # Other unit tests may mock these modules so we need to pop them first. - sys.modules.pop('petrel_client', None) - sys.modules.pop('petrel_client.client', None) + sys.modules.pop("petrel_client", None) + sys.modules.pop("petrel_client.client", None) # If petrel_client is imported successfully, we can test PetrelBackend # without mock. import petrel_client # noqa: F401 except ImportError: - sys.modules['petrel_client'] = MagicMock() - sys.modules['petrel_client.client'] = MagicMock() + sys.modules["petrel_client"] = MagicMock() + sys.modules["petrel_client.client"] = MagicMock() class MockPetrelClient: - - def __init__(self, - enable_mc=True, - enable_multi_cluster=False, - conf_path=None): + def __init__(self, enable_mc=True, enable_multi_cluster=False, conf_path=None): self.enable_mc = enable_mc self.enable_multi_cluster = enable_multi_cluster self.conf_path = conf_path def Get(self, filepath): - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: content = f.read() return content @@ -90,10 +86,10 @@ def isdir(self): def list(self, dir_path): for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): + if not entry.name.startswith(".") and entry.is_file(): yield entry.name elif osp.isdir(entry.path): - yield entry.name + '/' + yield entry.name + "/" @contextmanager def delete_and_reset_method(obj, method): @@ -104,148 +100,124 @@ def delete_and_reset_method(obj, method): finally: setattr(type(obj), method, method_obj) - @patch('petrel_client.client.Client', MockPetrelClient) + @patch("petrel_client.client.Client", MockPetrelClient) class TestPetrelBackend(TestCase): - @classmethod def setUpClass(cls): - cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' - cls.img_path = cls.test_data_dir / 'color.jpg' + cls.test_data_dir = Path(__file__).parent.parent.parent / "data" + cls.img_path = cls.test_data_dir / "color.jpg" cls.img_shape = (300, 400, 3) - cls.text_path = cls.test_data_dir / 'filelist.txt' - cls.petrel_dir = 'petrel://user/data' - cls.petrel_path = f'{cls.petrel_dir}/test.jpg' - cls.expected_dir = 's3://user/data' - cls.expected_path = f'{cls.expected_dir}/test.jpg' + cls.text_path = cls.test_data_dir / "filelist.txt" + cls.petrel_dir = "petrel://user/data" + cls.petrel_path = f"{cls.petrel_dir}/test.jpg" + cls.expected_dir = "s3://user/data" + cls.expected_path = f"{cls.expected_dir}/test.jpg" def test_name(self): backend = PetrelBackend() - self.assertEqual(backend.name, 'PetrelBackend') + self.assertEqual(backend.name, "PetrelBackend") def test_map_path(self): backend = PetrelBackend(path_mapping=None) - self.assertEqual( - backend._map_path(self.petrel_path), self.petrel_path) + self.assertEqual(backend._map_path(self.petrel_path), self.petrel_path) - backend = PetrelBackend( - path_mapping={'data/': 'petrel://user/data/'}) - self.assertEqual( - backend._map_path('data/test.jpg'), self.petrel_path) + backend = PetrelBackend(path_mapping={"data/": "petrel://user/data/"}) + self.assertEqual(backend._map_path("data/test.jpg"), self.petrel_path) def test_format_path(self): backend = PetrelBackend() - formatted_filepath = backend._format_path( - 'petrel://user\\data\\test.jpg') + formatted_filepath = backend._format_path("petrel://user\\data\\test.jpg") self.assertEqual(formatted_filepath, self.petrel_path) def test_replace_prefix(self): backend = PetrelBackend() - self.assertEqual( - backend._replace_prefix(self.petrel_path), self.expected_path) + self.assertEqual(backend._replace_prefix(self.petrel_path), self.expected_path) def test_join_path(self): backend = PetrelBackend() - self.assertEqual( - backend.join_path(self.petrel_dir, 'file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(f'{self.petrel_dir}/', 'file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(f'{self.petrel_dir}/', '/file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(self.petrel_dir, 'dir', 'file'), - f'{self.petrel_dir}/dir/file') + self.assertEqual(backend.join_path(self.petrel_dir, "file"), f"{self.petrel_dir}/file") + self.assertEqual(backend.join_path(f"{self.petrel_dir}/", "file"), f"{self.petrel_dir}/file") + self.assertEqual(backend.join_path(f"{self.petrel_dir}/", "/file"), f"{self.petrel_dir}/file") + self.assertEqual(backend.join_path(self.petrel_dir, "dir", "file"), f"{self.petrel_dir}/dir/file") def test_get(self): backend = PetrelBackend() - with patch.object( - backend._client, 'Get', - return_value=b'petrel') as patched_get: - self.assertEqual(backend.get(self.petrel_path), b'petrel') + with patch.object(backend._client, "Get", return_value=b"petrel") as patched_get: + self.assertEqual(backend.get(self.petrel_path), b"petrel") patched_get.assert_called_once_with(self.expected_path) def test_get_text(self): backend = PetrelBackend() - with patch.object( - backend._client, 'Get', - return_value=b'petrel') as patched_get: - self.assertEqual(backend.get_text(self.petrel_path), 'petrel') + with patch.object(backend._client, "Get", return_value=b"petrel") as patched_get: + self.assertEqual(backend.get_text(self.petrel_path), "petrel") patched_get.assert_called_once_with(self.expected_path) def test_put(self): backend = PetrelBackend() - with patch.object(backend._client, 'put') as patched_put: - backend.put(b'petrel', self.petrel_path) - patched_put.assert_called_once_with(self.expected_path, - b'petrel') + with patch.object(backend._client, "put") as patched_put: + backend.put(b"petrel", self.petrel_path) + patched_put.assert_called_once_with(self.expected_path, b"petrel") def test_put_text(self): backend = PetrelBackend() - with patch.object(backend._client, 'put') as patched_put: - backend.put_text('petrel', self.petrel_path) - patched_put.assert_called_once_with(self.expected_path, - b'petrel') + with patch.object(backend._client, "put") as patched_put: + backend.put_text("petrel", self.petrel_path) + patched_put.assert_called_once_with(self.expected_path, b"petrel") def test_exists(self): backend = PetrelBackend() - self.assertTrue(has_method(backend._client, 'contains')) - self.assertTrue(has_method(backend._client, 'isdir')) + self.assertTrue(has_method(backend._client, "contains")) + self.assertTrue(has_method(backend._client, "isdir")) # raise Exception if `_client.contains` and '_client.isdir' are not # implemented - with delete_and_reset_method(backend._client, 'contains'), \ - delete_and_reset_method(backend._client, 'isdir'): - self.assertFalse(has_method(backend._client, 'contains')) - self.assertFalse(has_method(backend._client, 'isdir')) + with ( + delete_and_reset_method(backend._client, "contains"), + delete_and_reset_method(backend._client, "isdir"), + ): + self.assertFalse(has_method(backend._client, "contains")) + self.assertFalse(has_method(backend._client, "isdir")) with self.assertRaises(NotImplementedError): backend.exists(self.petrel_path) - with patch.object( - backend._client, 'contains', - return_value=True) as patched_contains: + with patch.object(backend._client, "contains", return_value=True) as patched_contains: self.assertTrue(backend.exists(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) def test_isdir(self): backend = PetrelBackend() - self.assertTrue(has_method(backend._client, 'isdir')) + self.assertTrue(has_method(backend._client, "isdir")) # raise Exception if `_client.isdir` is not implemented - with delete_and_reset_method(backend._client, 'isdir'): - self.assertFalse(has_method(backend._client, 'isdir')) + with delete_and_reset_method(backend._client, "isdir"): + self.assertFalse(has_method(backend._client, "isdir")) with self.assertRaises(NotImplementedError): backend.isdir(self.petrel_path) - with patch.object( - backend._client, 'isdir', - return_value=True) as patched_contains: + with patch.object(backend._client, "isdir", return_value=True) as patched_contains: self.assertTrue(backend.isdir(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) def test_isfile(self): backend = PetrelBackend() - self.assertTrue(has_method(backend._client, 'contains')) + self.assertTrue(has_method(backend._client, "contains")) # raise Exception if `_client.contains` is not implemented - with delete_and_reset_method(backend._client, 'contains'): - self.assertFalse(has_method(backend._client, 'contains')) + with delete_and_reset_method(backend._client, "contains"): + self.assertFalse(has_method(backend._client, "contains")) with self.assertRaises(NotImplementedError): backend.isfile(self.petrel_path) - with patch.object( - backend._client, 'contains', - return_value=True) as patched_contains: + with patch.object(backend._client, "contains", return_value=True) as patched_contains: self.assertTrue(backend.isfile(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) def test_get_local_path(self): backend = PetrelBackend() - with patch.object(backend._client, 'Get', - return_value=b'petrel') as patched_get, \ - patch.object(backend._client, 'contains', - return_value=True) as patch_contains: + with ( + patch.object(backend._client, "Get", return_value=b"petrel") as patched_get, + patch.object(backend._client, "contains", return_value=True) as patch_contains, + ): with backend.get_local_path(self.petrel_path) as path: self.assertTrue(osp.isfile(path)) - self.assertEqual(Path(path).open('rb').read(), b'petrel') + self.assertEqual(Path(path).open("rb").read(), b"petrel") # exist the with block and path will be released self.assertFalse(osp.isfile(path)) patched_get.assert_called_once_with(self.expected_path) @@ -253,37 +225,36 @@ def test_get_local_path(self): def test_copyfile(self): backend = PetrelBackend() - with patch.object(backend._client, 'Get', - return_value=b'petrel') as patched_get, \ - patch.object(backend._client, 'put') as patched_put, \ - patch.object(backend._client, 'isdir', return_value=False) as \ - patched_isdir: + with ( + patch.object(backend._client, "Get", return_value=b"petrel") as patched_get, + patch.object(backend._client, "put") as patched_put, + patch.object(backend._client, "isdir", return_value=False) as patched_isdir, + ): src = self.petrel_path - dst = f'{self.petrel_dir}/test.bak.jpg' - expected_dst = f'{self.expected_dir}/test.bak.jpg' + dst = f"{self.petrel_dir}/test.bak.jpg" + expected_dst = f"{self.expected_dir}/test.bak.jpg" self.assertEqual(backend.copyfile(src, dst), dst) patched_get.assert_called_once_with(self.expected_path) - patched_put.assert_called_once_with(expected_dst, b'petrel') + patched_put.assert_called_once_with(expected_dst, b"petrel") patched_isdir.assert_called_once_with(expected_dst) - with patch.object(backend._client, 'Get', - return_value=b'petrel') as patched_get, \ - patch.object(backend._client, 'put') as patched_put, \ - patch.object(backend._client, 'isdir', return_value=True) as \ - patched_isdir: + with ( + patch.object(backend._client, "Get", return_value=b"petrel") as patched_get, + patch.object(backend._client, "put") as patched_put, + patch.object(backend._client, "isdir", return_value=True) as patched_isdir, + ): # dst is a directory - dst = f'{self.petrel_dir}/dir' - expected_dst = f'{self.expected_dir}/dir/test.jpg' - self.assertEqual(backend.copyfile(src, dst), f'{dst}/test.jpg') + dst = f"{self.petrel_dir}/dir" + expected_dst = f"{self.expected_dir}/dir/test.jpg" + self.assertEqual(backend.copyfile(src, dst), f"{dst}/test.jpg") patched_get.assert_called_once_with(self.expected_path) - patched_put.assert_called_once_with(expected_dst, b'petrel') - patched_isdir.assert_called_once_with( - f'{self.expected_dir}/dir') - - with patch.object(backend._client, 'Get', - return_value=b'petrel') as patched_get, \ - patch.object(backend._client, 'isdir', return_value=False) as \ - patched_isdir: + patched_put.assert_called_once_with(expected_dst, b"petrel") + patched_isdir.assert_called_once_with(f"{self.expected_dir}/dir") + + with ( + patch.object(backend._client, "Get", return_value=b"petrel") as patched_get, + patch.object(backend._client, "isdir", return_value=False) as patched_isdir, + ): # src and src should not be same file with self.assertRaises(SameFileError): backend.copyfile(src, src) @@ -299,48 +270,48 @@ def put(obj, filepath): def get(filepath): get_inputs.append(filepath) - with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'put', side_effect=put), \ - patch.object(backend, 'get', side_effect=get), \ - patch.object(backend, 'exists', return_value=False): - tmp_dir = tmp_dir.replace('\\', '/') - dst = f'{tmp_dir}/dir' + with ( + build_temporary_directory() as tmp_dir, + patch.object(backend, "put", side_effect=put), + patch.object(backend, "get", side_effect=get), + patch.object(backend, "exists", return_value=False), + ): + tmp_dir = tmp_dir.replace("\\", "/") + dst = f"{tmp_dir}/dir" self.assertEqual(backend.copytree(tmp_dir, dst), dst) self.assertEqual(len(put_inputs), 5) self.assertEqual(len(get_inputs), 5) # dst should not exist - with patch.object(backend, 'exists', return_value=True): + with patch.object(backend, "exists", return_value=True): with self.assertRaises(FileExistsError): backend.copytree(dst, tmp_dir) def test_copyfile_from_local(self): backend = PetrelBackend() - with patch.object(backend._client, 'put') as patched_put, \ - patch.object(backend._client, 'isdir', return_value=False) \ - as patched_isdir: + with ( + patch.object(backend._client, "put") as patched_put, + patch.object(backend._client, "isdir", return_value=False) as patched_isdir, + ): src = self.img_path - dst = f'{self.petrel_dir}/color.bak.jpg' - expected_dst = f'{self.expected_dir}/color.bak.jpg' + dst = f"{self.petrel_dir}/color.bak.jpg" + expected_dst = f"{self.expected_dir}/color.bak.jpg" self.assertEqual(backend.copyfile_from_local(src, dst), dst) - patched_put.assert_called_once_with(expected_dst, - src.open('rb').read()) + patched_put.assert_called_once_with(expected_dst, src.open("rb").read()) patched_isdir.assert_called_once_with(expected_dst) - with patch.object(backend._client, 'put') as patched_put, \ - patch.object(backend._client, 'isdir', return_value=True) as \ - patched_isdir: + with ( + patch.object(backend._client, "put") as patched_put, + patch.object(backend._client, "isdir", return_value=True) as patched_isdir, + ): # dst is a directory src = self.img_path - dst = f'{self.petrel_dir}/dir' - expected_dst = f'{self.expected_dir}/dir/color.jpg' - self.assertEqual( - backend.copyfile_from_local(src, dst), f'{dst}/color.jpg') - patched_put.assert_called_once_with(expected_dst, - src.open('rb').read()) - patched_isdir.assert_called_once_with( - f'{self.expected_dir}/dir') + dst = f"{self.petrel_dir}/dir" + expected_dst = f"{self.expected_dir}/dir/color.jpg" + self.assertEqual(backend.copyfile_from_local(src, dst), f"{dst}/color.jpg") + patched_put.assert_called_once_with(expected_dst, src.open("rb").read()) + patched_isdir.assert_called_once_with(f"{self.expected_dir}/dir") def test_copytree_from_local(self): backend = PetrelBackend() @@ -349,42 +320,43 @@ def test_copytree_from_local(self): def copyfile_from_local(src, dst): inputs.append((src, dst)) - with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'copyfile_from_local', - side_effect=copyfile_from_local), \ - patch.object(backend, 'exists', return_value=False): + with ( + build_temporary_directory() as tmp_dir, + patch.object(backend, "copyfile_from_local", side_effect=copyfile_from_local), + patch.object(backend, "exists", return_value=False), + ): backend.copytree_from_local(tmp_dir, self.petrel_dir) self.assertEqual(len(inputs), 5) # dst should not exist - with patch.object(backend, 'exists', return_value=True): + with patch.object(backend, "exists", return_value=True): with self.assertRaises(FileExistsError): backend.copytree_from_local(tmp_dir, self.petrel_dir) def test_copyfile_to_local(self): backend = PetrelBackend() - with patch.object(backend._client, 'Get', - return_value=b'petrel') as patched_get, \ - tempfile.TemporaryDirectory() as tmp_dir: + with ( + patch.object(backend._client, "Get", return_value=b"petrel") as patched_get, + tempfile.TemporaryDirectory() as tmp_dir, + ): src = self.petrel_path - dst = Path(tmp_dir) / 'test.bak.jpg' + dst = Path(tmp_dir) / "test.bak.jpg" self.assertEqual(backend.copyfile_to_local(src, dst), dst) patched_get.assert_called_once_with(self.expected_path) - self.assertEqual(dst.open('rb').read(), b'petrel') + self.assertEqual(dst.open("rb").read(), b"petrel") - with patch.object(backend._client, 'Get', - return_value=b'petrel') as patched_get, \ - tempfile.TemporaryDirectory() as tmp_dir: + with ( + patch.object(backend._client, "Get", return_value=b"petrel") as patched_get, + tempfile.TemporaryDirectory() as tmp_dir, + ): # dst is a directory src = self.petrel_path - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() - self.assertEqual( - backend.copyfile_to_local(src, dst), dst / 'test.jpg') + self.assertEqual(backend.copyfile_to_local(src, dst), dst / "test.jpg") patched_get.assert_called_once_with(self.expected_path) - self.assertEqual((dst / 'test.jpg').open('rb').read(), - b'petrel') + self.assertEqual((dst / "test.jpg").open("rb").read(), b"petrel") def test_copytree_to_local(self): backend = PetrelBackend() @@ -392,29 +364,28 @@ def test_copytree_to_local(self): def get(filepath): inputs.append(filepath) - return b'petrel' + return b"petrel" - with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'get', side_effect=get): - dst = f'{tmp_dir}/dir' + with build_temporary_directory() as tmp_dir, patch.object(backend, "get", side_effect=get): + dst = f"{tmp_dir}/dir" backend.copytree_to_local(tmp_dir, dst) self.assertEqual(len(inputs), 5) def test_remove(self): backend = PetrelBackend() - self.assertTrue(has_method(backend._client, 'delete')) + self.assertTrue(has_method(backend._client, "delete")) # raise Exception if `delete` is not implemented - with delete_and_reset_method(backend._client, 'delete'): - self.assertFalse(has_method(backend._client, 'delete')) + with delete_and_reset_method(backend._client, "delete"): + self.assertFalse(has_method(backend._client, "delete")) with self.assertRaises(NotImplementedError): backend.remove(self.petrel_path) - with patch.object(backend._client, 'delete') as patched_delete, \ - patch.object(backend._client, 'isdir', return_value=False) \ - as patched_isdir, \ - patch.object(backend._client, 'contains', return_value=True) \ - as patched_contains: + with ( + patch.object(backend._client, "delete") as patched_delete, + patch.object(backend._client, "isdir", return_value=False) as patched_isdir, + patch.object(backend._client, "contains", return_value=True) as patched_contains, + ): backend.remove(self.petrel_path) patched_delete.assert_called_once_with(self.expected_path) patched_isdir.assert_called_once_with(self.expected_path) @@ -427,8 +398,7 @@ def test_rmtree(self): def remove(filepath): inputs.append(filepath) - with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'remove', side_effect=remove): + with build_temporary_directory() as tmp_dir, patch.object(backend, "remove", side_effect=remove): backend.rmtree(tmp_dir) self.assertEqual(len(inputs), 5) @@ -444,15 +414,19 @@ def copyfile(src, dst): def copytree(src, dst): copytree_inputs.append((src, dst)) - with patch.object(backend, 'copyfile', side_effect=copyfile), \ - patch.object(backend, 'isfile', return_value=True): - backend.copy_if_symlink_fails(self.petrel_path, 'path') + with ( + patch.object(backend, "copyfile", side_effect=copyfile), + patch.object(backend, "isfile", return_value=True), + ): + backend.copy_if_symlink_fails(self.petrel_path, "path") self.assertEqual(len(copyfile_inputs), 1) - with patch.object(backend, 'copytree', side_effect=copytree), \ - patch.object(backend, 'isfile', return_value=False): - backend.copy_if_symlink_fails(self.petrel_dir, 'path') + with ( + patch.object(backend, "copytree", side_effect=copytree), + patch.object(backend, "isfile", return_value=False), + ): + backend.copy_if_symlink_fails(self.petrel_dir, "path") self.assertEqual(len(copytree_inputs), 1) @@ -460,104 +434,90 @@ def test_list_dir_or_file(self): backend = PetrelBackend() # raise Exception if `_client.list` is not implemented - self.assertTrue(has_method(backend._client, 'list')) - with delete_and_reset_method(backend._client, 'list'): - self.assertFalse(has_method(backend._client, 'list')) + self.assertTrue(has_method(backend._client, "list")) + with delete_and_reset_method(backend._client, "list"): + self.assertFalse(has_method(backend._client, "list")) with self.assertRaises(NotImplementedError): list(backend.list_dir_or_file(self.petrel_dir)) with build_temporary_directory() as tmp_dir: # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(tmp_dir)), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(tmp_dir)), {"dir1", "dir2", "text1.txt", "text2.txt"}) # list directories and files recursively self.assertEqual( - set(backend.list_dir_or_file(tmp_dir, recursive=True)), { - 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', - '/'.join(('dir2', 'dir3')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - }) + set(backend.list_dir_or_file(tmp_dir, recursive=True)), + { + "dir1", + "/".join(("dir1", "text3.txt")), + "dir2", + "/".join(("dir2", "dir3")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + }, + ) # only list directories - self.assertEqual( - set(backend.list_dir_or_file(tmp_dir, list_file=False)), - {'dir1', 'dir2'}) - with self.assertRaisesRegex( - TypeError, - '`list_dir` should be False when `suffix` is not None' - ): - backend.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + self.assertEqual(set(backend.list_dir_or_file(tmp_dir, list_file=False)), {"dir1", "dir2"}) + with self.assertRaisesRegex(TypeError, "`list_dir` should be False when `suffix` is not None"): + backend.list_dir_or_file(tmp_dir, list_file=False, suffix=".txt") # only list directories recursively self.assertEqual( - set( - backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)), - {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) + set(backend.list_dir_or_file(tmp_dir, list_file=False, recursive=True)), + {"dir1", "dir2", "/".join(("dir2", "dir3"))}, + ) # only list files - self.assertEqual( - set(backend.list_dir_or_file(tmp_dir, list_dir=False)), - {'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(tmp_dir, list_dir=False)), {"text1.txt", "text2.txt"}) # only list files recursively self.assertEqual( - set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, recursive=True)), + set(backend.list_dir_or_file(tmp_dir, list_dir=False, recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - }) + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + }, + ) # only list files ending with suffix self.assertEqual( - set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt')), - {'text1.txt', 'text2.txt'}) + set(backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt")), {"text1.txt", "text2.txt"} + ) self.assertEqual( - set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix=('.txt', '.jpg'))), - {'text1.txt', 'text2.txt'}) - with self.assertRaisesRegex( - TypeError, - '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + set(backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"))), + {"text1.txt", "text2.txt"}, + ) + with self.assertRaisesRegex(TypeError, "`suffix` must be a string or tuple of strings"): + backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=[".txt", ".jpg"]) # only list files ending with suffix recursively self.assertEqual( - set( - backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix='.txt', - recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), - 'text1.txt', 'text2.txt' - }) + set(backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt", recursive=True)), + { + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "text1.txt", + "text2.txt", + }, + ) # only list files ending with suffix self.assertEqual( - set( - backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + set(backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"), recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - }) + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + }, + ) def test_generate_presigned_url(self): pass @@ -565,13 +525,12 @@ def test_generate_presigned_url(self): else: class TestPetrelBackend(TestCase): # type: ignore - @classmethod def setUpClass(cls): - cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' - cls.local_img_path = cls.test_data_dir / 'color.jpg' + cls.test_data_dir = Path(__file__).parent.parent.parent / "data" + cls.local_img_path = cls.test_data_dir / "color.jpg" cls.local_img_shape = (300, 400, 3) - cls.petrel_dir = 'petrel://mmengine-test/data' + cls.petrel_dir = "petrel://mmengine-test/data" def setUp(self): backend = PetrelBackend() @@ -579,11 +538,11 @@ def setUp(self): with build_temporary_directory() as tmp_dir: backend.copytree_from_local(tmp_dir, self.petrel_dir) - text1_path = f'{self.petrel_dir}/text1.txt' - text2_path = f'{self.petrel_dir}/text2.txt' - text3_path = f'{self.petrel_dir}/dir1/text3.txt' - text4_path = f'{self.petrel_dir}/dir2/dir3/text4.txt' - img_path = f'{self.petrel_dir}/dir2/img.jpg' + text1_path = f"{self.petrel_dir}/text1.txt" + text2_path = f"{self.petrel_dir}/text2.txt" + text3_path = f"{self.petrel_dir}/dir1/text3.txt" + text4_path = f"{self.petrel_dir}/dir2/dir3/text4.txt" + img_path = f"{self.petrel_dir}/dir2/img.jpg" self.assertTrue(backend.isfile(text1_path)) self.assertTrue(backend.isfile(text2_path)) self.assertTrue(backend.isfile(text3_path)) @@ -592,59 +551,59 @@ def setUp(self): def test_get(self): backend = PetrelBackend() - img_path = f'{self.petrel_dir}/dir2/img.jpg' - self.assertEqual(backend.get(img_path), b'img') + img_path = f"{self.petrel_dir}/dir2/img.jpg" + self.assertEqual(backend.get(img_path), b"img") def test_get_text(self): backend = PetrelBackend() - text_path = f'{self.petrel_dir}/text1.txt' - self.assertEqual(backend.get_text(text_path), 'text1') + text_path = f"{self.petrel_dir}/text1.txt" + self.assertEqual(backend.get_text(text_path), "text1") def test_put(self): backend = PetrelBackend() - img_path = f'{self.petrel_dir}/img.jpg' - backend.put(b'img', img_path) + img_path = f"{self.petrel_dir}/img.jpg" + backend.put(b"img", img_path) def test_put_text(self): backend = PetrelBackend() - text_path = f'{self.petrel_dir}/text5.txt' - backend.put_text('text5', text_path) + text_path = f"{self.petrel_dir}/text5.txt" + backend.put_text("text5", text_path) def test_exists(self): backend = PetrelBackend() # file and directory exist - dir_path = f'{self.petrel_dir}/dir2' + dir_path = f"{self.petrel_dir}/dir2" self.assertTrue(backend.exists(dir_path)) - img_path = f'{self.petrel_dir}/dir2/img.jpg' + img_path = f"{self.petrel_dir}/dir2/img.jpg" self.assertTrue(backend.exists(img_path)) # file and directory does not exist - not_existed_dir = f'{self.petrel_dir}/not_existed_dir' + not_existed_dir = f"{self.petrel_dir}/not_existed_dir" self.assertFalse(backend.exists(not_existed_dir)) - not_existed_path = f'{self.petrel_dir}/img.jpg' + not_existed_path = f"{self.petrel_dir}/img.jpg" self.assertFalse(backend.exists(not_existed_path)) def test_isdir(self): backend = PetrelBackend() - dir_path = f'{self.petrel_dir}/dir2' + dir_path = f"{self.petrel_dir}/dir2" self.assertTrue(backend.isdir(dir_path)) - img_path = f'{self.petrel_dir}/dir2/img.jpg' + img_path = f"{self.petrel_dir}/dir2/img.jpg" self.assertFalse(backend.isdir(img_path)) def test_isfile(self): backend = PetrelBackend() - dir_path = f'{self.petrel_dir}/dir2' + dir_path = f"{self.petrel_dir}/dir2" self.assertFalse(backend.isfile(dir_path)) - img_path = f'{self.petrel_dir}/dir2/img.jpg' + img_path = f"{self.petrel_dir}/dir2/img.jpg" self.assertTrue(backend.isfile(img_path)) def test_get_local_path(self): backend = PetrelBackend() - img_path = f'{self.petrel_dir}/dir2/img.jpg' + img_path = f"{self.petrel_dir}/dir2/img.jpg" with backend.get_local_path(img_path) as path: self.assertTrue(osp.isfile(path)) - self.assertEqual(Path(path).open('rb').read(), b'img') + self.assertEqual(Path(path).open("rb").read(), b"img") # exist the with block and path will be released self.assertFalse(osp.isfile(path)) @@ -652,14 +611,14 @@ def test_copyfile(self): backend = PetrelBackend() # dst is a file - src = f'{self.petrel_dir}/dir2/img.jpg' - dst = f'{self.petrel_dir}/img.jpg' + src = f"{self.petrel_dir}/dir2/img.jpg" + dst = f"{self.petrel_dir}/img.jpg" self.assertEqual(backend.copyfile(src, dst), dst) self.assertTrue(backend.isfile(dst)) # dst is a directory - dst = f'{self.petrel_dir}/dir1' - expected_dst = f'{self.petrel_dir}/dir1/img.jpg' + dst = f"{self.petrel_dir}/dir1" + expected_dst = f"{self.petrel_dir}/dir1/img.jpg" self.assertEqual(backend.copyfile(src, dst), expected_dst) self.assertTrue(backend.isfile(expected_dst)) @@ -669,13 +628,11 @@ def test_copyfile(self): def test_copytree(self): backend = PetrelBackend() - src = f'{self.petrel_dir}/dir2' - dst = f'{self.petrel_dir}/dir3' + src = f"{self.petrel_dir}/dir2" + dst = f"{self.petrel_dir}/dir3" self.assertFalse(backend.exists(dst)) self.assertEqual(backend.copytree(src, dst), dst) - self.assertEqual( - list(backend.list_dir_or_file(src)), - list(backend.list_dir_or_file(dst))) + self.assertEqual(list(backend.list_dir_or_file(src)), list(backend.list_dir_or_file(dst))) # dst should not exist with self.assertRaises(FileExistsError): @@ -686,18 +643,17 @@ def test_copyfile_from_local(self): # dst is a file src = self.local_img_path - dst = f'{self.petrel_dir}/color.jpg' + dst = f"{self.petrel_dir}/color.jpg" self.assertFalse(backend.exists(dst)) self.assertEqual(backend.copyfile_from_local(src, dst), dst) self.assertTrue(backend.isfile(dst)) # dst is a directory src = self.local_img_path - dst = f'{self.petrel_dir}/dir1' - expected_dst = f'{self.petrel_dir}/dir1/color.jpg' + dst = f"{self.petrel_dir}/dir1" + expected_dst = f"{self.petrel_dir}/dir1/color.jpg" self.assertFalse(backend.exists(expected_dst)) - self.assertEqual( - backend.copyfile_from_local(src, dst), expected_dst) + self.assertEqual(backend.copyfile_from_local(src, dst), expected_dst) self.assertTrue(backend.isfile(expected_dst)) def test_copytree_from_local(self): @@ -705,43 +661,41 @@ def test_copytree_from_local(self): backend.rmtree(self.petrel_dir) with build_temporary_directory() as tmp_dir: backend.copytree_from_local(tmp_dir, self.petrel_dir) - files = backend.list_dir_or_file( - self.petrel_dir, recursive=True) + files = backend.list_dir_or_file(self.petrel_dir, recursive=True) self.assertEqual(len(list(files)), 8) def test_copyfile_to_local(self): backend = PetrelBackend() with tempfile.TemporaryDirectory() as tmp_dir: # dst is a file - src = f'{self.petrel_dir}/dir2/img.jpg' - dst = Path(tmp_dir) / 'img.jpg' + src = f"{self.petrel_dir}/dir2/img.jpg" + dst = Path(tmp_dir) / "img.jpg" self.assertEqual(backend.copyfile_to_local(src, dst), dst) - self.assertEqual(dst.open('rb').read(), b'img') + self.assertEqual(dst.open("rb").read(), b"img") # dst is a directory - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() - self.assertEqual( - backend.copyfile_to_local(src, dst), dst / 'img.jpg') - self.assertEqual((dst / 'img.jpg').open('rb').read(), b'img') + self.assertEqual(backend.copyfile_to_local(src, dst), dst / "img.jpg") + self.assertEqual((dst / "img.jpg").open("rb").read(), b"img") def test_copytree_to_local(self): backend = PetrelBackend() with tempfile.TemporaryDirectory() as tmp_dir: backend.copytree_to_local(self.petrel_dir, tmp_dir) - self.assertTrue(osp.exists(Path(tmp_dir) / 'text1.txt')) - self.assertTrue(osp.exists(Path(tmp_dir) / 'dir2' / 'img.jpg')) + self.assertTrue(osp.exists(Path(tmp_dir) / "text1.txt")) + self.assertTrue(osp.exists(Path(tmp_dir) / "dir2" / "img.jpg")) def test_remove(self): backend = PetrelBackend() - img_path = f'{self.petrel_dir}/dir2/img.jpg' + img_path = f"{self.petrel_dir}/dir2/img.jpg" self.assertTrue(backend.isfile(img_path)) backend.remove(img_path) self.assertFalse(backend.exists(img_path)) def test_rmtree(self): backend = PetrelBackend() - dir_path = f'{self.petrel_dir}/dir2' + dir_path = f"{self.petrel_dir}/dir2" self.assertTrue(backend.isdir(dir_path)) backend.rmtree(dir_path) self.assertFalse(backend.exists(dir_path)) @@ -750,15 +704,15 @@ def test_copy_if_symlink_fails(self): backend = PetrelBackend() # dst is a file - src = f'{self.petrel_dir}/dir2/img.jpg' - dst = f'{self.petrel_dir}/img.jpg' + src = f"{self.petrel_dir}/dir2/img.jpg" + dst = f"{self.petrel_dir}/img.jpg" self.assertFalse(backend.exists(dst)) self.assertFalse(backend.copy_if_symlink_fails(src, dst)) self.assertTrue(backend.isfile(dst)) # dst is a directory - src = f'{self.petrel_dir}/dir2' - dst = f'{self.petrel_dir}/dir' + src = f"{self.petrel_dir}/dir2" + dst = f"{self.petrel_dir}/dir" self.assertFalse(backend.exists(dst)) self.assertFalse(backend.copy_if_symlink_fails(src, dst)) self.assertTrue(backend.isdir(dst)) @@ -767,96 +721,75 @@ def test_list_dir_or_file(self): backend = PetrelBackend() # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(self.petrel_dir)), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(self.petrel_dir)), {"dir1", "dir2", "text1.txt", "text2.txt"}) # list directories and files recursively self.assertEqual( set(backend.list_dir_or_file(self.petrel_dir, recursive=True)), { - 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join( - ('dir2', 'dir3')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - }) + "dir1", + "/".join(("dir1", "text3.txt")), + "dir2", + "/".join(("dir2", "dir3")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + }, + ) # only list directories - self.assertEqual( - set( - backend.list_dir_or_file(self.petrel_dir, - list_file=False)), - {'dir1', 'dir2'}) - with self.assertRaisesRegex( - TypeError, - '`list_dir` should be False when `suffix` is not None'): - backend.list_dir_or_file( - self.petrel_dir, list_file=False, suffix='.txt') + self.assertEqual(set(backend.list_dir_or_file(self.petrel_dir, list_file=False)), {"dir1", "dir2"}) + with self.assertRaisesRegex(TypeError, "`list_dir` should be False when `suffix` is not None"): + backend.list_dir_or_file(self.petrel_dir, list_file=False, suffix=".txt") # only list directories recursively self.assertEqual( - set( - backend.list_dir_or_file( - self.petrel_dir, list_file=False, recursive=True)), - {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) + set(backend.list_dir_or_file(self.petrel_dir, list_file=False, recursive=True)), + {"dir1", "dir2", "/".join(("dir2", "dir3"))}, + ) # only list files - self.assertEqual( - set(backend.list_dir_or_file(self.petrel_dir, list_dir=False)), - {'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(self.petrel_dir, list_dir=False)), {"text1.txt", "text2.txt"}) # only list files recursively self.assertEqual( - set( - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, recursive=True)), + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False, recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - }) + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + }, + ) # only list files ending with suffix self.assertEqual( - set( - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, suffix='.txt')), - {'text1.txt', 'text2.txt'}) + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False, suffix=".txt")), + {"text1.txt", "text2.txt"}, + ) self.assertEqual( - set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) - with self.assertRaisesRegex( - TypeError, - '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, suffix=['.txt', '.jpg']) + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False, suffix=(".txt", ".jpg"))), + {"text1.txt", "text2.txt"}, + ) + with self.assertRaisesRegex(TypeError, "`suffix` must be a string or tuple of strings"): + backend.list_dir_or_file(self.petrel_dir, list_dir=False, suffix=[".txt", ".jpg"]) # only list files ending with suffix recursively self.assertEqual( - set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix='.txt', - recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), 'text1.txt', - 'text2.txt' - }) + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False, suffix=".txt", recursive=True)), + {"/".join(("dir1", "text3.txt")), "/".join(("dir2", "dir3", "text4.txt")), "text1.txt", "text2.txt"}, + ) # only list files ending with suffix self.assertEqual( - set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False, suffix=(".txt", ".jpg"), recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - }) + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + }, + ) diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 345832a026..6a1ac92e8c 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -15,10 +15,11 @@ from mmengine.fileio import BaseStorageBackend, FileClient from mmengine.utils import has_method -sys.modules['ceph'] = MagicMock() -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() -sys.modules['mc'] = MagicMock() + +sys.modules["ceph"] = MagicMock() +sys.modules["petrel_client"] = MagicMock() +sys.modules["petrel_client.client"] = MagicMock() +sys.modules["mc"] = MagicMock() def imfrombytes(content): @@ -43,22 +44,22 @@ def build_temporary_directory(): | -- text2.txt \n """ with tempfile.TemporaryDirectory() as tmp_dir: - text1 = Path(tmp_dir) / 'text1.txt' - text1.open('w').write('text1') - text2 = Path(tmp_dir) / 'text2.txt' - text2.open('w').write('text2') - dir1 = Path(tmp_dir) / 'dir1' + text1 = Path(tmp_dir) / "text1.txt" + text1.open("w").write("text1") + text2 = Path(tmp_dir) / "text2.txt" + text2.open("w").write("text2") + dir1 = Path(tmp_dir) / "dir1" dir1.mkdir() - text3 = dir1 / 'text3.txt' - text3.open('w').write('text3') - dir2 = Path(tmp_dir) / 'dir2' + text3 = dir1 / "text3.txt" + text3.open("w").write("text3") + dir2 = Path(tmp_dir) / "dir2" dir2.mkdir() - jpg1 = dir2 / 'img.jpg' - jpg1.open('wb').write(b'img') - dir3 = dir2 / 'dir3' + jpg1 = dir2 / "img.jpg" + jpg1.open("wb").write(b"img") + dir3 = dir2 / "dir3" dir3.mkdir() - text4 = dir3 / 'text4.txt' - text4.open('w').write('text4') + text4 = dir3 / "text4.txt" + text4.open("w").write("text4") yield tmp_dir @@ -73,28 +74,23 @@ def delete_and_reset_method(obj, method): class MockS3Client: - def __init__(self, enable_mc=True): self.enable_mc = enable_mc def Get(self, filepath): - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: content = f.read() return content class MockPetrelClient: - - def __init__(self, - enable_mc=True, - enable_multi_cluster=False, - conf_path=None): + def __init__(self, enable_mc=True, enable_multi_cluster=False, conf_path=None): self.enable_mc = enable_mc self.enable_multi_cluster = enable_multi_cluster self.conf_path = conf_path def Get(self, filepath): - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: content = f.read() return content @@ -112,84 +108,82 @@ def isdir(self): def list(self, dir_path): for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): + if not entry.name.startswith(".") and entry.is_file(): yield entry.name elif osp.isdir(entry.path): - yield entry.name + '/' + yield entry.name + "/" class MockMemcachedClient: - def __init__(self, server_list_cfg, client_cfg): pass def Get(self, filepath, buffer): - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: buffer.content = f.read() class TestFileClient: - @classmethod def setup_class(cls): - cls.test_data_dir = Path(__file__).parent.parent / 'data' - cls.img_path = cls.test_data_dir / 'color.jpg' + cls.test_data_dir = Path(__file__).parent.parent / "data" + cls.img_path = cls.test_data_dir / "color.jpg" cls.img_shape = (300, 400, 3) - cls.text_path = cls.test_data_dir / 'filelist.txt' + cls.text_path = cls.test_data_dir / "filelist.txt" def test_error(self): with pytest.raises(ValueError): - FileClient('hadoop') + FileClient("hadoop") def test_disk_backend(self): - disk_backend = FileClient('disk') + disk_backend = FileClient("disk") # test `name` attribute - assert disk_backend.name == 'HardDiskBackend' + assert disk_backend.name == "HardDiskBackend" # test `allow_symlink` attribute assert disk_backend.allow_symlink # test `get` # input path is Path object img_bytes = disk_backend.get(self.img_path) img = imfrombytes(img_bytes) - assert self.img_path.open('rb').read() == img_bytes + assert self.img_path.open("rb").read() == img_bytes assert img.shape == self.img_shape # input path is str img_bytes = disk_backend.get(str(self.img_path)) img = imfrombytes(img_bytes) - assert self.img_path.open('rb').read() == img_bytes + assert self.img_path.open("rb").read() == img_bytes assert img.shape == self.img_shape # test `get_text` # input path is Path object value_buf = disk_backend.get_text(self.text_path) - assert self.text_path.open('r').read() == value_buf + assert self.text_path.open("r").read() == value_buf # input path is str value_buf = disk_backend.get_text(str(self.text_path)) - assert self.text_path.open('r').read() == value_buf + assert self.text_path.open("r").read() == value_buf with tempfile.TemporaryDirectory() as tmp_dir: # test `put` - filepath1 = Path(tmp_dir) / 'test.jpg' - disk_backend.put(b'disk', filepath1) - assert filepath1.open('rb').read() == b'disk' + filepath1 = Path(tmp_dir) / "test.jpg" + disk_backend.put(b"disk", filepath1) + assert filepath1.open("rb").read() == b"disk" # test the `mkdir_or_exist` behavior in `put` - _filepath1 = Path(tmp_dir) / 'not_existed_dir1' / 'test.jpg' - disk_backend.put(b'disk', _filepath1) - assert _filepath1.open('rb').read() == b'disk' + _filepath1 = Path(tmp_dir) / "not_existed_dir1" / "test.jpg" + disk_backend.put(b"disk", _filepath1) + assert _filepath1.open("rb").read() == b"disk" # test `put_text` - filepath2 = Path(tmp_dir) / 'test.txt' - disk_backend.put_text('disk', filepath2) - assert filepath2.open('r').read() == 'disk' + filepath2 = Path(tmp_dir) / "test.txt" + disk_backend.put_text("disk", filepath2) + assert filepath2.open("r").read() == "disk" # test the `mkdir_or_exist` behavior in `put_text` - _filepath2 = Path(tmp_dir) / 'not_existed_dir2' / 'test.txt' - disk_backend.put_text('disk', _filepath2) - assert _filepath2.open('r').read() == 'disk' + _filepath2 = Path(tmp_dir) / "not_existed_dir2" / "test.txt" + disk_backend.put_text("disk", _filepath2) + assert _filepath2.open("r").read() == "disk" # test `isfile` assert disk_backend.isfile(filepath2) - assert not disk_backend.isfile(Path(tmp_dir) / 'not/existed/path') + assert not disk_backend.isfile(Path(tmp_dir) / "not/existed/path") # test `remove` disk_backend.remove(filepath2) @@ -204,94 +198,78 @@ def test_disk_backend(self): assert osp.isfile(filepath1) # test `join_path` - disk_dir = '/path/of/your/directory' - assert disk_backend.join_path(disk_dir, 'file') == \ - osp.join(disk_dir, 'file') - assert disk_backend.join_path(disk_dir, 'dir', 'file') == \ - osp.join(disk_dir, 'dir', 'file') + disk_dir = "/path/of/your/directory" + assert disk_backend.join_path(disk_dir, "file") == osp.join(disk_dir, "file") + assert disk_backend.join_path(disk_dir, "dir", "file") == osp.join(disk_dir, "dir", "file") # test `list_dir_or_file` with build_temporary_directory() as tmp_dir: # 1. list directories and files - assert set(disk_backend.list_dir_or_file(tmp_dir)) == { - 'dir1', 'dir2', 'text1.txt', 'text2.txt' - } + assert set(disk_backend.list_dir_or_file(tmp_dir)) == {"dir1", "dir2", "text1.txt", "text2.txt"} # 2. list directories and files recursively - assert set(disk_backend.list_dir_or_file( - tmp_dir, recursive=True)) == { - 'dir1', - osp.join('dir1', 'text3.txt'), 'dir2', - osp.join('dir2', 'dir3'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - } + assert set(disk_backend.list_dir_or_file(tmp_dir, recursive=True)) == { + "dir1", + osp.join("dir1", "text3.txt"), + "dir2", + osp.join("dir2", "dir3"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + } # 3. only list directories - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_file=False)) == {'dir1', 'dir2'} - with pytest.raises( - TypeError, - match='`suffix` should be None when `list_dir` is True'): + assert set(disk_backend.list_dir_or_file(tmp_dir, list_file=False)) == {"dir1", "dir2"} + with pytest.raises(TypeError, match="`suffix` should be None when `list_dir` is True"): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception - disk_backend.client.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + disk_backend.client.list_dir_or_file(tmp_dir, list_file=False, suffix=".txt") # 4. only list directories recursively - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)) == { - 'dir1', 'dir2', - osp.join('dir2', 'dir3') - } + assert set(disk_backend.list_dir_or_file(tmp_dir, list_file=False, recursive=True)) == { + "dir1", + "dir2", + osp.join("dir2", "dir3"), + } # 5. only list files - assert set(disk_backend.list_dir_or_file( - tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + assert set(disk_backend.list_dir_or_file(tmp_dir, list_dir=False)) == {"text1.txt", "text2.txt"} # 6. only list files recursively - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - } + assert set(disk_backend.list_dir_or_file(tmp_dir, list_dir=False, recursive=True)) == { + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + } # 7. only list files ending with suffix - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} - with pytest.raises( - TypeError, - match='`suffix` must be a string or tuple of strings'): - disk_backend.client.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + assert set(disk_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt")) == { + "text1.txt", + "text2.txt", + } + assert set(disk_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"))) == { + "text1.txt", + "text2.txt", + } + with pytest.raises(TypeError, match="`suffix` must be a string or tuple of strings"): + disk_backend.client.list_dir_or_file(tmp_dir, list_dir=False, suffix=[".txt", ".jpg"]) # 8. only list files ending with suffix recursively - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', - recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - } + assert set(disk_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt", recursive=True)) == { + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + "text1.txt", + "text2.txt", + } # 7. only list files ending with suffix assert set( - disk_backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - } - - @patch('petrel_client.client.Client', MockPetrelClient) - @pytest.mark.parametrize('backend,prefix', [('petrel', None), - (None, 's3')]) + disk_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"), recursive=True) + ) == { + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + } + + @patch("petrel_client.client.Client", MockPetrelClient) + @pytest.mark.parametrize("backend,prefix", [("petrel", None), (None, "s3")]) def test_petrel_backend(self, backend, prefix): petrel_backend = FileClient(backend=backend, prefix=prefix) @@ -309,222 +287,193 @@ def test_petrel_backend(self, backend, prefix): # `path_mapping` is either None or dict with pytest.raises(AssertionError): - FileClient('petrel', path_mapping=1) + FileClient("petrel", path_mapping=1) # test `_map_path` - petrel_dir = 's3://user/data' - petrel_backend = FileClient( - 'petrel', path_mapping={str(self.test_data_dir): petrel_dir}) - assert petrel_backend.client._map_path(str(self.img_path)) == \ - str(self.img_path).replace(str(self.test_data_dir), petrel_dir) + petrel_dir = "s3://user/data" + petrel_backend = FileClient("petrel", path_mapping={str(self.test_data_dir): petrel_dir}) + assert petrel_backend.client._map_path(str(self.img_path)) == str(self.img_path).replace( + str(self.test_data_dir), petrel_dir + ) - petrel_path = f'{petrel_dir}/test.jpg' - petrel_backend = FileClient('petrel') + petrel_path = f"{petrel_dir}/test.jpg" + petrel_backend = FileClient("petrel") # test `_format_path` - assert petrel_backend.client._format_path('s3://user\\data\\test.jpg')\ - == petrel_path + assert petrel_backend.client._format_path("s3://user\\data\\test.jpg") == petrel_path # test `get` - with patch.object( - petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get: - assert petrel_backend.get(petrel_path) == b'petrel' + with patch.object(petrel_backend.client._client, "Get", return_value=b"petrel") as mock_get: + assert petrel_backend.get(petrel_path) == b"petrel" mock_get.assert_called_once_with(petrel_path) # test `get_text` - with patch.object( - petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get: - assert petrel_backend.get_text(petrel_path) == 'petrel' + with patch.object(petrel_backend.client._client, "Get", return_value=b"petrel") as mock_get: + assert petrel_backend.get_text(petrel_path) == "petrel" mock_get.assert_called_once_with(petrel_path) # test `put` - with patch.object(petrel_backend.client._client, 'put') as mock_put: - petrel_backend.put(b'petrel', petrel_path) - mock_put.assert_called_once_with(petrel_path, b'petrel') + with patch.object(petrel_backend.client._client, "put") as mock_put: + petrel_backend.put(b"petrel", petrel_path) + mock_put.assert_called_once_with(petrel_path, b"petrel") # test `put_text` - with patch.object(petrel_backend.client._client, 'put') as mock_put: - petrel_backend.put_text('petrel', petrel_path) - mock_put.assert_called_once_with(petrel_path, b'petrel') + with patch.object(petrel_backend.client._client, "put") as mock_put: + petrel_backend.put_text("petrel", petrel_path) + mock_put.assert_called_once_with(petrel_path, b"petrel") # test `remove` - assert has_method(petrel_backend.client._client, 'delete') + assert has_method(petrel_backend.client._client, "delete") # raise Exception if `delete` is not implemented - with delete_and_reset_method(petrel_backend.client._client, 'delete'): - assert not has_method(petrel_backend.client._client, 'delete') + with delete_and_reset_method(petrel_backend.client._client, "delete"): + assert not has_method(petrel_backend.client._client, "delete") with pytest.raises(NotImplementedError): petrel_backend.remove(petrel_path) - with patch.object(petrel_backend.client._client, - 'delete') as mock_delete, \ - patch.object(petrel_backend.client._client, - 'isdir', return_value=False) as mock_isdir, \ - patch.object(petrel_backend.client._client, - 'contains', return_value=True) as mock_contains: + with ( + patch.object(petrel_backend.client._client, "delete") as mock_delete, + patch.object(petrel_backend.client._client, "isdir", return_value=False) as mock_isdir, + patch.object(petrel_backend.client._client, "contains", return_value=True) as mock_contains, + ): petrel_backend.remove(petrel_path) mock_delete.assert_called_once_with(petrel_path) mock_isdir.assert_called_once_with(petrel_path) mock_contains.assert_called_once_with(petrel_path) # test `exists` - assert has_method(petrel_backend.client._client, 'contains') - assert has_method(petrel_backend.client._client, 'isdir') + assert has_method(petrel_backend.client._client, "contains") + assert has_method(petrel_backend.client._client, "isdir") # raise Exception if `delete` is not implemented - with delete_and_reset_method(petrel_backend.client._client, - 'contains'), delete_and_reset_method( - petrel_backend.client._client, - 'isdir'): - assert not has_method(petrel_backend.client._client, 'contains') - assert not has_method(petrel_backend.client._client, 'isdir') + with ( + delete_and_reset_method(petrel_backend.client._client, "contains"), + delete_and_reset_method(petrel_backend.client._client, "isdir"), + ): + assert not has_method(petrel_backend.client._client, "contains") + assert not has_method(petrel_backend.client._client, "isdir") with pytest.raises(NotImplementedError): petrel_backend.exists(petrel_path) - with patch.object( - petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with patch.object(petrel_backend.client._client, "contains", return_value=True) as mock_contains: assert petrel_backend.exists(petrel_path) mock_contains.assert_called_once_with(petrel_path) # test `isdir` - assert has_method(petrel_backend.client._client, 'isdir') - with delete_and_reset_method(petrel_backend.client._client, 'isdir'): - assert not has_method(petrel_backend.client._client, 'isdir') + assert has_method(petrel_backend.client._client, "isdir") + with delete_and_reset_method(petrel_backend.client._client, "isdir"): + assert not has_method(petrel_backend.client._client, "isdir") with pytest.raises(NotImplementedError): petrel_backend.isdir(petrel_path) - with patch.object( - petrel_backend.client._client, 'isdir', - return_value=True) as mock_isdir: + with patch.object(petrel_backend.client._client, "isdir", return_value=True) as mock_isdir: assert petrel_backend.isdir(petrel_dir) mock_isdir.assert_called_once_with(petrel_dir) # test `isfile` - assert has_method(petrel_backend.client._client, 'contains') - with delete_and_reset_method(petrel_backend.client._client, - 'contains'): - assert not has_method(petrel_backend.client._client, 'contains') + assert has_method(petrel_backend.client._client, "contains") + with delete_and_reset_method(petrel_backend.client._client, "contains"): + assert not has_method(petrel_backend.client._client, "contains") with pytest.raises(NotImplementedError): petrel_backend.isfile(petrel_path) - with patch.object( - petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with patch.object(petrel_backend.client._client, "contains", return_value=True) as mock_contains: assert petrel_backend.isfile(petrel_path) mock_contains.assert_called_once_with(petrel_path) # test `join_path` - assert petrel_backend.join_path(petrel_dir, 'file') == \ - f'{petrel_dir}/file' - assert petrel_backend.join_path(f'{petrel_dir}/', 'file') == \ - f'{petrel_dir}/file' - assert petrel_backend.join_path(petrel_dir, 'dir', 'file') == \ - f'{petrel_dir}/dir/file' + assert petrel_backend.join_path(petrel_dir, "file") == f"{petrel_dir}/file" + assert petrel_backend.join_path(f"{petrel_dir}/", "file") == f"{petrel_dir}/file" + assert petrel_backend.join_path(petrel_dir, "dir", "file") == f"{petrel_dir}/dir/file" # test `get_local_path` - with patch.object(petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get, \ - patch.object(petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with ( + patch.object(petrel_backend.client._client, "Get", return_value=b"petrel") as mock_get, + patch.object(petrel_backend.client._client, "contains", return_value=True) as mock_contains, + ): with petrel_backend.get_local_path(petrel_path) as path: - assert Path(path).open('rb').read() == b'petrel' + assert Path(path).open("rb").read() == b"petrel" # exist the with block and path will be released assert not osp.isfile(path) mock_get.assert_called_once_with(petrel_path) mock_contains.assert_called_once_with(petrel_path) # test `list_dir_or_file` - assert has_method(petrel_backend.client._client, 'list') - with delete_and_reset_method(petrel_backend.client._client, 'list'): - assert not has_method(petrel_backend.client._client, 'list') + assert has_method(petrel_backend.client._client, "list") + with delete_and_reset_method(petrel_backend.client._client, "list"): + assert not has_method(petrel_backend.client._client, "list") with pytest.raises(NotImplementedError): list(petrel_backend.list_dir_or_file(petrel_dir)) with build_temporary_directory() as tmp_dir: # 1. list directories and files - assert set(petrel_backend.list_dir_or_file(tmp_dir)) == { - 'dir1', 'dir2', 'text1.txt', 'text2.txt' - } + assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {"dir1", "dir2", "text1.txt", "text2.txt"} # 2. list directories and files recursively - assert set( - petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == { - 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join( - ('dir2', 'dir3')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - } + assert set(petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == { + "dir1", + "/".join(("dir1", "text3.txt")), + "dir2", + "/".join(("dir2", "dir3")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + } # 3. only list directories - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_file=False)) == {'dir1', 'dir2'} - with pytest.raises( - TypeError, - match=('`list_dir` should be False when `suffix` is not ' - 'None')): + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_file=False)) == {"dir1", "dir2"} + with pytest.raises(TypeError, match=("`list_dir` should be False when `suffix` is not None")): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception - petrel_backend.client.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + petrel_backend.client.list_dir_or_file(tmp_dir, list_file=False, suffix=".txt") # 4. only list directories recursively - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)) == { - 'dir1', 'dir2', '/'.join(('dir2', 'dir3')) - } + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_file=False, recursive=True)) == { + "dir1", + "dir2", + "/".join(("dir2", "dir3")), + } # 5. only list files - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_dir=False)) == {"text1.txt", "text2.txt"} # 6. only list files recursively - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, recursive=True)) == { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - } + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_dir=False, recursive=True)) == { + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + } # 7. only list files ending with suffix - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} - with pytest.raises( - TypeError, - match='`suffix` must be a string or tuple of strings'): - petrel_backend.client.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt")) == { + "text1.txt", + "text2.txt", + } + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"))) == { + "text1.txt", + "text2.txt", + } + with pytest.raises(TypeError, match="`suffix` must be a string or tuple of strings"): + petrel_backend.client.list_dir_or_file(tmp_dir, list_dir=False, suffix=[".txt", ".jpg"]) # 8. only list files ending with suffix recursively - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', - recursive=True)) == { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), 'text1.txt', - 'text2.txt' - } + assert set(petrel_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt", recursive=True)) == { + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "text1.txt", + "text2.txt", + } # 7. only list files ending with suffix assert set( - petrel_backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)) == { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), '/'.join( - ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' - } - - @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) - @patch('mc.pyvector', MagicMock) - @patch('mc.ConvertBuffer', lambda x: x.content) + petrel_backend.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"), recursive=True) + ) == { + "/".join(("dir1", "text3.txt")), + "/".join(("dir2", "dir3", "text4.txt")), + "/".join(("dir2", "img.jpg")), + "text1.txt", + "text2.txt", + } + + @patch("mc.MemcachedClient.GetInstance", MockMemcachedClient) + @patch("mc.pyvector", MagicMock) + @patch("mc.ConvertBuffer", lambda x: x.content) def test_memcached_backend(self): - mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None) - mc_backend = FileClient('memcached', **mc_cfg) + mc_cfg = dict(server_list_cfg="", client_cfg="", sys_path=None) + mc_backend = FileClient("memcached", **mc_cfg) # test `allow_symlink` attribute assert not mc_backend.allow_symlink @@ -546,10 +495,10 @@ def test_memcached_backend(self): assert img.shape == self.img_shape def test_lmdb_backend(self): - lmdb_path = self.test_data_dir / 'demo.lmdb' + lmdb_path = self.test_data_dir / "demo.lmdb" # db_path is Path object - lmdb_backend = FileClient('lmdb', db_path=lmdb_path) + lmdb_backend = FileClient("lmdb", db_path=lmdb_path) # test `allow_symlink` attribute assert not lmdb_backend.allow_symlink @@ -557,26 +506,23 @@ def test_lmdb_backend(self): with pytest.raises(NotImplementedError): lmdb_backend.get_text(self.text_path) - img_bytes = lmdb_backend.get('baboon') + img_bytes = lmdb_backend.get("baboon") img = imfrombytes(img_bytes) assert img.shape == (120, 125, 3) # db_path is str - lmdb_backend = FileClient('lmdb', db_path=str(lmdb_path)) + lmdb_backend = FileClient("lmdb", db_path=str(lmdb_path)) with pytest.raises(NotImplementedError): lmdb_backend.get_text(str(self.text_path)) - img_bytes = lmdb_backend.get('baboon') + img_bytes = lmdb_backend.get("baboon") img = imfrombytes(img_bytes) assert img.shape == (120, 125, 3) - @pytest.mark.parametrize('backend,prefix', [('http', None), - (None, 'http')]) + @pytest.mark.parametrize("backend,prefix", [("http", None), (None, "http")]) def test_http_backend(self, backend, prefix): http_backend = FileClient(backend=backend, prefix=prefix) - img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ - 'master/tests/data/color.jpg' - text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ - 'master/tests/data/filelist.txt' + img_url = "https://raw.githubusercontent.com/open-mmlab/mmcv/master/tests/data/color.jpg" + text_url = "https://raw.githubusercontent.com/open-mmlab/mmcv/master/tests/data/filelist.txt" # test `allow_symlink` attribute assert not http_backend.allow_symlink @@ -598,44 +544,41 @@ def test_http_backend(self, backend, prefix): # input url is http text value_buf = http_backend.get_text(text_url) - assert self.text_path.open('r').read() == value_buf + assert self.text_path.open("r").read() == value_buf # test `_get_local_path` # exist the with block and path will be released with http_backend.get_local_path(img_url) as path: - img_bytes = Path(path).open('rb').read() + img_bytes = Path(path).open("rb").read() img = imfrombytes(img_bytes) assert img.shape == self.img_shape assert not osp.isfile(path) def test_new_magic_method(self): - class DummyBackend1(BaseStorageBackend): - def get(self, filepath): return filepath - def get_text(self, filepath, encoding='utf-8'): + def get_text(self, filepath, encoding="utf-8"): return filepath - FileClient.register_backend('dummy_backend', DummyBackend1) - client1 = FileClient(backend='dummy_backend') - client2 = FileClient(backend='dummy_backend') + FileClient.register_backend("dummy_backend", DummyBackend1) + client1 = FileClient(backend="dummy_backend") + client2 = FileClient(backend="dummy_backend") assert client1 is client2 # if a backend is overwrote, it will disable the singleton pattern for # the backend class DummyBackend2(BaseStorageBackend): - def get(self, filepath): pass def get_text(self, filepath): pass - FileClient.register_backend('dummy_backend', DummyBackend2, force=True) - client3 = FileClient(backend='dummy_backend') - client4 = FileClient(backend='dummy_backend') + FileClient.register_backend("dummy_backend", DummyBackend2, force=True) + client3 = FileClient(backend="dummy_backend") + client4 = FileClient(backend="dummy_backend") assert client2 is not client3 assert client3 is client4 @@ -653,36 +596,34 @@ def test_parse_uri_prefix(self): assert FileClient.parse_uri_prefix(str(self.img_path)) is None # input path starts with https - img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ - 'master/tests/data/color.jpg' - assert FileClient.parse_uri_prefix(img_url) == 'https' + img_url = "https://raw.githubusercontent.com/open-mmlab/mmcv/master/tests/data/color.jpg" + assert FileClient.parse_uri_prefix(img_url) == "https" # input path starts with s3 - img_url = 's3://your_bucket/img.png' - assert FileClient.parse_uri_prefix(img_url) == 's3' + img_url = "s3://your_bucket/img.png" + assert FileClient.parse_uri_prefix(img_url) == "s3" # input path starts with clusterName:s3 - img_url = 'clusterName:s3://your_bucket/img.png' - assert FileClient.parse_uri_prefix(img_url) == 's3' + img_url = "clusterName:s3://your_bucket/img.png" + assert FileClient.parse_uri_prefix(img_url) == "s3" def test_infer_client(self): # HardDiskBackend - file_client_args = {'backend': 'disk'} + file_client_args = {"backend": "disk"} client = FileClient.infer_client(file_client_args) - assert client.name == 'HardDiskBackend' + assert client.name == "HardDiskBackend" client = FileClient.infer_client(uri=self.img_path) - assert client.name == 'HardDiskBackend' + assert client.name == "HardDiskBackend" # PetrelBackend - file_client_args = {'backend': 'petrel'} + file_client_args = {"backend": "petrel"} client = FileClient.infer_client(file_client_args) - assert client.name == 'PetrelBackend' - uri = 's3://user_data' + assert client.name == "PetrelBackend" + uri = "s3://user_data" client = FileClient.infer_client(uri=uri) - assert client.name == 'PetrelBackend' + assert client.name == "PetrelBackend" def test_register_backend(self): - # name must be a string with pytest.raises(TypeError): @@ -693,7 +634,7 @@ class TestClass1: # module must be a class with pytest.raises(TypeError): - FileClient.register_backend('int', 0) + FileClient.register_backend("int", 0) # module must be a subclass of BaseStorageBackend with pytest.raises(TypeError): @@ -701,143 +642,125 @@ class TestClass1: class TestClass1: pass - FileClient.register_backend('TestClass1', TestClass1) + FileClient.register_backend("TestClass1", TestClass1) class ExampleBackend(BaseStorageBackend): - def get(self, filepath): return filepath - def get_text(self, filepath, encoding='utf-8'): + def get_text(self, filepath, encoding="utf-8"): return filepath - FileClient.register_backend('example', ExampleBackend) - example_backend = FileClient('example') + FileClient.register_backend("example", ExampleBackend) + example_backend = FileClient("example") assert example_backend.get(self.img_path) == self.img_path assert example_backend.get_text(self.text_path) == self.text_path - assert 'example' in FileClient._backends + assert "example" in FileClient._backends class Example2Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes2' + return b"bytes2" - def get_text(self, filepath, encoding='utf-8'): - return 'text2' + def get_text(self, filepath, encoding="utf-8"): + return "text2" # force=False with pytest.raises(KeyError): - FileClient.register_backend('example', Example2Backend) + FileClient.register_backend("example", Example2Backend) - FileClient.register_backend('example', Example2Backend, force=True) - example_backend = FileClient('example') - assert example_backend.get(self.img_path) == b'bytes2' - assert example_backend.get_text(self.text_path) == 'text2' + FileClient.register_backend("example", Example2Backend, force=True) + example_backend = FileClient("example") + assert example_backend.get(self.img_path) == b"bytes2" + assert example_backend.get_text(self.text_path) == "text2" - @FileClient.register_backend(name='example3') + @FileClient.register_backend(name="example3") class Example3Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes3' + return b"bytes3" - def get_text(self, filepath, encoding='utf-8'): - return 'text3' + def get_text(self, filepath, encoding="utf-8"): + return "text3" - example_backend = FileClient('example3') - assert example_backend.get(self.img_path) == b'bytes3' - assert example_backend.get_text(self.text_path) == 'text3' - assert 'example3' in FileClient._backends + example_backend = FileClient("example3") + assert example_backend.get(self.img_path) == b"bytes3" + assert example_backend.get_text(self.text_path) == "text3" + assert "example3" in FileClient._backends # force=False with pytest.raises(KeyError): - @FileClient.register_backend(name='example3') + @FileClient.register_backend(name="example3") class Example4Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes4' + return b"bytes4" - def get_text(self, filepath, encoding='utf-8'): - return 'text4' + def get_text(self, filepath, encoding="utf-8"): + return "text4" - @FileClient.register_backend(name='example3', force=True) + @FileClient.register_backend(name="example3", force=True) class Example5Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes5' + return b"bytes5" - def get_text(self, filepath, encoding='utf-8'): - return 'text5' + def get_text(self, filepath, encoding="utf-8"): + return "text5" - example_backend = FileClient('example3') - assert example_backend.get(self.img_path) == b'bytes5' - assert example_backend.get_text(self.text_path) == 'text5' + example_backend = FileClient("example3") + assert example_backend.get(self.img_path) == b"bytes5" + assert example_backend.get_text(self.text_path) == "text5" # prefixes is a str class Example6Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes6' - - def get_text(self, filepath, encoding='utf-8'): - return 'text6' - - FileClient.register_backend( - 'example4', - Example6Backend, - force=True, - prefixes='example4_prefix') - example_backend = FileClient('example4') - assert example_backend.get(self.img_path) == b'bytes6' - assert example_backend.get_text(self.text_path) == 'text6' - example_backend = FileClient(prefix='example4_prefix') - assert example_backend.get(self.img_path) == b'bytes6' - assert example_backend.get_text(self.text_path) == 'text6' - example_backend = FileClient('example4', prefix='example4_prefix') - assert example_backend.get(self.img_path) == b'bytes6' - assert example_backend.get_text(self.text_path) == 'text6' + return b"bytes6" + + def get_text(self, filepath, encoding="utf-8"): + return "text6" + + FileClient.register_backend("example4", Example6Backend, force=True, prefixes="example4_prefix") + example_backend = FileClient("example4") + assert example_backend.get(self.img_path) == b"bytes6" + assert example_backend.get_text(self.text_path) == "text6" + example_backend = FileClient(prefix="example4_prefix") + assert example_backend.get(self.img_path) == b"bytes6" + assert example_backend.get_text(self.text_path) == "text6" + example_backend = FileClient("example4", prefix="example4_prefix") + assert example_backend.get(self.img_path) == b"bytes6" + assert example_backend.get_text(self.text_path) == "text6" # prefixes is a list of str class Example7Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes7' + return b"bytes7" - def get_text(self, filepath, encoding='utf-8'): - return 'text7' + def get_text(self, filepath, encoding="utf-8"): + return "text7" FileClient.register_backend( - 'example5', - Example7Backend, - force=True, - prefixes=['example5_prefix1', 'example5_prefix2']) - example_backend = FileClient('example5') - assert example_backend.get(self.img_path) == b'bytes7' - assert example_backend.get_text(self.text_path) == 'text7' - example_backend = FileClient(prefix='example5_prefix1') - assert example_backend.get(self.img_path) == b'bytes7' - assert example_backend.get_text(self.text_path) == 'text7' - example_backend = FileClient(prefix='example5_prefix2') - assert example_backend.get(self.img_path) == b'bytes7' - assert example_backend.get_text(self.text_path) == 'text7' + "example5", Example7Backend, force=True, prefixes=["example5_prefix1", "example5_prefix2"] + ) + example_backend = FileClient("example5") + assert example_backend.get(self.img_path) == b"bytes7" + assert example_backend.get_text(self.text_path) == "text7" + example_backend = FileClient(prefix="example5_prefix1") + assert example_backend.get(self.img_path) == b"bytes7" + assert example_backend.get_text(self.text_path) == "text7" + example_backend = FileClient(prefix="example5_prefix2") + assert example_backend.get(self.img_path) == b"bytes7" + assert example_backend.get_text(self.text_path) == "text7" # backend has a higher priority than prefixes class Example8Backend(BaseStorageBackend): - def get(self, filepath): - return b'bytes8' - - def get_text(self, filepath, encoding='utf-8'): - return 'text8' - - FileClient.register_backend( - 'example6', - Example8Backend, - force=True, - prefixes='example6_prefix') - example_backend = FileClient('example6') - assert example_backend.get(self.img_path) == b'bytes8' - assert example_backend.get_text(self.text_path) == 'text8' - example_backend = FileClient('example6', prefix='example4_prefix') - assert example_backend.get(self.img_path) == b'bytes8' - assert example_backend.get_text(self.text_path) == 'text8' + return b"bytes8" + + def get_text(self, filepath, encoding="utf-8"): + return "text8" + + FileClient.register_backend("example6", Example8Backend, force=True, prefixes="example6_prefix") + example_backend = FileClient("example6") + assert example_backend.get(self.img_path) == b"bytes8" + assert example_backend.get_text(self.text_path) == "text8" + example_backend = FileClient("example6", prefix="example4_prefix") + assert example_backend.get(self.img_path) == b"bytes8" + assert example_backend.get_text(self.text_path) == "text8" diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index 33a0956fed..1af7457939 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -10,19 +10,20 @@ import mmengine from mmengine.fileio import HTTPBackend, PetrelBackend -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() + +sys.modules["petrel_client"] = MagicMock() +sys.modules["petrel_client.client"] = MagicMock() test_data_dir = osp.dirname(osp.dirname(__file__)) -def _test_handler(file_format, test_obj, str_checker, mode='r+'): +def _test_handler(file_format, test_obj, str_checker, mode="r+"): # dump to a string dump_str = mmengine.dump(test_obj, file_format=file_format) str_checker(dump_str) # load/dump with filenames from disk - tmp_filename = osp.join(tempfile.gettempdir(), 'mmengine_test_dump') + tmp_filename = osp.join(tempfile.gettempdir(), "mmengine_test_dump") mmengine.dump(test_obj, tmp_filename, file_format=file_format) assert osp.isfile(tmp_filename) load_obj = mmengine.load(tmp_filename, file_format=file_format) @@ -30,9 +31,9 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): os.remove(tmp_filename) # load/dump with filename from petrel - method = 'put' if 'b' in mode else 'put_text' + method = "put" if "b" in mode else "put_text" with patch.object(PetrelBackend, method, return_value=None) as mock_method: - filename = 's3://path/of/your/file' + filename = "s3://path/of/your/file" mmengine.dump(test_obj, filename, file_format=file_format) mock_method.assert_called() @@ -47,8 +48,7 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): os.remove(tmp_filename) # automatically inference the file format from the given filename - tmp_filename = osp.join(tempfile.gettempdir(), - 'mmengine_test_dump.' + file_format) + tmp_filename = osp.join(tempfile.gettempdir(), "mmengine_test_dump." + file_format) mmengine.dump(test_obj, tmp_filename) assert osp.isfile(tmp_filename) load_obj = mmengine.load(tmp_filename) @@ -56,54 +56,50 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): os.remove(tmp_filename) -obj_for_test = [{'a': 'abc', 'b': 1}, 2, 'c'] +obj_for_test = [{"a": "abc", "b": 1}, 2, "c"] def test_json(): - def json_checker(dump_str): - assert dump_str in [ - '[{"a": "abc", "b": 1}, 2, "c"]', '[{"b": 1, "a": "abc"}, 2, "c"]' - ] + assert dump_str in ['[{"a": "abc", "b": 1}, 2, "c"]', '[{"b": 1, "a": "abc"}, 2, "c"]'] - _test_handler('json', obj_for_test, json_checker) + _test_handler("json", obj_for_test, json_checker) def test_yaml(): - def yaml_checker(dump_str): assert dump_str in [ - '- {a: abc, b: 1}\n- 2\n- c\n', '- {b: 1, a: abc}\n- 2\n- c\n', - '- a: abc\n b: 1\n- 2\n- c\n', '- b: 1\n a: abc\n- 2\n- c\n' + "- {a: abc, b: 1}\n- 2\n- c\n", + "- {b: 1, a: abc}\n- 2\n- c\n", + "- a: abc\n b: 1\n- 2\n- c\n", + "- b: 1\n a: abc\n- 2\n- c\n", ] - _test_handler('yaml', obj_for_test, yaml_checker) + _test_handler("yaml", obj_for_test, yaml_checker) def test_pickle(): - def pickle_checker(dump_str): import pickle + assert pickle.loads(dump_str) == obj_for_test - _test_handler('pickle', obj_for_test, pickle_checker, mode='rb+') + _test_handler("pickle", obj_for_test, pickle_checker, mode="rb+") def test_exception(): - test_obj = [{'a': 'abc', 'b': 1}, 2, 'c'] + test_obj = [{"a": "abc", "b": 1}, 2, "c"] with pytest.raises(ValueError): mmengine.dump(test_obj) with pytest.raises(TypeError): - mmengine.dump(test_obj, 'tmp.txt') + mmengine.dump(test_obj, "tmp.txt") def test_register_handler(): - - @mmengine.register_handler('txt') + @mmengine.register_handler("txt") class TxtHandler1(mmengine.BaseFileHandler): - def load_from_fileobj(self, file): return file.read() @@ -113,116 +109,98 @@ def dump_to_fileobj(self, obj, file): def dump_to_str(self, obj, **kwargs): return str(obj) - @mmengine.register_handler(['txt1', 'txt2']) + @mmengine.register_handler(["txt1", "txt2"]) class TxtHandler2(mmengine.BaseFileHandler): - def load_from_fileobj(self, file): return file.read() def dump_to_fileobj(self, obj, file): - file.write('\n') + file.write("\n") file.write(str(obj)) def dump_to_str(self, obj, **kwargs): return str(obj) - content = mmengine.load(osp.join(test_data_dir, 'data/filelist.txt')) - assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg' - tmp_filename = osp.join(tempfile.gettempdir(), 'mmengine_test.txt2') + content = mmengine.load(osp.join(test_data_dir, "data/filelist.txt")) + assert content == "1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg" + tmp_filename = osp.join(tempfile.gettempdir(), "mmengine_test.txt2") mmengine.dump(content, tmp_filename) with open(tmp_filename) as f: written = f.read() os.remove(tmp_filename) - assert written == '\n' + content + assert written == "\n" + content def test_list_from_file(): # get list from disk - filename = osp.join(test_data_dir, 'data/filelist.txt') + filename = osp.join(test_data_dir, "data/filelist.txt") filelist = mmengine.list_from_file(filename) - assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg'] - filelist = mmengine.list_from_file(filename, prefix='a/') - assert filelist == ['a/1.jpg', 'a/2.jpg', 'a/3.jpg', 'a/4.jpg', 'a/5.jpg'] + assert filelist == ["1.jpg", "2.jpg", "3.jpg", "4.jpg", "5.jpg"] + filelist = mmengine.list_from_file(filename, prefix="a/") + assert filelist == ["a/1.jpg", "a/2.jpg", "a/3.jpg", "a/4.jpg", "a/5.jpg"] filelist = mmengine.list_from_file(filename, offset=2) - assert filelist == ['3.jpg', '4.jpg', '5.jpg'] + assert filelist == ["3.jpg", "4.jpg", "5.jpg"] filelist = mmengine.list_from_file(filename, max_num=2) - assert filelist == ['1.jpg', '2.jpg'] + assert filelist == ["1.jpg", "2.jpg"] filelist = mmengine.list_from_file(filename, offset=3, max_num=3) - assert filelist == ['4.jpg', '5.jpg'] + assert filelist == ["4.jpg", "5.jpg"] # get list from http - filename = 'http://path/of/your/file' - with patch.object( - HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): - filelist = mmengine.list_from_file( - filename, file_client_args={'backend': 'http'}) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, file_client_args={'prefix': 'http'}) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filename = "http://path/of/your/file" + with patch.object(HTTPBackend, "get_text", return_value="1.jpg\n2.jpg\n3.jpg"): + filelist = mmengine.list_from_file(filename, file_client_args={"backend": "http"}) + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] + filelist = mmengine.list_from_file(filename, file_client_args={"prefix": "http"}) + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] filelist = mmengine.list_from_file(filename) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, backend_args={'backend': 'http'}) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] + filelist = mmengine.list_from_file(filename, backend_args={"backend": "http"}) + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] # get list from petrel - filename = 's3://path/of/your/file' - with patch.object( - PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): - filelist = mmengine.list_from_file( - filename, file_client_args={'backend': 'petrel'}) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, file_client_args={'prefix': 's3'}) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filename = "s3://path/of/your/file" + with patch.object(PetrelBackend, "get_text", return_value="1.jpg\n2.jpg\n3.jpg"): + filelist = mmengine.list_from_file(filename, file_client_args={"backend": "petrel"}) + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] + filelist = mmengine.list_from_file(filename, file_client_args={"prefix": "s3"}) + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] filelist = mmengine.list_from_file(filename) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, backend_args={'backend': 'petrel'}) - assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] + filelist = mmengine.list_from_file(filename, backend_args={"backend": "petrel"}) + assert filelist == ["1.jpg", "2.jpg", "3.jpg"] def test_dict_from_file(): # get dict from disk - filename = osp.join(test_data_dir, 'data/mapping.txt') + filename = osp.join(test_data_dir, "data/mapping.txt") mapping = mmengine.dict_from_file(filename) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} mapping = mmengine.dict_from_file(filename, key_type=int) - assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} + assert mapping == {1: "cat", 2: ["dog", "cow"], 3: "panda"} # get dict from http - filename = 'http://path/of/your/file' - with patch.object( - HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - mapping = mmengine.dict_from_file( - filename, file_client_args={'backend': 'http'}) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, file_client_args={'prefix': 'http'}) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + filename = "http://path/of/your/file" + with patch.object(HTTPBackend, "get_text", return_value="1 cat\n2 dog cow\n3 panda"): + mapping = mmengine.dict_from_file(filename, file_client_args={"backend": "http"}) + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} + mapping = mmengine.dict_from_file(filename, file_client_args={"prefix": "http"}) + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} mapping = mmengine.dict_from_file(filename) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, backend_args={'backend': 'http'}) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} + mapping = mmengine.dict_from_file(filename, backend_args={"backend": "http"}) + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} # get dict from petrel - filename = 's3://path/of/your/file' - with patch.object( - PetrelBackend, 'get_text', - return_value='1 cat\n2 dog cow\n3 panda'): - mapping = mmengine.dict_from_file( - filename, file_client_args={'backend': 'petrel'}) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, file_client_args={'prefix': 's3'}) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + filename = "s3://path/of/your/file" + with patch.object(PetrelBackend, "get_text", return_value="1 cat\n2 dog cow\n3 panda"): + mapping = mmengine.dict_from_file(filename, file_client_args={"backend": "petrel"}) + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} + mapping = mmengine.dict_from_file(filename, file_client_args={"prefix": "s3"}) + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} mapping = mmengine.dict_from_file(filename) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, backend_args={'backend': 'petrel'}) - assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} + mapping = mmengine.dict_from_file(filename, backend_args={"backend": "petrel"}) + assert mapping == {"1": "cat", "2": ["dog", "cow"], "3": "panda"} diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py index c34af47e0b..6ff93a834a 100644 --- a/tests/test_fileio/test_io.py +++ b/tests/test_fileio/test_io.py @@ -13,13 +13,14 @@ import mmengine.fileio as fileio -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() -test_data_dir = Path(__file__).parent.parent / 'data' -text_path = test_data_dir / 'filelist.txt' -img_path = test_data_dir / 'color.jpg' -img_url = 'https://raw.githubusercontent.com/mmengine/tests/data/img.png' +sys.modules["petrel_client"] = MagicMock() +sys.modules["petrel_client.client"] = MagicMock() + +test_data_dir = Path(__file__).parent.parent / "data" +text_path = test_data_dir / "filelist.txt" +img_path = test_data_dir / "color.jpg" +img_url = "https://raw.githubusercontent.com/mmengine/tests/data/img.png" @contextmanager @@ -38,22 +39,22 @@ def build_temporary_directory(): | -- text2.txt \n """ with tempfile.TemporaryDirectory() as tmp_dir: - text1 = Path(tmp_dir) / 'text1.txt' - text1.open('w').write('text1') - text2 = Path(tmp_dir) / 'text2.txt' - text2.open('w').write('text2') - dir1 = Path(tmp_dir) / 'dir1' + text1 = Path(tmp_dir) / "text1.txt" + text1.open("w").write("text1") + text2 = Path(tmp_dir) / "text2.txt" + text2.open("w").write("text2") + dir1 = Path(tmp_dir) / "dir1" dir1.mkdir() - text3 = dir1 / 'text3.txt' - text3.open('w').write('text3') - dir2 = Path(tmp_dir) / 'dir2' + text3 = dir1 / "text3.txt" + text3.open("w").write("text3") + dir2 = Path(tmp_dir) / "dir2" dir2.mkdir() - jpg1 = dir2 / 'img.jpg' - jpg1.open('wb').write(b'img') - dir3 = dir2 / 'dir3' + jpg1 = dir2 / "img.jpg" + jpg1.open("wb").write(b"img") + dir3 = dir2 / "dir3" dir3.mkdir() - text4 = dir3 / 'text4.txt' - text4.open('w').write('text4') + text4 = dir3 / "text4.txt" + text4.open("w").write("text4") yield tmp_dir @@ -67,18 +68,18 @@ def test_parse_uri_prefix(): fileio.io._parse_uri_prefix([]) # input path is Path object - assert fileio.io._parse_uri_prefix(uri=text_path) == '' + assert fileio.io._parse_uri_prefix(uri=text_path) == "" # input path starts with https - assert fileio.io._parse_uri_prefix(uri=img_url) == 'https' + assert fileio.io._parse_uri_prefix(uri=img_url) == "https" # input path starts with s3 - uri = 's3://your_bucket/img.png' - assert fileio.io._parse_uri_prefix(uri) == 's3' + uri = "s3://your_bucket/img.png" + assert fileio.io._parse_uri_prefix(uri) == "s3" # input path starts with clusterName:s3 - uri = 'clusterName:s3://your_bucket/img.png' - assert fileio.io._parse_uri_prefix(uri) == 's3' + uri = "clusterName:s3://your_bucket/img.png" + assert fileio.io._parse_uri_prefix(uri) == "s3" def test_get_file_backend(): @@ -86,37 +87,37 @@ def test_get_file_backend(): fileio.io.backend_instances = {} # uri should not be None when "backend" does not exist in backend_args - with pytest.raises(ValueError, match='uri should not be None'): + with pytest.raises(ValueError, match="uri should not be None"): fileio.get_file_backend(None, backend_args=None) # uri is not None backend = fileio.get_file_backend(uri=text_path) assert isinstance(backend, fileio.backends.LocalBackend) - uri = 'petrel://your_bucket/img.png' + uri = "petrel://your_bucket/img.png" backend = fileio.get_file_backend(uri=uri) assert isinstance(backend, fileio.backends.PetrelBackend) backend = fileio.get_file_backend(uri=img_url) assert isinstance(backend, fileio.backends.HTTPBackend) - uri = 'http://raw.githubusercontent.com/mmengine/tests/data/img.png' + uri = "http://raw.githubusercontent.com/mmengine/tests/data/img.png" backend = fileio.get_file_backend(uri=uri) assert isinstance(backend, fileio.backends.HTTPBackend) # backend_args is not None and it contains a backend name - backend_args = {'backend': 'local'} + backend_args = {"backend": "local"} backend = fileio.get_file_backend(uri=None, backend_args=backend_args) assert isinstance(backend, fileio.backends.LocalBackend) # backend_args should not be modified - assert backend_args == {'backend': 'local'} + assert backend_args == {"backend": "local"} - backend_args = {'backend': 'petrel', 'enable_mc': True} + backend_args = {"backend": "petrel", "enable_mc": True} backend = fileio.get_file_backend(uri=None, backend_args=backend_args) assert isinstance(backend, fileio.backends.PetrelBackend) - assert backend_args == {'backend': 'petrel', 'enable_mc': True} + assert backend_args == {"backend": "petrel", "enable_mc": True} # backend name has a higher priority - backend_args = {'backend': 'http'} + backend_args = {"backend": "http"} backend = fileio.get_file_backend(uri=text_path, backend_args=backend_args) assert isinstance(backend, fileio.backends.HTTPBackend) @@ -125,7 +126,7 @@ def test_get_file_backend(): backend1 = fileio.get_file_backend(uri=text_path, enable_singleton=True) assert isinstance(backend1, fileio.backends.LocalBackend) assert len(fileio.io.backend_instances) == 1 - assert fileio.io.backend_instances[':{}'] is backend1 + assert fileio.io.backend_instances[":{}"] is backend1 backend2 = fileio.get_file_backend(uri=text_path, enable_singleton=True) assert isinstance(backend2, fileio.backends.LocalBackend) @@ -137,27 +138,24 @@ def test_get_file_backend(): assert len(fileio.io.backend_instances) == 1 assert backend3 is not backend2 - backend_args = {'path_mapping': {'src': 'dst'}, 'enable_mc': True} - uri = 'petrel://your_bucket/img.png' - backend4 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend_args = {"path_mapping": {"src": "dst"}, "enable_mc": True} + uri = "petrel://your_bucket/img.png" + backend4 = fileio.get_file_backend(uri=uri, backend_args=backend_args, enable_singleton=True) assert isinstance(backend4, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 2 unique_key = 'petrel:{"path_mapping": {"src": "dst"}, "enable_mc": true}' assert fileio.io.backend_instances[unique_key] is backend4 assert backend4 is not backend2 - uri = 'petrel://your_bucket/img1.png' - backend5 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + uri = "petrel://your_bucket/img1.png" + backend5 = fileio.get_file_backend(uri=uri, backend_args=backend_args, enable_singleton=True) assert isinstance(backend5, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 2 assert backend5 is backend4 assert backend5 is not backend2 - backend_args = {'path_mapping': {'src1': 'dst1'}, 'enable_mc': True} - backend6 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend_args = {"path_mapping": {"src1": "dst1"}, "enable_mc": True} + backend6 = fileio.get_file_backend(uri=uri, backend_args=backend_args, enable_singleton=True) assert isinstance(backend6, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 3 unique_key = 'petrel:{"path_mapping": {"src1": "dst1"}, "enable_mc": true}' @@ -165,8 +163,7 @@ def test_get_file_backend(): assert backend6 is not backend4 assert backend6 is not backend5 - backend7 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=False) + backend7 = fileio.get_file_backend(uri=uri, backend_args=backend_args, enable_singleton=False) assert isinstance(backend7, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 3 assert backend7 is not backend6 @@ -176,51 +173,51 @@ def test_get(): # test LocalBackend filepath = Path(img_path) img_bytes = fileio.get(filepath) - assert filepath.open('rb').read() == img_bytes + assert filepath.open("rb").read() == img_bytes def test_get_text(): # test LocalBackend filepath = Path(text_path) text = fileio.get_text(filepath) - assert filepath.open('r').read() == text + assert filepath.open("r").read() == text def test_put(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: - filepath = Path(tmp_dir) / 'img.png' - fileio.put(b'disk', filepath) - assert fileio.get(filepath) == b'disk' + filepath = Path(tmp_dir) / "img.png" + fileio.put(b"disk", filepath) + assert fileio.get(filepath) == b"disk" # If the directory does not exist, put will create a # directory first - filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' - fileio.put(b'disk', filepath) - assert fileio.get(filepath) == b'disk' + filepath = Path(tmp_dir) / "not_existed_dir" / "test.jpg" + fileio.put(b"disk", filepath) + assert fileio.get(filepath) == b"disk" def test_put_text(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: - filepath = Path(tmp_dir) / 'text.txt' - fileio.put_text('text', filepath) - assert fileio.get_text(filepath) == 'text' + filepath = Path(tmp_dir) / "text.txt" + fileio.put_text("text", filepath) + assert fileio.get_text(filepath) == "text" # If the directory does not exist, put_text will create a # directory first - filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.txt' - fileio.put_text('disk', filepath) - assert fileio.get_text(filepath) == 'disk' + filepath = Path(tmp_dir) / "not_existed_dir" / "test.txt" + fileio.put_text("disk", filepath) + assert fileio.get_text(filepath) == "disk" def test_exists(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: assert fileio.exists(tmp_dir) - filepath = Path(tmp_dir) / 'test.txt' + filepath = Path(tmp_dir) / "test.txt" assert not fileio.exists(filepath) - fileio.put_text('disk', filepath) + fileio.put_text("disk", filepath) assert fileio.exists(filepath) @@ -228,8 +225,8 @@ def test_isdir(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: assert fileio.isdir(tmp_dir) - filepath = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', filepath) + filepath = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", filepath) assert not fileio.isdir(filepath) @@ -237,19 +234,19 @@ def test_isfile(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: assert not fileio.isfile(tmp_dir) - filepath = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', filepath) + filepath = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", filepath) assert fileio.isfile(filepath) def test_join_path(): # test LocalBackend - filepath = fileio.join_path(test_data_dir, 'file') - expected = osp.join(test_data_dir, 'file') + filepath = fileio.join_path(test_data_dir, "file") + expected = osp.join(test_data_dir, "file") assert filepath == expected - filepath = fileio.join_path(test_data_dir, 'dir', 'file') - expected = osp.join(test_data_dir, 'dir', 'file') + filepath = fileio.join_path(test_data_dir, "dir", "file") + expected = osp.join(test_data_dir, "dir", "file") assert filepath == expected @@ -262,17 +259,17 @@ def test_get_local_path(): def test_copyfile(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: - src = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', src) - dst = Path(tmp_dir) / 'test.txt.bak' + src = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", src) + dst = Path(tmp_dir) / "test.txt.bak" assert fileio.copyfile(src, dst) == dst - assert fileio.get_text(dst) == 'disk' + assert fileio.get_text(dst) == "disk" # dst is a directory - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() - assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') - assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + assert fileio.copyfile(src, dst) == fileio.join_path(dst, "test.txt") + assert fileio.get_text(fileio.join_path(dst, "test.txt")) == "disk" # src and src should not be same file with pytest.raises(SameFileError): @@ -283,31 +280,31 @@ def test_copytree(): # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects - src = Path(tmp_dir) / 'dir1' - dst = Path(tmp_dir) / 'dir100' + src = Path(tmp_dir) / "dir1" + dst = Path(tmp_dir) / "dir100" assert fileio.copytree(src, dst) == dst assert fileio.isdir(dst) - assert fileio.isfile(dst / 'text3.txt') - assert fileio.get_text(dst / 'text3.txt') == 'text3' + assert fileio.isfile(dst / "text3.txt") + assert fileio.get_text(dst / "text3.txt") == "text3" # dst should not exist with pytest.raises(FileExistsError): - fileio.copytree(src, Path(tmp_dir) / 'dir2') + fileio.copytree(src, Path(tmp_dir) / "dir2") def test_copyfile_from_local(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: - src = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', src) - dst = Path(tmp_dir) / 'test.txt.bak' + src = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", src) + dst = Path(tmp_dir) / "test.txt.bak" assert fileio.copyfile(src, dst) == dst - assert fileio.get_text(dst) == 'disk' + assert fileio.get_text(dst) == "disk" - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() - assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') - assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + assert fileio.copyfile(src, dst) == fileio.join_path(dst, "test.txt") + assert fileio.get_text(fileio.join_path(dst, "test.txt")) == "disk" # src and src should not be same file with pytest.raises(SameFileError): @@ -318,31 +315,31 @@ def test_copytree_from_local(): # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects - src = Path(tmp_dir) / 'dir1' - dst = Path(tmp_dir) / 'dir100' + src = Path(tmp_dir) / "dir1" + dst = Path(tmp_dir) / "dir100" assert fileio.copytree(src, dst) == dst assert fileio.isdir(dst) - assert fileio.isfile(dst / 'text3.txt') - assert fileio.get_text(dst / 'text3.txt') == 'text3' + assert fileio.isfile(dst / "text3.txt") + assert fileio.get_text(dst / "text3.txt") == "text3" # dst should not exist with pytest.raises(FileExistsError): - fileio.copytree(src, Path(tmp_dir) / 'dir2') + fileio.copytree(src, Path(tmp_dir) / "dir2") def test_copyfile_to_local(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: - src = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', src) - dst = Path(tmp_dir) / 'test.txt.bak' + src = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", src) + dst = Path(tmp_dir) / "test.txt.bak" assert fileio.copyfile(src, dst) == dst - assert fileio.get_text(dst) == 'disk' + assert fileio.get_text(dst) == "disk" - dst = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / "dir" dst.mkdir() - assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') - assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + assert fileio.copyfile(src, dst) == fileio.join_path(dst, "test.txt") + assert fileio.get_text(fileio.join_path(dst, "test.txt")) == "disk" # src and src should not be same file with pytest.raises(SameFileError): @@ -353,35 +350,35 @@ def test_copytree_to_local(): # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects - src = Path(tmp_dir) / 'dir1' - dst = Path(tmp_dir) / 'dir100' + src = Path(tmp_dir) / "dir1" + dst = Path(tmp_dir) / "dir100" assert fileio.copytree(src, dst) == dst assert fileio.isdir(dst) - assert fileio.isfile(dst / 'text3.txt') - assert fileio.get_text(dst / 'text3.txt') == 'text3' + assert fileio.isfile(dst / "text3.txt") + assert fileio.get_text(dst / "text3.txt") == "text3" # dst should not exist with pytest.raises(FileExistsError): - fileio.copytree(src, Path(tmp_dir) / 'dir2') + fileio.copytree(src, Path(tmp_dir) / "dir2") def test_remove(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: # filepath is a Path object - filepath = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', filepath) + filepath = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", filepath) assert fileio.exists(filepath) fileio.remove(filepath) assert not fileio.exists(filepath) # raise error if file does not exist with pytest.raises(FileNotFoundError): - filepath = Path(tmp_dir) / 'test1.txt' + filepath = Path(tmp_dir) / "test1.txt" fileio.remove(filepath) # can not remove directory - filepath = Path(tmp_dir) / 'dir' + filepath = Path(tmp_dir) / "dir" filepath.mkdir() with pytest.raises(IsADirectoryError): fileio.remove(filepath) @@ -391,12 +388,12 @@ def test_rmtree(): # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects - dir_path = Path(tmp_dir) / 'dir1' + dir_path = Path(tmp_dir) / "dir1" assert fileio.exists(dir_path) fileio.rmtree(dir_path) assert not fileio.exists(dir_path) - dir_path = Path(tmp_dir) / 'dir2' + dir_path = Path(tmp_dir) / "dir2" assert fileio.exists(dir_path) fileio.rmtree(dir_path) assert not fileio.exists(dir_path) @@ -406,21 +403,21 @@ def test_copy_if_symlink_fails(): # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: # create a symlink for a file - src = Path(tmp_dir) / 'test.txt' - fileio.put_text('disk', src) - dst = Path(tmp_dir) / 'test_link.txt' + src = Path(tmp_dir) / "test.txt" + fileio.put_text("disk", src) + dst = Path(tmp_dir) / "test_link.txt" res = fileio.copy_if_symlink_fails(src, dst) - if platform.system() == 'Linux': + if platform.system() == "Linux": assert res assert osp.islink(dst) - assert fileio.get_text(dst) == 'disk' + assert fileio.get_text(dst) == "disk" # create a symlink for a directory - src = Path(tmp_dir) / 'dir' + src = Path(tmp_dir) / "dir" src.mkdir() - dst = Path(tmp_dir) / 'dir_link' + dst = Path(tmp_dir) / "dir_link" res = fileio.copy_if_symlink_fails(src, dst) - if platform.system() == 'Linux': + if platform.system() == "Linux": assert res assert osp.islink(dst) assert fileio.exists(dst) @@ -429,18 +426,18 @@ def symlink(src, dst): raise Exception # copy files if symblink fails - with patch.object(os, 'symlink', side_effect=symlink): - src = Path(tmp_dir) / 'test.txt' - dst = Path(tmp_dir) / 'test_link1.txt' + with patch.object(os, "symlink", side_effect=symlink): + src = Path(tmp_dir) / "test.txt" + dst = Path(tmp_dir) / "test_link1.txt" res = fileio.copy_if_symlink_fails(src, dst) assert not res assert not osp.islink(dst) assert fileio.exists(dst) # copy directory if symblink fails - with patch.object(os, 'symlink', side_effect=symlink): - src = Path(tmp_dir) / 'dir' - dst = Path(tmp_dir) / 'dir_link1' + with patch.object(os, "symlink", side_effect=symlink): + src = Path(tmp_dir) / "dir" + dst = Path(tmp_dir) / "dir_link1" res = fileio.copy_if_symlink_fails(src, dst) assert not res assert not osp.islink(dst) @@ -451,85 +448,68 @@ def test_list_dir_or_file(): # test LocalBackend with build_temporary_directory() as tmp_dir: # list directories and files - assert set(fileio.list_dir_or_file(tmp_dir)) == { - 'dir1', 'dir2', 'text1.txt', 'text2.txt' - } + assert set(fileio.list_dir_or_file(tmp_dir)) == {"dir1", "dir2", "text1.txt", "text2.txt"} # list directories and files recursively assert set(fileio.list_dir_or_file(tmp_dir, recursive=True)) == { - 'dir1', - osp.join('dir1', 'text3.txt'), 'dir2', - osp.join('dir2', 'dir3'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + "dir1", + osp.join("dir1", "text3.txt"), + "dir2", + osp.join("dir2", "dir3"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", } # only list directories - assert set(fileio.list_dir_or_file( - tmp_dir, list_file=False)) == {'dir1', 'dir2'} + assert set(fileio.list_dir_or_file(tmp_dir, list_file=False)) == {"dir1", "dir2"} - with pytest.raises( - TypeError, - match='`suffix` should be None when `list_dir` is True'): - list( - fileio.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt')) + with pytest.raises(TypeError, match="`suffix` should be None when `list_dir` is True"): + list(fileio.list_dir_or_file(tmp_dir, list_file=False, suffix=".txt")) # only list directories recursively - assert set( - fileio.list_dir_or_file( - tmp_dir, list_file=False, - recursive=True)) == {'dir1', 'dir2', - osp.join('dir2', 'dir3')} + assert set(fileio.list_dir_or_file(tmp_dir, list_file=False, recursive=True)) == { + "dir1", + "dir2", + osp.join("dir2", "dir3"), + } # only list files - assert set(fileio.list_dir_or_file( - tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + assert set(fileio.list_dir_or_file(tmp_dir, list_dir=False)) == {"text1.txt", "text2.txt"} # only list files recursively - assert set( - fileio.list_dir_or_file(tmp_dir, list_dir=False, - recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), - 'text1.txt', 'text2.txt' - } + assert set(fileio.list_dir_or_file(tmp_dir, list_dir=False, recursive=True)) == { + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + } # only list files ending with suffix - assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} - assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} - - with pytest.raises( - TypeError, - match='`suffix` must be a string or tuple of strings'): - list( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])) + assert set(fileio.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt")) == {"text1.txt", "text2.txt"} + assert set(fileio.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"))) == { + "text1.txt", + "text2.txt", + } + + with pytest.raises(TypeError, match="`suffix` must be a string or tuple of strings"): + list(fileio.list_dir_or_file(tmp_dir, list_dir=False, suffix=[".txt", ".jpg"])) # only list files ending with suffix recursively - assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - } + assert set(fileio.list_dir_or_file(tmp_dir, list_dir=False, suffix=".txt", recursive=True)) == { + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + "text1.txt", + "text2.txt", + } # only list files ending with suffix - assert set( - fileio.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - } + assert set(fileio.list_dir_or_file(tmp_dir, list_dir=False, suffix=(".txt", ".jpg"), recursive=True)) == { + osp.join("dir1", "text3.txt"), + osp.join("dir2", "dir3", "text4.txt"), + osp.join("dir2", "img.jpg"), + "text1.txt", + "text2.txt", + } diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d731a42b76..b0764e5c7c 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -18,8 +18,7 @@ class TriangleMetric(BaseMetric): - - default_prefix: str = 'test' + default_prefix: str = "test" def __init__(self, length): super().__init__() @@ -37,7 +36,6 @@ def compute_metrics(self, *args, **kwargs): class TestCheckpointHook(RunnerTestCase): - def setUp(self): super().setUp() METRICS.register_module(module=TriangleMetric, force=True) @@ -53,52 +51,45 @@ def test_init(self): # '"file_client_args" will be deprecated in future'): # CheckpointHook(file_client_args={'backend': 'disk'}) - with self.assertRaisesRegex( - ValueError, - '"file_client_args" and "backend_args" cannot be set ' - 'at the same time'): - CheckpointHook( - file_client_args={'backend': 'disk'}, - backend_args={'backend': 'local'}) + with self.assertRaisesRegex(ValueError, '"file_client_args" and "backend_args" cannot be set at the same time'): + CheckpointHook(file_client_args={"backend": "disk"}, backend_args={"backend": "local"}) # Test save best - CheckpointHook(save_best='acc') - CheckpointHook(save_best=['acc']) + CheckpointHook(save_best="acc") + CheckpointHook(save_best=["acc"]) with self.assertRaisesRegex(AssertionError, '"save_best" should be'): - CheckpointHook(save_best=dict(acc='acc')) + CheckpointHook(save_best=dict(acc="acc")) # error when 'auto' in `save_best` list - with self.assertRaisesRegex(AssertionError, 'Only support one'): - CheckpointHook(interval=2, save_best=['auto', 'acc']) + with self.assertRaisesRegex(AssertionError, "Only support one"): + CheckpointHook(interval=2, save_best=["auto", "acc"]) # Test rules - CheckpointHook(save_best=['acc', 'mAcc'], rule='greater') + CheckpointHook(save_best=["acc", "mAcc"], rule="greater") with self.assertRaisesRegex(AssertionError, '"rule" should be a str'): - CheckpointHook(save_best=['acc'], rule=1) + CheckpointHook(save_best=["acc"], rule=1) - with self.assertRaisesRegex(AssertionError, - 'Number of "rule" must be'): - CheckpointHook(save_best=['acc'], rule=['greater', 'loss']) + with self.assertRaisesRegex(AssertionError, 'Number of "rule" must be'): + CheckpointHook(save_best=["acc"], rule=["greater", "loss"]) # Test greater_keys - hook = CheckpointHook(greater_keys='acc') - self.assertEqual(hook.greater_keys, ('acc', )) + hook = CheckpointHook(greater_keys="acc") + self.assertEqual(hook.greater_keys, ("acc",)) - hook = CheckpointHook(greater_keys=['acc']) - self.assertEqual(hook.greater_keys, ['acc']) + hook = CheckpointHook(greater_keys=["acc"]) + self.assertEqual(hook.greater_keys, ["acc"]) - hook = CheckpointHook( - interval=2, by_epoch=False, save_best=['acc', 'mIoU']) - self.assertEqual(hook.key_indicators, ['acc', 'mIoU']) - self.assertEqual(hook.rules, ['greater', 'greater']) + hook = CheckpointHook(interval=2, by_epoch=False, save_best=["acc", "mIoU"]) + self.assertEqual(hook.key_indicators, ["acc", "mIoU"]) + self.assertEqual(hook.rules, ["greater", "greater"]) # Test less keys - hook = CheckpointHook(less_keys='loss_cls') - self.assertEqual(hook.less_keys, ('loss_cls', )) + hook = CheckpointHook(less_keys="loss_cls") + self.assertEqual(hook.less_keys, ("loss_cls",)) - hook = CheckpointHook(less_keys=['loss_cls']) - self.assertEqual(hook.less_keys, ['loss_cls']) + hook = CheckpointHook(less_keys=["loss_cls"]) + self.assertEqual(hook.less_keys, ["loss_cls"]) def test_before_train(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -110,12 +101,11 @@ def test_before_train(self): self.assertIsInstance(checkpoint_hook.file_backend, LocalBackend) # file_client_args is not None - checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'}) + checkpoint_hook = CheckpointHook(file_client_args={"backend": "disk"}) checkpoint_hook.before_train(runner) self.assertIsInstance(checkpoint_hook.file_client, FileClient) # file_backend is the alias of file_client - self.assertIs(checkpoint_hook.file_backend, - checkpoint_hook.file_client) + self.assertIs(checkpoint_hook.file_backend, checkpoint_hook.file_client) # the out_dir of the checkpoint hook is None checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) @@ -123,37 +113,33 @@ def test_before_train(self): self.assertEqual(checkpoint_hook.out_dir, runner.work_dir) # the out_dir of the checkpoint hook is not None - checkpoint_hook = CheckpointHook( - interval=1, by_epoch=True, out_dir='test_dir') + checkpoint_hook = CheckpointHook(interval=1, by_epoch=True, out_dir="test_dir") checkpoint_hook.before_train(runner) - self.assertEqual(checkpoint_hook.out_dir, - osp.join('test_dir', osp.basename(cfg.work_dir))) + self.assertEqual(checkpoint_hook.out_dir, osp.join("test_dir", osp.basename(cfg.work_dir))) # If `save_best` is a list of string, the path to save the best # checkpoint will be defined in attribute `best_ckpt_path_dict`. - checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU']) + checkpoint_hook = CheckpointHook(interval=1, save_best=["acc", "mIoU"]) checkpoint_hook.before_train(runner) - self.assertEqual(checkpoint_hook.best_ckpt_path_dict, - dict(acc=None, mIoU=None)) - self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path')) + self.assertEqual(checkpoint_hook.best_ckpt_path_dict, dict(acc=None, mIoU=None)) + self.assertFalse(hasattr(checkpoint_hook, "best_ckpt_path")) # Resume 'best_ckpt_path' from message_hub - runner.message_hub.update_info('best_ckpt_acc', 'best_acc') + runner.message_hub.update_info("best_ckpt_acc", "best_acc") checkpoint_hook.before_train(runner) - self.assertEqual(checkpoint_hook.best_ckpt_path_dict, - dict(acc='best_acc', mIoU=None)) + self.assertEqual(checkpoint_hook.best_ckpt_path_dict, dict(acc="best_acc", mIoU=None)) # If `save_best` is a string, the path to save best ckpt will be # defined in attribute `best_ckpt_path` - checkpoint_hook = CheckpointHook(interval=1, save_best='acc') + checkpoint_hook = CheckpointHook(interval=1, save_best="acc") checkpoint_hook.before_train(runner) self.assertIsNone(checkpoint_hook.best_ckpt_path) - self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path_dict')) + self.assertFalse(hasattr(checkpoint_hook, "best_ckpt_path_dict")) # Resume `best_ckpt` path from message_hub - runner.message_hub.update_info('best_ckpt', 'best_ckpt') + runner.message_hub.update_info("best_ckpt", "best_ckpt") checkpoint_hook.before_train(runner) - self.assertEqual(checkpoint_hook.best_ckpt_path, 'best_ckpt') + self.assertEqual(checkpoint_hook.best_ckpt_path, "best_ckpt") def test_after_val_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -161,103 +147,87 @@ def test_after_val_epoch(self): runner.train_loop._epoch = 9 # if metrics is an empty dict, print a warning information - with self.assertLogs(runner.logger, level='WARNING'): - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') + with self.assertLogs(runner.logger, level="WARNING"): + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best="auto") checkpoint_hook.after_val_epoch(runner, {}) # if save_best is None,no best_ckpt meta should be stored - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best=None) + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, {}) - self.assertNotIn('best_score', runner.message_hub.runtime_info) - self.assertNotIn('best_ckpt', runner.message_hub.runtime_info) + self.assertNotIn("best_score", runner.message_hub.runtime_info) + self.assertNotIn("best_ckpt", runner.message_hub.runtime_info) # when `save_best` is set to `auto`, first metric will be used. - metrics = {'acc': 0.5, 'map': 0.3} - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') + metrics = {"acc": 0.5, "map": 0.3} + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best="auto") checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) - best_ckpt_name = 'best_acc_epoch_9.pth' - best_ckpt_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual(checkpoint_hook.key_indicators, ['acc']) - self.assertEqual(checkpoint_hook.rules, ['greater']) - self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + best_ckpt_name = "best_acc_epoch_9.pth" + best_ckpt_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_ckpt_name) + self.assertEqual(checkpoint_hook.key_indicators, ["acc"]) + self.assertEqual(checkpoint_hook.rules, ["greater"]) + self.assertEqual(runner.message_hub.get_info("best_score"), 0.5) + self.assertEqual(runner.message_hub.get_info("best_ckpt"), best_ckpt_path) # # when `save_best` is set to `acc`, it should update greater value - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='acc') + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best="acc") checkpoint_hook.before_train(runner) - metrics['acc'] = 0.8 + metrics["acc"] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) - self.assertEqual(runner.message_hub.get_info('best_score'), 0.8) + self.assertEqual(runner.message_hub.get_info("best_score"), 0.8) # # when `save_best` is set to `loss`, it should update less value - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='loss') + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best="loss") checkpoint_hook.before_train(runner) - metrics['loss'] = 0.8 + metrics["loss"] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) - metrics['loss'] = 0.5 + metrics["loss"] = 0.5 checkpoint_hook.after_val_epoch(runner, metrics) - self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) + self.assertEqual(runner.message_hub.get_info("best_score"), 0.5) # when `rule` is set to `less`,then it should update less value # no matter what `save_best` is - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='acc', rule='less') + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best="acc", rule="less") checkpoint_hook.before_train(runner) - metrics['acc'] = 0.3 + metrics["acc"] = 0.3 checkpoint_hook.after_val_epoch(runner, metrics) - self.assertEqual(runner.message_hub.get_info('best_score'), 0.3) + self.assertEqual(runner.message_hub.get_info("best_score"), 0.3) # # when `rule` is set to `greater`,then it should update greater value # # no matter what `save_best` is - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='loss', rule='greater') + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True, save_best="loss", rule="greater") checkpoint_hook.before_train(runner) - metrics['loss'] = 1.0 + metrics["loss"] = 1.0 checkpoint_hook.after_val_epoch(runner, metrics) - self.assertEqual(runner.message_hub.get_info('best_score'), 1.0) + self.assertEqual(runner.message_hub.get_info("best_score"), 1.0) # test multi `save_best` with one rule - checkpoint_hook = CheckpointHook( - interval=2, save_best=['acc', 'mIoU'], rule='greater') - self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU']) - self.assertEqual(checkpoint_hook.rules, ['greater', 'greater']) + checkpoint_hook = CheckpointHook(interval=2, save_best=["acc", "mIoU"], rule="greater") + self.assertEqual(checkpoint_hook.key_indicators, ["acc", "mIoU"]) + self.assertEqual(checkpoint_hook.rules, ["greater", "greater"]) # test multi `save_best` with multi rules - checkpoint_hook = CheckpointHook( - interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) - self.assertEqual(checkpoint_hook.key_indicators, ['FID', 'IS']) - self.assertEqual(checkpoint_hook.rules, ['less', 'greater']) + checkpoint_hook = CheckpointHook(interval=2, save_best=["FID", "IS"], rule=["less", "greater"]) + self.assertEqual(checkpoint_hook.key_indicators, ["FID", "IS"]) + self.assertEqual(checkpoint_hook.rules, ["less", "greater"]) # test multi `save_best` with default rule - checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) - self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU']) - self.assertEqual(checkpoint_hook.rules, ['greater', 'greater']) - runner.message_hub = MessageHub.get_instance( - 'test_after_val_epoch_save_multi_best') + checkpoint_hook = CheckpointHook(interval=2, save_best=["acc", "mIoU"]) + self.assertEqual(checkpoint_hook.key_indicators, ["acc", "mIoU"]) + self.assertEqual(checkpoint_hook.rules, ["greater", "greater"]) + runner.message_hub = MessageHub.get_instance("test_after_val_epoch_save_multi_best") checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) checkpoint_hook.after_val_epoch(runner, metrics) - best_acc_name = 'best_acc_epoch_9.pth' - best_acc_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_acc_name) - best_mIoU_name = 'best_mIoU_epoch_9.pth' - best_mIoU_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_mIoU_name) - self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) - self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) + best_acc_name = "best_acc_epoch_9.pth" + best_acc_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_acc_name) + best_mIoU_name = "best_mIoU_epoch_9.pth" + best_mIoU_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_mIoU_name) + self.assertEqual(runner.message_hub.get_info("best_score_acc"), 0.5) + self.assertEqual(runner.message_hub.get_info("best_score_mIoU"), 0.6) + self.assertEqual(runner.message_hub.get_info("best_ckpt_acc"), best_acc_path) + self.assertEqual(runner.message_hub.get_info("best_ckpt_mIoU"), best_mIoU_path) # test behavior when by_epoch is False cfg = copy.deepcopy(self.iter_based_cfg) @@ -265,97 +235,79 @@ def test_after_val_epoch(self): runner.train_loop._iter = 9 # check best ckpt name and best score - metrics = {'acc': 0.5, 'map': 0.3} - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, save_best='acc', rule='greater') + metrics = {"acc": 0.5, "map": 0.3} + checkpoint_hook = CheckpointHook(interval=2, by_epoch=False, save_best="acc", rule="greater") checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) - self.assertEqual(checkpoint_hook.key_indicators, ['acc']) - self.assertEqual(checkpoint_hook.rules, ['greater']) - best_ckpt_name = 'best_acc_iter_9.pth' - best_ckpt_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_ckpt_name) + self.assertEqual(checkpoint_hook.key_indicators, ["acc"]) + self.assertEqual(checkpoint_hook.rules, ["greater"]) + best_ckpt_name = "best_acc_iter_9.pth" + best_ckpt_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) - self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) + self.assertEqual(runner.message_hub.get_info("best_ckpt"), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info("best_score"), 0.5) # check best score updating - metrics['acc'] = 0.666 + metrics["acc"] = 0.666 checkpoint_hook.after_val_epoch(runner, metrics) - best_ckpt_name = 'best_acc_iter_9.pth' - best_ckpt_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) - self.assertEqual(runner.message_hub.get_info('best_score'), 0.666) + best_ckpt_name = "best_acc_iter_9.pth" + best_ckpt_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_ckpt_name) + self.assertEqual(runner.message_hub.get_info("best_ckpt"), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info("best_score"), 0.666) # check best checkpoint name with `by_epoch` is False - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + checkpoint_hook = CheckpointHook(interval=2, by_epoch=False, save_best=["acc", "mIoU"]) checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) checkpoint_hook.after_val_epoch(runner, metrics) - best_acc_name = 'best_acc_iter_9.pth' - best_acc_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_acc_name) - best_mIoU_name = 'best_mIoU_iter_9.pth' - best_mIoU_path = checkpoint_hook.file_client.join_path( - checkpoint_hook.out_dir, best_mIoU_name) - - self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) - self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) + best_acc_name = "best_acc_iter_9.pth" + best_acc_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_acc_name) + best_mIoU_name = "best_mIoU_iter_9.pth" + best_mIoU_path = checkpoint_hook.file_client.join_path(checkpoint_hook.out_dir, best_mIoU_name) + + self.assertEqual(runner.message_hub.get_info("best_score_acc"), 0.5) + self.assertEqual(runner.message_hub.get_info("best_score_mIoU"), 0.6) + self.assertEqual(runner.message_hub.get_info("best_ckpt_acc"), best_acc_path) + self.assertEqual(runner.message_hub.get_info("best_ckpt_mIoU"), best_mIoU_path) # after_val_epoch should not save last_checkpoint - self.assertFalse( - osp.isfile(osp.join(runner.work_dir, 'last_checkpoint'))) + self.assertFalse(osp.isfile(osp.join(runner.work_dir, "last_checkpoint"))) # There should only one best checkpoint be reserved # dist backend - for by_epoch, cfg in [(True, self.epoch_based_cfg), - (False, self.iter_based_cfg)]: + for by_epoch, cfg in [(True, self.epoch_based_cfg), (False, self.iter_based_cfg)]: self.clear_work_dir() cfg = copy.deepcopy(cfg) runner = self.build_runner(cfg) - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=by_epoch, save_best='acc') + checkpoint_hook = CheckpointHook(interval=2, by_epoch=by_epoch, save_best="acc") checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) all_files = os.listdir(runner.work_dir) - best_ckpts = [ - file for file in all_files if file.startswith('best') - ] + best_ckpts = [file for file in all_files if file.startswith("best")] self.assertTrue(len(best_ckpts) == 1) # petrel backend # TODO use real petrel oss bucket to test petrel_client = MagicMock() - for by_epoch, cfg in [(True, self.epoch_based_cfg), - (False, self.iter_based_cfg)]: - isfile = MagicMock(return_value=True) + for by_epoch, cfg in [(True, self.epoch_based_cfg), (False, self.iter_based_cfg)]: self.clear_work_dir() - with patch.dict(sys.modules, {'petrel_client': petrel_client}), \ - patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \ - patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \ - patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501 + with ( + patch.dict(sys.modules, {"petrel_client": petrel_client}), + patch("mmengine.fileio.backends.PetrelBackend.put") as put_mock, + patch("mmengine.fileio.backends.PetrelBackend.remove") as remove_mock, + patch("mmengine.fileio.backends.PetrelBackend.isfile") as isfile, + ): # noqa: E501 cfg = copy.deepcopy(cfg) runner = self.build_runner(cfg) metrics = dict(acc=0.5) - petrel_client.client.Client = MagicMock( - return_value=petrel_client) + petrel_client.client.Client = MagicMock(return_value=petrel_client) checkpoint_hook = CheckpointHook( - interval=2, - by_epoch=by_epoch, - save_best='acc', - backend_args=dict(backend='petrel')) + interval=2, by_epoch=by_epoch, save_best="acc", backend_args=dict(backend="petrel") + ) checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) put_mock.assert_called_once() - metrics['acc'] += 0.1 + metrics["acc"] += 0.1 runner.train_loop._epoch += 1 runner.train_loop._iter += 1 checkpoint_hook.after_val_epoch(runner, metrics) @@ -373,23 +325,19 @@ def test_after_train_epoch(self): checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) self.assertEqual((runner.epoch + 1) % 2, 0) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(runner.message_hub.get_info("last_ckpt"), osp.join(cfg.work_dir, "epoch_10.pth")) - last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint') + last_ckpt_path = osp.join(cfg.work_dir, "last_checkpoint") self.assertTrue(osp.isfile(last_ckpt_path)) with open(last_ckpt_path) as f: filepath = f.read() - self.assertEqual(filepath, osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(filepath, osp.join(cfg.work_dir, "epoch_10.pth")) # epoch can not be evenly divided by 2 runner.train_loop._epoch = 10 checkpoint_hook.after_train_epoch(runner) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(runner.message_hub.get_info("last_ckpt"), osp.join(cfg.work_dir, "epoch_10.pth")) runner.message_hub.runtime_info.clear() # by epoch is False @@ -397,7 +345,7 @@ def test_after_train_epoch(self): checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) - self.assertNotIn('last_ckpt', runner.message_hub.runtime_info) + self.assertNotIn("last_ckpt", runner.message_hub.runtime_info) runner.message_hub.runtime_info.clear() def test_after_train_iter(self): @@ -409,32 +357,25 @@ def test_after_train_iter(self): checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=9) - self.assertNotIn('last_ckpt', runner.message_hub.runtime_info) + self.assertNotIn("last_ckpt", runner.message_hub.runtime_info) # by epoch is False checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=9) - self.assertIn('last_ckpt', runner.message_hub.runtime_info) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'iter_10.pth')) + self.assertIn("last_ckpt", runner.message_hub.runtime_info) + self.assertEqual(runner.message_hub.get_info("last_ckpt"), osp.join(cfg.work_dir, "iter_10.pth")) # epoch can not be evenly divided by 2 runner.train_loop._iter = 10 checkpoint_hook.after_train_epoch(runner) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'iter_10.pth')) + self.assertEqual(runner.message_hub.get_info("last_ckpt"), osp.join(cfg.work_dir, "iter_10.pth")) - @parameterized.expand([['iter'], ['epoch']]) + @parameterized.expand([["iter"], ["epoch"]]) def test_with_runner(self, training_type): - common_cfg = getattr(self, f'{training_type}_based_cfg') - setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) - checkpoint_cfg = dict( - type='CheckpointHook', - interval=1, - by_epoch=training_type == 'epoch') + common_cfg = getattr(self, f"{training_type}_based_cfg") + setattr(common_cfg.train_cfg, f"max_{training_type}s", 11) + checkpoint_cfg = dict(type="CheckpointHook", interval=1, by_epoch=training_type == "epoch") common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) # Test interval in epoch based training @@ -444,13 +385,10 @@ def test_with_runner(self, training_type): runner.train() for i in range(1, 11): - self.assertEqual( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')), - i % 2 == 0) + self.assertEqual(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth")), i % 2 == 0) # save_last=True - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_11.pth"))) self.clear_work_dir() @@ -458,48 +396,40 @@ def test_with_runner(self, training_type): cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) - self.assertIn('optimizer', ckpt) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_11.pth")) + self.assertIn("optimizer", ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) - self.assertNotIn('optimizer', ckpt) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_11.pth")) + self.assertNotIn("optimizer", ckpt) # Test save_param_scheduler=False cfg = copy.deepcopy(common_cfg) cfg.param_scheduler = [ - dict( - type='LinearLR', - start_factor=0.1, - begin=0, - end=500, - by_epoch=training_type == 'epoch') + dict(type="LinearLR", start_factor=0.1, begin=0, end=500, by_epoch=training_type == "epoch") ] runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) - self.assertIn('param_schedulers', ckpt) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_11.pth")) + self.assertIn("param_schedulers", ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) - self.assertNotIn('param_schedulers', ckpt) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_11.pth")) + self.assertNotIn("param_schedulers", ckpt) self.clear_work_dir() # Test out_dir cfg = copy.deepcopy(common_cfg) - out_dir = osp.join(self.temp_dir.name, 'out_dir') + out_dir = osp.join(self.temp_dir.name, "out_dir") cfg.default_hooks.checkpoint.out_dir = out_dir runner = self.build_runner(cfg) runner.train() - self.assertTrue( - osp.isfile( - osp.join(out_dir, osp.basename(cfg.work_dir), - f'{training_type}_11.pth'))) + self.assertTrue(osp.isfile(osp.join(out_dir, osp.basename(cfg.work_dir), f"{training_type}_11.pth"))) self.clear_work_dir() @@ -508,12 +438,10 @@ def test_with_runner(self, training_type): cfg.default_hooks.checkpoint.max_keep_ckpts = 1 runner = self.build_runner(cfg) runner.train() - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_11.pth"))) for i in range(11): - self.assertFalse( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) self.clear_work_dir() @@ -522,114 +450,100 @@ def test_with_runner(self, training_type): cfg.default_hooks.checkpoint.max_keep_ckpts = 3 runner = self.build_runner(cfg) runner.train() - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth'))) - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_9.pth"))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_10.pth"))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_11.pth"))) for i in range(9): - self.assertFalse( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) - self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], - [9, 10, 11]) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_11.pth")) + self.assertEqual(ckpt["message_hub"]["runtime_info"]["keep_ckpt_ids"], [9, 10, 11]) # Test max_keep_ckpts when resuming traing cfg = copy.deepcopy(common_cfg) - setattr(cfg.train_cfg, f'max_{training_type}s', 12) + setattr(cfg.train_cfg, f"max_{training_type}s", 12) cfg.default_hooks.checkpoint.max_keep_ckpts = 2 - cfg.load_from = osp.join(cfg.work_dir, f'{training_type}_11.pth') + cfg.load_from = osp.join(cfg.work_dir, f"{training_type}_11.pth") cfg.resume = True runner = self.build_runner(cfg) runner.train() - self.assertFalse( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth'))) - self.assertFalse( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_12.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_9.pth"))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_10.pth"))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_11.pth"))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_12.pth"))) self.clear_work_dir() # Test filename_tmpl cfg = copy.deepcopy(common_cfg) - cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' + cfg.default_hooks.checkpoint.filename_tmpl = "test_{}.pth" runner = self.build_runner(cfg) runner.train() - self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_11.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, "test_11.pth"))) self.clear_work_dir() # Test save_best cfg = copy.deepcopy(common_cfg) - cfg.default_hooks.checkpoint.save_best = 'test/acc' - cfg.val_evaluator = dict(type='TriangleMetric', length=11) + cfg.default_hooks.checkpoint.save_best = "test/acc" + cfg.val_evaluator = dict(type="TriangleMetric", length=11) cfg.train_cfg.val_interval = 1 runner = self.build_runner(cfg) runner.train() - best_ckpt_path = osp.join(cfg.work_dir, - f'best_test_acc_{training_type}_5.pth') + best_ckpt_path = osp.join(cfg.work_dir, f"best_test_acc_{training_type}_5.pth") best_ckpt = torch.load(best_ckpt_path) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) - self.assertEqual(best_ckpt_path, - ckpt['message_hub']['runtime_info']['best_ckpt']) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_5.pth")) + self.assertEqual(best_ckpt_path, ckpt["message_hub"]["runtime_info"]["best_ckpt"]) - if training_type == 'epoch': - self.assertEqual(ckpt['meta']['epoch'], 5) - self.assertEqual(ckpt['meta']['iter'], 20) - self.assertEqual(best_ckpt['meta']['epoch'], 5) - self.assertEqual(best_ckpt['meta']['iter'], 20) + if training_type == "epoch": + self.assertEqual(ckpt["meta"]["epoch"], 5) + self.assertEqual(ckpt["meta"]["iter"], 20) + self.assertEqual(best_ckpt["meta"]["epoch"], 5) + self.assertEqual(best_ckpt["meta"]["iter"], 20) else: - self.assertEqual(ckpt['meta']['epoch'], 0) - self.assertEqual(ckpt['meta']['iter'], 5) - self.assertEqual(best_ckpt['meta']['epoch'], 0) - self.assertEqual(best_ckpt['meta']['iter'], 5) + self.assertEqual(ckpt["meta"]["epoch"], 0) + self.assertEqual(ckpt["meta"]["iter"], 5) + self.assertEqual(best_ckpt["meta"]["epoch"], 0) + self.assertEqual(best_ckpt["meta"]["iter"], 5) self.clear_work_dir() # Test save_best with interval=2 cfg = copy.deepcopy(common_cfg) - cfg.default_hooks.checkpoint.save_best = 'test/acc' + cfg.default_hooks.checkpoint.save_best = "test/acc" cfg.default_hooks.checkpoint.interval = 2 - cfg.val_evaluator = dict(type='TriangleMetric', length=11) + cfg.val_evaluator = dict(type="TriangleMetric", length=11) cfg.train_cfg.val_interval = 1 runner = self.build_runner(cfg) runner.train() - best_ckpt_path = osp.join(cfg.work_dir, - f'best_test_acc_{training_type}_5.pth') + best_ckpt_path = osp.join(cfg.work_dir, f"best_test_acc_{training_type}_5.pth") best_ckpt = torch.load(best_ckpt_path) # if the current ckpt is the best, the interval will be ignored the # the ckpt will also be saved - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) - self.assertEqual(best_ckpt_path, - ckpt['message_hub']['runtime_info']['best_ckpt']) - - if training_type == 'epoch': - self.assertEqual(ckpt['meta']['epoch'], 5) - self.assertEqual(ckpt['meta']['iter'], 20) - self.assertEqual(best_ckpt['meta']['epoch'], 5) - self.assertEqual(best_ckpt['meta']['iter'], 20) + ckpt = torch.load(osp.join(cfg.work_dir, f"{training_type}_5.pth")) + self.assertEqual(best_ckpt_path, ckpt["message_hub"]["runtime_info"]["best_ckpt"]) + + if training_type == "epoch": + self.assertEqual(ckpt["meta"]["epoch"], 5) + self.assertEqual(ckpt["meta"]["iter"], 20) + self.assertEqual(best_ckpt["meta"]["epoch"], 5) + self.assertEqual(best_ckpt["meta"]["iter"], 20) else: - self.assertEqual(ckpt['meta']['epoch'], 0) - self.assertEqual(ckpt['meta']['iter'], 5) - self.assertEqual(best_ckpt['meta']['epoch'], 0) - self.assertEqual(best_ckpt['meta']['iter'], 5) + self.assertEqual(ckpt["meta"]["epoch"], 0) + self.assertEqual(ckpt["meta"]["iter"], 5) + self.assertEqual(best_ckpt["meta"]["epoch"], 0) + self.assertEqual(best_ckpt["meta"]["iter"], 5) # Test save published keys cfg = copy.deepcopy(common_cfg) - cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict'] + cfg.default_hooks.checkpoint.published_keys = ["meta", "state_dict"] runner = self.build_runner(cfg) runner.train() ckpt_files = os.listdir(runner.work_dir) - self.assertTrue( - any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files)) + self.assertTrue(any(re.findall(r"-[\d\w]{8}\.pth", file) for file in ckpt_files)) self.clear_work_dir() @@ -641,17 +555,12 @@ def test_with_runner(self, training_type): runner.train() for i in range(5): - self.assertFalse( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) for i in range(5, 11): if (i - 5) % 2 == 1: - self.assertFalse( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) else: - self.assertTrue( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) self.clear_work_dir() # Test save_begin with interval=2, save_begin=0 @@ -662,13 +571,9 @@ def test_with_runner(self, training_type): for i in range(1, 11): if i % 2 == 1: - self.assertFalse( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) else: - self.assertTrue( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) self.clear_work_dir() # Test save_begin with interval=2, save_begin=1 @@ -680,11 +585,7 @@ def test_with_runner(self, training_type): for i in range(1, 11): if i % 2 == 1: - self.assertTrue( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) else: - self.assertFalse( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertFalse(osp.isfile(osp.join(cfg.work_dir, f"{training_type}_{i}.pth"))) self.clear_work_dir() diff --git a/tests/test_hooks/test_early_stopping_hook.py b/tests/test_hooks/test_early_stopping_hook.py index 16f8fd981c..e0ad2d5cbe 100644 --- a/tests/test_hooks/test_early_stopping_hook.py +++ b/tests/test_hooks/test_early_stopping_hook.py @@ -19,18 +19,17 @@ class ToyModel(BaseModel): - def __init__(self): super().__init__() self.linear = nn.Linear(2, 1) - def forward(self, inputs, data_sample, mode='tensor'): + def forward(self, inputs, data_sample, mode="tensor"): labels = torch.stack(data_sample) inputs = torch.stack(inputs) outputs = self.linear(inputs) - if mode == 'tensor': + if mode == "tensor": return outputs - elif mode == 'loss': + elif mode == "loss": loss = (labels - outputs).sum() outputs = dict(loss=loss) return outputs @@ -55,8 +54,7 @@ def __getitem__(self, index): class DummyMetric(BaseMetric): - - default_prefix: str = 'test' + default_prefix: str = "test" def __init__(self, length): super().__init__() @@ -82,7 +80,6 @@ def get_mock_runner(): class TestEarlyStoppingHook(RunnerTestCase): - def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -94,38 +91,37 @@ def tearDown(self): self.temp_dir.cleanup() def test_init(self): - - hook = EarlyStoppingHook(monitor='acc') - self.assertEqual(hook.rule, 'greater') + hook = EarlyStoppingHook(monitor="acc") + self.assertEqual(hook.rule, "greater") self.assertLess(hook.best_score, 0) - hook = EarlyStoppingHook(monitor='ACC') - self.assertEqual(hook.rule, 'greater') + hook = EarlyStoppingHook(monitor="ACC") + self.assertEqual(hook.rule, "greater") self.assertLess(hook.best_score, 0) - hook = EarlyStoppingHook(monitor='mAP_50') - self.assertEqual(hook.rule, 'greater') + hook = EarlyStoppingHook(monitor="mAP_50") + self.assertEqual(hook.rule, "greater") self.assertLess(hook.best_score, 0) - hook = EarlyStoppingHook(monitor='loss') - self.assertEqual(hook.rule, 'less') + hook = EarlyStoppingHook(monitor="loss") + self.assertEqual(hook.rule, "less") self.assertGreater(hook.best_score, 0) - hook = EarlyStoppingHook(monitor='Loss') - self.assertEqual(hook.rule, 'less') + hook = EarlyStoppingHook(monitor="Loss") + self.assertEqual(hook.rule, "less") self.assertGreater(hook.best_score, 0) - hook = EarlyStoppingHook(monitor='ce_loss') - self.assertEqual(hook.rule, 'less') + hook = EarlyStoppingHook(monitor="ce_loss") + self.assertEqual(hook.rule, "less") self.assertGreater(hook.best_score, 0) with self.assertRaises(ValueError): # `rule` should be passed. - EarlyStoppingHook(monitor='recall') + EarlyStoppingHook(monitor="recall") with self.assertRaises(ValueError): # Invalid `rule` - EarlyStoppingHook(monitor='accuracy/top1', rule='the world') + EarlyStoppingHook(monitor="accuracy/top1", rule="the world") def test_before_run(self): runner = Mock() @@ -133,13 +129,13 @@ def test_before_run(self): # `train_loop` must contain `stop_training` variable. with self.assertRaises(AssertionError): - hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater') + hook = EarlyStoppingHook(monitor="accuracy/top1", rule="greater") hook.before_run(runner) def test_after_val_epoch(self): runner = get_mock_runner() - metrics = {'accuracy/top1': 0.5, 'loss': 0.23} - hook = EarlyStoppingHook(monitor='acc', rule='greater') + metrics = {"accuracy/top1": 0.5, "loss": 0.23} + hook = EarlyStoppingHook(monitor="acc", rule="greater") with self.assertWarns(UserWarning): # Skip early stopping process since the evaluation results does not @@ -148,15 +144,14 @@ def test_after_val_epoch(self): # if `monitor` does not match and strict=True, crash the training. with self.assertRaises(RuntimeError): - metrics = {'accuracy/top1': 0.5, 'loss': 0.23} - hook = EarlyStoppingHook( - monitor='acc', rule='greater', strict=True) + metrics = {"accuracy/top1": 0.5, "loss": 0.23} + hook = EarlyStoppingHook(monitor="acc", rule="greater", strict=True) hook.after_val_epoch(runner, metrics) # Check largest value runner = get_mock_runner() - metrics = [{'accuracy/top1': i / 9.} for i in range(8)] - hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater') + metrics = [{"accuracy/top1": i / 9.0} for i in range(8)] + hook = EarlyStoppingHook(monitor="accuracy/top1", rule="greater") for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -165,8 +160,8 @@ def test_after_val_epoch(self): # Check smallest value runner = get_mock_runner() - metrics = [{'loss': i / 9.} for i in range(8, 0, -1)] - hook = EarlyStoppingHook(monitor='loss') + metrics = [{"loss": i / 9.0} for i in range(8, 0, -1)] + hook = EarlyStoppingHook(monitor="loss") for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -175,9 +170,8 @@ def test_after_val_epoch(self): # Check stop training runner = get_mock_runner() - metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1) + metrics = [{"accuracy/top1": i} for i in torch.linspace(98, 99, 8)] + hook = EarlyStoppingHook(monitor="accuracy/top1", rule="greater", min_delta=1) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -186,9 +180,8 @@ def test_after_val_epoch(self): # Check finite runner = get_mock_runner() - metrics = [{'accuracy/top1': math.inf} for i in range(5)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1) + metrics = [{"accuracy/top1": math.inf} for i in range(5)] + hook = EarlyStoppingHook(monitor="accuracy/top1", rule="greater", min_delta=1) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -197,9 +190,8 @@ def test_after_val_epoch(self): # Check patience runner = get_mock_runner() - metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1, patience=10) + metrics = [{"accuracy/top1": i} for i in torch.linspace(98, 99, 8)] + hook = EarlyStoppingHook(monitor="accuracy/top1", rule="greater", min_delta=1, patience=10) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -208,12 +200,8 @@ def test_after_val_epoch(self): # Check stopping_threshold runner = get_mock_runner() - metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', - rule='greater', - stopping_threshold=98.5, - patience=0) + metrics = [{"accuracy/top1": i} for i in torch.linspace(98, 99, 8)] + hook = EarlyStoppingHook(monitor="accuracy/top1", rule="greater", stopping_threshold=98.5, patience=0) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -222,11 +210,11 @@ def test_after_val_epoch(self): def test_with_runner(self): max_epoch = 10 - work_dir = osp.join(self.temp_dir.name, 'runner_test') + work_dir = osp.join(self.temp_dir.name, "runner_test") early_stop_cfg = dict( - type='EarlyStoppingHook', - monitor='test/acc', - rule='greater', + type="EarlyStoppingHook", + monitor="test/acc", + rule="greater", min_delta=1, patience=3, ) @@ -234,22 +222,17 @@ def test_with_runner(self): model=ToyModel(), work_dir=work_dir, train_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), + dataset=DummyDataset(), sampler=dict(type="DefaultSampler", shuffle=True), batch_size=3, num_workers=0 + ), val_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + dataset=DummyDataset(), sampler=dict(type="DefaultSampler", shuffle=False), batch_size=3, num_workers=0 + ), val_evaluator=dict(type=DummyMetric, length=max_epoch), - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict( - by_epoch=True, max_epochs=max_epoch, val_interval=1), + optim_wrapper=OptimWrapper(torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=True, max_epochs=max_epoch, val_interval=1), val_cfg=dict(), custom_hooks=[early_stop_cfg], - experiment_name='earlystop_test') + experiment_name="earlystop_test", + ) runner.train() self.assertEqual(runner.epoch, 6) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 6dad7ba4f0..65b1fe8ddb 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -16,7 +16,6 @@ class DummyWrapper(BaseModel): - def __init__(self, model): super().__init__() if not isinstance(model, nn.Module): @@ -28,7 +27,6 @@ def forward(self, *args, **kwargs): class ToyModel2(ToyModel): - def __init__(self): super().__init__() self.linear3 = nn.Linear(2, 1) @@ -38,7 +36,6 @@ def forward(self, *args, **kwargs): class ToyModel3(ToyModel): - def __init__(self): super().__init__() self.linear2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 1)) @@ -48,33 +45,30 @@ def forward(self, *args, **kwargs): # TODO:haowen.han@mtheads.com -@unittest.skipIf(is_musa_available(), - "musa backend do not support 'aten::lerp.Scalar_out'") +@unittest.skipIf(is_musa_available(), "musa backend do not support 'aten::lerp.Scalar_out'") class TestEMAHook(RunnerTestCase): - def setUp(self) -> None: - MODELS.register_module(name='DummyWrapper', module=DummyWrapper) - MODELS.register_module(name='ToyModel2', module=ToyModel2) - MODELS.register_module(name='ToyModel3', module=ToyModel3) + MODELS.register_module(name="DummyWrapper", module=DummyWrapper) + MODELS.register_module(name="ToyModel2", module=ToyModel2) + MODELS.register_module(name="ToyModel3", module=ToyModel3) return super().setUp() def tearDown(self): - MODELS.module_dict.pop('DummyWrapper') - MODELS.module_dict.pop('ToyModel2') - MODELS.module_dict.pop('ToyModel3') + MODELS.module_dict.pop("DummyWrapper") + MODELS.module_dict.pop("ToyModel2") + MODELS.module_dict.pop("ToyModel3") return super().tearDown() def test_init(self): EMAHook() - with self.assertRaisesRegex(AssertionError, '`begin_iter` must'): + with self.assertRaisesRegex(AssertionError, "`begin_iter` must"): EMAHook(begin_iter=-1) - with self.assertRaisesRegex(AssertionError, '`begin_epoch` must'): + with self.assertRaisesRegex(AssertionError, "`begin_epoch` must"): EMAHook(begin_epoch=-1) - with self.assertRaisesRegex(AssertionError, - '`begin_iter` and `begin_epoch`'): + with self.assertRaisesRegex(AssertionError, "`begin_iter` and `begin_epoch`"): EMAHook(begin_iter=1, begin_epoch=1) def _get_ema_hook(self, runner): @@ -84,7 +78,7 @@ def _get_ema_hook(self, runner): def test_before_run(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [dict(type='EMAHook')] + cfg.custom_hooks = [dict(type="EMAHook")] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) ema_hook.before_run(runner) @@ -93,36 +87,30 @@ def test_before_run(self): def test_before_train(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [ - dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs - 1) - ] + cfg.custom_hooks = [dict(type="EMAHook", begin_epoch=cfg.train_cfg.max_epochs - 1)] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) ema_hook.before_train(runner) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [ - dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs + 1) - ] + cfg.custom_hooks = [dict(type="EMAHook", begin_epoch=cfg.train_cfg.max_epochs + 1)] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) - with self.assertRaisesRegex(AssertionError, 'self.begin_epoch'): + with self.assertRaisesRegex(AssertionError, "self.begin_epoch"): ema_hook.before_train(runner) cfg = copy.deepcopy(self.iter_based_cfg) - cfg.custom_hooks = [ - dict(type='EMAHook', begin_iter=cfg.train_cfg.max_iters + 1) - ] + cfg.custom_hooks = [dict(type="EMAHook", begin_iter=cfg.train_cfg.max_iters + 1)] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) - with self.assertRaisesRegex(AssertionError, 'self.begin_iter'): + with self.assertRaisesRegex(AssertionError, "self.begin_iter"): ema_hook.before_train(runner) def test_after_train_iter(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [dict(type='EMAHook')] + cfg.custom_hooks = [dict(type="EMAHook")] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) @@ -138,7 +126,7 @@ def test_after_train_iter(self): parameter.data.copy_(torch.randn(parameter.shape)) ema_hook.after_train_iter(runner, 1) - for src, ema in zip(src_model.parameters(), ema_model.parameters()): + for src, ema in zip(src_model.parameters(), ema_model.parameters(), strict=False): assert_allclose(src.data, ema.data) with torch.no_grad(): @@ -147,20 +135,20 @@ def test_after_train_iter(self): ema_hook.after_train_iter(runner, 1) - for src, ema in zip(src_model.parameters(), ema_model.parameters()): + for src, ema in zip(src_model.parameters(), ema_model.parameters(), strict=False): self.assertFalse((src.data == ema.data).all()) def test_before_val_epoch(self): - self._test_swap_parameters('before_val_epoch') + self._test_swap_parameters("before_val_epoch") def test_after_val_epoch(self): - self._test_swap_parameters('after_val_epoch') + self._test_swap_parameters("after_val_epoch") def test_before_test_epoch(self): - self._test_swap_parameters('before_test_epoch') + self._test_swap_parameters("before_test_epoch") def test_after_test_epoch(self): - self._test_swap_parameters('after_test_epoch') + self._test_swap_parameters("after_test_epoch") def test_before_save_checkpoint(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -173,14 +161,12 @@ def test_before_save_checkpoint(self): ori_checkpoint = copy.deepcopy(checkpoint) ema_hook.before_save_checkpoint(runner, checkpoint) - for key in ori_checkpoint['state_dict'].keys(): + for key in ori_checkpoint["state_dict"].keys(): assert_allclose( - ori_checkpoint['state_dict'][key].cpu(), - checkpoint['ema_state_dict'][f'module.{key}'].cpu()) + ori_checkpoint["state_dict"][key].cpu(), checkpoint["ema_state_dict"][f"module.{key}"].cpu() + ) - assert_allclose( - ema_hook.ema_model.state_dict()[f'module.{key}'].cpu(), - checkpoint['state_dict'][key].cpu()) + assert_allclose(ema_hook.ema_model.state_dict()[f"module.{key}"].cpu(), checkpoint["state_dict"][key].cpu()) def test_after_load_checkpoint(self): # Test load a checkpoint without ema_state_dict. @@ -192,112 +178,104 @@ def test_after_load_checkpoint(self): ema_hook.before_train(runner) ema_hook.after_load_checkpoint(runner, checkpoint) - for key in checkpoint['state_dict'].keys(): - assert_allclose( - checkpoint['state_dict'][key].cpu(), - ema_hook.ema_model.state_dict()[f'module.{key}'].cpu()) + for key in checkpoint["state_dict"].keys(): + assert_allclose(checkpoint["state_dict"][key].cpu(), ema_hook.ema_model.state_dict()[f"module.{key}"].cpu()) # Test a warning should be raised when resuming from a checkpoint # without `ema_state_dict` runner._resume = True ema_hook.after_load_checkpoint(runner, checkpoint) - with self.assertLogs(runner.logger, level='WARNING') as cm: + with self.assertLogs(runner.logger, level="WARNING") as cm: ema_hook.after_load_checkpoint(runner, checkpoint) - self.assertRegex(cm.records[0].msg, 'There is no `ema_state_dict`') + self.assertRegex(cm.records[0].msg, "There is no `ema_state_dict`") # Check the weight of state_dict and ema_state_dict have been swapped. # when runner._resume is True runner._resume = True checkpoint = dict( - state_dict=ToyModel().state_dict(), - ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict()) + state_dict=ToyModel().state_dict(), ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict() + ) ori_checkpoint = copy.deepcopy(checkpoint) ema_hook.after_load_checkpoint(runner, checkpoint) - for key in ori_checkpoint['state_dict'].keys(): + for key in ori_checkpoint["state_dict"].keys(): assert_allclose( - ori_checkpoint['state_dict'][key].cpu(), - ema_hook.ema_model.state_dict()[f'module.{key}'].cpu()) + ori_checkpoint["state_dict"][key].cpu(), ema_hook.ema_model.state_dict()[f"module.{key}"].cpu() + ) runner._resume = False ema_hook.after_load_checkpoint(runner, checkpoint) def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [ConfigDict(type='EMAHook')] + cfg.custom_hooks = [ConfigDict(type="EMAHook")] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) runner.train() - self.assertTrue( - isinstance(ema_hook.ema_model, ExponentialMovingAverage)) + self.assertTrue(isinstance(ema_hook.ema_model, ExponentialMovingAverage)) - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) - self.assertTrue('ema_state_dict' in checkpoint) - self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) + checkpoint = torch.load(osp.join(self.temp_dir.name, "epoch_2.pth")) + self.assertTrue("ema_state_dict" in checkpoint) + self.assertTrue(checkpoint["ema_state_dict"]["steps"] == 8) # load and testing - cfg.load_from = osp.join(self.temp_dir.name, 'epoch_2.pth') + cfg.load_from = osp.join(self.temp_dir.name, "epoch_2.pth") runner = self.build_runner(cfg) runner.test() # with model wrapper - cfg.model = ConfigDict(type='DummyWrapper', model=cfg.model) + cfg.model = ConfigDict(type="DummyWrapper", model=cfg.model) runner = self.build_runner(cfg) runner.test() # Test load checkpoint without ema_state_dict - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) - checkpoint.pop('ema_state_dict') - torch.save(checkpoint, - osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, "epoch_2.pth")) + checkpoint.pop("ema_state_dict") + torch.save(checkpoint, osp.join(self.temp_dir.name, "without_ema_state_dict.pth")) - cfg.load_from = osp.join(self.temp_dir.name, - 'without_ema_state_dict.pth') + cfg.load_from = osp.join(self.temp_dir.name, "without_ema_state_dict.pth") runner = self.build_runner(cfg) runner.test() # Test does not load checkpoint strictly (different name). # Test load checkpoint without ema_state_dict - cfg.model = ConfigDict(type='ToyModel2') - cfg.custom_hooks = [ConfigDict(type='EMAHook', strict_load=False)] + cfg.model = ConfigDict(type="ToyModel2") + cfg.custom_hooks = [ConfigDict(type="EMAHook", strict_load=False)] runner = self.build_runner(cfg) runner.test() # Test does not load ckpt strictly (different weight size). # Test load checkpoint without ema_state_dict - cfg.model = ConfigDict(type='ToyModel3') + cfg.model = ConfigDict(type="ToyModel3") runner = self.build_runner(cfg) runner.test() # Test enable ema at 5 epochs. cfg.train_cfg.max_epochs = 10 - cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_epoch=5)] + cfg.custom_hooks = [ConfigDict(type="EMAHook", begin_epoch=5)] runner = self.build_runner(cfg) runner.train() - state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') - self.assertIn('ema_state_dict', state_dict) - for k, v in state_dict['state_dict'].items(): - assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) + state_dict = torch.load(osp.join(self.temp_dir.name, "epoch_4.pth"), map_location="cpu") + self.assertIn("ema_state_dict", state_dict) + for k, v in state_dict["state_dict"].items(): + assert_allclose(v, state_dict["ema_state_dict"]["module." + k]) # Test enable ema at 5 iterations. cfg = copy.deepcopy(self.iter_based_cfg) cfg.train_cfg.val_interval = 1 - cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_iter=5)] + cfg.custom_hooks = [ConfigDict(type="EMAHook", begin_iter=5)] cfg.default_hooks.checkpoint.interval = 1 runner = self.build_runner(cfg) runner.train() - state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') - self.assertIn('ema_state_dict', state_dict) - for k, v in state_dict['state_dict'].items(): - assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) - state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') - self.assertIn('ema_state_dict', state_dict) + state_dict = torch.load(osp.join(self.temp_dir.name, "iter_4.pth"), map_location="cpu") + self.assertIn("ema_state_dict", state_dict) + for k, v in state_dict["state_dict"].items(): + assert_allclose(v, state_dict["ema_state_dict"]["module." + k]) + state_dict = torch.load(osp.join(self.temp_dir.name, "iter_5.pth"), map_location="cpu") + self.assertIn("ema_state_dict", state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [dict(type='EMAHook')] + cfg.custom_hooks = [dict(type="EMAHook")] runner = self.build_runner(cfg) ema_hook = self._get_ema_hook(runner) @@ -317,7 +295,7 @@ def _test_swap_parameters(self, func_name, *args, **kwargs): swapped_ema = ema_hook.ema_model for src, ema, swapped_src, swapped_ema in zip( - src_model.parameters(), ema_model.parameters(), - swapped_src.parameters(), swapped_ema.parameters()): + src_model.parameters(), ema_model.parameters(), swapped_src.parameters(), swapped_ema.parameters(), strict=False + ): self.assertTrue((src.data == swapped_ema.data).all()) self.assertTrue((ema.data == swapped_src.data).all()) diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index d30972d360..04aee923ab 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -8,13 +8,11 @@ class TestEmptyCacheHook(RunnerTestCase): - - @pytest.mark.skipif( - not is_cuda_available(), reason='cuda should be available') + @pytest.mark.skipif(not is_cuda_available(), reason="cuda should be available") def test_with_runner(self): - with patch('torch.cuda.empty_cache') as mock_empty_cache: + with patch("torch.cuda.empty_cache") as mock_empty_cache: cfg = self.epoch_based_cfg - cfg.custom_hooks = [dict(type='EmptyCacheHook')] + cfg.custom_hooks = [dict(type="EmptyCacheHook")] cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501 runner = self.build_runner(cfg) @@ -29,8 +27,8 @@ def test_with_runner(self): target_called_times = runner.max_epochs + 2 self.assertEqual(mock_empty_cache.call_count, target_called_times) - with patch('torch.cuda.empty_cache') as mock_empty_cache: - cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)] + with patch("torch.cuda.empty_cache") as mock_empty_cache: + cfg.custom_hooks = [dict(type="EmptyCacheHook", before_epoch=True)] runner = self.build_runner(cfg) runner.train() @@ -45,11 +43,8 @@ def test_with_runner(self): target_called_times = runner.max_epochs * 2 + 4 self.assertEqual(mock_empty_cache.call_count, target_called_times) - with patch('torch.cuda.empty_cache') as mock_empty_cache: - cfg.custom_hooks = [ - dict( - type='EmptyCacheHook', after_iter=True, before_epoch=True) - ] + with patch("torch.cuda.empty_cache") as mock_empty_cache: + cfg.custom_hooks = [dict(type="EmptyCacheHook", after_iter=True, before_epoch=True)] runner = self.build_runner(cfg) runner.train() @@ -62,9 +57,11 @@ def test_with_runner(self): # runner.val: `1*2 + len(val_dataloader)` times. # runner.test: `1*2 + len(val_dataloader)` times. - target_called_times = \ - runner.max_epochs * 2 + 4 + \ - len(runner.train_dataloader) * runner.max_epochs + \ - len(runner.val_dataloader) + \ - len(runner.test_dataloader) + target_called_times = ( + runner.max_epochs * 2 + + 4 + + len(runner.train_dataloader) * runner.max_epochs + + len(runner.val_dataloader) + + len(runner.test_dataloader) + ) self.assertEqual(mock_empty_cache.call_count, target_called_times) diff --git a/tests/test_hooks/test_hook.py b/tests/test_hooks/test_hook.py index 96cf066a31..0fbdba453b 100644 --- a/tests/test_hooks/test_hook.py +++ b/tests/test_hooks/test_hook.py @@ -6,7 +6,6 @@ class TestHook(RunnerTestCase): - def test_before_run(self): hook = Hook() runner = Mock() @@ -195,32 +194,24 @@ def test_is_last_train_iter(self): assert return_val def test_get_triggered_stages(self): - class CustomHook(Hook): - def after_train(self, runner): return super().after_train(runner) hook = CustomHook() triggered_stages = hook.get_triggered_stages() - self.assertListEqual(triggered_stages, ['after_train']) + self.assertListEqual(triggered_stages, ["after_train"]) class CustomHook(Hook): - - def _before_iter(self, runner): - ... + def _before_iter(self, runner): ... hook = CustomHook() triggered_stages = hook.get_triggered_stages() self.assertEqual(len(triggered_stages), 3) - self.assertSetEqual( - set(triggered_stages), - {'before_train_iter', 'before_val_iter', 'before_test_iter'}) + self.assertSetEqual(set(triggered_stages), {"before_train_iter", "before_val_iter", "before_test_iter"}) class CustomHook(Hook): - - def is_last_train_epoch(self, runner): - ... + def is_last_train_epoch(self, runner): ... hook = CustomHook() triggered_stages = hook.get_triggered_stages() diff --git a/tests/test_hooks/test_iter_timer_hook.py b/tests/test_hooks/test_iter_timer_hook.py index 54119805a6..db3106ad79 100644 --- a/tests/test_hooks/test_iter_timer_hook.py +++ b/tests/test_hooks/test_iter_timer_hook.py @@ -17,19 +17,17 @@ def time(cls): class TestIterTimerHook(RunnerTestCase): - - @patch('mmengine.hooks.iter_timer_hook.time', patched_time) + @patch("mmengine.hooks.iter_timer_hook.time", patched_time) def test_before_iter(self): runner = self.build_runner(self.epoch_based_cfg) hook = self._get_iter_timer_hook(runner) - for mode in ('train', 'val', 'test'): + for mode in ("train", "val", "test"): hook._before_epoch(runner) hook._before_iter(runner, batch_idx=1, mode=mode) - time = runner.message_hub.get_scalar( - f'{mode}/data_time')._log_history + time = runner.message_hub.get_scalar(f"{mode}/data_time")._log_history self.assertEqual(list(time)[-1], 1) - @patch('mmengine.hooks.iter_timer_hook.time', patched_time) + @patch("mmengine.hooks.iter_timer_hook.time", patched_time) def test_after_iter(self): cfg = copy.deepcopy(self.iter_based_cfg) cfg.train_cfg.max_iters = 100 @@ -48,46 +46,46 @@ def test_after_iter(self): runner.train_loop._iter += 1 # Left 90 iterations, so the ETA should be 90 * 2s - self.assertEqual(runner.message_hub.get_info('eta'), 180) + self.assertEqual(runner.message_hub.get_info("eta"), 180) hook.after_train_epoch(runner) for i in range(2): hook.before_val_iter(runner, i) hook.after_val_iter(runner, batch_idx=i) - self.assertEqual(runner.message_hub.get_info('eta'), 4) + self.assertEqual(runner.message_hub.get_info("eta"), 4) for i in range(2, 4): hook.before_val_iter(runner, i) hook.after_val_iter(runner, batch_idx=i) hook.after_val_epoch(runner) - self.assertEqual(runner.message_hub.get_info('eta'), 0) + self.assertEqual(runner.message_hub.get_info("eta"), 0) for i in range(2): hook.before_test_iter(runner, i) hook.after_test_iter(runner, batch_idx=i) - self.assertEqual(runner.message_hub.get_info('eta'), 4) + self.assertEqual(runner.message_hub.get_info("eta"), 4) for i in range(2, 4): hook.before_test_iter(runner, i) hook.after_test_iter(runner, batch_idx=i) hook.after_test_epoch(runner) - self.assertEqual(runner.message_hub.get_info('eta'), 0) + self.assertEqual(runner.message_hub.get_info("eta"), 0) def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) cfg.train_cfg.val_interval = 1e6 # disable validation - with patch('mmengine.hooks.iter_timer_hook.time', patched_time): + with patch("mmengine.hooks.iter_timer_hook.time", patched_time): runner.train() # 4 iteration per epoch, totally 2 epochs # Under pathced_time, before_iter will cost "1s" and after_iter will # cost "1s", so the total time for each iteration is 2s. - train_time = runner.message_hub.log_scalars['train/time']._log_history + train_time = runner.message_hub.log_scalars["train/time"]._log_history self.assertEqual(len(train_time), 8) self.assertListEqual(list(train_time), [2] * 8) - eta = runner.message_hub.runtime_info['eta'] + eta = runner.message_hub.runtime_info["eta"] self.assertEqual(eta, 0) def _get_iter_timer_hook(self, runner): diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 52b8bc1fa3..e1e7946c1b 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -15,51 +15,45 @@ class TestLoggerHook(RunnerTestCase): - def test_init(self): # Test build logger hook. LoggerHook() LoggerHook(interval=100, ignore_last=False, interval_exp_name=100) - with self.assertRaisesRegex(TypeError, 'interval must be'): - LoggerHook(interval='100') + with self.assertRaisesRegex(TypeError, "interval must be"): + LoggerHook(interval="100") - with self.assertRaisesRegex(ValueError, 'interval must be'): + with self.assertRaisesRegex(ValueError, "interval must be"): LoggerHook(interval=-1) - with self.assertRaisesRegex(TypeError, 'ignore_last must be'): - LoggerHook(ignore_last='False') + with self.assertRaisesRegex(TypeError, "ignore_last must be"): + LoggerHook(ignore_last="False") - with self.assertRaisesRegex(TypeError, 'interval_exp_name'): - LoggerHook(interval_exp_name='100') + with self.assertRaisesRegex(TypeError, "interval_exp_name"): + LoggerHook(interval_exp_name="100") - with self.assertRaisesRegex(ValueError, 'interval_exp_name'): + with self.assertRaisesRegex(ValueError, "interval_exp_name"): LoggerHook(interval_exp_name=-1) - with self.assertRaisesRegex(TypeError, 'out_suffix'): + with self.assertRaisesRegex(TypeError, "out_suffix"): LoggerHook(out_suffix=[100]) # out_dir should be None or string or tuple of string. - with self.assertRaisesRegex(TypeError, 'out_dir must be'): + with self.assertRaisesRegex(TypeError, "out_dir must be"): LoggerHook(out_dir=1) - with self.assertRaisesRegex(ValueError, 'file_client_args'): + with self.assertRaisesRegex(ValueError, "file_client_args"): LoggerHook(file_client_args=dict(enable_mc=True)) # test deprecated warning raised by `file_client_args` logger = MMLogger.get_current_instance() - with self.assertLogs(logger, level='WARNING'): - LoggerHook( - out_dir=self.temp_dir.name, - file_client_args=dict(backend='disk')) + with self.assertLogs(logger, level="WARNING"): + LoggerHook(out_dir=self.temp_dir.name, file_client_args=dict(backend="disk")) - with self.assertRaisesRegex( - ValueError, - '"file_client_args" and "backend_args" cannot be '): + with self.assertRaisesRegex(ValueError, '"file_client_args" and "backend_args" cannot be '): LoggerHook( - out_dir=self.temp_dir.name, - file_client_args=dict(enable_mc=True), - backend_args=dict(enable_mc=True)) + out_dir=self.temp_dir.name, file_client_args=dict(enable_mc=True), backend_args=dict(enable_mc=True) + ) def test_after_train_iter(self): # Test LoggerHook by iter. @@ -67,8 +61,7 @@ def test_after_train_iter(self): ori_every_n_train_iters = LoggerHook.every_n_train_iters LoggerHook.every_n_train_iters = MagicMock(return_value=True) runner = MagicMock() - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) logger_hook = LoggerHook() logger_hook.after_train_iter(runner, batch_idx=5) # `cur_iter=10+1`, which cannot be exact division by @@ -80,8 +73,7 @@ def test_after_train_iter(self): # Test LoggerHook by epoch. logger_hook = LoggerHook() runner = MagicMock() - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) # Only `batch_idx` will work. logger_hook.after_train_iter(runner, batch_idx=10) runner.log_processor.get_log_after_iter.assert_not_called() @@ -90,8 +82,7 @@ def test_after_train_iter(self): # Test end of the epoch. runner = MagicMock() - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) logger_hook = LoggerHook(ignore_last=False) runner.train_dataloader = [0] * 5 logger_hook.after_train_iter(runner, batch_idx=4) @@ -99,8 +90,7 @@ def test_after_train_iter(self): # Test print exp_name runner = MagicMock() - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) runner.logger = MagicMock() logger_hook = LoggerHook() logger_hook.after_train_iter(runner, batch_idx=999) @@ -109,8 +99,7 @@ def test_after_train_iter(self): # Test print training log when the num of # iterations is smaller than the default interval runner = MagicMock() - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) runner.train_dataloader = [0] * 9 logger_hook = LoggerHook() logger_hook.after_train_iter(runner, batch_idx=8) @@ -122,77 +111,54 @@ def test_after_val_epoch(self): runner = MagicMock() # Test when `log_metric_by_epoch` is True runner.log_processor.get_log_after_epoch = MagicMock( - return_value=({ - 'time': 1, - 'datatime': 1, - 'acc': 0.8 - }, 'string')) + return_value=({"time": 1, "datatime": 1, "acc": 0.8}, "string") + ) logger_hook.after_val_epoch(runner) # expect visualizer log `time` and `metric` respectively - args = {'step': ANY, 'file_path': ANY} + args = {"step": ANY, "file_path": ANY} calls = [ - call({ - 'time': 1, - 'datatime': 1, - 'acc': 0.8 - }, **args), + call({"time": 1, "datatime": 1, "acc": 0.8}, **args), ] - self.assertEqual( - len(calls), len(runner.visualizer.add_scalars.mock_calls)) + self.assertEqual(len(calls), len(runner.visualizer.add_scalars.mock_calls)) runner.visualizer.add_scalars.assert_has_calls(calls) # Test when `log_metric_by_epoch` is False logger_hook = LoggerHook(log_metric_by_epoch=False) runner.log_processor.get_log_after_epoch = MagicMock( - return_value=({ - 'time': 5, - 'datatime': 5, - 'acc': 0.5 - }, 'string')) + return_value=({"time": 5, "datatime": 5, "acc": 0.5}, "string") + ) logger_hook.after_val_epoch(runner) # expect visualizer log `time` and `metric` jointly calls = [ - call({ - 'time': 1, - 'datatime': 1, - 'acc': 0.8 - }, **args), - call({ - 'time': 5, - 'datatime': 5, - 'acc': 0.5 - }, **args), + call({"time": 1, "datatime": 1, "acc": 0.8}, **args), + call({"time": 5, "datatime": 5, "acc": 0.5}, **args), ] - self.assertEqual( - len(calls), len(runner.visualizer.add_scalars.mock_calls)) + self.assertEqual(len(calls), len(runner.visualizer.add_scalars.mock_calls)) runner.visualizer.add_scalars.assert_has_calls(calls) def test_after_test_epoch(self): logger_hook = LoggerHook() runner = MagicMock() runner.log_dir = self.temp_dir.name - runner.timestamp = 'test_after_test_epoch' + runner.timestamp = "test_after_test_epoch" runner.log_processor.get_log_after_epoch = MagicMock( - return_value=( - dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])), - 'log_str')) + return_value=(dict(a=1, b=2, c={"list": [1, 2]}, d=torch.tensor([1, 2, 3])), "log_str") + ) logger_hook.before_run(runner) logger_hook.after_test_epoch(runner) runner.log_processor.get_log_after_epoch.assert_called() runner.logger.info.assert_called() - osp.isfile(osp.join(runner.log_dir, 'test_after_test_epoch.json')) - json_content = load( - osp.join(runner.log_dir, 'test_after_test_epoch.json')) - assert json_content == dict(a=1, b=2, c={'list': [1, 2]}, d=[1, 2, 3]) + osp.isfile(osp.join(runner.log_dir, "test_after_test_epoch.json")) + json_content = load(osp.join(runner.log_dir, "test_after_test_epoch.json")) + assert json_content == dict(a=1, b=2, c={"list": [1, 2]}, d=[1, 2, 3]) def test_after_val_iter(self): logger_hook = LoggerHook() runner = MagicMock() runner.iter = 0 - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) logger_hook.after_val_iter(runner, 1) runner.log_processor.get_log_after_iter.assert_not_called() logger_hook.after_val_iter(runner, 9) @@ -202,8 +168,7 @@ def test_after_test_iter(self): logger_hook = LoggerHook() runner = MagicMock() runner.iter = 0 - runner.log_processor.get_log_after_iter = MagicMock( - return_value=(dict(), 'log_str')) + runner.log_processor.get_log_after_iter = MagicMock(return_value=(dict(), "log_str")) logger_hook.after_test_iter(runner, 1) runner.log_processor.get_log_after_iter.assert_not_called() logger_hook.after_test_iter(runner, 9) @@ -212,18 +177,17 @@ def test_after_test_iter(self): def test_with_runner(self): # Test dumped the json exits cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.default_hooks.logger = dict(type='LoggerHook') + cfg.default_hooks.logger = dict(type="LoggerHook") cfg.train_cfg.max_epochs = 10 runner = self.build_runner(cfg) runner.train() - json_path = osp.join(runner._log_dir, 'vis_data', - f'{runner.timestamp}.json') + json_path = osp.join(runner._log_dir, "vis_data", f"{runner.timestamp}.json") self.assertTrue(osp.isfile(json_path)) # Test out_dir - out_dir = osp.join(cfg.work_dir, 'test') + out_dir = osp.join(cfg.work_dir, "test") mkdir_or_exist(out_dir) - cfg.default_hooks.logger = dict(type='LoggerHook', out_dir=out_dir) + cfg.default_hooks.logger = dict(type="LoggerHook", out_dir=out_dir) runner = self.build_runner(cfg) runner.train() self.assertTrue(os.listdir(out_dir)) @@ -232,22 +196,17 @@ def test_with_runner(self): shutil.rmtree(osp.join(out_dir, filename)) # Test out_suffix - cfg.default_hooks.logger = dict( - type='LoggerHook', out_dir=out_dir, out_suffix='.log') + cfg.default_hooks.logger = dict(type="LoggerHook", out_dir=out_dir, out_suffix=".log") runner = self.build_runner(cfg) runner.train() filenames = scandir(out_dir, recursive=True) - self.assertTrue( - all(filename.endswith('.log') for filename in filenames)) + self.assertTrue(all(filename.endswith(".log") for filename in filenames)) # Test keep_local=False - cfg.default_hooks.logger = dict( - type='LoggerHook', out_dir=out_dir, keep_local=False) + cfg.default_hooks.logger = dict(type="LoggerHook", out_dir=out_dir, keep_local=False) runner = self.build_runner(cfg) runner.train() filenames = scandir(runner._log_dir, recursive=True) for filename in filenames: - self.assertFalse( - filename.endswith(('.log', '.json', '.py', '.yaml')), - f'{filename} should not be kept.') + self.assertFalse(filename.endswith((".log", ".json", ".py", ".yaml")), f"{filename} should not be kept.") diff --git a/tests/test_hooks/test_naive_visualization_hook.py b/tests/test_hooks/test_naive_visualization_hook.py index 2e39e94527..57ec49f3c7 100644 --- a/tests/test_hooks/test_naive_visualization_hook.py +++ b/tests/test_hooks/test_naive_visualization_hook.py @@ -8,7 +8,6 @@ class TestNaiveVisualizationHook: - def test_after_train_iter(self): naive_visualization_hook = NaiveVisualizationHook() runner = Mock(iter=1) @@ -18,54 +17,40 @@ def test_after_train_iter(self): # test with normalize, resize, pad gt_datasamples = BaseDataElement( metainfo=dict( - img_norm_cfg=dict( - mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), + img_norm_cfg=dict(mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), scale=(10, 10), pad_shape=(15, 15, 3), ori_height=5, ori_width=5, - img_path='tmp.jpg')) + img_path="tmp.jpg", + ) + ) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] - naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, - pred_datasamples) + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad gt_datasamples = BaseDataElement( - metainfo=dict( - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + metainfo=dict(scale=(10, 10), pad_shape=(15, 15, 3), ori_height=5, ori_width=5, img_path="tmp.jpg") + ) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] - naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, - pred_datasamples) + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize - gt_datasamples = BaseDataElement( - metainfo=dict( - scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(scale=(15, 15), ori_height=5, ori_width=5, img_path="tmp.jpg")) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] - naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, - pred_datasamples) + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad gt_datasamples = BaseDataElement( - metainfo=dict( - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + metainfo=dict(pad_shape=(15, 15, 3), ori_height=5, ori_width=5, img_path="tmp.jpg") + ) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] - naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, - pred_datasamples) + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test no transform - gt_datasamples = BaseDataElement( - metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(ori_height=15, ori_width=15, img_path="tmp.jpg")) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] - naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, - pred_datasamples) + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) diff --git a/tests/test_hooks/test_param_scheduler_hook.py b/tests/test_hooks/test_param_scheduler_hook.py index c1d8e6a84b..b2b86268ab 100644 --- a/tests/test_hooks/test_param_scheduler_hook.py +++ b/tests/test_hooks/test_param_scheduler_hook.py @@ -8,8 +8,7 @@ class TestParamSchedulerHook(RunnerTestCase): - error_msg = ('runner.param_schedulers should be list of ParamScheduler or ' - 'a dict containing list of ParamScheduler') + error_msg = "runner.param_schedulers should be list of ParamScheduler or a dict containing list of ParamScheduler" def test_after_train_iter(self): # runner.param_schedulers should be a list or dict @@ -85,7 +84,6 @@ def test_after_val_epoch(self): # mock super _ParamScheduler class class MockParamScheduler(_ParamScheduler): - def __init__(self): pass @@ -133,15 +131,15 @@ def test_with_runner(self): cfg.train_cfg.max_epochs = 3 cfg.param_scheduler = [ dict( - type='ConstantLR', + type="ConstantLR", factor=0.5, begin=0, ), dict( - type='ConstantLR', + type="ConstantLR", factor=0.5, begin=1, - ) + ), ] init_lr = cfg.optim_wrapper.optimizer.lr runner = self.build_runner(cfg) @@ -151,26 +149,24 @@ def test_with_runner(self): # Learning rate of the first epoch is init_lr*0.5 # Learning rate of the second epoch is init_lr*0.5*0.5 # Learning rate of the last epoch will be reset to 0.1 - train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history) - target_lr = [init_lr * 0.5] * 4 + \ - [init_lr * 0.5 * 0.5] * 4 + \ - [init_lr] * 4 + train_lr = list(runner.message_hub.get_scalar("train/lr")._log_history) + target_lr = [init_lr * 0.5] * 4 + [init_lr * 0.5 * 0.5] * 4 + [init_lr] * 4 self.assertListEqual(train_lr, target_lr) cfg = copy.deepcopy(self.iter_based_cfg) cfg.param_scheduler = [ dict( - type='ConstantLR', + type="ConstantLR", factor=0.5, begin=0, by_epoch=False, ), dict( - type='ConstantLR', + type="ConstantLR", factor=0.5, begin=4, by_epoch=False, - ) + ), ] init_lr = cfg.optim_wrapper.optimizer.lr @@ -179,8 +175,6 @@ def test_with_runner(self): # Learning rate of 1-4 iteration is init_lr*0.5 # Learning rate of 5-11 iteration is init_lr*0.5*0.5 - train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history) - target_lr = [init_lr * 0.5] * 4 + \ - [init_lr * 0.5 * 0.5] * 7 + \ - [init_lr] + train_lr = list(runner.message_hub.get_scalar("train/lr")._log_history) + target_lr = [init_lr * 0.5] * 4 + [init_lr * 0.5 * 0.5] * 7 + [init_lr] self.assertListEqual(train_lr, target_lr) diff --git a/tests/test_hooks/test_prepare_tta_hook.py b/tests/test_hooks/test_prepare_tta_hook.py index a356164ef6..53b952651e 100644 --- a/tests/test_hooks/test_prepare_tta_hook.py +++ b/tests/test_hooks/test_prepare_tta_hook.py @@ -33,31 +33,27 @@ def __getitem__(self, index): class ToyModel(BaseModel): - def __init__(self): super().__init__() # DDPWrapper requires at least one parameter. self.linear = torch.nn.Linear(1, 1) - def forward(self, inputs, data_samples, mode='tensor'): + def forward(self, inputs, data_samples, mode="tensor"): return data_samples class ToyTestTimeAugModel(BaseTTAModel): - def merge_preds(self, data_samples_list): result = [sum(x) for x in data_samples_list] return result class ToyTTAPipeline: - def __call__(self, result): return {key: [value] for key, value in result.items()} class TestPrepareTTAHook(RunnerTestCase): - def setUp(self) -> None: super().setUp() TRANSFORMS.register_module(module=ToyTTAPipeline, force=True) @@ -67,13 +63,13 @@ def setUp(self) -> None: def tearDown(self): super().tearDown() - TRANSFORMS.module_dict.pop('ToyTTAPipeline', None) - MODELS.module_dict.pop('ToyModel', None) - MODELS.module_dict.pop('ToyTestTimeAugModel', None) - DATASETS.module_dict.pop('ToyDatasetTTA', None) + TRANSFORMS.module_dict.pop("ToyTTAPipeline", None) + MODELS.module_dict.pop("ToyModel", None) + MODELS.module_dict.pop("ToyTestTimeAugModel", None) + DATASETS.module_dict.pop("ToyDatasetTTA", None) def test_init(self): - tta_cfg = dict(type='ToyTTAModel') + tta_cfg = dict(type="ToyTTAModel") prepare_tta_hook = PrepareTTAHook(tta_cfg) self.assertIsInstance(prepare_tta_hook, Hook) self.assertIs(tta_cfg, prepare_tta_hook.tta_cfg) @@ -81,13 +77,9 @@ def test_init(self): def test_before_test(self): # Test with epoch based runner. cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks.append( - dict( - type='PrepareTTAHook', - tta_cfg=dict(type='ToyTestTimeAugModel'))) - cfg.model = dict(type='ToyModel') - cfg.test_dataloader.dataset = dict( - type='ToyDatasetTTA', pipeline=dict(type='ToyTTAPipeline')) + cfg.custom_hooks.append(dict(type="PrepareTTAHook", tta_cfg=dict(type="ToyTestTimeAugModel"))) + cfg.model = dict(type="ToyModel") + cfg.test_dataloader.dataset = dict(type="ToyDatasetTTA", pipeline=dict(type="ToyTTAPipeline")) runner = self.build_runner(cfg) self.assertNotIsInstance(runner.model, BaseTTAModel) runner.test() @@ -95,13 +87,9 @@ def test_before_test(self): # Test with iteration based runner cfg = copy.deepcopy(self.iter_based_cfg) - cfg.custom_hooks.append( - dict( - type='PrepareTTAHook', - tta_cfg=dict(type='ToyTestTimeAugModel'))) - cfg.model = dict(type='ToyModel') - cfg.test_dataloader.dataset = dict( - type='ToyDatasetTTA', pipeline=dict(type='ToyTTAPipeline')) + cfg.custom_hooks.append(dict(type="PrepareTTAHook", tta_cfg=dict(type="ToyTestTimeAugModel"))) + cfg.model = dict(type="ToyModel") + cfg.test_dataloader.dataset = dict(type="ToyDatasetTTA", pipeline=dict(type="ToyTTAPipeline")) runner = self.build_runner(cfg) self.assertNotIsInstance(runner.model, BaseTTAModel) runner.test() @@ -110,7 +98,7 @@ def test_before_test(self): # Test with ddp if torch.cuda.is_available() and torch.distributed.is_nccl_available(): self.setup_dist_env() - cfg.launcher = 'pytorch' + cfg.launcher = "pytorch" runner = self.build_runner(cfg) self.assertNotIsInstance(runner.model, BaseTTAModel) runner.test() @@ -118,13 +106,12 @@ def test_before_test(self): class TestBuildRunenrWithTTA(TestPrepareTTAHook): - def test_build_runner_with_tta(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.model = dict(type='ToyModel') - cfg.test_dataloader.dataset = dict(type='ToyDatasetTTA') - cfg.tta_pipeline = dict(type='ToyTTAPipeline') - cfg.tta_model = dict(type='ToyTestTimeAugModel') + cfg.model = dict(type="ToyModel") + cfg.test_dataloader.dataset = dict(type="ToyDatasetTTA") + cfg.tta_pipeline = dict(type="ToyTTAPipeline") + cfg.tta_model = dict(type="ToyTestTimeAugModel") runner = build_runner_with_tta(cfg) runner.test() self.assertIsInstance(runner.model, ToyTestTimeAugModel) diff --git a/tests/test_hooks/test_profiler_hook.py b/tests/test_hooks/test_profiler_hook.py index 2db6df01b6..ba435c469d 100644 --- a/tests/test_hooks/test_profiler_hook.py +++ b/tests/test_hooks/test_profiler_hook.py @@ -17,10 +17,9 @@ @unittest.skipIf( not mmengine.hooks.profiler_hook.check_kineto(), - reason='Due to Kineto support issues, ' - 'please upgrade pytorch above 1.8.1 (windows users above 1.9.1)') + reason="Due to Kineto support issues, please upgrade pytorch above 1.8.1 (windows users above 1.9.1)", +) class TestProfilerHook(RunnerTestCase): - def test_init(self): # Test profile_times_args ProfilerHook(by_epoch=False, profile_times=1) @@ -49,43 +48,37 @@ def deal_profile(_profile): hook._parse_trace_config(runner) with self.assertRaises(ValueError): - hook.on_trace_ready = dict(type='unknown') + hook.on_trace_ready = dict(type="unknown") hook._parse_trace_config(runner) - hook.on_trace_ready = dict( - type='log_trace', sort_by='self_cpu_time_total', row_limit=10) + hook.on_trace_ready = dict(type="log_trace", sort_by="self_cpu_time_total", row_limit=10) hook._parse_trace_config(runner) - @unittest.skipIf( - not is_installed('torch-tb-profiler'), - reason='required torch-tb-profiler') + @unittest.skipIf(not is_installed("torch-tb-profiler"), reason="required torch-tb-profiler") def test_parse_trace_config_tensorboard(self): # Test on_trace_ready_args runner = MagicMock() runner.log_dir = self.temp_dir.name - runner.logger = MMLogger.get_instance('test_profiler') + runner.logger = MMLogger.get_instance("test_profiler") hook = ProfilerHook(on_trace_ready=None) - hook.on_trace_ready = dict(type='tb_trace') + hook.on_trace_ready = dict(type="tb_trace") hook._parse_trace_config(runner) - hook.on_trace_ready['dir_name'] = 'tb' + hook.on_trace_ready["dir_name"] = "tb" hook._parse_trace_config(runner) - hook.on_trace_ready['dir_name'] = ops.join(self.temp_dir.name, 'tb') + hook.on_trace_ready["dir_name"] = ops.join(self.temp_dir.name, "tb") hook._parse_trace_config(runner) # with self.assertWarns(DeprecationWarning): hook = ProfilerHook( - on_trace_ready=dict(type='tb_trace'), - json_trace_path=ops.join(self.temp_dir.name, 'demo.json')) + on_trace_ready=dict(type="tb_trace"), json_trace_path=ops.join(self.temp_dir.name, "demo.json") + ) hook._parse_trace_config(runner) - self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - on_trace_ready=dict( - type='tb_trace', dir_name=self.temp_dir.name)) + self.epoch_based_cfg["custom_hooks"] = [ + dict(type="ProfilerHook", on_trace_ready=dict(type="tb_trace", dir_name=self.temp_dir.name)) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -94,7 +87,7 @@ def test_before_run(self): runner = MagicMock() runner.max_epochs = 1000 runner.max_iters = 10000 - runner.logger = MMLogger.get_instance('test_profiler') + runner.logger = MMLogger.get_instance("test_profiler") hook = ProfilerHook() hook.before_run(runner) @@ -113,17 +106,16 @@ def test_before_run(self): def test_export_chrome_trace(self): runner = MagicMock() runner.max_epochs = 1000 - runner.logger = MMLogger.get_instance('test_profiler') + runner.logger = MMLogger.get_instance("test_profiler") - hook = ProfilerHook( - json_trace_path=ops.join(self.temp_dir.name, 'demo.json')) + hook = ProfilerHook(json_trace_path=ops.join(self.temp_dir.name, "demo.json")) hook.before_run(runner) hook._export_chrome_trace(runner) def test_after_train_epoch(self): runner = MagicMock() runner.max_epochs = 1000 - runner.logger = MMLogger.get_instance('test_profiler') + runner.logger = MMLogger.get_instance("test_profiler") runner.epoch = 0 @@ -138,7 +130,7 @@ def test_after_train_epoch(self): def test_after_train_iter(self): runner = MagicMock() runner.max_iters = 10000 - runner.logger = MMLogger.get_instance('test_profiler') + runner.logger = MMLogger.get_instance("test_profiler") runner.iter = 9 @@ -148,65 +140,47 @@ def test_after_train_iter(self): hook.profiler.__exit__.assert_called_once() hook.profiler.step.assert_called_once() - hook = ProfilerHook( - by_epoch=False, - schedule=dict(wait=1, warmup=1, active=3, repeat=1)) + hook = ProfilerHook(by_epoch=False, schedule=dict(wait=1, warmup=1, active=3, repeat=1)) hook.profiler = MagicMock() hook.after_train_iter(runner, 1, 1, 1) hook.profiler.step.assert_called_once() def test_with_runner(self): - self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - activity_with_cpu=False, - activity_with_cuda=False) + self.epoch_based_cfg["custom_hooks"] = [ + dict(type="ProfilerHook", activity_with_cpu=False, activity_with_cuda=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - json_path = ops.join(self.temp_dir.name, 'demo.json') - self.epoch_based_cfg['custom_hooks'] = [ - dict(type='ProfilerHook', json_trace_path=json_path) - ] + json_path = ops.join(self.temp_dir.name, "demo.json") + self.epoch_based_cfg["custom_hooks"] = [dict(type="ProfilerHook", json_trace_path=json_path)] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.assertTrue( - ops.exists(json_path), 'ERROR::json file is not generated!') + self.assertTrue(ops.exists(json_path), "ERROR::json file is not generated!") - self.epoch_based_cfg['custom_hooks'] = [ + self.epoch_based_cfg["custom_hooks"] = [ dict( - type='ProfilerHook', - on_trace_ready=dict( - type='log_trace', - sort_by='self_cpu_time_total', - row_limit=10)) + type="ProfilerHook", on_trace_ready=dict(type="log_trace", sort_by="self_cpu_time_total", row_limit=10) + ) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() with self.assertRaises(ValueError): - self.epoch_based_cfg['custom_hooks'] = [ - dict(type='ProfilerHook', on_trace_ready=0) - ] + self.epoch_based_cfg["custom_hooks"] = [dict(type="ProfilerHook", on_trace_ready=0)] runner = self.build_runner(self.epoch_based_cfg) runner.train() if torch.cuda.is_available(): - self.epoch_based_cfg['custom_hooks'] = [ - dict(type='ProfilerHook', activity_with_cuda=True) - ] + self.epoch_based_cfg["custom_hooks"] = [dict(type="ProfilerHook", activity_with_cuda=True)] runner = self.build_runner(self.epoch_based_cfg) runner.train() -@unittest.skipIf( - not is_npu_available(), reason='Ascend PyTorch and npu devices not exist') +@unittest.skipIf(not is_npu_available(), reason="Ascend PyTorch and npu devices not exist") class TestNPUProfilerHook(RunnerTestCase): - def test_init(self): - - result_path = ops.join(self.temp_dir.name, 'test/cann_profiling') + result_path = ops.join(self.temp_dir.name, "test/cann_profiling") NPUProfilerHook(result_path=result_path) @@ -214,10 +188,10 @@ def test_init(self): NPUProfilerHook(begin=1, end=0, result_path=result_path) def test_before_run(self): - result_path = ops.join(self.temp_dir.name, 'test/cann_profiling') + result_path = ops.join(self.temp_dir.name, "test/cann_profiling") runner = MagicMock() runner.max_iters = 1 - runner.logger = MMLogger.get_instance('test_npu_profiler') + runner.logger = MMLogger.get_instance("test_npu_profiler") hook = NPUProfilerHook(result_path=result_path) hook.before_run(runner) @@ -227,10 +201,10 @@ def test_before_run(self): hook.before_run(runner) def test_after_train_iter(self): - result_path = ops.join(self.temp_dir.name, 'test/cann_profiling') + result_path = ops.join(self.temp_dir.name, "test/cann_profiling") runner = MagicMock() runner.max_iters = 10000 - runner.logger = MMLogger.get_instance('test_npu_profiler') + runner.logger = MMLogger.get_instance("test_npu_profiler") runner.iter = 0 @@ -241,30 +215,24 @@ def test_after_train_iter(self): hook.after_train_iter(runner, 1) def test_with_runner(self): - result_path = ops.join(self.temp_dir.name, 'test/cann_profiling') - self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='NPUProfilerHook', - begin=0, - result_path=result_path, - exit_after_profiling=False) + result_path = ops.join(self.temp_dir.name, "test/cann_profiling") + self.epoch_based_cfg["custom_hooks"] = [ + dict(type="NPUProfilerHook", begin=0, result_path=result_path, exit_after_profiling=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.epoch_based_cfg['custom_hooks'] = [ + self.epoch_based_cfg["custom_hooks"] = [ dict( - type='NPUProfilerHook', + type="NPUProfilerHook", result_path=result_path, ge_profiling_to_std_out=True, - exit_after_profiling=False) + exit_after_profiling=False, + ) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.assertTrue( - ops.exists(result_path), 'profiler result path is not generated!') + self.assertTrue(ops.exists(result_path), "profiler result path is not generated!") - self.assertTrue( - os.getenv('GE_PROFILING_TO_STD_OUT', '0') == '1', - 'GE PROFILING failed to start!') + self.assertTrue(os.getenv("GE_PROFILING_TO_STD_OUT", "0") == "1", "GE PROFILING failed to start!") diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index c7e7a3c339..4d3633846b 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -25,36 +25,34 @@ class DatasetWithMetainfo(DatasetWithoutMetainfo): class TestRuntimeInfoHook(RunnerTestCase): - def setUp(self) -> None: DATASETS.register_module(module=DatasetWithoutMetainfo, force=True) DATASETS.register_module(module=DatasetWithMetainfo, force=True) return super().setUp() def tearDown(self): - DATASETS.module_dict.pop('DatasetWithoutMetainfo') - DATASETS.module_dict.pop('DatasetWithMetainfo') + DATASETS.module_dict.pop("DatasetWithoutMetainfo") + DATASETS.module_dict.pop("DatasetWithMetainfo") return super().tearDown() def test_before_and_after_train(self): - cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo' + cfg.train_dataloader.dataset.type = "DatasetWithoutMetainfo" runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) hook.before_train(runner) - self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train') - self.assertEqual(runner.message_hub.get_info('epoch'), 0) - self.assertEqual(runner.message_hub.get_info('iter'), 0) - self.assertEqual(runner.message_hub.get_info('max_epochs'), 2) - self.assertEqual(runner.message_hub.get_info('max_iters'), 8) + self.assertEqual(runner.message_hub.get_info("loop_stage"), "train") + self.assertEqual(runner.message_hub.get_info("epoch"), 0) + self.assertEqual(runner.message_hub.get_info("iter"), 0) + self.assertEqual(runner.message_hub.get_info("max_epochs"), 2) + self.assertEqual(runner.message_hub.get_info("max_iters"), 8) hook.after_train(runner) - self.assertIsNone(runner.message_hub.get_info('loop_stage')) + self.assertIsNone(runner.message_hub.get_info("loop_stage")) - cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo' + cfg.train_dataloader.dataset.type = "DatasetWithMetainfo" runner = self.build_runner(cfg) hook.before_train(runner) - self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict()) + self.assertEqual(runner.message_hub.get_info("dataset_meta"), dict()) def test_before_train_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -62,7 +60,7 @@ def test_before_train_epoch(self): runner.train_loop._epoch = 9 hook = self._get_runtime_info_hook(runner) hook.before_train_epoch(runner) - self.assertEqual(runner.message_hub.get_info('epoch'), 9) + self.assertEqual(runner.message_hub.get_info("epoch"), 9) def test_before_train_iter(self): # single optimizer @@ -75,14 +73,12 @@ def test_before_train_iter(self): runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) hook = self._get_runtime_info_hook(runner) hook.before_train_iter(runner, batch_idx=2, data_batch=None) - self.assertEqual(runner.message_hub.get_info('iter'), 9) - self.assertEqual( - runner.message_hub.get_scalar('train/lr').current(), lr) + self.assertEqual(runner.message_hub.get_info("iter"), 9) + self.assertEqual(runner.message_hub.get_scalar("train/lr").current(), lr) - with self.assertRaisesRegex(AssertionError, - 'runner.optim_wrapper.get_lr()'): + with self.assertRaisesRegex(AssertionError, "runner.optim_wrapper.get_lr()"): runner.optim_wrapper = Mock() - runner.optim_wrapper.get_lr = Mock(return_value='error type') + runner.optim_wrapper.get_lr = Mock(return_value="error type") hook.before_train_iter(runner, batch_idx=2, data_batch=None) # multiple optimizers @@ -90,75 +86,69 @@ def test_before_train_iter(self): dict( layer1=nn.Linear(1, 1), layer2=nn.Linear(1, 1), - )) + ) + ) optim1 = SGD(model.layer1.parameters(), lr=0.01) optim2 = SGD(model.layer2.parameters(), lr=0.02) optim_wrapper1 = OptimWrapper(optim1) optim_wrapper2 = OptimWrapper(optim2) - optim_wrapper_dict = OptimWrapperDict( - key1=optim_wrapper1, key2=optim_wrapper2) + optim_wrapper_dict = OptimWrapperDict(key1=optim_wrapper1, key2=optim_wrapper2) runner.optim_wrapper = optim_wrapper_dict hook.before_train_iter(runner, batch_idx=2, data_batch=None) - self.assertEqual( - runner.message_hub.get_scalar('train/key1.lr').current(), 0.01) - self.assertEqual( - runner.message_hub.get_scalar('train/key2.lr').current(), 0.02) + self.assertEqual(runner.message_hub.get_scalar("train/key1.lr").current(), 0.01) + self.assertEqual(runner.message_hub.get_scalar("train/key2.lr").current(), 0.02) def test_after_train_iter(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) - hook.after_train_iter( - runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111}) - self.assertEqual( - runner.message_hub.get_scalar('train/loss_cls').current(), 1.111) + hook.after_train_iter(runner, batch_idx=2, data_batch=None, outputs={"loss_cls": 1.111}) + self.assertEqual(runner.message_hub.get_scalar("train/loss_cls").current(), 1.111) def test_before_and_after_val(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) hook.before_val(runner) - self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val') + self.assertEqual(runner.message_hub.get_info("loop_stage"), "val") self.assertIsNone(hook.last_loop_stage) hook.after_val(runner) - self.assertIsNone(runner.message_hub.get_info('loop_stage')) + self.assertIsNone(runner.message_hub.get_info("loop_stage")) # Simulate the workflow of calling the ValLoop within the TrainLoop runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) hook.before_train(runner) - self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train') + self.assertEqual(runner.message_hub.get_info("loop_stage"), "train") hook.before_val(runner) - self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val') - self.assertEqual(hook.last_loop_stage, 'train') + self.assertEqual(runner.message_hub.get_info("loop_stage"), "val") + self.assertEqual(hook.last_loop_stage, "train") hook.after_val(runner) - self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train') + self.assertEqual(runner.message_hub.get_info("loop_stage"), "train") self.assertIsNone(hook.last_loop_stage) def test_after_val_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) - hook.after_val_epoch(runner, metrics={'acc': 0.8}) - self.assertEqual( - runner.message_hub.get_scalar('val/acc').current(), 0.8) + hook.after_val_epoch(runner, metrics={"acc": 0.8}) + self.assertEqual(runner.message_hub.get_scalar("val/acc").current(), 0.8) def test_before_and_after_test(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) hook.before_test(runner) - self.assertEqual(runner.message_hub.get_info('loop_stage'), 'test') + self.assertEqual(runner.message_hub.get_info("loop_stage"), "test") hook.after_test(runner) - self.assertIsNone(runner.message_hub.get_info('loop_stage')) + self.assertIsNone(runner.message_hub.get_info("loop_stage")) def test_after_test_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) - hook.after_test_epoch(runner, metrics={'acc': 0.8}) - self.assertEqual( - runner.message_hub.get_scalar('test/acc').current(), 0.8) + hook.after_test_epoch(runner, metrics={"acc": 0.8}) + self.assertEqual(runner.message_hub.get_scalar("test/acc").current(), 0.8) def test_scalar_check(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -170,34 +160,29 @@ def test_scalar_check(self): hook.after_val_epoch( runner, metrics={ - 'acc_f32': val.astype(np.float32), - 'acc_i32': val.astype(np.int32), - 'acc_u8': val.astype(np.uint8), - 'acc_ndarray': np.array([5]), - }) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_f32').current(), 5) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_i32').current(), 5) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_u8').current(), 5) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_ndarray').current(), 5) + "acc_f32": val.astype(np.float32), + "acc_i32": val.astype(np.int32), + "acc_u8": val.astype(np.uint8), + "acc_ndarray": np.array([5]), + }, + ) + self.assertEqual(runner.message_hub.get_scalar("val/acc_f32").current(), 5) + self.assertEqual(runner.message_hub.get_scalar("val/acc_i32").current(), 5) + self.assertEqual(runner.message_hub.get_scalar("val/acc_u8").current(), 5) + self.assertEqual(runner.message_hub.get_scalar("val/acc_ndarray").current(), 5) val = torch.tensor([5.0]).mean() hook.after_val_epoch( runner, metrics={ - 'acc_f32': val.float(), - 'acc_i64': val.long(), - 'acc_tensor': torch.tensor([5]), - }) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_f32').current(), 5) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_i64').current(), 5) - self.assertEqual( - runner.message_hub.get_scalar('val/acc_tensor').current(), 5) + "acc_f32": val.float(), + "acc_i64": val.long(), + "acc_tensor": torch.tensor([5]), + }, + ) + self.assertEqual(runner.message_hub.get_scalar("val/acc_f32").current(), 5) + self.assertEqual(runner.message_hub.get_scalar("val/acc_i64").current(), 5) + self.assertEqual(runner.message_hub.get_scalar("val/acc_tensor").current(), 5) def _get_runtime_info_hook(self, runner): for hook in runner.hooks: diff --git a/tests/test_hooks/test_sampler_seed_hook.py b/tests/test_hooks/test_sampler_seed_hook.py index 879febcf2a..f1af0ec8cb 100644 --- a/tests/test_hooks/test_sampler_seed_hook.py +++ b/tests/test_hooks/test_sampler_seed_hook.py @@ -6,9 +6,7 @@ class TestDistSamplerSeedHook(RunnerTestCase): - def test_before_train_epoch(self): - hook = DistSamplerSeedHook() # Test dataset sampler runner = MagicMock() @@ -16,13 +14,12 @@ def test_before_train_epoch(self): hook.before_train_epoch(runner) runner.train_loop.dataloader.sampler.set_epoch.assert_called() # Test batch sampler - runner.train_loop.dataloader = MagicMock(spec_set=['batch_sampler']) + runner.train_loop.dataloader = MagicMock(spec_set=["batch_sampler"]) hook.before_train_epoch(runner) - runner.train_loop.dataloader.\ - batch_sampler.sampler.set_epoch.assert_called() + runner.train_loop.dataloader.batch_sampler.sampler.set_epoch.assert_called() def test_with_runner(self): cfg = self.epoch_based_cfg - cfg.custom_hooks = [dict(type='DistSamplerSeedHook')] + cfg.custom_hooks = [dict(type="DistSamplerSeedHook")] runner = self.build_runner(cfg) runner.train() diff --git a/tests/test_hooks/test_sync_buffers_hook.py b/tests/test_hooks/test_sync_buffers_hook.py index 6d4019dc58..ed0b1144a1 100644 --- a/tests/test_hooks/test_sync_buffers_hook.py +++ b/tests/test_hooks/test_sync_buffers_hook.py @@ -14,7 +14,6 @@ class ToyModuleWithNorm(ToyModel): - def __init__(self, data_preprocessor=None): super().__init__(data_preprocessor=data_preprocessor) bn = nn.BatchNorm1d(2) @@ -22,13 +21,11 @@ def __init__(self, data_preprocessor=None): def init_weights(self): for buffer in self.buffers(): - buffer.fill_( - torch.tensor(int(os.environ['RANK']), dtype=torch.float32)) + buffer.fill_(torch.tensor(int(os.environ["RANK"]), dtype=torch.float32)) return super().init_weights() class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase): - def setUp(self) -> None: super().setUp() self._spawn_processes() @@ -57,9 +54,9 @@ def test_sync_buffers_hook(self): def test_with_runner(self): self.setup_dist_env() cfg = self.epoch_based_cfg - cfg.model = dict(type='ToyModuleWithNorm') - cfg.launch = 'pytorch' - cfg.custom_hooks = [dict(type='SyncBuffersHook')] + cfg.model = dict(type="ToyModuleWithNorm") + cfg.launch = "pytorch" + cfg.custom_hooks = [dict(type="SyncBuffersHook")] runner = self.build_runner(cfg) runner.train() @@ -69,6 +66,5 @@ def test_with_runner(self): def setup_dist_env(self): super().setup_dist_env() - os.environ['RANK'] = str(self.rank) - torch_dist.init_process_group( - backend='gloo', rank=self.rank, world_size=self.world_size) + os.environ["RANK"] = str(self.rank) + torch_dist.init_process_group(backend="gloo", rank=self.rank, world_size=self.world_size) diff --git a/tests/test_hub/test_hub.py b/tests/test_hub/test_hub.py index ae21d3dab4..08a3c52565 100644 --- a/tests/test_hub/test_hub.py +++ b/tests/test_hub/test_hub.py @@ -7,46 +7,44 @@ from mmengine.hub import get_config, get_model from mmengine.utils import get_installed_path, is_installed -data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data/') + +data_path = osp.join(osp.dirname(osp.dirname(__file__)), "data/") # mmdet has a more typical config structure, while mmpose has a complex # config structure @pytest.mark.skipif( - not (is_installed('mmdet') and is_installed('mmpose')), - reason='mmdet and mmpose should be installed') + not (is_installed("mmdet") and is_installed("mmpose")), reason="mmdet and mmpose should be installed" +) def test_get_config(): # Test load base config. - base_cfg = get_config('mmdet::_base_/models/faster-rcnn_r50_fpn.py') - package_path = get_installed_path('mmdet') - test_base_cfg = Config.fromfile( - osp.join(package_path, '.mim', - 'configs/_base_/models/faster-rcnn_r50_fpn.py')) + base_cfg = get_config("mmdet::_base_/models/faster-rcnn_r50_fpn.py") + package_path = get_installed_path("mmdet") + test_base_cfg = Config.fromfile(osp.join(package_path, ".mim", "configs/_base_/models/faster-rcnn_r50_fpn.py")) assert test_base_cfg._cfg_dict == base_cfg._cfg_dict # Test load faster_rcnn config - cfg = get_config('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py') - test_cfg = Config.fromfile( - osp.join(package_path, '.mim', - 'configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py')) + cfg = get_config("mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py") + test_cfg = Config.fromfile(osp.join(package_path, ".mim", "configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py")) assert cfg._cfg_dict == test_cfg._cfg_dict # Test pretrained - cfg = get_config( - 'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) - assert cfg.model_path == 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa E301 + cfg = get_config("mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py", pretrained=True) + assert ( + cfg.model_path + == "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" + ) # noqa E301 # Test load mmpose get_config( - 'mmpose::face_2d_keypoint/topdown_heatmap/wflw/td-hm_hrnetv2-w18_8xb64-60e_wflw-256x256.py' # noqa E501 + "mmpose::face_2d_keypoint/topdown_heatmap/wflw/td-hm_hrnetv2-w18_8xb64-60e_wflw-256x256.py" # noqa E501 ) -@pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet and mmpose should be installed') +@pytest.mark.skipif(not is_installed("mmdet"), reason="mmdet and mmpose should be installed") def test_get_model(): # TODO compatible with downstream codebase. - DefaultScope.get_instance('test_get_model', scope_name='test_scope') - get_model('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py') - assert DefaultScope.get_current_instance().scope_name == 'test_scope' - DefaultScope._instance_dict.pop('test_get_model') + DefaultScope.get_instance("test_get_model", scope_name="test_scope") + get_model("mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py") + assert DefaultScope.get_current_instance().scope_name == "test_scope" + DefaultScope._instance_dict.pop("test_get_model") diff --git a/tests/test_infer/test_infer.py b/tests/test_infer/test_infer.py index 2d020b6300..6f751412ba 100644 --- a/tests/test_infer/test_infer.py +++ b/tests/test_infer/test_infer.py @@ -23,10 +23,10 @@ def is_imported(package): class ToyInferencer(BaseInferencer): - preprocess_kwargs = {'pre_arg'} - forward_kwargs = {'for_arg'} - visualize_kwargs = {'vis_arg'} - postprocess_kwargs = {'pos_arg'} + preprocess_kwargs = {"pre_arg"} + forward_kwargs = {"for_arg"} + visualize_kwargs = {"vis_arg"} + postprocess_kwargs = {"pos_arg"} def preprocess(self, inputs, batch_size=1, pre_arg=None, **kwargs): return super().preprocess(inputs, batch_size, **kwargs) @@ -37,16 +37,10 @@ def forward(self, inputs, for_arg=None, **kwargs): def visualize(self, inputs, preds, vis_arg=None, **kwargs): return inputs - def postprocess(self, - preds, - imgs, - return_datasamples, - pos_arg=None, - **kwargs): + def postprocess(self, preds, imgs, return_datasamples, pos_arg=None, **kwargs): return imgs, preds def _init_pipeline(self, cfg): - def pipeline(img): if isinstance(img, str): img = np.load(img, allow_pickle=True) @@ -60,30 +54,28 @@ def pipeline(img): return pipeline -class ToyVisualizer(Visualizer): - ... +class ToyVisualizer(Visualizer): ... class TestBaseInferencer(RunnerTestCase): - def setUp(self) -> None: super().setUp() runner = self.build_runner(copy.deepcopy(self.epoch_based_cfg)) runner.train() - self.cfg_path = osp.join(runner.work_dir, f'{runner.timestamp}.py') - self.ckpt_path = osp.join(runner.work_dir, 'epoch_1.pth') - VISUALIZERS.register_module(module=ToyVisualizer, name='ToyVisualizer') + self.cfg_path = osp.join(runner.work_dir, f"{runner.timestamp}.py") + self.ckpt_path = osp.join(runner.work_dir, "epoch_1.pth") + VISUALIZERS.register_module(module=ToyVisualizer, name="ToyVisualizer") def test_custom_inferencer(self): # Inferencer should not define ***_kwargs with duplicate keys. - with self.assertRaisesRegex(AssertionError, 'Class define error'): + with self.assertRaisesRegex(AssertionError, "Class define error"): class CustomInferencer(BaseInferencer): - preprocess_kwargs = set('a') - forward_kwargs = set('a') + preprocess_kwargs = set("a") + forward_kwargs = set("a") def tearDown(self): - VISUALIZERS._module_dict.pop('ToyVisualizer') + VISUALIZERS._module_dict.pop("ToyVisualizer") return super().tearDown() def test_init(self): @@ -97,22 +89,22 @@ def test_init(self): # Pass model as string point to path of config ToyInferencer(self.cfg_path, self.ckpt_path) - cfg.model.pretrained = 'fake_path' + cfg.model.pretrained = "fake_path" inferencer = ToyInferencer(cfg, self.ckpt_path) - self.assertNotIn('pretrained', inferencer.cfg.model) + self.assertNotIn("pretrained", inferencer.cfg.model) # Pass invalid model - with self.assertRaisesRegex(TypeError, 'model must'): + with self.assertRaisesRegex(TypeError, "model must"): ToyInferencer([self.epoch_based_cfg], self.ckpt_path) # Pass model as model name defined in metafile - if is_imported('mmdet'): + if is_imported("mmdet"): from mmdet.utils import register_all_modules register_all_modules() ToyInferencer( - 'faster-rcnn_s50_fpn_syncbn-backbone+head_ms-range-1x_coco', - 'https://download.openmmlab.com/mmdetection/v2.0/resnest/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco_20200926_125502-20289c16.pth', # noqa: E501 + "faster-rcnn_s50_fpn_syncbn-backbone+head_ms-range-1x_coco", + "https://download.openmmlab.com/mmdetection/v2.0/resnest/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco_20200926_125502-20289c16.pth", # noqa: E501 ) checkpoint = self.ckpt_path @@ -124,7 +116,7 @@ def test_call(self): img_paths = [] for i in range(num_imgs): img = np.random.random((1, 2)) - img_path = osp.join(self.temp_dir.name, f'{i}.npy') + img_path = osp.join(self.temp_dir.name, f"{i}.npy") img.dump(img_path) imgs.append(img) img_paths.append(img_path) @@ -133,16 +125,15 @@ def test_call(self): inferencer(imgs) inferencer(img_paths) - @pytest.mark.skipif( - not is_imported('mmdet'), reason='mmdet is not installed') + @pytest.mark.skipif(not is_imported("mmdet"), reason="mmdet is not installed") def test_load_model_from_meta(self): from mmdet.utils import register_all_modules register_all_modules() inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) - inferencer._load_model_from_metafile('retinanet_r18_fpn_1x_coco') - with self.assertRaisesRegex(ValueError, 'Cannot find model'): - inferencer._load_model_from_metafile('fake_model') + inferencer._load_model_from_metafile("retinanet_r18_fpn_1x_coco") + with self.assertRaisesRegex(ValueError, "Cannot find model"): + inferencer._load_model_from_metafile("fake_model") # TODO: Test alias def test_init_model(self): @@ -154,39 +145,33 @@ def test_get_chunk_data(self): inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) data = list(range(1, 11)) chunk_data = inferencer._get_chunk_data(data, 3) - self.assertEqual( - list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) + self.assertEqual(list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) def test_init_visualizer(self): cfg = copy.deepcopy(self.epoch_based_cfg) inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) visualizer = inferencer._init_visualizer(cfg) self.assertIsNone(visualizer, None) - cfg.visualizer = dict(type='ToyVisualizer') + cfg.visualizer = dict(type="ToyVisualizer") visualizer = inferencer._init_visualizer(cfg) self.assertIsInstance(visualizer, ToyVisualizer) # Visualizer could be built with the same name repeatedly. - cfg.visualizer = dict(type='ToyVisualizer', name='toy') + cfg.visualizer = dict(type="ToyVisualizer", name="toy") visualizer = inferencer._init_visualizer(cfg) visualizer = inferencer._init_visualizer(cfg) def test_dispatch_kwargs(self): inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) - kwargs = dict( - pre_arg=dict(a=1), - for_arg=dict(c=2), - vis_arg=dict(b=3), - pos_arg=dict(d=4)) - pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs( - **kwargs) + kwargs = dict(pre_arg=dict(a=1), for_arg=dict(c=2), vis_arg=dict(b=3), pos_arg=dict(d=4)) + pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs(**kwargs) self.assertEqual(pre_arg, dict(pre_arg=dict(a=1))) self.assertEqual(for_arg, dict(for_arg=dict(c=2))) self.assertEqual(vis_arg, dict(vis_arg=dict(b=3))) self.assertEqual(pos_arg, dict(pos_arg=dict(d=4))) # Test unknown arg. kwargs = dict(return_datasample=dict()) - with self.assertRaisesRegex(ValueError, 'unknown'): + with self.assertRaisesRegex(ValueError, "unknown"): inferencer._dispatch_kwargs(**kwargs) def test_preprocess(self): @@ -194,36 +179,28 @@ def test_preprocess(self): data = list(range(1, 11)) pre_data = inferencer.preprocess(data, batch_size=3) target_data = [ - [torch.tensor(1), - torch.tensor(2), - torch.tensor(3)], - [torch.tensor(4), - torch.tensor(5), - torch.tensor(6)], - [torch.tensor(7), - torch.tensor(8), - torch.tensor(9)], + [torch.tensor(1), torch.tensor(2), torch.tensor(3)], + [torch.tensor(4), torch.tensor(5), torch.tensor(6)], + [torch.tensor(7), torch.tensor(8), torch.tensor(9)], [torch.tensor(10)], ] self.assertEqual(list(pre_data), target_data) - os.mkdir(osp.join(self.temp_dir.name, 'imgs')) + os.mkdir(osp.join(self.temp_dir.name, "imgs")) for i in range(1, 11): img = np.array(1) - img.dump(osp.join(self.temp_dir.name, 'imgs', f'{i}.npy')) + img.dump(osp.join(self.temp_dir.name, "imgs", f"{i}.npy")) # Passing a directory of images. - inputs = inferencer._inputs_to_list( - osp.join(self.temp_dir.name, 'imgs')) + inputs = inferencer._inputs_to_list(osp.join(self.temp_dir.name, "imgs")) dataloader = inferencer.preprocess(inputs, batch_size=3) for data in dataloader: self.assertTrue(is_list_of(data, torch.Tensor)) - @pytest.mark.skipif( - not is_imported('mmdet'), reason='mmdet is not installed') + @pytest.mark.skipif(not is_imported("mmdet"), reason="mmdet is not installed") def test_list_models(self): - model_list = BaseInferencer.list_models('mmdet') + model_list = BaseInferencer.list_models("mmdet") self.assertTrue(len(model_list) > 0) DefaultScope._instance_dict.clear() - with self.assertRaisesRegex(AssertionError, 'scope should be'): + with self.assertRaisesRegex(AssertionError, "scope should be"): BaseInferencer.list_models() - with self.assertRaisesRegex(AssertionError, 'unknown not in'): - BaseInferencer.list_models('unknown') + with self.assertRaisesRegex(AssertionError, "unknown not in"): + BaseInferencer.list_models("unknown") diff --git a/tests/test_logging/test_history_buffer.py b/tests/test_logging/test_history_buffer.py index 99c03165f8..e6fb25fbf3 100644 --- a/tests/test_logging/test_history_buffer.py +++ b/tests/test_logging/test_history_buffer.py @@ -4,6 +4,7 @@ from mmengine.logging import HistoryBuffer + array_method = [np.array, lambda x: x] try: import torch @@ -19,7 +20,6 @@ def custom_statistics(self): class TestLoggerBuffer: - def test_init(self): log_buffer = HistoryBuffer() assert log_buffer.max_length == 1000000 @@ -42,7 +42,7 @@ def test_init(self): with pytest.raises(AssertionError): HistoryBuffer([1, 2], [1]) - @pytest.mark.parametrize('array_method', array_method) + @pytest.mark.parametrize("array_method", array_method) def test_update(self, array_method): # test `update` method log_buffer = HistoryBuffer() @@ -52,9 +52,9 @@ def test_update(self, array_method): log_buffer.update(float(log_history[i]), float(count_history[i])) recorded_history, recorded_count = log_buffer.data - for a, b in zip(log_history, recorded_history): + for a, b in zip(log_history, recorded_history, strict=False): assert float(a) == float(b) - for a, b in zip(count_history, recorded_count): + for a, b in zip(count_history, recorded_count, strict=False): assert float(a) == float(b) # test the length of `array` exceed `max_length` @@ -72,27 +72,20 @@ def test_update(self, array_method): with pytest.raises(TypeError): log_buffer.update(array_method([1, 2])) - @pytest.mark.parametrize('statistics_method, log_buffer_type', - [(np.min, 'min'), (np.max, 'max')]) + @pytest.mark.parametrize("statistics_method, log_buffer_type", [(np.min, "min"), (np.max, "max")]) def test_max_min(self, statistics_method, log_buffer_type): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) log_buffer = HistoryBuffer(log_history, count_history) - assert statistics_method(log_history[-10:]) == \ - getattr(log_buffer, log_buffer_type)(10) - assert statistics_method(log_history) == \ - getattr(log_buffer, log_buffer_type)() + assert statistics_method(log_history[-10:]) == getattr(log_buffer, log_buffer_type)(10) + assert statistics_method(log_history) == getattr(log_buffer, log_buffer_type)() def test_mean(self): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) log_buffer = HistoryBuffer(log_history, count_history) - assert np.sum(log_history[-10:]) / \ - np.sum(count_history[-10:]) == \ - log_buffer.mean(10) - assert np.sum(log_history) / \ - np.sum(count_history) == \ - log_buffer.mean() + assert np.sum(log_history[-10:]) / np.sum(count_history[-10:]) == log_buffer.mean(10) + assert np.sum(log_history) / np.sum(count_history) == log_buffer.mean() def test_current(self): log_history = np.random.randint(1, 5, 20) @@ -108,14 +101,14 @@ def test_statistics(self): log_history = np.array([1, 2, 3, 4, 5]) count_history = np.array([1, 1, 1, 1, 1]) log_buffer = HistoryBuffer(log_history, count_history) - assert log_buffer.statistics('mean') == 3 - assert log_buffer.statistics('min') == 1 - assert log_buffer.statistics('max') == 5 - assert log_buffer.statistics('current') == 5 + assert log_buffer.statistics("mean") == 3 + assert log_buffer.statistics("min") == 1 + assert log_buffer.statistics("max") == 5 + assert log_buffer.statistics("current") == 5 # Access unknown method will raise an error. with pytest.raises(KeyError): - log_buffer.statistics('unknown') + log_buffer.statistics("unknown") def test_register_statistics(self): log_buffer = HistoryBuffer() - assert log_buffer.statistics('custom_statistics') == -1 + assert log_buffer.statistics("custom_statistics") == -1 diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index 2ac2b3548e..31dbaed45f 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -15,15 +15,15 @@ class TestLogger: - stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}' - file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}' + stream_handler_regex_time = r"\d{2}/\d{2} \d{2}:\d{2}:\d{2}" + file_handler_regex_time = r"\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}" - @patch('mmengine.logging.logger._get_rank', lambda: 0) + @patch("mmengine.logging.logger._get_rank", lambda: 0) def test_init_rank0(self, tmp_path): - logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO') - assert logger.name == 'mmengine' - assert logger.instance_name == 'rank0.pkg1' - assert logger.instance_name == 'rank0.pkg1' + logger = MMLogger.get_instance("rank0.pkg1", log_level="INFO") + assert logger.name == "mmengine" + assert logger.instance_name == "rank0.pkg1" + assert logger.instance_name == "rank0.pkg1" # Logger get from `MMLogger.get_instance` does not inherit from # `logging.root` assert logger.parent is None @@ -33,40 +33,33 @@ def test_init_rank0(self, tmp_path): assert logger.handlers[0].level == logging.INFO # If `rank=0`, the `log_level` of stream_handler and file_handler # depends on the given arguments. - tmp_file = tmp_path / 'tmp_file.log' - logger = MMLogger.get_instance( - 'rank0.pkg2', log_level='INFO', log_file=str(tmp_file)) + tmp_file = tmp_path / "tmp_file.log" + logger = MMLogger.get_instance("rank0.pkg2", log_level="INFO", log_file=str(tmp_file)) assert isinstance(logger, logging.Logger) assert len(logger.handlers) == 2 assert isinstance(logger.handlers[0], logging.StreamHandler) assert isinstance(logger.handlers[1], logging.FileHandler) - logger_pkg3 = MMLogger.get_instance('rank0.pkg2') + logger_pkg3 = MMLogger.get_instance("rank0.pkg2") assert id(logger_pkg3) == id(logger) - logger = MMLogger.get_instance( - 'rank0.pkg3', logger_name='logger_test', log_level='INFO') - assert logger.name == 'logger_test' - assert logger.instance_name == 'rank0.pkg3' + logger = MMLogger.get_instance("rank0.pkg3", logger_name="logger_test", log_level="INFO") + assert logger.name == "logger_test" + assert logger.instance_name == "rank0.pkg3" # `FileHandler` should be closed in Windows, otherwise we cannot # delete the temporary directory logging.shutdown() MMLogger._instance_dict.clear() - @patch('mmengine.logging.logger._get_rank', lambda: 1) - @patch('mmengine.logging.logger._get_device_id', lambda: 1) - @patch('mmengine.logging.logger._get_world_size', lambda: 2) - @patch('mmengine.logging.logger._get_host_info', lambda: 'test') + @patch("mmengine.logging.logger._get_rank", lambda: 1) + @patch("mmengine.logging.logger._get_device_id", lambda: 1) + @patch("mmengine.logging.logger._get_world_size", lambda: 2) + @patch("mmengine.logging.logger._get_host_info", lambda: "test") def test_init_rank1(self, tmp_path): # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. - tmp_file = tmp_path / 'tmp_file.log' - log_path = tmp_path / 'tmp_file_test_device1_rank1.log' - logger = MMLogger.get_instance( - 'rank1.pkg2', log_level='INFO', log_file=str(tmp_file)) + tmp_file = tmp_path / "tmp_file.log" + log_path = tmp_path / "tmp_file_test_device1_rank1.log" + logger = MMLogger.get_instance("rank1.pkg2", log_level="INFO", log_file=str(tmp_file)) assert len(logger.handlers) == 1 - logger = MMLogger.get_instance( - 'rank1.pkg3', - log_level='INFO', - log_file=str(tmp_file), - distributed=True) + logger = MMLogger.get_instance("rank1.pkg3", log_level="INFO", log_file=str(tmp_file), distributed=True) assert logger.handlers[0].level == logging.ERROR assert logger.handlers[1].level == logging.INFO assert len(logger.handlers) == 2 @@ -76,34 +69,27 @@ def test_init_rank1(self, tmp_path): logging.shutdown() MMLogger._instance_dict.clear() - @pytest.mark.parametrize('log_level', - [logging.WARNING, logging.INFO, logging.DEBUG]) + @pytest.mark.parametrize("log_level", [logging.WARNING, logging.INFO, logging.DEBUG]) def test_handler(self, capsys, tmp_path, log_level): # test stream handler can output correct format logs - instance_name = f'test_stream_{str(log_level)}' + instance_name = f"test_stream_{str(log_level)}" logger = MMLogger.get_instance(instance_name, log_level=log_level) - logger.log(level=log_level, msg='welcome') + logger.log(level=log_level, msg="welcome") out, _ = capsys.readouterr() # Skip match colored INFO loglevl_name = logging._levelToName[log_level] - match = re.fullmatch( - self.stream_handler_regex_time + f' - mmengine - ' - f'(.*){loglevl_name}(.*) - welcome\n', out) + match = re.fullmatch(self.stream_handler_regex_time + f" - mmengine - (.*){loglevl_name}(.*) - welcome\n", out) assert match is not None # test file_handler output plain text without color. - tmp_file = tmp_path / 'tmp_file.log' - instance_name = f'test_file_{log_level}' - logger = MMLogger.get_instance( - instance_name, log_level=log_level, log_file=tmp_file) - logger.log(level=log_level, msg='welcome') + tmp_file = tmp_path / "tmp_file.log" + instance_name = f"test_file_{log_level}" + logger = MMLogger.get_instance(instance_name, log_level=log_level, log_file=tmp_file) + logger.log(level=log_level, msg="welcome") with open(tmp_file) as f: log_text = f.read() - match = re.fullmatch( - self.file_handler_regex_time + - f' - mmengine - {loglevl_name} - ' - f'welcome\n', log_text) + match = re.fullmatch(self.file_handler_regex_time + f" - mmengine - {loglevl_name} - welcome\n", log_text) assert match is not None # `FileHandler` should be closed in Windows, otherwise we cannot # delete the temporary directory @@ -113,18 +99,19 @@ def test_handler(self, capsys, tmp_path, log_level): def test_error_format(self, capsys): # test error level log can output file path, function name and # line number - logger = MMLogger.get_instance('test_error', log_level='INFO') - logger.error('welcome') + logger = MMLogger.get_instance("test_error", log_level="INFO") + logger.error("welcome") lineno = sys._getframe().f_lineno - 1 # replace \ for windows: # origin: c:\\a\\b\\c.py # replaced: c:\\\\a\\\\b\\\\c.py for re.match. - file_path = __file__.replace('\\', '\\\\') + file_path = __file__.replace("\\", "\\\\") function_name = sys._getframe().f_code.co_name - pattern = self.stream_handler_regex_time + \ - r' - mmengine - (.*)ERROR(.*) - ' \ - f'{file_path} - {function_name} - ' \ - f'{lineno} - welcome\n' + pattern = ( + self.stream_handler_regex_time + r" - mmengine - (.*)ERROR(.*) - " + f"{file_path} - {function_name} - " + f"{lineno} - welcome\n" + ) out, _ = capsys.readouterr() match = re.fullmatch(pattern, out) assert match is not None @@ -132,161 +119,139 @@ def test_error_format(self, capsys): def test_print_log(self, capsys, tmp_path): # caplog cannot record MMLogger's logs. # Test simple print. - print_log('welcome', logger=None) + print_log("welcome", logger=None) out, _ = capsys.readouterr() - assert out == 'welcome\n' + assert out == "welcome\n" # Test silent logger and skip print. - print_log('welcome', logger='silent') + print_log("welcome", logger="silent") out, _ = capsys.readouterr() - assert out == '' - logger = MMLogger.get_instance('test_print_log') + assert out == "" + logger = MMLogger.get_instance("test_print_log") # Test using specified logger - print_log('welcome', logger=logger) + print_log("welcome", logger=logger) out, _ = capsys.readouterr() - match = re.fullmatch( - self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - ' - 'welcome\n', out) + match = re.fullmatch(self.stream_handler_regex_time + " - mmengine - (.*)INFO(.*) - welcome\n", out) assert match is not None # Test access logger by name. - print_log('welcome', logger='test_print_log') + print_log("welcome", logger="test_print_log") out, _ = capsys.readouterr() - match = re.fullmatch( - self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - ' - 'welcome\n', out) + match = re.fullmatch(self.stream_handler_regex_time + " - mmengine - (.*)INFO(.*) - welcome\n", out) assert match is not None # Test access the latest created logger. - print_log('welcome', logger='current') + print_log("welcome", logger="current") out, _ = capsys.readouterr() - match = re.fullmatch( - self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - ' - 'welcome\n', out) + match = re.fullmatch(self.stream_handler_regex_time + " - mmengine - (.*)INFO(.*) - welcome\n", out) assert match is not None # Test invalid logger type. with pytest.raises(TypeError): - print_log('welcome', logger=dict) + print_log("welcome", logger=dict) with pytest.raises(ValueError): - print_log('welcome', logger='unknown') + print_log("welcome", logger="unknown") def test_get_instance(self): # Test get root mmengine logger. MMLogger._instance_dict = OrderedDict() root_logger = MMLogger.get_current_instance() - mmdet_logger = MMLogger.get_instance('mmdet') + mmdet_logger = MMLogger.get_instance("mmdet") assert root_logger.name == mmdet_logger.name assert id(root_logger) != id(mmdet_logger) - assert id(MMLogger.get_instance('mmengine')) == id(root_logger) + assert id(MMLogger.get_instance("mmengine")) == id(root_logger) # Test original `get_current_instance` function. - MMLogger.get_instance('mmdet') - assert MMLogger.get_current_instance().instance_name == 'mmdet' + MMLogger.get_instance("mmdet") + assert MMLogger.get_current_instance().instance_name == "mmdet" def test_set_level(self, capsys): - logger = MMLogger.get_instance('test_set_level') - logger.info('hello') + logger = MMLogger.get_instance("test_set_level") + logger.info("hello") out, _ = capsys.readouterr() - assert 'INFO' in out - logger.setLevel('WARNING') - logger.info('hello') + assert "INFO" in out + logger.setLevel("WARNING") + logger.info("hello") out, _ = capsys.readouterr() assert not out - logger.warning('hello') + logger.warning("hello") out, _ = capsys.readouterr() - assert 'WARNING' in out + assert "WARNING" in out def test_filter(self, capsys): - logger = MMLogger.get_instance('test_filter') - logger.warning('hello') + logger = MMLogger.get_instance("test_filter") + logger.warning("hello") out, _ = capsys.readouterr() - assert 'WARNING' in out + assert "WARNING" in out # Filter repeated warning. - logger.warning('hello') + logger.warning("hello") out, _ = capsys.readouterr() assert not out # Pass new warning - logger.warning('hello1') + logger.warning("hello1") out, _ = capsys.readouterr() - assert 'WARNING' in out + assert "WARNING" in out def test_file_handlers(self, tmp_path): - tmp_file = tmp_path / 'tmp_file.log' + tmp_file = tmp_path / "tmp_file.log" fh = None - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name="test_file_handlers", log_file=tmp_file, file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.FileHandler) - fh = dict(type='BaseRotatingHandler', mode='a') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) - assert isinstance(logger.handlers[-1], - logging.handlers.BaseRotatingHandler) - fh = dict(type='RotatingFileHandler', maxBytes=1024) - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) - assert isinstance(logger.handlers[-1], - logging.handlers.RotatingFileHandler) - fh = dict(type='TimedRotatingFileHandler', when='MIDNIGHT') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) - assert isinstance(logger.handlers[-1], - logging.handlers.TimedRotatingFileHandler) - fh = dict(type='WatchedFileHandler') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) - assert isinstance(logger.handlers[-1], - logging.handlers.WatchedFileHandler) + fh = dict(type="BaseRotatingHandler", mode="a") + logger = MMLogger(name="test_file_handlers", log_file=tmp_file, file_handler_cfg=fh) + assert isinstance(logger.handlers[-1], logging.handlers.BaseRotatingHandler) + fh = dict(type="RotatingFileHandler", maxBytes=1024) + logger = MMLogger(name="test_file_handlers", log_file=tmp_file, file_handler_cfg=fh) + assert isinstance(logger.handlers[-1], logging.handlers.RotatingFileHandler) + fh = dict(type="TimedRotatingFileHandler", when="MIDNIGHT") + logger = MMLogger(name="test_file_handlers", log_file=tmp_file, file_handler_cfg=fh) + assert isinstance(logger.handlers[-1], logging.handlers.TimedRotatingFileHandler) + fh = dict(type="WatchedFileHandler") + logger = MMLogger(name="test_file_handlers", log_file=tmp_file, file_handler_cfg=fh) + assert isinstance(logger.handlers[-1], logging.handlers.WatchedFileHandler) # `FileHandler` should be closed in Windows, otherwise we cannot # delete the temporary directory logging.shutdown() MMLogger._instance_dict.clear() -@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') -@patch('torch.cuda.device_count', lambda: 4) +@pytest.mark.skipif(not is_installed("torch"), reason="tests requires torch") +@patch("torch.cuda.device_count", lambda: 4) def test_get_device_id(): - @contextmanager def patch_env(local_rank, cuda_visible_devices): - ori_local_rank = os.getenv('LOCAL_RANK', None) - ori_cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) + ori_local_rank = os.getenv("LOCAL_RANK", None) + ori_cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) if local_rank is not None: - os.environ['LOCAL_RANK'] = local_rank + os.environ["LOCAL_RANK"] = local_rank if cuda_visible_devices is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices yield if ori_local_rank is not None: - os.environ['LOCAL_RANK'] = ori_local_rank - elif 'LOCAL_RANK' in os.environ: - os.environ.pop('LOCAL_RANK') + os.environ["LOCAL_RANK"] = ori_local_rank + elif "LOCAL_RANK" in os.environ: + os.environ.pop("LOCAL_RANK") if ori_cuda_visible_devices is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = ori_cuda_visible_devices - elif 'CUDA_VISIBLE_DEVICES' in os.environ: - os.environ.pop('CUDA_VISIBLE_DEVICES') + os.environ["CUDA_VISIBLE_DEVICES"] = ori_cuda_visible_devices + elif "CUDA_VISIBLE_DEVICES" in os.environ: + os.environ.pop("CUDA_VISIBLE_DEVICES") # cuda is not available and local_rank is not set - with patch('torch.cuda.is_available', lambda: False), \ - patch_env(None, '0,1,2,3'): + with patch("torch.cuda.is_available", lambda: False), patch_env(None, "0,1,2,3"): assert _get_device_id() == 0 # cuda is not available and local_rank is set - with patch('torch.cuda.is_available', lambda: False), \ - patch_env('1', '0,1,2,3'): + with patch("torch.cuda.is_available", lambda: False), patch_env("1", "0,1,2,3"): assert _get_device_id() == 1 # CUDA_VISIBLE_DEVICES will not influence non-cuda device - with patch('torch.cuda.is_available', lambda: False), \ - patch_env('1', '0,100,2,3'): + with patch("torch.cuda.is_available", lambda: False), patch_env("1", "0,100,2,3"): assert _get_device_id() == 1 # cuda is available and local_rank is not set - with patch('torch.cuda.is_available', lambda: True), \ - patch_env(None, '0,1,2,3'): + with patch("torch.cuda.is_available", lambda: True), patch_env(None, "0,1,2,3"): assert _get_device_id() == 0 # cuda is available and local_rank is set - with patch('torch.cuda.is_available', lambda: True), \ - patch_env('2', '0,1,2,3'): + with patch("torch.cuda.is_available", lambda: True), patch_env("2", "0,1,2,3"): assert _get_device_id() == 2 # CUDA_VISIBLE_DEVICES worked - with patch('torch.cuda.is_available', lambda: True), \ - patch_env('2', '0,1,3,5'): + with patch("torch.cuda.is_available", lambda: True), patch_env("2", "0,1,3,5"): assert _get_device_id() == 3 diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index 3dc5cef748..8e25eef81e 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -10,105 +10,95 @@ class NoDeepCopy: - - def __deepcopy__(self, memodict={}): + def __deepcopy__(self, memodict=None): raise NotImplementedError class TestMessageHub: - def test_init(self): - message_hub = MessageHub('name') - assert message_hub.instance_name == 'name' + message_hub = MessageHub("name") + assert message_hub.instance_name == "name" assert len(message_hub.log_scalars) == 0 assert len(message_hub.log_scalars) == 0 # The type of log_scalars's value must be `HistoryBuffer`. with pytest.raises(AssertionError): - MessageHub('hello', log_scalars=OrderedDict(a=1)) + MessageHub("hello", log_scalars=OrderedDict(a=1)) # `Resumed_keys` with pytest.raises(AssertionError): - MessageHub( - 'hello', - runtime_info=OrderedDict(iter=1), - resumed_keys=OrderedDict(iters=False)) + MessageHub("hello", runtime_info=OrderedDict(iter=1), resumed_keys=OrderedDict(iters=False)) def test_update_scalar(self): - message_hub = MessageHub.get_instance('mmengine') + message_hub = MessageHub.get_instance("mmengine") # Update scalar with int. - message_hub.update_scalar('name', 1) - log_buffer = message_hub.log_scalars['name'] + message_hub.update_scalar("name", 1) + log_buffer = message_hub.log_scalars["name"] assert (log_buffer._log_history == np.array([1])).all() # Update scalar with np.ndarray. - message_hub.update_scalar('name', np.array(1)) + message_hub.update_scalar("name", np.array(1)) assert (log_buffer._log_history == np.array([1, 1])).all() # Update scalar with np.int - message_hub.update_scalar('name', np.int32(1)) + message_hub.update_scalar("name", np.int32(1)) assert (log_buffer._log_history == np.array([1, 1, 1])).all() def test_update_info(self): - message_hub = MessageHub.get_instance('mmengine') + message_hub = MessageHub.get_instance("mmengine") # test runtime value can be overwritten. - message_hub.update_info('key', 2) - assert message_hub.runtime_info['key'] == 2 - message_hub.update_info('key', 1) - assert message_hub.runtime_info['key'] == 1 + message_hub.update_info("key", 2) + assert message_hub.runtime_info["key"] == 2 + message_hub.update_info("key", 1) + assert message_hub.runtime_info["key"] == 1 def test_pop_info(self): - message_hub = MessageHub.get_instance('mmengine') - message_hub.update_info('pop_key', 'pop_info') - assert message_hub.runtime_info['pop_key'] == 'pop_info' - assert message_hub.pop_info('pop_key') == 'pop_info' + message_hub = MessageHub.get_instance("mmengine") + message_hub.update_info("pop_key", "pop_info") + assert message_hub.runtime_info["pop_key"] == "pop_info" + assert message_hub.pop_info("pop_key") == "pop_info" - assert message_hub.pop_info('not_existed_key', 'info') == 'info' + assert message_hub.pop_info("not_existed_key", "info") == "info" def test_update_infos(self): - message_hub = MessageHub.get_instance('mmengine') + message_hub = MessageHub.get_instance("mmengine") # test runtime value can be overwritten. - message_hub.update_info_dict({'a': 2, 'b': 3}) - assert message_hub.runtime_info['a'] == 2 - assert message_hub.runtime_info['b'] == 3 - assert message_hub._resumed_keys['a'] - assert message_hub._resumed_keys['b'] + message_hub.update_info_dict({"a": 2, "b": 3}) + assert message_hub.runtime_info["a"] == 2 + assert message_hub.runtime_info["b"] == 3 + assert message_hub._resumed_keys["a"] + assert message_hub._resumed_keys["b"] def test_get_scalar(self): - message_hub = MessageHub.get_instance('mmengine') + message_hub = MessageHub.get_instance("mmengine") # Get undefined key will raise error with pytest.raises(KeyError): - message_hub.get_scalar('unknown') + message_hub.get_scalar("unknown") # test get log_buffer as wished log_history = np.array([1, 2, 3, 4, 5]) count = np.array([1, 1, 1, 1, 1]) for i in range(len(log_history)): - message_hub.update_scalar('test_value', float(log_history[i]), - int(count[i])) - recorded_history, recorded_count = \ - message_hub.get_scalar('test_value').data + message_hub.update_scalar("test_value", float(log_history[i]), int(count[i])) + recorded_history, recorded_count = message_hub.get_scalar("test_value").data assert (log_history == recorded_history).all() assert (recorded_count == count).all() def test_get_runtime(self): - message_hub = MessageHub.get_instance('mmengine') - assert message_hub.get_info('unknown') is None + message_hub = MessageHub.get_instance("mmengine") + assert message_hub.get_info("unknown") is None recorded_dict = dict(a=1, b=2) - message_hub.update_info('test_value', recorded_dict) - assert message_hub.get_info('test_value') == recorded_dict + message_hub.update_info("test_value", recorded_dict) + assert message_hub.get_info("test_value") == recorded_dict - @pytest.mark.skipif(not is_installed('torch'), reason='requires torch') + @pytest.mark.skipif(not is_installed("torch"), reason="requires torch") def test_get_scalars(self): import torch - message_hub = MessageHub.get_instance('mmengine') - log_dict = dict( - loss=1, - loss_cls=torch.tensor(2), - loss_bbox=np.array(3), - loss_iou=dict(value=1, count=2)) + + message_hub = MessageHub.get_instance("mmengine") + log_dict = dict(loss=1, loss_cls=torch.tensor(2), loss_bbox=np.array(3), loss_iou=dict(value=1, count=2)) message_hub.update_scalars(log_dict) - loss = message_hub.get_scalar('loss') - loss_cls = message_hub.get_scalar('loss_cls') - loss_bbox = message_hub.get_scalar('loss_bbox') - loss_iou = message_hub.get_scalar('loss_iou') + loss = message_hub.get_scalar("loss") + loss_cls = message_hub.get_scalar("loss_cls") + loss_bbox = message_hub.get_scalar("loss_bbox") + loss_iou = message_hub.get_scalar("loss_iou") assert loss.current() == 1 assert loss_cls.current() == 2 assert loss_bbox.current() == 3 @@ -123,87 +113,79 @@ def test_get_scalars(self): message_hub.update_scalars(loss_dict) def test_state_dict(self): - message_hub = MessageHub.get_instance('test_state_dict') + message_hub = MessageHub.get_instance("test_state_dict") # update log_scalars. - message_hub.update_scalar('loss', 0.1) - message_hub.update_scalar('lr', 0.1, resumed=False) + message_hub.update_scalar("loss", 0.1) + message_hub.update_scalar("lr", 0.1, resumed=False) # update runtime information - message_hub.update_info('iter', 1, resumed=True) - message_hub.update_info('tensor', [1, 2, 3], resumed=False) + message_hub.update_info("iter", 1, resumed=True) + message_hub.update_info("tensor", [1, 2, 3], resumed=False) no_copy = NoDeepCopy() - message_hub.update_info('no_copy', no_copy, resumed=True) + message_hub.update_info("no_copy", no_copy, resumed=True) state_dict = message_hub.state_dict() - assert state_dict['log_scalars']['loss'].data == (np.array([0.1]), - np.array([1])) - assert 'lr' not in state_dict['log_scalars'] - assert state_dict['runtime_info']['iter'] == 1 - assert 'tensor' not in state_dict['runtime_info'] - assert state_dict['runtime_info']['no_copy'] is no_copy + assert state_dict["log_scalars"]["loss"].data == (np.array([0.1]), np.array([1])) + assert "lr" not in state_dict["log_scalars"] + assert state_dict["runtime_info"]["iter"] == 1 + assert "tensor" not in state_dict["runtime_info"] + assert state_dict["runtime_info"]["no_copy"] is no_copy def test_load_state_dict(self, capsys): - message_hub1 = MessageHub.get_instance('test_load_state_dict1') + message_hub1 = MessageHub.get_instance("test_load_state_dict1") # update log_scalars. - message_hub1.update_scalar('loss', 0.1) - message_hub1.update_scalar('lr', 0.1, resumed=False) + message_hub1.update_scalar("loss", 0.1) + message_hub1.update_scalar("lr", 0.1, resumed=False) # update runtime information - message_hub1.update_info('iter', 1, resumed=True) - message_hub1.update_info('tensor', [1, 2, 3], resumed=False) + message_hub1.update_info("iter", 1, resumed=True) + message_hub1.update_info("tensor", [1, 2, 3], resumed=False) state_dict = message_hub1.state_dict() # Resume from state_dict - message_hub2 = MessageHub.get_instance('test_load_state_dict2') + message_hub2 = MessageHub.get_instance("test_load_state_dict2") message_hub2.load_state_dict(state_dict) - assert message_hub2.get_scalar('loss').data == (np.array([0.1]), - np.array([1])) - assert message_hub2.get_info('iter') == 1 + assert message_hub2.get_scalar("loss").data == (np.array([0.1]), np.array([1])) + assert message_hub2.get_info("iter") == 1 # Test resume from `MessageHub` instance. - message_hub3 = MessageHub.get_instance('test_load_state_dict3') + message_hub3 = MessageHub.get_instance("test_load_state_dict3") message_hub3.load_state_dict(state_dict) - assert message_hub3.get_scalar('loss').data == (np.array([0.1]), - np.array([1])) - assert message_hub3.get_info('iter') == 1 + assert message_hub3.get_scalar("loss").data == (np.array([0.1]), np.array([1])) + assert message_hub3.get_info("iter") == 1 # Test resume custom state_dict state_dict = OrderedDict() - state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer()) - state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1) - state_dict['resumed_keys'] = dict( - a=True, b=True, c=True, e=False, f=True) + state_dict["log_scalars"] = dict(a=1, b=HistoryBuffer()) + state_dict["runtime_info"] = dict(c=1, d=NoDeepCopy(), e=1) + state_dict["resumed_keys"] = dict(a=True, b=True, c=True, e=False, f=True) - message_hub4 = MessageHub.get_instance('test_load_state_dict4') + message_hub4 = MessageHub.get_instance("test_load_state_dict4") message_hub4.load_state_dict(state_dict) - assert 'a' not in message_hub4.log_scalars and 'b' in \ - message_hub4.log_scalars - assert 'c' in message_hub4.runtime_info and \ - state_dict['runtime_info']['d'] is \ - message_hub4.runtime_info['d'] - assert message_hub4._resumed_keys == OrderedDict( - b=True, c=True, e=False) + assert "a" not in message_hub4.log_scalars and "b" in message_hub4.log_scalars + assert "c" in message_hub4.runtime_info and state_dict["runtime_info"]["d"] is message_hub4.runtime_info["d"] + assert message_hub4._resumed_keys == OrderedDict(b=True, c=True, e=False) def test_getstate(self): - message_hub = MessageHub.get_instance('name') + message_hub = MessageHub.get_instance("name") # update log_scalars. - message_hub.update_scalar('loss', 0.1) - message_hub.update_scalar('lr', 0.1, resumed=False) + message_hub.update_scalar("loss", 0.1) + message_hub.update_scalar("lr", 0.1, resumed=False) # update runtime information - message_hub.update_info('iter', 1, resumed=True) - message_hub.update_info('tensor', [1, 2, 3], resumed=False) + message_hub.update_info("iter", 1, resumed=True) + message_hub.update_info("tensor", [1, 2, 3], resumed=False) obj = pickle.dumps(message_hub) instance = pickle.loads(obj) - assert instance.get_info('feat') is None - assert instance.get_info('lr') is None + assert instance.get_info("feat") is None + assert instance.get_info("lr") is None - instance.get_info('iter') - instance.get_scalar('loss') + instance.get_info("iter") + instance.get_scalar("loss") def test_get_instance(self): # Test get root mmengine message hub. MessageHub._instance_dict = OrderedDict() message_hub = MessageHub.get_current_instance() - assert id(MessageHub.get_instance('mmengine')) == id(message_hub) + assert id(MessageHub.get_instance("mmengine")) == id(message_hub) # Test original `get_current_instance` function. - MessageHub.get_instance('mmdet') - assert MessageHub.get_current_instance().instance_name == 'mmdet' + MessageHub.get_instance("mmdet") + assert MessageHub.get_current_instance().instance_name == "mmdet" diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index 6438b8bde5..73aabf4a82 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -5,8 +5,7 @@ import torch from mmengine.logging import MMLogger -from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA, - StochasticWeightAverage) +from mmengine.model import ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage from mmengine.testing import assert_allclose @@ -18,22 +17,18 @@ class TestAveragedModel(TestCase): """ # noqa: E501 def _test_swa_model(self, net_device, avg_device): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.Linear(5, 10)).to(net_device) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)).to(net_device) averaged_model = StochasticWeightAverage(model, device=avg_device) - averaged_params = [ - torch.zeros_like(param) for param in model.parameters() - ] + averaged_params = [torch.zeros_like(param) for param in model.parameters()] n_updates = 2 - for i in range(n_updates): - for p, p_avg in zip(model.parameters(), averaged_params): + for _i in range(n_updates): + for p, p_avg in zip(model.parameters(), averaged_params, strict=False): p.detach().add_(torch.randn_like(p)) p_avg += p.detach() / n_updates averaged_model.update_parameters(model) - for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()): + for p_avg, p_swa in zip(averaged_params, averaged_model.parameters(), strict=False): # Check that AveragedModel is on the correct device self.assertTrue(p_swa.device == avg_device) self.assertTrue(p.device == net_device) @@ -41,7 +36,7 @@ def _test_swa_model(self, net_device, avg_device): self.assertTrue(averaged_model.steps.device == avg_device) def test_averaged_model_all_devices(self): - cpu = torch.device('cpu') + cpu = torch.device("cpu") self._test_swa_model(cpu, cpu) if torch.cuda.is_available(): cuda = torch.device(0) @@ -52,90 +47,77 @@ def test_averaged_model_all_devices(self): def test_swa_mixed_device(self): if not torch.cuda.is_available(): return - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) model[0].cuda() model[1].cpu() averaged_model = StochasticWeightAverage(model) - averaged_params = [ - torch.zeros_like(param) for param in model.parameters() - ] + averaged_params = [torch.zeros_like(param) for param in model.parameters()] n_updates = 10 - for i in range(n_updates): - for p, p_avg in zip(model.parameters(), averaged_params): + for _i in range(n_updates): + for p, p_avg in zip(model.parameters(), averaged_params, strict=False): p.detach().add_(torch.randn_like(p)) p_avg += p.detach() / n_updates averaged_model.update_parameters(model) - for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()): + for p_avg, p_swa in zip(averaged_params, averaged_model.parameters(), strict=False): assert_allclose(p_avg, p_swa) # Check that AveragedModel is on the correct device self.assertTrue(p_avg.device == p_swa.device) def test_swa_state_dict(self): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) averaged_model = StochasticWeightAverage(model) averaged_model2 = StochasticWeightAverage(model) n_updates = 10 - for i in range(n_updates): + for _i in range(n_updates): for p in model.parameters(): p.detach().add_(torch.randn_like(p)) averaged_model.update_parameters(model) averaged_model2.load_state_dict(averaged_model.state_dict()) - for p_swa, p_swa2 in zip(averaged_model.parameters(), - averaged_model2.parameters()): + for p_swa, p_swa2 in zip(averaged_model.parameters(), averaged_model2.parameters(), strict=False): assert_allclose(p_swa, p_swa2) self.assertTrue(averaged_model.steps == averaged_model2.steps) def test_ema(self): # test invalid momentum - with self.assertRaisesRegex(AssertionError, - 'momentum must be in range'): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + with self.assertRaisesRegex(AssertionError, "momentum must be in range"): + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=3) # Warning should be raised if the value of momentum in EMA is # a large number - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=0.9) # test EMA - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) momentum = 0.1 ema_model = ExponentialMovingAverage(model, momentum=momentum) - averaged_params = [ - torch.zeros_like(param) for param in model.parameters() - ] + averaged_params = [torch.zeros_like(param) for param in model.parameters()] n_updates = 10 for i in range(n_updates): updated_averaged_params = [] - for p, p_avg in zip(model.parameters(), averaged_params): + for p, p_avg in zip(model.parameters(), averaged_params, strict=False): p.detach().add_(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: - updated_averaged_params.append( - (p_avg * (1 - momentum) + p * momentum).clone()) + updated_averaged_params.append((p_avg * (1 - momentum) + p * momentum).clone()) ema_model.update_parameters(model) averaged_params = updated_averaged_params - for p_target, p_ema in zip(averaged_params, ema_model.parameters()): + for p_target, p_ema in zip(averaged_params, ema_model.parameters(), strict=False): assert_allclose(p_target, p_ema) def test_ema_update_buffers(self): # Test EMA and update_buffers as True. model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10) + ) momentum = 0.1 - ema_model = ExponentialMovingAverage( - model, momentum=momentum, update_buffers=True) + ema_model = ExponentialMovingAverage(model, momentum=momentum, update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -145,43 +127,40 @@ def test_ema_update_buffers(self): for i in range(n_updates): updated_averaged_params = [] params = [ - param for param in itertools.chain(model.parameters(), - model.buffers()) + param + for param in itertools.chain(model.parameters(), model.buffers()) if param.size() != torch.Size([]) ] - for p, p_avg in zip(params, averaged_params): + for p, p_avg in zip(params, averaged_params, strict=False): p.detach().add_(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: - updated_averaged_params.append( - (p_avg * (1 - momentum) + p * momentum).clone()) + updated_averaged_params.append((p_avg * (1 - momentum) + p * momentum).clone()) ema_model.update_parameters(model) averaged_params = updated_averaged_params ema_params = [ - param for param in itertools.chain(ema_model.module.parameters(), - ema_model.module.buffers()) + param + for param in itertools.chain(ema_model.module.parameters(), ema_model.module.buffers()) if param.size() != torch.Size([]) ] - for p_target, p_ema in zip(averaged_params, ema_params): + for p_target, p_ema in zip(averaged_params, ema_params, strict=False): assert_allclose(p_target, p_ema) def test_momentum_annealing_ema(self): model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10) + ) # Test invalid gamma - with self.assertRaisesRegex(AssertionError, - 'gamma must be greater than 0'): + with self.assertRaisesRegex(AssertionError, "gamma must be greater than 0"): MomentumAnnealingEMA(model, gamma=-1) # Test EMA with momentum annealing. momentum = 0.1 gamma = 4 - ema_model = MomentumAnnealingEMA( - model, gamma=gamma, momentum=momentum, update_buffers=True) + ema_model = MomentumAnnealingEMA(model, gamma=gamma, momentum=momentum, update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -191,44 +170,38 @@ def test_momentum_annealing_ema(self): for i in range(n_updates): updated_averaged_params = [] params = [ - param for param in itertools.chain(model.parameters(), - model.buffers()) + param + for param in itertools.chain(model.parameters(), model.buffers()) if param.size() != torch.Size([]) ] - for p, p_avg in zip(params, averaged_params): + for p, p_avg in zip(params, averaged_params, strict=False): p.add(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: m = max(momentum, gamma / (gamma + i)) - updated_averaged_params.append( - (p_avg * (1 - m) + p * m).clone()) + updated_averaged_params.append((p_avg * (1 - m) + p * m).clone()) ema_model.update_parameters(model) averaged_params = updated_averaged_params ema_params = [ - param for param in itertools.chain(ema_model.module.parameters(), - ema_model.module.buffers()) + param + for param in itertools.chain(ema_model.module.parameters(), ema_model.module.buffers()) if param.size() != torch.Size([]) ] - for p_target, p_ema in zip(averaged_params, ema_params): + for p_target, p_ema in zip(averaged_params, ema_params, strict=False): assert_allclose(p_target, p_ema) def test_momentum_annealing_ema_with_interval(self): # Test EMA with momentum annealing and interval model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10) + ) momentum = 0.1 gamma = 4 interval = 3 - ema_model = MomentumAnnealingEMA( - model, - gamma=gamma, - momentum=momentum, - interval=interval, - update_buffers=True) + ema_model = MomentumAnnealingEMA(model, gamma=gamma, momentum=momentum, interval=interval, update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -238,27 +211,26 @@ def test_momentum_annealing_ema_with_interval(self): for i in range(n_updates): updated_averaged_params = [] params = [ - param for param in itertools.chain(model.parameters(), - model.buffers()) + param + for param in itertools.chain(model.parameters(), model.buffers()) if param.size() != torch.Size([]) ] - for p, p_avg in zip(params, averaged_params): + for p, p_avg in zip(params, averaged_params, strict=False): p.add(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) elif i % interval == 0: m = max(momentum, gamma / (gamma + i)) - updated_averaged_params.append( - (p_avg * (1 - m) + p * m).clone()) + updated_averaged_params.append((p_avg * (1 - m) + p * m).clone()) else: updated_averaged_params.append(p_avg.clone()) ema_model.update_parameters(model) averaged_params = updated_averaged_params ema_params = [ - param for param in itertools.chain(ema_model.module.parameters(), - ema_model.module.buffers()) + param + for param in itertools.chain(ema_model.module.parameters(), ema_model.module.buffers()) if param.size() != torch.Size([]) ] - for p_target, p_ema in zip(averaged_params, ema_params): + for p_target, p_ema in zip(averaged_params, ema_params, strict=False): assert_allclose(p_target, p_ema) diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 8dc23eec86..b68569590e 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -13,10 +13,11 @@ from mmengine.registry import MODELS from mmengine.testing import assert_allclose + dtypes_to_test = [torch.float16, torch.float32, torch.float64, torch.half] -cpu_devices = ['cpu', torch.device('cpu')] -cuda_devices = ['cuda', 0, torch.device('cuda')] +cpu_devices = ["cpu", torch.device("cpu")] +cuda_devices = ["cuda", 0, torch.device("cuda")] devices_to_test = cpu_devices if torch.cuda.is_available(): devices_to_test += cuda_devices @@ -28,7 +29,6 @@ def list_product(*args): @MODELS.register_module() class CustomDataPreprocessor(BaseDataPreprocessor): - def forward(self, data, training=False): if training: return 1 @@ -37,25 +37,23 @@ def forward(self, data, training=False): class ToyModel(BaseModel): - def __init__(self, data_preprocessor=None): super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) self.conv = nn.Conv2d(3, 1, 1) - def forward(self, inputs, data_sample=None, mode='tensor'): - if mode == 'loss': + def forward(self, inputs, data_sample=None, mode="tensor"): + if mode == "loss": out = self.conv(inputs) return dict(loss=out) - elif mode == 'predict': + elif mode == "predict": out = self.conv(inputs) return out - elif mode == 'tensor': + elif mode == "tensor": out = self.conv(inputs) return out class NestedModel(BaseModel): - def __init__(self): super().__init__() self.toy_model = ToyModel() @@ -65,12 +63,11 @@ def forward(self): class TestBaseModel(TestCase): - def test_init(self): # initiate model without `data_preprocessor` model = ToyModel() self.assertIsInstance(model.data_preprocessor, BaseDataPreprocessor) - data_preprocessor = dict(type='CustomDataPreprocessor') + data_preprocessor = dict(type="CustomDataPreprocessor") model = ToyModel(data_preprocessor=data_preprocessor) self.assertIsInstance(model.data_preprocessor, CustomDataPreprocessor) self.assertEqual(model.data_preprocessor(1, training=True), 1) @@ -82,22 +79,20 @@ def test_init(self): self.assertIs(model.data_preprocessor, data_preprocessor) # initiate model with error type `data_preprocessor`. - with self.assertRaisesRegex(TypeError, 'data_preprocessor should be'): + with self.assertRaisesRegex(TypeError, "data_preprocessor should be"): ToyModel(data_preprocessor=[data_preprocessor]) def test_parse_losses(self): model = ToyModel() loss_cls = torch.tensor(1, dtype=torch.float32) - loss_list = [ - torch.tensor(2, dtype=torch.float32), - torch.tensor(3, dtype=torch.float32) - ] + loss_list = [torch.tensor(2, dtype=torch.float32), torch.tensor(3, dtype=torch.float32)] losses = dict(loss_cls=loss_cls, loss_list=loss_list) target_parsed_losses = torch.tensor(6, dtype=torch.float32) targe_log_vars = dict( loss=torch.tensor(6, dtype=torch.float32), loss_cls=torch.tensor(1, dtype=torch.float32), - loss_list=torch.tensor(5, dtype=torch.float32)) + loss_list=torch.tensor(5, dtype=torch.float32), + ) parse_losses, log_vars = model.parse_losses(losses) assert_allclose(parse_losses, target_parsed_losses) for key in log_vars: @@ -105,7 +100,7 @@ def test_parse_losses(self): assert_allclose(log_vars[key], targe_log_vars[key]) with self.assertRaises(TypeError): - losses['error_key'] = dict() + losses["error_key"] = dict() model.parse_losses(losses) def test_train_step(self): @@ -117,7 +112,7 @@ def test_train_step(self): data = dict(inputs=inputs, data_sample=None) log_vars = model.train_step(data, optim_wrapper) self.assertFalse(torch.equal(ori_conv_weight, model.conv.weight)) - self.assertIsInstance(log_vars['loss'], torch.Tensor) + self.assertIsInstance(log_vars["loss"], torch.Tensor) def test_val_step(self): inputs = torch.randn(1, 3, 1, 1) @@ -133,70 +128,63 @@ def test_test_step(self): out = model.val_step(data) self.assertIsInstance(out, torch.Tensor) - @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') + @unittest.skipIf(not torch.cuda.is_available(), "cuda should be available") def test_cuda(self): inputs = torch.randn(1, 3, 1, 1).cuda() data = dict(inputs=inputs, data_sample=None) model = ToyModel().cuda() out = model.val_step(data) - self.assertEqual(out.device.type, 'cuda') + self.assertEqual(out.device.type, "cuda") model = NestedModel() - self.assertEqual(model.data_preprocessor._device, torch.device('cpu')) - self.assertEqual(model.toy_model.data_preprocessor._device, - torch.device('cpu')) + self.assertEqual(model.data_preprocessor._device, torch.device("cpu")) + self.assertEqual(model.toy_model.data_preprocessor._device, torch.device("cpu")) model.cuda() - self.assertEqual(model.data_preprocessor._device, torch.device('cuda')) - self.assertEqual(model.toy_model.data_preprocessor._device, - torch.device('cuda')) + self.assertEqual(model.data_preprocessor._device, torch.device("cuda")) + self.assertEqual(model.toy_model.data_preprocessor._device, torch.device("cuda")) - @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') + @unittest.skipIf(not torch.cuda.is_available(), "cuda should be available") def test_to(self): - inputs = torch.randn(1, 3, 1, 1).to('cuda:0') + inputs = torch.randn(1, 3, 1, 1).to("cuda:0") data = dict(inputs=inputs, data_sample=None) model = ToyModel().to(torch.cuda.current_device()) out = model.val_step(data) - self.assertEqual(out.device.type, 'cuda') + self.assertEqual(out.device.type, "cuda") model = NestedModel() - self.assertEqual(model.data_preprocessor._device, torch.device('cpu')) - self.assertEqual(model.toy_model.data_preprocessor._device, - torch.device('cpu')) - model.to('cuda') - self.assertEqual(model.data_preprocessor._device, torch.device('cuda')) - self.assertEqual(model.toy_model.data_preprocessor._device, - torch.device('cuda')) + self.assertEqual(model.data_preprocessor._device, torch.device("cpu")) + self.assertEqual(model.toy_model.data_preprocessor._device, torch.device("cpu")) + model.to("cuda") + self.assertEqual(model.data_preprocessor._device, torch.device("cuda")) + self.assertEqual(model.toy_model.data_preprocessor._device, torch.device("cuda")) model.to() - self.assertEqual(model.data_preprocessor._device, torch.device('cuda')) - self.assertEqual(model.toy_model.data_preprocessor._device, - torch.device('cuda')) + self.assertEqual(model.data_preprocessor._device, torch.device("cuda")) + self.assertEqual(model.toy_model.data_preprocessor._device, torch.device("cuda")) @parameterized.expand(list_product(devices_to_test)) def test_to_device(self, device): model = ToyModel().to(device) self.assertTrue( - all(p.device.type == torch.device(device).type - for p in model.parameters()) - and model.data_preprocessor._device == torch.device(device)) + all(p.device.type == torch.device(device).type for p in model.parameters()) + and model.data_preprocessor._device == torch.device(device) + ) @parameterized.expand(list_product(dtypes_to_test)) def test_to_dtype(self, dtype): model = ToyModel().to(dtype) self.assertTrue(all(p.dtype == dtype for p in model.parameters())) - @parameterized.expand( - list_product(devices_to_test, dtypes_to_test, - ['args', 'kwargs', 'hybrid'])) + @parameterized.expand(list_product(devices_to_test, dtypes_to_test, ["args", "kwargs", "hybrid"])) def test_to_device_and_dtype(self, device, dtype, mode): - if mode == 'args': + if mode == "args": model = ToyModel().to(device, dtype) - elif mode == 'kwargs': + elif mode == "kwargs": model = ToyModel().to(device=device, dtype=dtype) - elif mode == 'hybrid': + elif mode == "hybrid": model = ToyModel().to(device, dtype=dtype) self.assertTrue( all(p.dtype == dtype for p in model.parameters()) and model.data_preprocessor._device == torch.device(device) - and all(p.device.type == torch.device(device).type - for p in model.parameters())) + and all(p.device.type == torch.device(device).type for p in model.parameters()) + ) diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index c409260a50..3490610454 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -10,14 +10,13 @@ class TestBaseDataPreprocessor(TestCase): - def test_init(self): base_data_preprocessor = BaseDataPreprocessor() - self.assertEqual(base_data_preprocessor._device.type, 'cpu') + self.assertEqual(base_data_preprocessor._device.type, "cpu") self.assertEqual(base_data_preprocessor._non_blocking, False) base_data_preprocessor = BaseDataPreprocessor(True) - self.assertEqual(base_data_preprocessor._device.type, 'cpu') + self.assertEqual(base_data_preprocessor._device.type, "cpu") self.assertEqual(base_data_preprocessor._non_blocking, True) def test_forward(self): @@ -31,7 +30,7 @@ def test_forward(self): # Test with dict of batch inputs and batch data samples data = dict(inputs=[input1, input2], data_sample=[label1, label2]) output = base_data_preprocessor(data) - batch_inputs, batch_labels = output['inputs'], output['data_sample'] + batch_inputs, batch_labels = output["inputs"], output["data_sample"] self.assertTrue(torch.is_floating_point(batch_inputs[0])) self.assertEqual(batch_inputs[0].shape, (1, 3, 5)) @@ -54,83 +53,68 @@ def test_forward(self): data = dict(inputs=[input1, input2], data_sample=[label1, label2]) base_data_preprocessor = base_data_preprocessor.cuda() output = base_data_preprocessor(data) - batch_inputs, batch_labels = output['inputs'], output[ - 'data_sample'] + batch_inputs, batch_labels = output["inputs"], output["data_sample"] self.assertTrue(torch.is_floating_point(batch_inputs[0])) - self.assertEqual(batch_inputs[0].device.type, 'cuda') + self.assertEqual(batch_inputs[0].device.type, "cuda") # Fallback to test with cpu. base_data_preprocessor = base_data_preprocessor.cpu() output = base_data_preprocessor(data) - batch_inputs, batch_labels = output['inputs'], output[ - 'data_sample'] + batch_inputs, batch_labels = output["inputs"], output["data_sample"] self.assertTrue(torch.is_floating_point(batch_inputs[0])) - self.assertEqual(batch_inputs[0].device.type, 'cpu') + self.assertEqual(batch_inputs[0].device.type, "cpu") # Test `base_data_preprocessor` can be moved to cuda again. - base_data_preprocessor = base_data_preprocessor.to('cuda:0') + base_data_preprocessor = base_data_preprocessor.to("cuda:0") output = base_data_preprocessor(data) - batch_inputs, batch_labels = output['inputs'], output[ - 'data_sample'] + batch_inputs, batch_labels = output["inputs"], output["data_sample"] self.assertTrue(torch.is_floating_point(batch_inputs[0])) - self.assertEqual(batch_inputs[0].device.type, 'cuda') + self.assertEqual(batch_inputs[0].device.type, "cuda") # device of `base_data_preprocessor` is cuda, output should be # cuda tensor. - self.assertEqual(batch_inputs[0].device.type, 'cuda') - self.assertEqual(batch_labels[0].device.type, 'cuda') + self.assertEqual(batch_inputs[0].device.type, "cuda") + self.assertEqual(batch_labels[0].device.type, "cuda") # Test forward with string value - data = dict(string='abc') + data = dict(string="abc") base_data_preprocessor(data) class TestImgDataPreprocessor(TestBaseDataPreprocessor): - def test_init(self): # Initiate processor without arguments data_processor = ImgDataPreprocessor() self.assertFalse(data_processor._channel_conversion) - self.assertFalse(hasattr(data_processor, 'mean')) - self.assertFalse(hasattr(data_processor, 'std')) + self.assertFalse(hasattr(data_processor, "mean")) + self.assertFalse(hasattr(data_processor, "std")) self.assertEqual(data_processor.pad_size_divisor, 1) assert_allclose(data_processor.pad_value, torch.tensor(0)) # Initiate model with bgr2rgb, mean, std .etc.. data_processor = ImgDataPreprocessor( - bgr_to_rgb=True, - mean=[0, 0, 0], - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) + bgr_to_rgb=True, mean=[0, 0, 0], std=[255, 255, 255], pad_size_divisor=16, pad_value=10 + ) self.assertTrue(data_processor._enable_normalize) self.assertTrue(data_processor._channel_conversion, True) - assert_allclose(data_processor.mean, - torch.tensor([0, 0, 0]).view(-1, 1, 1)) - assert_allclose(data_processor.std, - torch.tensor([255, 255, 255]).view(-1, 1, 1)) + assert_allclose(data_processor.mean, torch.tensor([0, 0, 0]).view(-1, 1, 1)) + assert_allclose(data_processor.std, torch.tensor([255, 255, 255]).view(-1, 1, 1)) assert_allclose(data_processor.pad_value, torch.tensor(10)) self.assertEqual(data_processor.pad_size_divisor, 16) - with self.assertRaisesRegex(AssertionError, '`mean` should have'): + with self.assertRaisesRegex(AssertionError, "`mean` should have"): ImgDataPreprocessor(mean=(1, 2), std=(1, 2, 3)) - with self.assertRaisesRegex(AssertionError, '`std` should have'): + with self.assertRaisesRegex(AssertionError, "`std` should have"): ImgDataPreprocessor(mean=(1, 2, 3), std=(1, 2)) - with self.assertRaisesRegex(AssertionError, '`bgr2rgb` and `rgb2bgr`'): + with self.assertRaisesRegex(AssertionError, "`bgr2rgb` and `rgb2bgr`"): ImgDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True) - with self.assertRaisesRegex(AssertionError, 'mean and std should be'): - ImgDataPreprocessor( - bgr_to_rgb=True, - mean=None, - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) + with self.assertRaisesRegex(AssertionError, "mean and std should be"): + ImgDataPreprocessor(bgr_to_rgb=True, mean=None, std=[255, 255, 255], pad_size_divisor=16, pad_value=10) - data_processor = ImgDataPreprocessor( - bgr_to_rgb=True, pad_size_divisor=16, pad_value=10) + data_processor = ImgDataPreprocessor(bgr_to_rgb=True, pad_size_divisor=16, pad_value=10) self.assertFalse(data_processor._enable_normalize) def test_forward(self): @@ -147,10 +131,7 @@ def test_forward(self): data_sample1 = InstanceData(bboxes=torch.randn(5, 4)) data_sample2 = InstanceData(bboxes=torch.randn(5, 4)) - data = dict( - inputs=[inputs1.clone(), inputs2.clone()], - data_sample=[data_sample1.clone(), - data_sample2.clone()]) + data = dict(inputs=[inputs1.clone(), inputs2.clone()], data_sample=[data_sample1.clone(), data_sample2.clone()]) std = torch.tensor([1, 2, 3]).view(-1, 1, 1) target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std @@ -161,12 +142,13 @@ def test_forward(self): target_inputs = [target_inputs1, target_inputs2] output = data_preprocessor(data, True) - inputs, data_samples = output['inputs'], output['data_sample'] + inputs, data_samples = output["inputs"], output["data_sample"] self.assertTrue(torch.is_floating_point(inputs)) target_data_samples = [data_sample1, data_sample2] for input_, data_sample, target_input, target_data_sample in zip( - inputs, data_samples, target_inputs, target_data_samples): + inputs, data_samples, target_inputs, target_data_samples, strict=False + ): assert_allclose(input_, target_input) assert_allclose(data_sample.bboxes, target_data_sample.bboxes) @@ -176,48 +158,42 @@ def test_forward(self): pad_value=10, rgb_to_bgr=True, ) - target_inputs1 = (inputs1.clone()[[2, 1, 0], ...]) - target_inputs2 = (inputs2.clone()[[2, 1, 0], ...]) + target_inputs1 = inputs1.clone()[[2, 1, 0], ...] + target_inputs2 = inputs2.clone()[[2, 1, 0], ...] target_inputs1 = F.pad(target_inputs1, (0, 6, 0, 6), value=10) target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10) target_inputs = [target_inputs1, target_inputs2] output = data_preprocessor(data, True) - inputs, data_samples = output['inputs'], output['data_sample'] + inputs, data_samples = output["inputs"], output["data_sample"] self.assertTrue(torch.is_floating_point(inputs)) target_data_samples = [data_sample1, data_sample2] for input_, data_sample, target_input, target_data_sample in zip( - inputs, data_samples, target_inputs, target_data_samples): + inputs, data_samples, target_inputs, target_data_samples, strict=False + ): assert_allclose(input_, target_input) assert_allclose(data_sample.bboxes, target_data_sample.bboxes) # Test gray image with 3 dim mean will raise error - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) - data = dict( - inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) - with self.assertRaisesRegex(AssertionError, - 'If the mean has 3 values'): + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) + data = dict(inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) + with self.assertRaisesRegex(AssertionError, "If the mean has 3 values"): data_preprocessor(data) - data = dict( - inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) - with self.assertRaisesRegex(AssertionError, - 'If the mean has 3 values'): + data = dict(inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) + with self.assertRaisesRegex(AssertionError, "If the mean has 3 values"): data_preprocessor(data) # Test stacked batch inputs and batch data samples data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), - std=(127.5, 127.5, 127.5), - rgb_to_bgr=True, - pad_size_divisor=16) + mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5), rgb_to_bgr=True, pad_size_divisor=16 + ) _batch_inputs = torch.randn(2, 3, 10, 10) _batch_labels = [torch.randn(1), torch.randn(1)] data = dict(inputs=_batch_inputs, data_sample=_batch_labels) output = data_preprocessor(data) - inputs, data_samples = output['inputs'], output['data_sample'] + inputs, data_samples = output["inputs"], output["data_sample"] target_batch_inputs = _batch_inputs[:, [2, 1, 0], ...] target_batch_inputs = (target_batch_inputs - 127.5) / 127.5 target_batch_inputs = F.pad(target_batch_inputs, (0, 6, 0, 6), value=0) @@ -226,22 +202,20 @@ def test_forward(self): assert_allclose(target_batch_inputs, inputs) # Test batch inputs without convert channel order and pad - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) _batch_inputs = torch.randn(2, 3, 10, 10) _batch_labels = [torch.randn(1), torch.randn(1)] data = dict(inputs=_batch_inputs, data_sample=_batch_labels) output = data_preprocessor(data) - inputs, data_samples = output['inputs'], output['data_sample'] + inputs, data_samples = output["inputs"], output["data_sample"] target_batch_inputs = (_batch_inputs - 127.5) / 127.5 self.assertEqual(inputs.shape, torch.Size([2, 3, 10, 10])) self.assertTrue(torch.is_floating_point(inputs)) assert_allclose(target_batch_inputs, inputs) # Test empty `data_sample` - data = dict( - inputs=[inputs1.clone(), inputs2.clone()], data_sample=None) + data = dict(inputs=[inputs1.clone(), inputs2.clone()], data_sample=None) output = data_preprocessor(data, True) - inputs, data_samples = output['inputs'], output['data_sample'] + inputs, data_samples = output["inputs"], output["data_sample"] self.assertIsNone(data_samples) self.assertTrue(torch.is_floating_point(inputs)) diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py index 1401eed298..12e48b8b95 100644 --- a/tests/test_model/test_base_module.py +++ b/tests/test_model/test_base_module.py @@ -14,15 +14,15 @@ from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential from mmengine.registry import Registry, build_from_cfg -COMPONENTS = Registry('component') -FOOMODELS = Registry('model') + +COMPONENTS = Registry("component") +FOOMODELS = Registry("model") Logger = MMLogger.get_current_instance() @COMPONENTS.register_module() class FooConv1d(BaseModule): - def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv1d = nn.Conv1d(4, 1, 4) @@ -33,7 +33,6 @@ def forward(self, x): @COMPONENTS.register_module() class FooConv2d(BaseModule): - def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv2d = nn.Conv2d(3, 1, 3) @@ -44,7 +43,6 @@ def forward(self, x): @COMPONENTS.register_module() class FooLinear(BaseModule): - def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(3, 4) @@ -55,7 +53,6 @@ def forward(self, x): @COMPONENTS.register_module() class FooLinearConv1d(BaseModule): - def __init__(self, linear=None, conv1d=None, init_cfg=None): super().__init__(init_cfg) if linear is not None: @@ -70,13 +67,7 @@ def forward(self, x): @FOOMODELS.register_module() class FooModel(BaseModule): - - def __init__(self, - component1=None, - component2=None, - component3=None, - component4=None, - init_cfg=None) -> None: + def __init__(self, component1=None, component2=None, component3=None, component4=None, init_cfg=None) -> None: super().__init__(init_cfg) if component1 is not None: self.component1 = build_from_cfg(component1, COMPONENTS) @@ -93,24 +84,21 @@ def __init__(self, class TestBaseModule(TestCase): - def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.BaseModule = BaseModule() self.model_cfg = dict( - type='FooModel', + type="FooModel", init_cfg=[ - dict(type='Constant', val=1, bias=2, layer='Linear'), - dict(type='Constant', val=3, bias=4, layer='Conv1d'), - dict(type='Constant', val=5, bias=6, layer='Conv2d') + dict(type="Constant", val=1, bias=2, layer="Linear"), + dict(type="Constant", val=3, bias=4, layer="Conv1d"), + dict(type="Constant", val=5, bias=6, layer="Conv2d"), ], - component1=dict(type='FooConv1d'), - component2=dict(type='FooConv2d'), - component3=dict(type='FooLinear'), - component4=dict( - type='FooLinearConv1d', - linear=dict(type='FooLinear'), - conv1d=dict(type='FooConv1d'))) + component1=dict(type="FooConv1d"), + component2=dict(type="FooConv2d"), + component3=dict(type="FooLinear"), + component4=dict(type="FooLinearConv1d", linear=dict(type="FooLinear"), conv1d=dict(type="FooConv1d")), + ) self.model = build_from_cfg(self.model_cfg, FOOMODELS) self.logger = MMLogger.get_instance(self._testMethodName) @@ -149,44 +137,37 @@ def test_init_weights(self): self.model.init_weights() assert torch.equal( - self.model.component1.conv1d.weight, - torch.full(self.model.component1.conv1d.weight.shape, 3.0)) - assert torch.equal( - self.model.component1.conv1d.bias, - torch.full(self.model.component1.conv1d.bias.shape, 4.0)) - assert torch.equal( - self.model.component2.conv2d.weight, - torch.full(self.model.component2.conv2d.weight.shape, 5.0)) - assert torch.equal( - self.model.component2.conv2d.bias, - torch.full(self.model.component2.conv2d.bias.shape, 6.0)) + self.model.component1.conv1d.weight, torch.full(self.model.component1.conv1d.weight.shape, 3.0) + ) + assert torch.equal(self.model.component1.conv1d.bias, torch.full(self.model.component1.conv1d.bias.shape, 4.0)) assert torch.equal( - self.model.component3.linear.weight, - torch.full(self.model.component3.linear.weight.shape, 1.0)) + self.model.component2.conv2d.weight, torch.full(self.model.component2.conv2d.weight.shape, 5.0) + ) + assert torch.equal(self.model.component2.conv2d.bias, torch.full(self.model.component2.conv2d.bias.shape, 6.0)) assert torch.equal( - self.model.component3.linear.bias, - torch.full(self.model.component3.linear.bias.shape, 2.0)) + self.model.component3.linear.weight, torch.full(self.model.component3.linear.weight.shape, 1.0) + ) + assert torch.equal(self.model.component3.linear.bias, torch.full(self.model.component3.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.linear.linear.weight, - torch.full(self.model.component4.linear.linear.weight.shape, 1.0)) + torch.full(self.model.component4.linear.linear.weight.shape, 1.0), + ) assert torch.equal( - self.model.component4.linear.linear.bias, - torch.full(self.model.component4.linear.linear.bias.shape, 2.0)) + self.model.component4.linear.linear.bias, torch.full(self.model.component4.linear.linear.bias.shape, 2.0) + ) assert torch.equal( self.model.component4.conv1d.conv1d.weight, - torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0)) + torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0), + ) assert torch.equal( - self.model.component4.conv1d.conv1d.bias, - torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0)) - assert torch.equal(self.model.reg.weight, - torch.full(self.model.reg.weight.shape, 1.0)) - assert torch.equal(self.model.reg.bias, - torch.full(self.model.reg.bias.shape, 2.0)) + self.model.component4.conv1d.conv1d.bias, torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0) + ) + assert torch.equal(self.model.reg.weight, torch.full(self.model.reg.weight.shape, 1.0)) + assert torch.equal(self.model.reg.bias, torch.full(self.model.reg.bias.shape, 2.0)) # Test build model from Pretrained weights class CustomLinear(BaseModule): - def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(1, 1) @@ -197,23 +178,17 @@ def init_weights(self): @FOOMODELS.register_module() class PratrainedModel(FooModel): - - def __init__(self, - component1=None, - component2=None, - component3=None, - component4=None, - init_cfg=None) -> None: - super().__init__(component1, component2, component3, - component4, init_cfg) + def __init__( + self, component1=None, component2=None, component3=None, component4=None, init_cfg=None + ) -> None: + super().__init__(component1, component2, component3, component4, init_cfg) self.linear = CustomLinear() - checkpoint_path = osp.join(self.temp_dir.name, 'test.pth') + checkpoint_path = osp.join(self.temp_dir.name, "test.pth") torch.save(self.model.state_dict(), checkpoint_path) model_cfg = copy.deepcopy(self.model_cfg) - model_cfg['type'] = 'PratrainedModel' - model_cfg['init_cfg'] = dict( - type='Pretrained', checkpoint=checkpoint_path) + model_cfg["type"] = "PratrainedModel" + model_cfg["init_cfg"] = dict(type="Pretrained", checkpoint=checkpoint_path) model = FOOMODELS.build(model_cfg) ori_layer_weight = model.linear.linear.weight.clone() ori_layer_bias = model.linear.linear.bias.clone() @@ -223,15 +198,13 @@ def __init__(self, self.assertTrue((ori_layer_bias != model.linear.linear.bias).any()) class FakeDDP(nn.Module): - def __init__(self, module) -> None: super().__init__() self.module = module # Test initialization of nested modules in DDPModule which define # `init_weights`. - with patch('mmengine.model.base_module.is_model_wrapper', - lambda x: isinstance(x, FakeDDP)): + with patch("mmengine.model.base_module.is_model_wrapper", lambda x: isinstance(x, FakeDDP)): model = FOOMODELS.build(model_cfg) model.ddp = FakeDDP(CustomLinear()) model.init_weights() @@ -265,7 +238,8 @@ def __init__(self, module) -> None: def test_dump_init_info(self): import os import shutil - dump_dir = 'tests/test_model/test_dump_info' + + dump_dir = "tests/test_model/test_dump_info" if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)): os.makedirs(dump_dir) for filename in os.listdir(dump_dir): @@ -275,13 +249,12 @@ def test_dump_init_info(self): elif os.path.isdir(file_path): shutil.rmtree(file_path) - MMLogger.get_instance('logger1') # add logger without FileHandler + MMLogger.get_instance("logger1") # add logger without FileHandler model1 = build_from_cfg(self.model_cfg, FOOMODELS) model1.init_weights() assert len(os.listdir(dump_dir)) == 0 - log_path = os.path.join(dump_dir, 'out.log') - MMLogger.get_instance( - 'logger2', log_file=log_path) # add logger with FileHandler + log_path = os.path.join(dump_dir, "out.log") + MMLogger.get_instance("logger2", log_file=log_path) # add logger with FileHandler model2 = build_from_cfg(self.model_cfg, FOOMODELS) model2.init_weights() assert len(os.listdir(dump_dir)) == 1 @@ -294,161 +267,101 @@ def test_dump_init_info(self): class TestModuleList(TestCase): - def test_modulelist_weight_init(self): models_cfg = [ - dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + dict(type="FooConv1d", init_cfg=dict(type="Constant", layer="Conv1d", val=0.0, bias=1.0)), + dict(type="FooConv2d", init_cfg=dict(type="Constant", layer="Conv2d", val=2.0, bias=3.0)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList(layers) modellist.init_weights() - self.assertTrue( - torch.equal(modellist[0].conv1d.weight, - torch.full(modellist[0].conv1d.weight.shape, 0.))) - self.assertTrue( - torch.equal(modellist[0].conv1d.bias, - torch.full(modellist[0].conv1d.bias.shape, 1.))) - self.assertTrue( - torch.equal(modellist[1].conv2d.weight, - torch.full(modellist[1].conv2d.weight.shape, 2.))) - self.assertTrue( - torch.equal(modellist[1].conv2d.bias, - torch.full(modellist[1].conv2d.bias.shape, 3.))) + self.assertTrue(torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.0))) + self.assertTrue(torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.0))) + self.assertTrue(torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.0))) + self.assertTrue(torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.0))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] - modellist = ModuleList( - layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modellist = ModuleList(layers, init_cfg=dict(type="Constant", layer=["Conv1d", "Conv2d"], val=4.0, bias=5.0)) modellist.init_weights() - self.assertTrue( - torch.equal(modellist[0].conv1d.weight, - torch.full(modellist[0].conv1d.weight.shape, 0.))) - self.assertTrue( - torch.equal(modellist[0].conv1d.bias, - torch.full(modellist[0].conv1d.bias.shape, 1.))) - self.assertTrue( - torch.equal(modellist[1].conv2d.weight, - torch.full(modellist[1].conv2d.weight.shape, 2.))) - self.assertTrue( - torch.equal(modellist[1].conv2d.bias, - torch.full(modellist[1].conv2d.bias.shape, 3.))) + self.assertTrue(torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.0))) + self.assertTrue(torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.0))) + self.assertTrue(torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.0))) + self.assertTrue(torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.0))) class TestModuleDict(TestCase): - def test_moduledict_weight_init(self): models_cfg = dict( - foo_conv_1d=dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - foo_conv_2d=dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + foo_conv_1d=dict(type="FooConv1d", init_cfg=dict(type="Constant", layer="Conv1d", val=0.0, bias=1.0)), + foo_conv_2d=dict(type="FooConv2d", init_cfg=dict(type="Constant", layer="Conv2d", val=2.0, bias=3.0)), ) - layers = { - name: build_from_cfg(cfg, COMPONENTS) - for name, cfg in models_cfg.items() - } + layers = {name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items()} modeldict = ModuleDict(layers) modeldict.init_weights() self.assertTrue( torch.equal( - modeldict['foo_conv_1d'].conv1d.weight, - torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) + modeldict["foo_conv_1d"].conv1d.weight, torch.full(modeldict["foo_conv_1d"].conv1d.weight.shape, 0.0) + ) + ) self.assertTrue( torch.equal( - modeldict['foo_conv_1d'].conv1d.bias, - torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) + modeldict["foo_conv_1d"].conv1d.bias, torch.full(modeldict["foo_conv_1d"].conv1d.bias.shape, 1.0) + ) + ) self.assertTrue( torch.equal( - modeldict['foo_conv_2d'].conv2d.weight, - torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) + modeldict["foo_conv_2d"].conv2d.weight, torch.full(modeldict["foo_conv_2d"].conv2d.weight.shape, 2.0) + ) + ) self.assertTrue( torch.equal( - modeldict['foo_conv_2d'].conv2d.bias, - torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) + modeldict["foo_conv_2d"].conv2d.bias, torch.full(modeldict["foo_conv_2d"].conv2d.bias.shape, 3.0) + ) + ) # inner init_cfg has higher priority - layers = { - name: build_from_cfg(cfg, COMPONENTS) - for name, cfg in models_cfg.items() - } - modeldict = ModuleDict( - layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + layers = {name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items()} + modeldict = ModuleDict(layers, init_cfg=dict(type="Constant", layer=["Conv1d", "Conv2d"], val=4.0, bias=5.0)) modeldict.init_weights() self.assertTrue( torch.equal( - modeldict['foo_conv_1d'].conv1d.weight, - torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) + modeldict["foo_conv_1d"].conv1d.weight, torch.full(modeldict["foo_conv_1d"].conv1d.weight.shape, 0.0) + ) + ) self.assertTrue( torch.equal( - modeldict['foo_conv_1d'].conv1d.bias, - torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) + modeldict["foo_conv_1d"].conv1d.bias, torch.full(modeldict["foo_conv_1d"].conv1d.bias.shape, 1.0) + ) + ) self.assertTrue( torch.equal( - modeldict['foo_conv_2d'].conv2d.weight, - torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) + modeldict["foo_conv_2d"].conv2d.weight, torch.full(modeldict["foo_conv_2d"].conv2d.weight.shape, 2.0) + ) + ) self.assertTrue( torch.equal( - modeldict['foo_conv_2d'].conv2d.bias, - torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) + modeldict["foo_conv_2d"].conv2d.bias, torch.full(modeldict["foo_conv_2d"].conv2d.bias.shape, 3.0) + ) + ) class TestSequential(TestCase): - def test_sequential_model_weight_init(self): seq_model_cfg = [ - dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + dict(type="FooConv1d", init_cfg=dict(type="Constant", layer="Conv1d", val=0.0, bias=1.0)), + dict(type="FooConv2d", init_cfg=dict(type="Constant", layer="Conv2d", val=2.0, bias=3.0)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential(*layers) seq_model.init_weights() - self.assertTrue( - torch.equal(seq_model[0].conv1d.weight, - torch.full(seq_model[0].conv1d.weight.shape, 0.))) - self.assertTrue( - torch.equal(seq_model[0].conv1d.bias, - torch.full(seq_model[0].conv1d.bias.shape, 1.))) - self.assertTrue( - torch.equal(seq_model[1].conv2d.weight, - torch.full(seq_model[1].conv2d.weight.shape, 2.))) - self.assertTrue( - torch.equal(seq_model[1].conv2d.bias, - torch.full(seq_model[1].conv2d.bias.shape, 3.))) + self.assertTrue(torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.0))) + self.assertTrue(torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.0))) + self.assertTrue(torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.0))) + self.assertTrue(torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.0))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] - seq_model = Sequential( - *layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + seq_model = Sequential(*layers, init_cfg=dict(type="Constant", layer=["Conv1d", "Conv2d"], val=4.0, bias=5.0)) seq_model.init_weights() - self.assertTrue( - torch.equal(seq_model[0].conv1d.weight, - torch.full(seq_model[0].conv1d.weight.shape, 0.))) - self.assertTrue( - torch.equal(seq_model[0].conv1d.bias, - torch.full(seq_model[0].conv1d.bias.shape, 1.))) - self.assertTrue( - torch.equal(seq_model[1].conv2d.weight, - torch.full(seq_model[1].conv2d.weight.shape, 2.))) - self.assertTrue( - torch.equal(seq_model[1].conv2d.bias, - torch.full(seq_model[1].conv2d.bias.shape, 3.))) + self.assertTrue(torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.0))) + self.assertTrue(torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.0))) + self.assertTrue(torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.0))) + self.assertTrue(torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.0))) diff --git a/tests/test_model/test_efficient_conv_bn_eval.py b/tests/test_model/test_efficient_conv_bn_eval.py index eb91a6d090..7959cd099b 100644 --- a/tests/test_model/test_efficient_conv_bn_eval.py +++ b/tests/test_model/test_efficient_conv_bn_eval.py @@ -5,22 +5,22 @@ import torch from torch import nn -from mmengine.model.efficient_conv_bn_eval import \ - turn_on_efficient_conv_bn_eval_for_single_model +from mmengine.model.efficient_conv_bn_eval import turn_on_efficient_conv_bn_eval_for_single_model from mmengine.testing import assert_allclose from mmengine.utils import is_installed from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version -mmcv_is_installed = is_installed('mmcv') +mmcv_is_installed = is_installed("mmcv") -class BackboneModel(nn.Module): +class BackboneModel(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) if mmcv_is_installed: from mmcv.cnn import ConvModule + conv0 = nn.Conv2d(6, 6, 6) bn0 = nn.BatchNorm2d(6) self.mod1 = ConvModule.create_from_conv_bn(conv0, bn0) @@ -46,9 +46,7 @@ def forward(self, x): return x -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('1.8'), - reason='torch.fx needs Pytorch 1.8 or higher') +@unittest.skipIf(digit_version(TORCH_VERSION) < digit_version("1.8"), reason="torch.fx needs Pytorch 1.8 or higher") class TestEfficientConvBNEval(TestCase): """Test the turn_on_efficient_conv_bn_eval function.""" diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py index a08ff67d77..b018dc65d3 100644 --- a/tests/test_model/test_model_utils.py +++ b/tests/test_model/test_model_utils.py @@ -7,16 +7,18 @@ from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DataParallel, DistributedDataParallel -from mmengine.model import (MMDistributedDataParallel, - MMSeparateDistributedDataParallel, - convert_sync_batchnorm, is_model_wrapper, - revert_sync_batchnorm) +from mmengine.model import ( + MMDistributedDataParallel, + MMSeparateDistributedDataParallel, + convert_sync_batchnorm, + is_model_wrapper, + revert_sync_batchnorm, +) from mmengine.registry import MODEL_WRAPPERS, Registry from mmengine.utils import is_installed class ToyModule(nn.Module): - def __init__(self): super().__init__() self.layer1 = nn.Linear(1, 1) @@ -25,8 +27,7 @@ def add_module(self, name, module): raise ValueError() -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') +@pytest.mark.skipif(torch.__version__ == "parrots", reason="not supported in parrots now") def test_revert_syncbn(): # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8)) @@ -40,19 +41,18 @@ def test_revert_syncbn(): revert_sync_batchnorm(conv) -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') +@pytest.mark.skipif(torch.__version__ == "parrots", reason="not supported in parrots now") def test_convert_syncbn(): # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.BatchNorm2d(8)) # Test convert to mmcv SyncBatchNorm - if is_installed('mmcv'): + if is_installed("mmcv"): # MMCV SyncBatchNorm is only supported on distributed training. # torch 1.6 will throw an AssertionError, and higher version will # throw an RuntimeError with pytest.raises((RuntimeError, AssertionError)): - convert_sync_batchnorm(conv, implementation='mmcv') + convert_sync_batchnorm(conv, implementation="mmcv") # Test convert BN to Pytorch SyncBatchNorm converted_conv = convert_sync_batchnorm(conv) @@ -61,25 +61,26 @@ def test_convert_syncbn(): def test_is_model_wrapper(): # Test basic module wrapper. - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29510' - os.environ['RANK'] = str(0) - init_process_group(backend='gloo', rank=0, world_size=1) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29510" + os.environ["RANK"] = str(0) + init_process_group(backend="gloo", rank=0, world_size=1) model = nn.Linear(1, 1) for wrapper in [ - DistributedDataParallel, MMDistributedDataParallel, - MMSeparateDistributedDataParallel, DataParallel + DistributedDataParallel, + MMDistributedDataParallel, + MMSeparateDistributedDataParallel, + DataParallel, ]: wrapper_model = wrapper(model) assert is_model_wrapper(wrapper_model) # Test `is_model_wrapper` can check model wrapper registered in custom # registry. - CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS) + CHILD_REGISTRY = Registry("test_is_model_wrapper", parent=MODEL_WRAPPERS) class CustomModelWrapper(nn.Module): - def __init__(self, model): super().__init__() self.module = model @@ -89,8 +90,11 @@ def __init__(self, model): CHILD_REGISTRY.register_module(module=CustomModelWrapper, force=True) for wrapper in [ - DistributedDataParallel, MMDistributedDataParallel, - MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper + DistributedDataParallel, + MMDistributedDataParallel, + MMSeparateDistributedDataParallel, + DataParallel, + CustomModelWrapper, ]: wrapper_model = wrapper(model) assert is_model_wrapper(wrapper_model) @@ -98,8 +102,10 @@ def __init__(self, model): # Test `is_model_wrapper` will not check model wrapper in parent # registry from a child registry. for wrapper in [ - DistributedDataParallel, MMDistributedDataParallel, - MMSeparateDistributedDataParallel, DataParallel + DistributedDataParallel, + MMDistributedDataParallel, + MMSeparateDistributedDataParallel, + DataParallel, ]: wrapper_model = wrapper(model) assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY) diff --git a/tests/test_model/test_test_aug_time.py b/tests/test_model/test_test_aug_time.py index d2b8c97190..48a6b0af24 100644 --- a/tests/test_model/test_test_aug_time.py +++ b/tests/test_model/test_test_aug_time.py @@ -11,26 +11,23 @@ class ToyTTAPipeline: - def __call__(self, result): return {key: [value] for key, value in result.items()} class ToyTestTimeAugModel(BaseTTAModel): - def merge_preds(self, data_samples_list): result = [sum(x) for x in data_samples_list] return result class ToyModel(BaseModel): - def __init__(self): super().__init__() # DDPWrapper requires at least one parameter. self.linear = torch.nn.Linear(1, 1) - def forward(self, inputs, data_samples, mode='tensor'): + def forward(self, inputs, data_samples, mode="tensor"): return data_samples @@ -56,7 +53,6 @@ def __getitem__(self, index): class TestBaseTTAModel(RunnerTestCase): - def setUp(self) -> None: super().setUp() DATASETS.register_module(module=ToyDatasetTTA, force=True) @@ -66,23 +62,19 @@ def setUp(self) -> None: def tearDown(self): super().tearDown() - DATASETS.module_dict.pop('ToyDatasetTTA', None) - MODELS.module_dict.pop('ToyTestTimeAugModel', None) - MODELS.module_dict.pop('ToyModel', None) - TRANSFORMS.module_dict.pop('ToyTTAPipeline', None) + DATASETS.module_dict.pop("ToyDatasetTTA", None) + MODELS.module_dict.pop("ToyTestTimeAugModel", None) + MODELS.module_dict.pop("ToyModel", None) + TRANSFORMS.module_dict.pop("ToyTTAPipeline", None) def test_test_step(self): model = ToyModel() tta_model = ToyTestTimeAugModel(model) - dict_dataset = [ - dict(inputs=[1, 2], data_samples=[3, 4]) for _ in range(10) - ] + dict_dataset = [dict(inputs=[1, 2], data_samples=[3, 4]) for _ in range(10)] tuple_dataset = [([1, 2], [3, 4]) for _ in range(10)] - dict_dataloader = DataLoader( - dict_dataset, batch_size=2, collate_fn=pseudo_collate) - tuple_dataloader = DataLoader( - tuple_dataset, batch_size=2, collate_fn=pseudo_collate) + dict_dataloader = DataLoader(dict_dataset, batch_size=2, collate_fn=pseudo_collate) + tuple_dataloader = DataLoader(tuple_dataset, batch_size=2, collate_fn=pseudo_collate) for data in dict_dataloader: result = tta_model.test_step(data) @@ -97,21 +89,20 @@ def test_init(self): tta_model = ToyTestTimeAugModel(model) self.assertIs(tta_model.module, model) # Test build from cfg. - model = dict(type='ToyModel') + model = dict(type="ToyModel") tta_model = ToyTestTimeAugModel(model) self.assertIsInstance(tta_model.module, ToyModel) def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.model = dict( - type='ToyTestTimeAugModel', module=dict(type='ToyModel')) - cfg.test_dataloader.dataset = dict(type='ToyDatasetTTA') - cfg.test_dataloader.dataset['pipeline'] = dict(type='ToyTTAPipeline') + cfg.model = dict(type="ToyTestTimeAugModel", module=dict(type="ToyModel")) + cfg.test_dataloader.dataset = dict(type="ToyDatasetTTA") + cfg.test_dataloader.dataset["pipeline"] = dict(type="ToyTTAPipeline") runner = self.build_runner(cfg) runner.test() if torch.cuda.is_available() and torch.distributed.is_nccl_available(): - cfg.launcher = 'pytorch' + cfg.launcher = "pytorch" self.setup_dist_env() runner = self.build_runner(cfg) runner.test() diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index ea657acac1..d6e00ed9be 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -9,58 +9,59 @@ from torch.optim import SGD from mmengine.dist import all_gather, broadcast -from mmengine.model import (BaseDataPreprocessor, BaseModel, - ExponentialMovingAverage, - MMDistributedDataParallel, - MMSeparateDistributedDataParallel) +from mmengine.model import ( + BaseDataPreprocessor, + BaseModel, + ExponentialMovingAverage, + MMDistributedDataParallel, + MMSeparateDistributedDataParallel, +) from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version -if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): + +if digit_version(TORCH_VERSION) >= digit_version("2.0.0"): from mmengine.model import MMFullyShardedDataParallel # noqa: F401 class ToyDataPreprocessor(BaseDataPreprocessor): - def forward(self, data: dict, training: bool = False): self.called = True return super().forward(data, training) class ToyModel(BaseModel): - def __init__(self): super().__init__(data_preprocessor=ToyDataPreprocessor()) self.conv1 = nn.Conv2d(3, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) - def forward(self, inputs, data_sample=None, mode='tensor'): + def forward(self, inputs, data_sample=None, mode="tensor"): x = self.conv1(inputs) x = self.conv2(x) - if mode == 'loss': + if mode == "loss": return dict(loss=x) - elif mode == 'predict': + elif mode == "predict": return x else: return x class ComplexModel(BaseModel): - def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 1, 1) self.conv2 = nn.Conv2d(3, 1, 1) def train_step(self, data, optim_wrapper): - inputs = self.data_preprocessor(data)['inputs'] + inputs = self.data_preprocessor(data)["inputs"] loss1 = self.conv1(inputs) - optim_wrapper['optim_wrapper1'].update_params(loss1) + optim_wrapper["optim_wrapper1"].update_params(loss1) loss2 = self.conv2(inputs) - optim_wrapper['optim_wrapper2'].update_params(loss2) + optim_wrapper["optim_wrapper2"].update_params(loss2) return dict(loss1=loss1, loss2=loss2) def val_step(self, data): @@ -74,13 +75,11 @@ def forward(self): class TestDistributedDataParallel(MultiProcessTestCase): - def setUp(self): super().setUp() self._spawn_processes() - @unittest.skipIf( - not torch.cuda.is_available(), reason='cuda should be available') + @unittest.skipIf(not torch.cuda.is_available(), reason="cuda should be available") def test_train_step(self): self._init_dist_env(self.rank, self.world_size) # Mixed precision training and gradient asynchronous should be valid at @@ -88,11 +87,10 @@ def test_train_step(self): model = ToyModel().cuda() ddp_model = MMDistributedDataParallel(module=model) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = AmpOptimWrapper( - optimizer=optimizer, accumulative_counts=3) + optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=3) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) - res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] + res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)["loss"] self.assertIs(res.dtype, torch.float16) grad = ddp_model.module.conv1.weight.grad all_grads = all_gather(grad) @@ -105,7 +103,7 @@ def test_train_step(self): # Test update params and clean grads. ddp_model.train_step(data, optim_wrapper=optim_wrapper) grad = ddp_model.module.conv1.weight.grad - if digit_version(torch.__version__) < digit_version('2.0.0'): + if digit_version(torch.__version__) < digit_version("2.0.0"): all_grads = all_gather(grad) assert_allclose(all_grads[0], torch.zeros_like(all_grads[0])) assert_allclose(all_grads[1], torch.zeros_like(all_grads[0])) @@ -113,14 +111,12 @@ def test_train_step(self): self.assertIsNone(grad) # Test enable detect_anomalous_params. - ddp_model = MMDistributedDataParallel( - module=model, detect_anomalous_params=True) + ddp_model = MMDistributedDataParallel(module=model, detect_anomalous_params=True) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = AmpOptimWrapper( - optimizer=optimizer, accumulative_counts=3) + optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=3) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) - res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] + res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)["loss"] def test_val_step(self): self._init_dist_env(self.rank, self.world_size) @@ -145,17 +141,14 @@ def test_test_step(self): def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29510' - os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29510" + os.environ["RANK"] = str(rank) + torch_dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) -@unittest.skipIf( - not torch.cuda.is_available(), reason='cuda should be available') +@unittest.skipIf(not torch.cuda.is_available(), reason="cuda should be available") class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): - def test_init(self): self._init_dist_env(self.rank, self.world_size) model = ComplexModel() @@ -163,8 +156,7 @@ def test_init(self): model.act = nn.ReLU() ddp_model = MMSeparateDistributedDataParallel(model.cuda()) self.assertIsInstance(ddp_model.module.ema, ExponentialMovingAverage) - self.assertIsInstance(ddp_model.module.conv1, - MMDistributedDataParallel) + self.assertIsInstance(ddp_model.module.conv1, MMDistributedDataParallel) self.assertIsInstance(ddp_model.module.act, nn.ReLU) def test_train_step(self): @@ -178,8 +170,7 @@ def test_train_step(self): optimizer2 = SGD(model.conv1.parameters(), lr=0.2) optim_wrapper1 = OptimWrapper(optimizer1, 1) optim_wrapper2 = OptimWrapper(optimizer2, 1) - optim_wrapper_dict = OptimWrapperDict( - optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2) + optim_wrapper_dict = OptimWrapperDict(optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) # Automatically sync grads of `optim_wrapper1` since @@ -212,30 +203,24 @@ def test_test_step(self): def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29515' - os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) - - -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp') -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('2.0.0'), - reason='fsdp needs Pytorch 2.0.0 or higher') -class TestMMFullyShardedDataParallel(MultiProcessTestCase): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29515" + os.environ["RANK"] = str(rank) + torch_dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + +@unittest.skipIf(torch.cuda.device_count() < 2, reason="need 2 gpu to test fsdp") +@unittest.skipIf(digit_version(TORCH_VERSION) < digit_version("2.0.0"), reason="fsdp needs Pytorch 2.0.0 or higher") +class TestMMFullyShardedDataParallel(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29520' - os.environ['RANK'] = str(rank) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29520" + os.environ["RANK"] = str(rank) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) def setUp(self) -> None: super().setUp() @@ -256,7 +241,7 @@ def test_train_step(self): # require_grad=False model = ToyModel() - for _, param in model.state_dict().items(): + for param in model.state_dict().values(): broadcast(param) model.conv1.requires_grad_(False) ori_weight = model.conv1.weight.clone() @@ -266,8 +251,7 @@ def wrap_policy(module, recurse=True, *args, **kwargs): return True return isinstance(module, nn.Conv2d) - fsdp_model = MMFullyShardedDataParallel( - module=model.cuda(), auto_wrap_policy=wrap_policy) + fsdp_model = MMFullyShardedDataParallel(module=model.cuda(), auto_wrap_policy=wrap_policy) optimizer = SGD(fsdp_model.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1) inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 113aacd6c8..cc5bed9268 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -11,28 +11,35 @@ from mmengine.dist import get_rank from mmengine.logging import MMLogger -from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, - DefaultOptimWrapperConstructor, OptimWrapper, - build_optim_wrapper) -from mmengine.optim.optimizer.builder import (BITSANDBYTES_OPTIMIZERS, - DADAPTATION_OPTIMIZERS, - LION_OPTIMIZERS, - TORCH_OPTIMIZERS, - TRANSFORMERS_OPTIMIZERS) +from mmengine.optim import ( + OPTIM_WRAPPER_CONSTRUCTORS, + OPTIMIZERS, + DefaultOptimWrapperConstructor, + OptimWrapper, + build_optim_wrapper, +) +from mmengine.optim.optimizer.builder import ( + BITSANDBYTES_OPTIMIZERS, + DADAPTATION_OPTIMIZERS, + LION_OPTIMIZERS, + TORCH_OPTIMIZERS, + TRANSFORMERS_OPTIMIZERS, +) from mmengine.registry import DefaultScope, Registry, build_from_cfg from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available from mmengine.utils.version_utils import digit_version + MMCV_FULL_AVAILABLE = mmcv_full_available() if not MMCV_FULL_AVAILABLE: - sys.modules['mmcv.ops'] = MagicMock( - DeformConv2d=dict, ModulatedDeformConv2d=dict) + sys.modules["mmcv.ops"] = MagicMock(DeformConv2d=dict, ModulatedDeformConv2d=dict) def has_dadaptation() -> bool: try: import dadaptation # noqa: F401 + return True except ImportError: return False @@ -41,6 +48,7 @@ def has_dadaptation() -> bool: def has_lion() -> bool: try: import lion_pytorch # noqa: F401 + return True except ImportError: return False @@ -49,6 +57,7 @@ def has_lion() -> bool: def has_bitsandbytes() -> bool: try: import bitsandbytes # noqa: F401 + return True except ImportError: return False @@ -57,13 +66,13 @@ def has_bitsandbytes() -> bool: def has_transformers() -> bool: try: import transformers # noqa: F401 + return True except ImportError: return False class ExampleModel(nn.Module): - def __init__(self): super().__init__() self.param1 = nn.Parameter(torch.ones(1)) @@ -73,12 +82,11 @@ def __init__(self): self.sub = SubModel() if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack - self.dcn = DeformConv2dPack( - 3, 4, kernel_size=3, deformable_groups=1) + self.dcn = DeformConv2dPack(3, 4, kernel_size=3, deformable_groups=1) -class ExampleDuplicateModel(nn.Module): +class ExampleDuplicateModel(nn.Module): def __init__(self): super().__init__() self.param1 = nn.Parameter(torch.ones(1)) @@ -90,15 +98,14 @@ def __init__(self): self.conv3[0] = self.conv1[0] if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack - self.dcn = DeformConv2dPack( - 3, 4, kernel_size=3, deformable_groups=1) + + self.dcn = DeformConv2dPack(3, 4, kernel_size=3, deformable_groups=1) def forward(self, x): return x class SubModel(nn.Module): - def __init__(self): super().__init__() self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2) @@ -110,7 +117,6 @@ def forward(self, x): class PseudoDataParallel(nn.Module): - def __init__(self): super().__init__() self.module = ExampleModel() @@ -120,270 +126,286 @@ def forward(self, x): class TestBuilder(TestCase): - def setUp(self): self.model = ExampleModel() self.base_lr = 0.01 self.momentum = 0.0001 self.base_wd = 0.9 - def _check_default_optimizer(self, optimizer, model, prefix=''): + def _check_default_optimizer(self, optimizer, model, prefix=""): assert isinstance(optimizer, torch.optim.SGD) - assert optimizer.defaults['lr'] == self.base_lr - assert optimizer.defaults['momentum'] == self.momentum - assert optimizer.defaults['weight_decay'] == self.base_wd + assert optimizer.defaults["lr"] == self.base_lr + assert optimizer.defaults["momentum"] == self.momentum + assert optimizer.defaults["weight_decay"] == self.base_wd param_groups = optimizer.param_groups[0] if MMCV_FULL_AVAILABLE: param_names = [ - 'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', - 'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight', - 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight', - 'dcn.conv_offset.weight', 'dcn.conv_offset.bias' + "param1", + "conv1.weight", + "conv2.weight", + "conv2.bias", + "bn.weight", + "bn.bias", + "sub.param1", + "sub.conv1.weight", + "sub.conv1.bias", + "sub.gn.weight", + "sub.gn.bias", + "dcn.weight", + "dcn.conv_offset.weight", + "dcn.conv_offset.bias", ] else: param_names = [ - 'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', - 'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight', - 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias' + "param1", + "conv1.weight", + "conv2.weight", + "conv2.bias", + "bn.weight", + "bn.bias", + "sub.param1", + "sub.conv1.weight", + "sub.conv1.bias", + "sub.gn.weight", + "sub.gn.bias", ] param_dict = dict(model.named_parameters()) - assert len(param_groups['params']) == len(param_names) - for i in range(len(param_groups['params'])): - assert torch.equal(param_groups['params'][i], - param_dict[prefix + param_names[i]]) - - def _check_sgd_optimizer(self, - optimizer, - model, - prefix='', - bias_lr_mult=1, - bias_decay_mult=1, - norm_decay_mult=1, - dwconv_decay_mult=1, - dcn_offset_lr_mult=1, - flat_decay_mult=1, - bypass_duplicate=False): + assert len(param_groups["params"]) == len(param_names) + for i in range(len(param_groups["params"])): + assert torch.equal(param_groups["params"][i], param_dict[prefix + param_names[i]]) + + def _check_sgd_optimizer( + self, + optimizer, + model, + prefix="", + bias_lr_mult=1, + bias_decay_mult=1, + norm_decay_mult=1, + dwconv_decay_mult=1, + dcn_offset_lr_mult=1, + flat_decay_mult=1, + bypass_duplicate=False, + ): param_groups = optimizer.param_groups assert isinstance(optimizer, torch.optim.SGD) - assert optimizer.defaults['lr'] == self.base_lr - assert optimizer.defaults['momentum'] == self.momentum - assert optimizer.defaults['weight_decay'] == self.base_wd + assert optimizer.defaults["lr"] == self.base_lr + assert optimizer.defaults["momentum"] == self.momentum + assert optimizer.defaults["weight_decay"] == self.base_wd model_parameters = list(model.parameters()) assert len(param_groups) == len(model_parameters) for i, param in enumerate(model_parameters): param_group = param_groups[i] - assert torch.equal(param_group['params'][0], param) - assert param_group['momentum'] == self.momentum + assert torch.equal(param_group["params"][0], param) + assert param_group["momentum"] == self.momentum # param1 param1 = param_groups[0] - assert param1['lr'] == self.base_lr - assert param1['weight_decay'] == self.base_wd * flat_decay_mult + assert param1["lr"] == self.base_lr + assert param1["weight_decay"] == self.base_wd * flat_decay_mult # conv1.weight conv1_weight = param_groups[1] - assert conv1_weight['lr'] == self.base_lr - assert conv1_weight['weight_decay'] == self.base_wd + assert conv1_weight["lr"] == self.base_lr + assert conv1_weight["weight_decay"] == self.base_wd # conv2.weight conv2_weight = param_groups[2] - assert conv2_weight['lr'] == self.base_lr - assert conv2_weight['weight_decay'] == self.base_wd + assert conv2_weight["lr"] == self.base_lr + assert conv2_weight["weight_decay"] == self.base_wd # conv2.bias conv2_bias = param_groups[3] - assert conv2_bias['lr'] == self.base_lr * bias_lr_mult - assert conv2_bias['weight_decay'] == self.base_wd * bias_decay_mult + assert conv2_bias["lr"] == self.base_lr * bias_lr_mult + assert conv2_bias["weight_decay"] == self.base_wd * bias_decay_mult # bn.weight bn_weight = param_groups[4] - assert bn_weight['lr'] == self.base_lr - assert bn_weight['weight_decay'] == self.base_wd * norm_decay_mult + assert bn_weight["lr"] == self.base_lr + assert bn_weight["weight_decay"] == self.base_wd * norm_decay_mult # bn.bias bn_bias = param_groups[5] - assert bn_bias['lr'] == self.base_lr - assert bn_bias['weight_decay'] == self.base_wd * norm_decay_mult + assert bn_bias["lr"] == self.base_lr + assert bn_bias["weight_decay"] == self.base_wd * norm_decay_mult # sub.param1 sub_param1 = param_groups[6] - assert sub_param1['lr'] == self.base_lr - assert sub_param1['weight_decay'] == self.base_wd * flat_decay_mult + assert sub_param1["lr"] == self.base_lr + assert sub_param1["weight_decay"] == self.base_wd * flat_decay_mult # sub.conv1.weight sub_conv1_weight = param_groups[7] - assert sub_conv1_weight['lr'] == self.base_lr - assert sub_conv1_weight[ - 'weight_decay'] == self.base_wd * dwconv_decay_mult + assert sub_conv1_weight["lr"] == self.base_lr + assert sub_conv1_weight["weight_decay"] == self.base_wd * dwconv_decay_mult # sub.conv1.bias sub_conv1_bias = param_groups[8] - assert sub_conv1_bias['lr'] == self.base_lr * bias_lr_mult - assert sub_conv1_bias['weight_decay'] == self.base_wd * bias_decay_mult + assert sub_conv1_bias["lr"] == self.base_lr * bias_lr_mult + assert sub_conv1_bias["weight_decay"] == self.base_wd * bias_decay_mult # sub.gn.weight sub_gn_weight = param_groups[9] - assert sub_gn_weight['lr'] == self.base_lr - assert sub_gn_weight['weight_decay'] == self.base_wd * norm_decay_mult + assert sub_gn_weight["lr"] == self.base_lr + assert sub_gn_weight["weight_decay"] == self.base_wd * norm_decay_mult # sub.gn.bias sub_gn_bias = param_groups[10] - assert sub_gn_bias['lr'] == self.base_lr - assert sub_gn_bias['weight_decay'] == self.base_wd * norm_decay_mult + assert sub_gn_bias["lr"] == self.base_lr + assert sub_gn_bias["weight_decay"] == self.base_wd * norm_decay_mult # test dcn which requires cuda is available and # mmcv-full has been installed if torch.cuda.is_available() and MMCV_FULL_AVAILABLE: dcn_conv_weight = param_groups[11] - assert dcn_conv_weight['lr'] == self.base_lr - assert dcn_conv_weight['weight_decay'] == self.base_wd + assert dcn_conv_weight["lr"] == self.base_lr + assert dcn_conv_weight["weight_decay"] == self.base_wd dcn_offset_weight = param_groups[12] - assert dcn_offset_weight['lr'] == self.base_lr * dcn_offset_lr_mult - assert dcn_offset_weight['weight_decay'] == self.base_wd + assert dcn_offset_weight["lr"] == self.base_lr * dcn_offset_lr_mult + assert dcn_offset_weight["weight_decay"] == self.base_wd dcn_offset_bias = param_groups[13] - assert dcn_offset_bias['lr'] == self.base_lr * dcn_offset_lr_mult - assert dcn_offset_bias['weight_decay'] == self.base_wd + assert dcn_offset_bias["lr"] == self.base_lr * dcn_offset_lr_mult + assert dcn_offset_bias["weight_decay"] == self.base_wd def test_torch_optimizers(self): torch_optimizers = [ - 'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS', - 'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam' + "ASGD", + "Adadelta", + "Adagrad", + "Adam", + "AdamW", + "Adamax", + "LBFGS", + "Optimizer", + "RMSprop", + "Rprop", + "SGD", + "SparseAdam", ] assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS)) - @unittest.skipIf(not has_dadaptation(), 'dadaptation is not installed') + @unittest.skipIf(not has_dadaptation(), "dadaptation is not installed") def test_dadaptation_optimizers(self): - dadaptation_optimizers = ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD'] - assert set(dadaptation_optimizers).issubset( - set(DADAPTATION_OPTIMIZERS)) + dadaptation_optimizers = ["DAdaptAdaGrad", "DAdaptAdam", "DAdaptSGD"] + assert set(dadaptation_optimizers).issubset(set(DADAPTATION_OPTIMIZERS)) - @unittest.skipIf(not has_lion(), 'lion-pytorch is not installed') + @unittest.skipIf(not has_lion(), "lion-pytorch is not installed") def test_lion_optimizers(self): - assert 'Lion' in LION_OPTIMIZERS + assert "Lion" in LION_OPTIMIZERS - @unittest.skipIf(not has_bitsandbytes(), 'bitsandbytes is not installed') + @unittest.skipIf(not has_bitsandbytes(), "bitsandbytes is not installed") def test_bitsandbytes_optimizers(self): bitsandbytes_optimizers = [ - 'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit', - 'PagedAdamW8bit', 'LAMB8bit', 'LARS8bit', 'RMSprop8bit', - 'Lion8bit', 'PagedLion8bit', 'SGD8bit' + "AdamW8bit", + "Adam8bit", + "Adagrad8bit", + "PagedAdam8bit", + "PagedAdamW8bit", + "LAMB8bit", + "LARS8bit", + "RMSprop8bit", + "Lion8bit", + "PagedLion8bit", + "SGD8bit", ] - assert set(bitsandbytes_optimizers).issubset( - set(BITSANDBYTES_OPTIMIZERS)) + assert set(bitsandbytes_optimizers).issubset(set(BITSANDBYTES_OPTIMIZERS)) - @unittest.skipIf(not has_transformers(), 'transformers is not installed') + @unittest.skipIf(not has_transformers(), "transformers is not installed") def test_transformers_optimizers(self): - transformers_optimizers = ['Adafactor'] - assert set(transformers_optimizers).issubset( - set(TRANSFORMERS_OPTIMIZERS)) + transformers_optimizers = ["Adafactor"] + assert set(transformers_optimizers).issubset(set(TRANSFORMERS_OPTIMIZERS)) def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build optimizer without type in optim_wrapper_cfg optim_wrapper_cfg = dict( - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum) + ) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapper) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build function with invalid ``constructor`` with self.assertRaises(KeyError): - optim_wrapper_cfg['constructor'] = 'INVALID_CONSTRUCTOR' + optim_wrapper_cfg["constructor"] = "INVALID_CONSTRUCTOR" build_optim_wrapper(self.model, optim_wrapper_cfg) # test build function with invalid ``paramwise_cfg`` with self.assertRaises(KeyError): - optim_wrapper_cfg['paramwise_cfg'] = dict(invalid_mult=1) + optim_wrapper_cfg["paramwise_cfg"] = dict(invalid_mult=1) build_optim_wrapper(self.model, optim_wrapper_cfg) - optim_wrapper_cfg.pop('optimizer') - optim_wrapper_cfg.pop('constructor') - optim_wrapper_cfg.pop('paramwise_cfg') + optim_wrapper_cfg.pop("optimizer") + optim_wrapper_cfg.pop("constructor") + optim_wrapper_cfg.pop("paramwise_cfg") self.assertRaisesRegex( - AssertionError, '`optim_wrapper_cfg` must contain', - lambda: build_optim_wrapper(self.model, optim_wrapper_cfg)) + AssertionError, + "`optim_wrapper_cfg` must contain", + lambda: build_optim_wrapper(self.model, optim_wrapper_cfg), + ) def test_build_default_optimizer_constructor(self): optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + flat_decay_mult=0.3, + ) optim_constructor_cfg = dict( - type='DefaultOptimWrapperConstructor', - optim_wrapper_cfg=optim_wrapper, - paramwise_cfg=paramwise_cfg) - optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - optim_constructor_cfg) + type="DefaultOptimWrapperConstructor", optim_wrapper_cfg=optim_wrapper, paramwise_cfg=paramwise_cfg + ) + optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(optim_constructor_cfg) optim_wrapper = optim_constructor(self.model) - self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, - **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, **paramwise_cfg) def test_build_custom_optimizer_constructor(self): optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) @OPTIM_WRAPPER_CONSTRUCTORS.register_module() class MyOptimizerConstructor(DefaultOptimWrapperConstructor): - def __call__(self, model): - if hasattr(model, 'module'): + if hasattr(model, "module"): model = model.module - conv1_lr_mult = self.paramwise_cfg.get('conv1_lr_mult', 1.) + conv1_lr_mult = self.paramwise_cfg.get("conv1_lr_mult", 1.0) params = [] for name, param in model.named_parameters(): - param_group = {'params': [param]} - if name.startswith('conv1') and param.requires_grad: - param_group['lr'] = self.base_lr * conv1_lr_mult + param_group = {"params": [param]} + if name.startswith("conv1") and param.requires_grad: + param_group["lr"] = self.base_lr * conv1_lr_mult params.append(param_group) - self.optimizer_cfg['params'] = params + self.optimizer_cfg["params"] = params return build_from_cfg(self.optimizer_cfg, OPTIMIZERS) paramwise_cfg = dict(conv1_lr_mult=5) optim_constructor_cfg = dict( - type='MyOptimizerConstructor', - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg) - optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - optim_constructor_cfg) + type="MyOptimizerConstructor", optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg + ) + optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(optim_constructor_cfg) optimizer = optim_constructor(self.model) param_groups = optimizer.param_groups assert isinstance(optimizer, torch.optim.SGD) - assert optimizer.defaults['lr'] == self.base_lr - assert optimizer.defaults['momentum'] == self.momentum - assert optimizer.defaults['weight_decay'] == self.base_wd + assert optimizer.defaults["lr"] == self.base_lr + assert optimizer.defaults["momentum"] == self.momentum + assert optimizer.defaults["weight_decay"] == self.base_wd for i, param in enumerate(self.model.parameters()): param_group = param_groups[i] - assert torch.equal(param_group['params'][0], param) - assert param_group['momentum'] == self.momentum + assert torch.equal(param_group["params"][0], param) + assert param_group["momentum"] == self.momentum # conv1.weight - assert param_groups[1][ - 'lr'] == self.base_lr * paramwise_cfg['conv1_lr_mult'] - assert param_groups[1]['weight_decay'] == self.base_wd + assert param_groups[1]["lr"] == self.base_lr * paramwise_cfg["conv1_lr_mult"] + assert param_groups[1]["weight_decay"] == self.base_wd def test_default_optimizer_constructor(self): with self.assertRaises(TypeError): @@ -394,146 +416,113 @@ def test_default_optimizer_constructor(self): with self.assertRaises(TypeError): # paramwise_cfg must be a dict or None - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict(lr=0.0001, weight_decay=None)) - paramwise_cfg = ['error'] - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_wrapper_cfg = dict(type="OptimWrapper", optimizer=dict(lr=0.0001, weight_decay=None)) + paramwise_cfg = ["error"] + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) with self.assertRaises(ValueError): # bias_decay_mult/norm_decay_mult is specified but weight_decay # is None - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict(lr=0.0001, weight_decay=None)) + optim_wrapper_cfg = dict(type="OptimWrapper", optimizer=dict(lr=0.0001, weight_decay=None)) paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) # basic config with ExampleModel optimizer_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # Support building custom optimizers - CUSTOM_OPTIMIZERS = Registry( - 'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS) + CUSTOM_OPTIMIZERS = Registry("custom optimizer", scope="custom optimizer", parent=OPTIMIZERS) class CustomOptimizer(torch.optim.SGD): - def __init__(self, model_params, *args, **kwargs): super().__init__(params=model_params, *args, **kwargs) CUSTOM_OPTIMIZERS.register_module()(CustomOptimizer) - optimizer_cfg = dict(optimizer=dict(type='CustomOptimizer', lr=0.1), ) - with DefaultScope.overwrite_default_scope('custom optimizer'): + optimizer_cfg = dict( + optimizer=dict(type="CustomOptimizer", lr=0.1), + ) + with DefaultScope.overwrite_default_scope("custom optimizer"): optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_wrapper = optim_constructor(self.model) - OPTIMIZERS.children.pop('custom optimizer') + OPTIMIZERS.children.pop("custom optimizer") def test_default_optimizer_constructor_with_model_wrapper(self): # basic config with pseudo data parallel model = PseudoDataParallel() optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) optim_wrapper = optim_constructor(model) - self._check_default_optimizer( - optim_wrapper.optimizer, model, prefix='module.') + self._check_default_optimizer(optim_wrapper.optimizer, model, prefix="module.") # paramwise_cfg with pseudo data parallel model = PseudoDataParallel() optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + flat_decay_mult=0.3, + ) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self._check_sgd_optimizer( - optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, model, prefix="module.", **paramwise_cfg) # basic config with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(ExampleModel()) optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = None - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) optim_wrapper = optim_constructor(model) - self._check_default_optimizer( - optim_wrapper.optimizer, model, prefix='module.') + self._check_default_optimizer(optim_wrapper.optimizer, model, prefix="module.") # paramwise_cfg with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(self.model) optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + flat_decay_mult=0.3, + ) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self._check_sgd_optimizer( - optim_wrapper.optimizer, - model, - prefix='module.', - **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, model, prefix="module.", **paramwise_cfg) def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): # Empty paramwise_cfg with ExampleModel optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict() - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) @@ -542,64 +531,49 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): for param in model.parameters(): param.requires_grad = False optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict() - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) self._check_default_optimizer(optim_wrapper.optimizer, model) def test_default_optimizer_constructor_with_paramwise_cfg(self): # paramwise_cfg with ExampleModel optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + flat_decay_mult=0.3, + ) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) - self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, - **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, **paramwise_cfg) def test_default_optimizer_constructor_no_grad(self): # paramwise_cfg with ExampleModel and no grad optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1) + bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1 + ) self.model.conv1.requires_grad_(False) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) all_params = [] for pg in optim_wrapper.param_groups: - all_params.extend(map(id, pg['params'])) + all_params.extend(map(id, pg["params"])) self.assertNotIn(id(self.model.conv1.weight), all_params) self.assertIn(id(self.model.conv2.weight), all_params) @@ -607,23 +581,13 @@ def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option model = ExampleDuplicateModel() optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) + paramwise_cfg = dict(bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1) - with self.assertRaisesRegex( - ValueError, - 'some parameters appear in more than one parameter group'): - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + with self.assertRaisesRegex(ValueError, "some parameters appear in more than one parameter group"): + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optim_constructor(model) paramwise_cfg = dict( @@ -633,78 +597,68 @@ def test_default_optimizer_constructor_bypass_duplicate(self): dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, flat_decay_mult=0.3, - bypass_duplicate=True) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + bypass_duplicate=True, + ) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): # Warning should be raised since conv3.0 is a duplicate param. optim_constructor(model) optim_wrapper = optim_constructor(model) model_parameters = list(model.parameters()) num_params = 14 if MMCV_FULL_AVAILABLE else 11 - assert len(optim_wrapper.optimizer.param_groups) == len( - model_parameters) == num_params - self._check_sgd_optimizer(optim_wrapper.optimizer, model, - **paramwise_cfg) + assert len(optim_wrapper.optimizer.param_groups) == len(model_parameters) == num_params + self._check_sgd_optimizer(optim_wrapper.optimizer, model, **paramwise_cfg) # test DefaultOptimWrapperConstructor when the params in shared # modules do not require grad model.conv1[0].requires_grad_(False) - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): # Warning should be raised since conv3.0 is a duplicate param. optim_constructor(model) optim_wrapper = optim_constructor(model) model_parameters = list(model.parameters()) num_params = 14 if MMCV_FULL_AVAILABLE else 11 - assert len(optim_wrapper.optimizer.param_groups - ) == len(model_parameters) - 1 == num_params - 1 + assert len(optim_wrapper.optimizer.param_groups) == len(model_parameters) - 1 == num_params - 1 def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="OptimWrapper", + optimizer=dict(type="SGD", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum), + ) paramwise_cfg = dict( custom_keys={ - 'param1': dict(lr_mult=10), - 'sub': dict(lr_mult=0.1, decay_mult=0), - 'sub.gn': dict(lr_mult=0.01), - 'non_exist_key': dict(lr_mult=0.0) + "param1": dict(lr_mult=10), + "sub": dict(lr_mult=0.1, decay_mult=0), + "sub.gn": dict(lr_mult=0.01), + "non_exist_key": dict(lr_mult=0.0), }, - norm_decay_mult=0.5) + norm_decay_mult=0.5, + ) with self.assertRaises(TypeError): # custom_keys should be a dict paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001]) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg_) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg_) optimizer = optim_constructor(self.model) with self.assertRaises(ValueError): # if 'decay_mult' is specified in custom_keys, weight_decay # should be specified - optim_wrapper_cfg_ = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)) - paramwise_cfg_ = dict( - custom_keys={'.backbone': dict(decay_mult=0.5)}) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg_, paramwise_cfg_) + optim_wrapper_cfg_ = dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)) + paramwise_cfg_ = dict(custom_keys={".backbone": dict(decay_mult=0.5)}) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg_, paramwise_cfg_) optim_constructor(self.model) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optimizer = optim_constructor(self.model).optimizer # check optimizer type and default config assert isinstance(optimizer, torch.optim.SGD) - assert optimizer.defaults['lr'] == self.base_lr - assert optimizer.defaults['momentum'] == self.momentum - assert optimizer.defaults['weight_decay'] == self.base_wd + assert optimizer.defaults["lr"] == self.base_lr + assert optimizer.defaults["momentum"] == self.momentum + assert optimizer.defaults["weight_decay"] == self.base_wd # check params groups param_groups = optimizer.param_groups @@ -713,67 +667,68 @@ def test_default_optimizer_constructor_custom_key(self): group_settings = [] # group 1, matches of 'param1' # 'param1' is the longest match for 'sub.param1' - groups.append(['param1', 'sub.param1']) - group_settings.append({ - 'lr': self.base_lr * 10, - 'momentum': self.momentum, - 'weight_decay': self.base_wd, - }) + groups.append(["param1", "sub.param1"]) + group_settings.append( + { + "lr": self.base_lr * 10, + "momentum": self.momentum, + "weight_decay": self.base_wd, + } + ) # group 2, matches of 'sub.gn' - groups.append(['sub.gn.weight', 'sub.gn.bias']) - group_settings.append({ - 'lr': self.base_lr * 0.01, - 'momentum': self.momentum, - 'weight_decay': self.base_wd, - }) + groups.append(["sub.gn.weight", "sub.gn.bias"]) + group_settings.append( + { + "lr": self.base_lr * 0.01, + "momentum": self.momentum, + "weight_decay": self.base_wd, + } + ) # group 3, matches of 'sub' - groups.append(['sub.conv1.weight', 'sub.conv1.bias']) - group_settings.append({ - 'lr': self.base_lr * 0.1, - 'momentum': self.momentum, - 'weight_decay': 0, - }) + groups.append(["sub.conv1.weight", "sub.conv1.bias"]) + group_settings.append( + { + "lr": self.base_lr * 0.1, + "momentum": self.momentum, + "weight_decay": 0, + } + ) # group 4, bn is configured by 'norm_decay_mult' - groups.append(['bn.weight', 'bn.bias']) - group_settings.append({ - 'lr': self.base_lr, - 'momentum': self.momentum, - 'weight_decay': self.base_wd * 0.5, - }) + groups.append(["bn.weight", "bn.bias"]) + group_settings.append( + { + "lr": self.base_lr, + "momentum": self.momentum, + "weight_decay": self.base_wd * 0.5, + } + ) # group 5, default group - groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias']) - group_settings.append({ - 'lr': self.base_lr, - 'momentum': self.momentum, - 'weight_decay': self.base_wd - }) + groups.append(["conv1.weight", "conv2.weight", "conv2.bias"]) + group_settings.append({"lr": self.base_lr, "momentum": self.momentum, "weight_decay": self.base_wd}) num_params = 14 if MMCV_FULL_AVAILABLE else 11 assert len(param_groups) == num_params for i, (name, param) in enumerate(self.model.named_parameters()): - assert torch.equal(param_groups[i]['params'][0], param) - for group, settings in zip(groups, group_settings): + assert torch.equal(param_groups[i]["params"][0], param) + for group, settings in zip(groups, group_settings, strict=False): if name in group: for setting in settings: - assert param_groups[i][setting] == settings[ - setting], f'{name} {setting}' + assert param_groups[i][setting] == settings[setting], f"{name} {setting}" # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel 2 optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', lr=self.base_lr, momentum=self.momentum)) - paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) + type="OptimWrapper", optimizer=dict(type="SGD", lr=self.base_lr, momentum=self.momentum) + ) + paramwise_cfg = dict(custom_keys={"param1": dict(lr_mult=10)}) - optim_constructor = DefaultOptimWrapperConstructor( - optim_wrapper_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg, paramwise_cfg) optimizer = optim_constructor(self.model).optimizer # check optimizer type and default config assert isinstance(optimizer, torch.optim.SGD) - assert optimizer.defaults['lr'] == self.base_lr - assert optimizer.defaults['momentum'] == self.momentum - assert optimizer.defaults['weight_decay'] == 0 + assert optimizer.defaults["lr"] == self.base_lr + assert optimizer.defaults["momentum"] == self.momentum + assert optimizer.defaults["weight_decay"] == 0 # check params groups param_groups = optimizer.param_groups @@ -781,63 +736,61 @@ def test_default_optimizer_constructor_custom_key(self): groups = [] group_settings = [] # group 1, matches of 'param1' - groups.append(['param1', 'sub.param1']) - group_settings.append({ - 'lr': self.base_lr * 10, - 'momentum': self.momentum, - 'weight_decay': 0, - }) + groups.append(["param1", "sub.param1"]) + group_settings.append( + { + "lr": self.base_lr * 10, + "momentum": self.momentum, + "weight_decay": 0, + } + ) # group 2, default group - groups.append([ - 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', - 'sub.gn.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', - 'bn.weight', 'bn.bias' - ]) - group_settings.append({ - 'lr': self.base_lr, - 'momentum': self.momentum, - 'weight_decay': 0 - }) + groups.append( + [ + "sub.conv1.weight", + "sub.conv1.bias", + "sub.gn.weight", + "sub.gn.bias", + "conv1.weight", + "conv2.weight", + "conv2.bias", + "bn.weight", + "bn.bias", + ] + ) + group_settings.append({"lr": self.base_lr, "momentum": self.momentum, "weight_decay": 0}) num_params = 14 if MMCV_FULL_AVAILABLE else 11 assert len(param_groups) == num_params for i, (name, param) in enumerate(self.model.named_parameters()): - assert torch.equal(param_groups[i]['params'][0], param) - for group, settings in zip(groups, group_settings): + assert torch.equal(param_groups[i]["params"][0], param) + for group, settings in zip(groups, group_settings, strict=False): if name in group: for setting in settings: - assert param_groups[i][setting] == settings[ - setting], f'{name} {setting}' + assert param_groups[i][setting] == settings[setting], f"{name} {setting}" @unittest.skipIf( - (digit_version(TORCH_VERSION) < digit_version('1.8.0')) - or not is_available(), - reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.') + (digit_version(TORCH_VERSION) < digit_version("1.8.0")) or not is_available(), + reason="ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.", +) class TestZeroOptimizer(MultiProcessTestCase): - def setUp(self): super().setUp() self._spawn_processes() def _check_default_optimizer(self, optimizer, model): self.assertIsInstance(optimizer.optim, torch.optim.SGD) - self.assertEqual(optimizer.defaults['lr'], self.base_lr) - self.assertEqual(optimizer.defaults['momentum'], self.momentum) - self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd) + self.assertEqual(optimizer.defaults["lr"], self.base_lr) + self.assertEqual(optimizer.defaults["momentum"], self.momentum) + self.assertEqual(optimizer.defaults["weight_decay"], self.base_wd) param_groups = optimizer.param_groups params_set = set(model.parameters()) - self.assertEqual( - sum(len(param_group['params']) for param_group in param_groups), - len(params_set)) - self.assertTrue( - all(param in params_set for param_group in param_groups - for param in param_group['params'])) + self.assertEqual(sum(len(param_group["params"]) for param_group in param_groups), len(params_set)) + self.assertTrue(all(param in params_set for param_group in param_groups for param in param_group["params"])) state_dict = optimizer.state_dict() if get_rank() == 0: - self.assertEqual( - sum(len(pg['params']) for pg in state_dict['param_groups']), - len(params_set)) + self.assertEqual(sum(len(pg["params"]) for pg in state_dict["param_groups"]), len(params_set)) else: self.assertEqual(state_dict, {}) @@ -851,11 +804,13 @@ def test_zero_redundancy_optimizer(self): # test build function optim_wrapper_cfg = dict( optimizer=dict( - type='ZeroRedundancyOptimizer', - optimizer_type='SGD', + type="ZeroRedundancyOptimizer", + optimizer_type="SGD", lr=self.base_lr, weight_decay=self.base_wd, - momentum=self.momentum)) + momentum=self.momentum, + ) + ) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, model) @@ -863,15 +818,15 @@ def test_zero_redundancy_optimizer(self): with self.assertRaises(TypeError): optim_wrapper_cfg = dict( optimizer=dict( - type='ZeroRedundancyOptimizer', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + type="ZeroRedundancyOptimizer", lr=self.base_lr, weight_decay=self.base_wd, momentum=self.momentum + ) + ) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) @unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('1.12.0'), - reason='ZeRO started to support param groups since pytorch 1.12.0') + digit_version(TORCH_VERSION) < digit_version("1.12.0"), + reason="ZeRO started to support param groups since pytorch 1.12.0", + ) def test_zero_redundancy_optimizer_with_paramwise_cfg(self): self._init_dist_env(self.rank, self.world_size) model = ExampleModel() @@ -881,25 +836,24 @@ def test_zero_redundancy_optimizer_with_paramwise_cfg(self): # test build function paramwise_cfg = dict( - custom_keys={ - 'conv1': dict(lr_mult=0.0, decay_mult=0.0), - 'conv2': dict(lr_mult=1.0, decay_mult=2.0) - }) + custom_keys={"conv1": dict(lr_mult=0.0, decay_mult=0.0), "conv2": dict(lr_mult=1.0, decay_mult=2.0)} + ) optim_wrapper_cfg = dict( optimizer=dict( - type='ZeroRedundancyOptimizer', - optimizer_type='SGD', + type="ZeroRedundancyOptimizer", + optimizer_type="SGD", lr=self.base_lr, weight_decay=self.base_wd, - momentum=self.momentum), - paramwise_cfg=paramwise_cfg) + momentum=self.momentum, + ), + paramwise_cfg=paramwise_cfg, + ) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, model) def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29510' - os.environ['RANK'] = str(rank) - torch.distributed.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29510" + os.environ["RANK"] = str(rank) + torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..819c1caed8 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -14,34 +14,30 @@ from mmengine.dist import all_gather from mmengine.logging import MessageHub, MMLogger -from mmengine.optim import (AmpOptimWrapper, ApexOptimWrapper, - DefaultOptimWrapperConstructor, OptimWrapper) +from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, DefaultOptimWrapperConstructor, OptimWrapper from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version + is_apex_available = False try: import apex.amp as apex_amp + is_apex_available = True except ImportError: pass -amp_valid_dtypes = ['float64', 'float32', 'float16', 'bfloat16', None] -torch_dtypes = [ - torch.float16 if dtype is None else getattr(torch, dtype) - for dtype in amp_valid_dtypes -] +amp_valid_dtypes = ["float64", "float32", "float16", "bfloat16", None] +torch_dtypes = [torch.float16 if dtype is None else getattr(torch, dtype) for dtype in amp_valid_dtypes] def bf16_supported() -> bool: - return (hasattr(torch.cuda, 'is_bf16_supported') - and torch.cuda.is_bf16_supported()) + return hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() class ToyModel(nn.Module): - def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) @@ -55,7 +51,6 @@ def forward(self, x): class ToyModel2(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 1, 1) @@ -76,8 +71,8 @@ def setUp(self) -> None: def run_test(self, test_name: str, parent_pipe) -> None: self.model = ToyModel() self.optimizer = SGD(self.model.parameters(), lr=0.1) - self.logger = MMLogger.get_instance('test_optim_wrapper') - self.message_hub = MessageHub.get_instance('test_optim_wrapper_init') + self.logger = MMLogger.get_instance("test_optim_wrapper") + self.message_hub = MessageHub.get_instance("test_optim_wrapper_init") super().run_test(test_name, parent_pipe) def test_init(self): @@ -90,17 +85,16 @@ def test_init(self): self.assertEqual(optim_wrapper._max_counts, -1) self.assertEqual(optim_wrapper._remainder_counts, -1) - with self.assertRaisesRegex(AssertionError, - 'If `clip_grad` is not None'): + with self.assertRaisesRegex(AssertionError, "If `clip_grad` is not None"): OptimWrapper(self.optimizer, clip_grad=[]) def test_update_params(self): # Test update params every iteration. optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1) self._mock_method(optim_wrapper) - loss = torch.tensor(1.) + loss = torch.tensor(1.0) optim_wrapper.update_params(loss) - self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.)) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.0)) optim_wrapper.step.assert_called_with() optim_wrapper.zero_grad.assert_called_with() @@ -108,15 +102,15 @@ def test_update_params(self): optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) self._mock_method(optim_wrapper) # `iter=0`, accumulate gradient and do not update params. - loss = torch.tensor(1.) + loss = torch.tensor(1.0) optim_wrapper.update_params(loss) - self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.0) / 3.0) optim_wrapper.step.assert_not_called() optim_wrapper.zero_grad.assert_not_called() # gradient accumulate optim_wrapper.update_params(loss) - self.assertEqual(optim_wrapper._inner_count, 2.) + self.assertEqual(optim_wrapper._inner_count, 2.0) # `iter=2`, update params. optim_wrapper.update_params(loss) @@ -129,7 +123,7 @@ def test_update_params(self): optim_wrapper.update_params(loss) optim_wrapper.step.assert_not_called() optim_wrapper.zero_grad.assert_not_called() - self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.0) / 3.0) self._mock_method(optim_wrapper) # After calling `initialize_iter_status`, params will be updated at the @@ -138,7 +132,7 @@ def test_update_params(self): optim_wrapper.update_params(loss) optim_wrapper.step.assert_called() optim_wrapper.zero_grad.assert_called() - self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.)) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.0)) self._mock_method(optim_wrapper) # optim_wrapper.step should not be called at iteration 97 98, and the @@ -148,7 +142,7 @@ def test_update_params(self): optim_wrapper.update_params(loss) optim_wrapper.step.assert_not_called() optim_wrapper.zero_grad.assert_not_called() - self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.0) / 3) def test_initialize_iter_status(self): optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) @@ -160,7 +154,7 @@ def test_initialize_iter_status(self): with self.assertLogs(self.logger) as cm: optim_wrapper.initialize_count_status(self.model, 2, 100) self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.records[0].msg, 'Resumed iteration number') + self.assertRegex(cm.records[0].msg, "Resumed iteration number") # Model with batch norm will output warning. optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) @@ -168,7 +162,7 @@ def test_initialize_iter_status(self): with self.assertLogs(self.logger) as cm: optim_wrapper.initialize_count_status(model, 0, 99) self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.records[0].msg, 'Gradient accumulative') + self.assertRegex(cm.records[0].msg, "Gradient accumulative") def test_ger_lr(self): model = ToyModel() @@ -176,23 +170,20 @@ def test_ger_lr(self): optim_wrapper = OptimWrapper(optim) self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1])) model = ToyModel() - optimizer_cfg = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.1)) - paramwise_cfg = dict(custom_keys={'conv1.weight': dict(lr_mult=0.1)}) - optim_constructor = DefaultOptimWrapperConstructor( - optimizer_cfg, paramwise_cfg) + optimizer_cfg = dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.1)) + paramwise_cfg = dict(custom_keys={"conv1.weight": dict(lr_mult=0.1)}) + optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self.assertEqual(optim_wrapper.get_lr(), - dict(base_lr=[0.1], lr=[0.1 * 0.1] + [0.1] * 5)) + self.assertEqual(optim_wrapper.get_lr(), dict(base_lr=[0.1], lr=[0.1 * 0.1] + [0.1] * 5)) def test_get_momentum(self): # Get momentum from SGD model = ToyModel() - optim = SGD(model.parameters(), lr=0., momentum=0.8) + optim = SGD(model.parameters(), lr=0.0, momentum=0.8) optim_wrapper = OptimWrapper(optim) self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.8])) # Get momentum from Adam - optim = Adam(model.parameters(), lr=0., betas=(0.9, 0.9)) + optim = Adam(model.parameters(), lr=0.0, betas=(0.9, 0.9)) optim_wrapper = OptimWrapper(optim) self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.9])) @@ -204,8 +195,7 @@ def test_backward(self): def test_zero_grad(self): optimizer = MagicMock(spec=Optimizer) - optimizer.defaults = { - } # adjust this line according to what OptimWrapper expects + optimizer.defaults = {} # adjust this line according to what OptimWrapper expects optimizer.param_groups = [{}] optim_wrapper = OptimWrapper(optimizer) optim_wrapper.zero_grad() @@ -213,8 +203,7 @@ def test_zero_grad(self): def test_step(self): optimizer = MagicMock(spec=Optimizer) - optimizer.defaults = { - } # adjust this line according to what OptimWrapper expects + optimizer.defaults = {} # adjust this line according to what OptimWrapper expects optimizer.param_groups = [{}] optim_wrapper = OptimWrapper(optimizer) optim_wrapper.step() @@ -223,30 +212,27 @@ def test_step(self): # TODO: This unit test could cause CI to fail with some probability, which # is caused by MultiProcessTestCase. This problem should be solved # in the future). - @unittest.skipIf(True, reason='Solved in the future') + @unittest.skipIf(True, reason="Solved in the future") def test_clip_grads(self): # Test `clip_grad` with `clip_norm_` - optim_wrapper = OptimWrapper( - self.optimizer, clip_grad=dict(max_norm=35)) + optim_wrapper = OptimWrapper(self.optimizer, clip_grad=dict(max_norm=35)) loss = self.model(torch.Tensor(1, 1, 1, 1)) loss.backward() optim_wrapper._clip_grad() log_scalars = self.message_hub.log_scalars - self.assertIn('train/grad_norm', log_scalars) + self.assertIn("train/grad_norm", log_scalars) self.message_hub._log_scalars.clear() # Test `clip_grad` with `clip_value_` - optim_wrapper = OptimWrapper( - self.optimizer, clip_grad=dict(type='value', clip_value=0.5)) + optim_wrapper = OptimWrapper(self.optimizer, clip_grad=dict(type="value", clip_value=0.5)) loss = self.model(torch.Tensor(1, 1, 1, 1)) loss.backward() optim_wrapper._clip_grad() - self.assertNotIn('train/grad_norm', log_scalars) + self.assertNotIn("train/grad_norm", log_scalars) def test_state_dict(self): optim_wrapper = OptimWrapper(self.optimizer) - self.assertEqual(optim_wrapper.state_dict(), - self.optimizer.state_dict()) + self.assertEqual(optim_wrapper.state_dict(), self.optimizer.state_dict()) def test_load_state_dict(self): optim_wrapper = OptimWrapper(self.optimizer) @@ -257,8 +243,7 @@ def test_load_state_dict(self): def test_param_groups(self): optim_wrapper = OptimWrapper(self.optimizer) - self.assertEqual(optim_wrapper.param_groups, - self.optimizer.param_groups) + self.assertEqual(optim_wrapper.param_groups, self.optimizer.param_groups) def test_optim_context(self): self._init_dist_env(self.rank, self.world_size) @@ -297,16 +282,14 @@ def test_optim_context(self): def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29515' - os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29515" + os.environ["RANK"] = str(rank) + torch_dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) # TODO Test the real interface after add testing tool function which can # test the function or method is read called. def _mock_method(self, optim_wrapper): - def mock_methd(loss): optim_wrapper._inner_count += 1 optim_wrapper.scaled_loss = loss @@ -316,31 +299,28 @@ def mock_methd(loss): optim_wrapper.zero_grad = MagicMock() -@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex') +@unittest.skipIf(not torch.cuda.is_available(), reason="need gpu to test Apex") class TestApexOptimWrapper(TestCase): - def setUp(self) -> None: self.model = ToyModel().cuda() self.optimizer = SGD(self.model.parameters(), lr=0.1) @unittest.skipIf( not is_apex_available, - reason='`apex` is not available, Please install apex from ' - 'https://www.github.com/nvidia/apex') + reason="`apex` is not available, Please install apex from https://www.github.com/nvidia/apex", + ) def test_init(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, opt_level="O1", loss_scale=1) with apex_optim_wrapper.optim_context(self.model): pass @unittest.skipIf( not is_apex_available, - reason='`apex` is not available, Please install apex from ' - 'https://www.github.com/nvidia/apex') + reason="`apex` is not available, Please install apex from https://www.github.com/nvidia/apex", + ) def test_step(self): optimizer = MagicMock(spec=Optimizer) - apex_optim_wrapper = ApexOptimWrapper( - optimizer=optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=optimizer, opt_level="O1", loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) @@ -348,63 +328,54 @@ def test_step(self): @unittest.skipIf( not is_apex_available, - reason='`apex` is not available, Please install apex from ' - 'https://www.github.com/nvidia/apex') + reason="`apex` is not available, Please install apex from https://www.github.com/nvidia/apex", + ) def test_backward(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, opt_level="O1", loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) @unittest.skipIf( not is_apex_available, - reason='`apex` is not available, Please install apex from ' - 'https://www.github.com/nvidia/apex') + reason="`apex` is not available, Please install apex from https://www.github.com/nvidia/apex", + ) def test_state_dict(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, opt_level="O1", loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.update_params(loss) state_dict = apex_optim_wrapper.state_dict() - amp_state_dict = state_dict.pop('apex_amp') + amp_state_dict = state_dict.pop("apex_amp") optim_state_dict = state_dict - self.assertDictEqual(optim_state_dict, - apex_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(optim_state_dict, apex_optim_wrapper.optimizer.state_dict()) self.assertDictEqual(amp_state_dict, apex_amp.state_dict()) @unittest.skipIf( not is_apex_available, - reason='`apex` is not available, Please install apex from ' - 'https://www.github.com/nvidia/apex') + reason="`apex` is not available, Please install apex from https://www.github.com/nvidia/apex", + ) def test_load_state_dict(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, opt_level="O1", loss_scale=1) with apex_optim_wrapper.optim_context(self.model): # Test load from optimizer optimizer = SGD(self.model.parameters(), lr=0.1) apex_optim_wrapper.load_state_dict(optimizer.state_dict()) - self.assertDictEqual(optimizer.state_dict(), - apex_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(optimizer.state_dict(), apex_optim_wrapper.optimizer.state_dict()) # Test load from optim_wrapper apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) - apex_optim_wrapper_ = ApexOptimWrapper( - optimizer=SGD(self.model.parameters(), lr=0.1)) - apex_optim_wrapper_.load_state_dict( - apex_optim_wrapper.state_dict()) - self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), - apex_optim_wrapper_.optimizer.state_dict()) + apex_optim_wrapper_ = ApexOptimWrapper(optimizer=SGD(self.model.parameters(), lr=0.1)) + apex_optim_wrapper_.load_state_dict(apex_optim_wrapper.state_dict()) + self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), apex_optim_wrapper_.optimizer.state_dict()) @unittest.skipIf( not is_apex_available, - reason='`apex` is not available, Please install apex from ' - 'https://www.github.com/nvidia/apex') + reason="`apex` is not available, Please install apex from https://www.github.com/nvidia/apex", + ) def test_optim_context(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, opt_level="O1", loss_scale=1) with apex_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) @@ -412,80 +383,66 @@ def test_optim_context(self): class TestAmpOptimWrapper(TestCase): - def setUp(self) -> None: self.model = ToyModel() self.optimizer = SGD(self.model.parameters(), lr=0.1) @unittest.skipIf( - not torch.cuda.is_available(), - reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + not torch.cuda.is_available(), reason="`torch.cuda.amp` is only available when pytorch-gpu installed" + ) def test_init(self): # Test with default arguments. amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) # Test with dynamic. - amp_optim_wrapper = AmpOptimWrapper( - 'dynamic', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper("dynamic", optimizer=self.optimizer) self.assertIsNone(amp_optim_wrapper._scale_update_param) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) # Test with dtype float16 - amp_optim_wrapper = AmpOptimWrapper( - dtype='float16', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dtype="float16", optimizer=self.optimizer) self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16) # Test with dtype bfloat16 - amp_optim_wrapper = AmpOptimWrapper( - dtype='bfloat16', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dtype="bfloat16", optimizer=self.optimizer) self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16) # Test with dict loss_scale. - amp_optim_wrapper = AmpOptimWrapper( - dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) self.assertIsNone(amp_optim_wrapper._scale_update_param) - with self.assertRaisesRegex(TypeError, - 'loss_scale must be of type float'): - AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown') + with self.assertRaisesRegex(TypeError, "loss_scale must be of type float"): + AmpOptimWrapper(optimizer=self.optimizer, loss_scale="unknown") - @parameterized.expand(list(zip(amp_valid_dtypes))) + @parameterized.expand(list(zip(amp_valid_dtypes, strict=False))) @unittest.skipIf( - not torch.cuda.is_available(), - reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + not torch.cuda.is_available(), reason="`torch.cuda.amp` is only available when pytorch-gpu installed" + ) def test_step(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): - raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' - 'support `dtype` argument in autocast') - if dtype == 'bfloat16' and not bf16_supported(): - raise unittest.SkipTest('bfloat16 not supported by device') + if dtype is not None and (digit_version(TORCH_VERSION) < digit_version("1.10.0")): + raise unittest.SkipTest("Require PyTorch version >= 1.10.0 to support `dtype` argument in autocast") + if dtype == "bfloat16" and not bf16_supported(): + raise unittest.SkipTest("bfloat16 not supported by device") optimizer = MagicMock(spec=Optimizer) - optimizer.defaults = { - } # adjust this line according to what OptimWrapper expects + optimizer.defaults = {} # adjust this line according to what OptimWrapper expects optimizer.param_groups = [{}] amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype) amp_optim_wrapper.loss_scaler = MagicMock() amp_optim_wrapper.step() - amp_optim_wrapper.loss_scaler.step.assert_called_with( - amp_optim_wrapper.optimizer) - amp_optim_wrapper.loss_scaler.update.assert_called_with( - amp_optim_wrapper._scale_update_param) + amp_optim_wrapper.loss_scaler.step.assert_called_with(amp_optim_wrapper.optimizer) + amp_optim_wrapper.loss_scaler.update.assert_called_with(amp_optim_wrapper._scale_update_param) - @parameterized.expand(list(zip(amp_valid_dtypes))) + @parameterized.expand(list(zip(amp_valid_dtypes, strict=False))) @unittest.skipIf( - not torch.cuda.is_available(), - reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + not torch.cuda.is_available(), reason="`torch.cuda.amp` is only available when pytorch-gpu installed" + ) def test_backward(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): - raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' - 'support `dtype` argument in autocast') - if dtype == 'bfloat16' and not bf16_supported(): - raise unittest.SkipTest('bfloat16 not supported by device') - amp_optim_wrapper = AmpOptimWrapper( - optimizer=self.optimizer, dtype=dtype) + if dtype is not None and (digit_version(TORCH_VERSION) < digit_version("1.10.0")): + raise unittest.SkipTest("Require PyTorch version >= 1.10.0 to support `dtype` argument in autocast") + if dtype == "bfloat16" and not bf16_supported(): + raise unittest.SkipTest("bfloat16 not supported by device") + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer, dtype=dtype) loss_scaler = MagicMock() scale_return = MagicMock() scale_fn = MagicMock(return_value=scale_return) @@ -497,56 +454,47 @@ def test_backward(self, dtype): scale_return.backward.assert_called_with() @unittest.skipIf( - not torch.cuda.is_available(), - reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + not torch.cuda.is_available(), reason="`torch.cuda.amp` is only available when pytorch-gpu installed" + ) def test_state_dict(self): self.model = self.model.cuda() amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) amp_optim_wrapper.update_params(loss) state_dict = amp_optim_wrapper.state_dict() - scalar_state_dict = state_dict.pop('loss_scaler') + scalar_state_dict = state_dict.pop("loss_scaler") optim_state_dict = state_dict - self.assertDictEqual(optim_state_dict, - amp_optim_wrapper.optimizer.state_dict()) - self.assertDictEqual(scalar_state_dict, - amp_optim_wrapper.loss_scaler.state_dict()) + self.assertDictEqual(optim_state_dict, amp_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(scalar_state_dict, amp_optim_wrapper.loss_scaler.state_dict()) @unittest.skipIf( - not torch.cuda.is_available(), - reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + not torch.cuda.is_available(), reason="`torch.cuda.amp` is only available when pytorch-gpu installed" + ) def test_load_state_dict(self): amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) self.model = self.model.cuda() # Test load from optimizer optimizer = SGD(self.model.parameters(), lr=0.1) amp_optim_wrapper.load_state_dict(optimizer.state_dict()) - self.assertDictEqual(optimizer.state_dict(), - amp_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(optimizer.state_dict(), amp_optim_wrapper.optimizer.state_dict()) # Test load from optim_wrapper amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) - amp_optim_wrapper_ = AmpOptimWrapper( - optimizer=SGD(self.model.parameters(), lr=0.1)) + amp_optim_wrapper_ = AmpOptimWrapper(optimizer=SGD(self.model.parameters(), lr=0.1)) amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict()) - self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(), - amp_optim_wrapper_.optimizer.state_dict()) - self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(), - amp_optim_wrapper_.loss_scaler.state_dict()) + self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(), amp_optim_wrapper_.optimizer.state_dict()) + self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(), amp_optim_wrapper_.loss_scaler.state_dict()) - @parameterized.expand(list(zip(amp_valid_dtypes, torch_dtypes))) + @parameterized.expand(list(zip(amp_valid_dtypes, torch_dtypes, strict=False))) @unittest.skipIf( - not torch.cuda.is_available(), - reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + not torch.cuda.is_available(), reason="`torch.cuda.amp` is only available when pytorch-gpu installed" + ) def test_optim_context(self, dtype, target_dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): - raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' - 'support `dtype` argument in autocast') - if dtype == 'bfloat16' and not bf16_supported(): - raise unittest.SkipTest('bfloat16 not supported by device') - amp_optim_wrapper = AmpOptimWrapper( - optimizer=self.optimizer, dtype=dtype) + if dtype is not None and (digit_version(TORCH_VERSION) < digit_version("1.10.0")): + raise unittest.SkipTest("Require PyTorch version >= 1.10.0 to support `dtype` argument in autocast") + if dtype == "bfloat16" and not bf16_supported(): + raise unittest.SkipTest("bfloat16 not supported by device") + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer, dtype=dtype) with amp_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py index 3925a33ac9..8ee8abed23 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py @@ -10,7 +10,6 @@ class TestOptimWrapperDict(TestCase): - def setUp(self) -> None: self.model1 = nn.Linear(1, 1) self.model2 = nn.Linear(1, 1) @@ -18,33 +17,27 @@ def setUp(self) -> None: self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9) self.optim_wrapper1 = OptimWrapper(self.optim1) self.optim_wrapper2 = OptimWrapper(self.optim2) - self.optimizers_wrappers = dict( - optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) + self.optimizers_wrappers = dict(optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) def test_init(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertEqual(optim_wrapper_dict.optim_wrappers, - self.optimizers_wrappers) - with self.assertRaisesRegex(AssertionError, - '`OptimWrapperDict` only accept'): + self.assertEqual(optim_wrapper_dict.optim_wrappers, self.optimizers_wrappers) + with self.assertRaisesRegex(AssertionError, "`OptimWrapperDict` only accept"): OptimWrapperDict(**dict(optim1=self.optim1, optim2=self.optim2)) def test_update_params(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - with self.assertRaisesRegex(NotImplementedError, - '`update_params` should be called'): + with self.assertRaisesRegex(NotImplementedError, "`update_params` should be called"): optim_wrapper_dict.update_params(1) def test_backward(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - with self.assertRaisesRegex(NotImplementedError, - '`backward` should be called'): + with self.assertRaisesRegex(NotImplementedError, "`backward` should be called"): optim_wrapper_dict.backward(1) def test_step(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - with self.assertRaisesRegex(NotImplementedError, - '`step` should be called'): + with self.assertRaisesRegex(NotImplementedError, "`step` should be called"): optim_wrapper_dict.step() def test_zero_grad(self): @@ -55,7 +48,7 @@ def test_zero_grad(self): self.assertTrue((self.model1.weight.grad != 0).any()) self.assertTrue((self.model2.weight.grad != 0).any()) optim_wrapper_dict.zero_grad() - if digit_version(torch.__version__) < digit_version('2.0.0'): + if digit_version(torch.__version__) < digit_version("2.0.0"): self.assertTrue((self.model1.weight.grad == 0).all()) self.assertTrue((self.model2.weight.grad == 0).all()) else: @@ -64,8 +57,7 @@ def test_zero_grad(self): def test_optim_context(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - with self.assertRaisesRegex(NotImplementedError, - '`optim_context` should be called'): + with self.assertRaisesRegex(NotImplementedError, "`optim_context` should be called"): with optim_wrapper_dict.optim_context(self.model1): yield @@ -76,30 +68,26 @@ def test_initialize_count_status(self): def test_param_groups(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertEqual(optim_wrapper_dict.param_groups['optim1'], - self.optim1.param_groups) - self.assertEqual(optim_wrapper_dict.param_groups['optim2'], - self.optim2.param_groups) + self.assertEqual(optim_wrapper_dict.param_groups["optim1"], self.optim1.param_groups) + self.assertEqual(optim_wrapper_dict.param_groups["optim2"], self.optim2.param_groups) def test_get_lr(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) lr = optim_wrapper_dict.get_lr() - self.assertEqual(lr['optim1.lr'], [0.1]) - self.assertEqual(lr['optim2.lr'], [0.2]) + self.assertEqual(lr["optim1.lr"], [0.1]) + self.assertEqual(lr["optim2.lr"], [0.2]) def test_get_momentum(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) momentum = optim_wrapper_dict.get_momentum() - self.assertEqual(momentum['optim1.momentum'], [0.8]) - self.assertEqual(momentum['optim2.momentum'], [0.9]) + self.assertEqual(momentum["optim1.momentum"], [0.8]) + self.assertEqual(momentum["optim2.momentum"], [0.9]) def test_state_dict(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) state_dict = optim_wrapper_dict.state_dict() - self.assertEqual(state_dict['optim1'], - self.optim_wrapper1.state_dict()) - self.assertEqual(state_dict['optim2'], - self.optim_wrapper2.state_dict()) + self.assertEqual(state_dict["optim1"], self.optim_wrapper1.state_dict()) + self.assertEqual(state_dict["optim2"], self.optim_wrapper2.state_dict()) def test_load_state_dict(self): # Test OptimWrapperDict can load from saved state dict. @@ -111,38 +99,28 @@ def test_load_state_dict(self): optim_wrapper_load2 = OptimWrapper(optim2) optim_wrapper_dict_save = OptimWrapperDict(**self.optimizers_wrappers) - optim_wrapper_dict_load = OptimWrapperDict( - optim1=optim_wrapper_load1, optim2=optim_wrapper_load2) + optim_wrapper_dict_load = OptimWrapperDict(optim1=optim_wrapper_load1, optim2=optim_wrapper_load2) state_dict = optim_wrapper_dict_save.state_dict() optim_wrapper_dict_load.load_state_dict(state_dict) - self.assertDictEqual(optim_wrapper_dict_load.state_dict(), - optim_wrapper_dict_save.state_dict()) + self.assertDictEqual(optim_wrapper_dict_load.state_dict(), optim_wrapper_dict_save.state_dict()) def test_items(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.items()), - list(self.optimizers_wrappers.items())) + self.assertListEqual(list(optim_wrapper_dict.items()), list(self.optimizers_wrappers.items())) def test_values(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.values()), - list(self.optimizers_wrappers.values())) + self.assertListEqual(list(optim_wrapper_dict.values()), list(self.optimizers_wrappers.values())) def test_keys(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.keys()), - list(self.optimizers_wrappers.keys())) + self.assertListEqual(list(optim_wrapper_dict.keys()), list(self.optimizers_wrappers.keys())) def test_getitem(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertIs(self.optimizers_wrappers['optim1'], - optim_wrapper_dict['optim1']) - self.assertIs(self.optimizers_wrappers['optim2'], - optim_wrapper_dict['optim2']) + self.assertIs(self.optimizers_wrappers["optim1"], optim_wrapper_dict["optim1"]) + self.assertIs(self.optimizers_wrappers["optim2"], optim_wrapper_dict["optim2"]) def test_len(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) @@ -150,9 +128,9 @@ def test_len(self): def test_contain(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertIn('optim1', optim_wrapper_dict) + self.assertIn("optim1", optim_wrapper_dict) def test_repr(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) desc = repr(optim_wrapper_dict) - self.assertRegex(desc, 'name: optim1') + self.assertRegex(desc, "name: optim1") diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 22787e4709..0bb7bebb21 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -6,16 +6,23 @@ import torch.nn.functional as F import torch.optim as optim -from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, - CosineRestartLR, ExponentialLR, LinearLR, - MultiStepLR, OneCycleLR, PolyLR, - ReduceOnPlateauLR, StepLR, - _ParamScheduler) +from mmengine.optim.scheduler import ( + ConstantLR, + CosineAnnealingLR, + CosineRestartLR, + ExponentialLR, + LinearLR, + MultiStepLR, + OneCycleLR, + PolyLR, + ReduceOnPlateauLR, + StepLR, + _ParamScheduler, +) from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): - def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) @@ -26,7 +33,6 @@ def forward(self, x): class TestLRScheduler(TestCase): - def setUp(self): """Setup the model and optimizer which are used in every test method. @@ -36,23 +42,26 @@ def setUp(self): self.model = ToyModel() lr = 0.05 self.layer2_mult = 10 - self.optimizer = optim.SGD([{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'lr': lr * self.layer2_mult, - }], - lr=lr, - momentum=0.01, - weight_decay=5e-4) + self.optimizer = optim.SGD( + [ + {"params": self.model.conv1.parameters()}, + { + "params": self.model.conv2.parameters(), + "lr": lr * self.layer2_mult, + }, + ], + lr=lr, + momentum=0.01, + weight_decay=5e-4, + ) def test_base_scheduler_step(self): with self.assertRaises(NotImplementedError): - _ParamScheduler(self.optimizer, param_name='lr') + _ParamScheduler(self.optimizer, param_name="lr") def test_invalid_optimizer(self): - with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): - StepLR('invalid_optimizer', step_size=1) + with self.assertRaisesRegex(TypeError, "should be an Optimizer"): + StepLR("invalid_optimizer", step_size=1) def test_overwrite_optimzer_step(self): # raise warning if the counter in optimizer.step() is overwritten @@ -63,13 +72,11 @@ def overwrite_fun(): self.optimizer.step = overwrite_fun self.optimizer.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - scheduler.step) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", scheduler.step) def test_resume(self): # test invalid case: optimizer and scheduler are not both resumed - with self.assertRaisesRegex(KeyError, - "param 'initial_lr' is not specified"): + with self.assertRaisesRegex(KeyError, "param 'initial_lr' is not specified"): StepLR(self.optimizer, gamma=0.1, step_size=3, last_step=10) # test manually resume with ``last_step`` instead of load_state_dict @@ -79,7 +86,7 @@ def test_resume(self): results = [] for epoch in range(5): - results.append(self.optimizer.param_groups[0]['lr']) + results.append(self.optimizer.param_groups[0]["lr"]) # The order should be # train_epoch() -> save_checkpoint() -> scheduler.step(). # Break at here to simulate the checkpoint is saved before @@ -89,17 +96,17 @@ def test_resume(self): scheduler.step() scheduler2 = ExponentialLR(self.optimizer, gamma=0.9, last_step=4) for epoch in range(6): - results.append(self.optimizer.param_groups[0]['lr']) + results.append(self.optimizer.param_groups[0]["lr"]) scheduler2.step() for epoch in range(epochs): assert_allclose( targets[epoch], results[epoch], - msg='lr is wrong in epoch {}: expected {}, got {}'.format( - epoch, targets[epoch], results[epoch]), + msg=f"lr is wrong in epoch {epoch}: expected {targets[epoch]}, got {results[epoch]}", atol=1e-5, - rtol=0) + rtol=0, + ) def test_scheduler_before_optim_warning(self): """Warns if scheduler is used before optimizer.""" @@ -110,43 +117,38 @@ def call_sch_before_optim(): self.optimizer.step() # check warning doc link - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - call_sch_before_optim) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", call_sch_before_optim) # check warning when resume - for i, group in enumerate(self.optimizer.param_groups): - group['initial_lr'] = 0.01 + for _i, group in enumerate(self.optimizer.param_groups): + group["initial_lr"] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepLR( - self.optimizer, gamma=0.1, step_size=3, last_step=10) + scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3, last_step=10) scheduler.step() self.optimizer.step() # check warning doc link - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - call_sch_before_optim_resume) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", call_sch_before_optim_resume) def test_get_last_value(self): epochs = 10 single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = StepLR(self.optimizer, 3, gamma=0.1) for epoch in range(epochs): result = scheduler.get_last_value() self.optimizer.step() scheduler.step() target = [t[epoch] for t in targets] - for t, r in zip(target, result): + for t, r in zip(target, result, strict=False): assert_allclose( target, result, - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, t, r), + msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}", atol=1e-5, - rtol=0) + rtol=0, + ) def test_scheduler_step_count(self): iteration = 10 @@ -154,7 +156,7 @@ def test_scheduler_step_count(self): self.assertEqual(scheduler.last_step, 0) target = [i + 1 for i in range(iteration)] step_counts = [] - for i in range(iteration): + for _i in range(iteration): self.optimizer.step() scheduler.step() step_counts.append(scheduler.last_step) @@ -162,8 +164,7 @@ def test_scheduler_step_count(self): def test_effective_interval(self): # check invalid begin end - with self.assertRaisesRegex(ValueError, - 'end should be larger than begin'): + with self.assertRaisesRegex(ValueError, "end should be larger than begin"): StepLR(self.optimizer, gamma=0.1, step_size=3, begin=10, end=5) # lr = 0.05 if epoch == 0 @@ -176,28 +177,13 @@ def test_effective_interval(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = LinearLR( - self.optimizer, - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [0.05] * begin + [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters - begin) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = LinearLR(self.optimizer, start_factor=start_factor, begin=begin, end=begin + iters + 1) self._test_scheduler_value(scheduler, targets, epochs) - def _test_scheduler_value(self, - schedulers, - targets, - epochs=10, - param_name='lr', - step_kwargs=None): + def _test_scheduler_value(self, schedulers, targets, epochs=10, param_name="lr", step_kwargs=None): if isinstance(schedulers, _ParamScheduler): schedulers = [schedulers] if step_kwargs is None: @@ -207,20 +193,15 @@ def _test_scheduler_value(self, assert len(step_kwargs) == epochs assert len(step_kwargs[0]) == len(schedulers) for epoch in range(epochs): - for param_group, target in zip(self.optimizer.param_groups, - targets): + for param_group, target in zip(self.optimizer.param_groups, targets, strict=False): assert_allclose( target[epoch], param_group[param_name], - msg='{} is wrong in epoch {}: expected {}, got {}'.format( - param_name, epoch, target[epoch], - param_group[param_name]), + msg=f"{param_name} is wrong in epoch {epoch}: expected {target[epoch]}, got {param_group[param_name]}", atol=1e-5, - rtol=0) - [ - scheduler.step(**step_kwargs[epoch][i]) - for i, scheduler in enumerate(schedulers) - ] + rtol=0, + ) + [scheduler.step(**step_kwargs[epoch][i]) for i, scheduler in enumerate(schedulers)] def test_step_scheduler(self): # lr = 0.05 if epoch < 3 @@ -228,13 +209,9 @@ def test_step_scheduler(self): # lr = 0.0005 if 6 <= epoch < 9 # lr = 0.00005 if epoch >=9 epochs = 10 - single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005 - ] * 3 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = StepLR( - self.optimizer, gamma=0.1, step_size=3, verbose=True) + single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3, verbose=True) self._test_scheduler_value(scheduler, targets, epochs) def test_multi_step_scheduler(self): @@ -243,13 +220,9 @@ def test_multi_step_scheduler(self): # lr = 0.0005 if 5 <= epoch < 9 # lr = 0.00005 if epoch >= 9 epochs = 10 - single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005 - ] * 3 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]) + single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = MultiStepLR(self.optimizer, gamma=0.1, milestones=[2, 5, 9]) self._test_scheduler_value(scheduler, targets, epochs) def test_constant_scheduler(self): @@ -261,9 +234,7 @@ def test_constant_scheduler(self): # lr = 0.005 if 5 <= epoch epochs = 10 single_targets = [0.025] * 4 + [0.05] * 6 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = ConstantLR(self.optimizer, factor=1.0 / 2, end=5) self._test_scheduler_value(scheduler, targets, epochs) @@ -284,24 +255,16 @@ def test_linear_scheduler(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = LinearLR( - self.optimizer, start_factor=start_factor, end=iters + 1) + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = LinearLR(self.optimizer, start_factor=start_factor, end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler(self): epochs = 10 single_targets = [0.05 * (0.9**x) for x in range(epochs)] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = ExponentialLR(self.optimizer, gamma=0.9) self._test_scheduler_value(scheduler, targets, epochs) @@ -309,19 +272,13 @@ def test_cos_anneal_scheduler(self): epochs = 12 t = 10 eta_min = 1e-10 - single_targets = [ - eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 - for x in range(epochs) - ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + single_targets = [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs)] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingLR( - self.optimizer, begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingLR(self.optimizer, begin=5, end=100, eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -329,47 +286,35 @@ def test_poly_scheduler(self): power = 0.9 min_lr = 0.001 iters = 4 - targets_layer1 = [ - min_lr + (0.05 - min_lr) * (1 - i / iters)**power - for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + targets_layer1 = [min_lr + (0.05 - min_lr) * (1 - i / iters) ** power for i in range(iters)] + [min_lr] * ( + epochs - iters + ) targets_layer2 = [ - min_lr + (0.05 * self.layer2_mult - min_lr) * - (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters) ** power for i in range(iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyLR( - self.optimizer, power=power, eta_min=min_lr, end=iters + 1) + scheduler = PolyLR(self.optimizer, power=power, eta_min=min_lr, end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartLR(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0, eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartLR(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5, 0.0], eta_min=0) single_targets = [ - 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, - 0.0086372, 0.0023872, 0.0023872 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] - scheduler = CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) + 0.05, + 0.0426776, + 0.025, + 0.00732233, + 0.025, + 0.022612712, + 0.01636271, + 0.0086372, + 0.0023872, + 0.0023872, + ] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] + scheduler = CosineRestartLR(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): @@ -378,7 +323,7 @@ def test_reduce_on_plateau_scheduler(self): # Test error in __init__ method with self.assertRaises(TypeError): - ReduceOnPlateauLR('invalid_optimizer') + ReduceOnPlateauLR("invalid_optimizer") with self.assertRaises(ValueError): ReduceOnPlateauLR(self.optimizer, begin=10, end=5) with self.assertRaises(AssertionError): @@ -396,25 +341,35 @@ def test_reduce_on_plateau_scheduler(self): with self.assertRaises(ValueError): ReduceOnPlateauLR(self.optimizer, threshold=-1.0) with self.assertRaises(ValueError): - ReduceOnPlateauLR(self.optimizer, rule='foo') + ReduceOnPlateauLR(self.optimizer, rule="foo") with self.assertRaises(ValueError): - ReduceOnPlateauLR(self.optimizer, threshold_rule='foo') + ReduceOnPlateauLR(self.optimizer, threshold_rule="foo") # Test error in step method - scheduler = ReduceOnPlateauLR(self.optimizer, monitor='loss') + scheduler = ReduceOnPlateauLR(self.optimizer, monitor="loss") assert scheduler.step() is None with self.assertRaises(TypeError): - scheduler.step(('foo', 1.0)) + scheduler.step(("foo", 1.0)) metrics = dict(loss_foo=1.0) with self.assertRaises(KeyError): scheduler.step(metrics) # Test scheduler value - def _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, - min_value): + def _test_value( + epochs, + targets, + metrics_list, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ): lr = 0.05 momentum = 0.01 weight_decay = 5e-4 @@ -429,19 +384,21 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + self._test_scheduler_value(scheduler, targets, epochs=epochs, step_kwargs=metrics_list) # reset the state of optimizers - self.optimizer = optim.SGD([{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'lr': lr * self.layer2_mult, - }], - lr=lr, - momentum=momentum, - weight_decay=weight_decay) + self.optimizer = optim.SGD( + [ + {"params": self.model.conv1.parameters()}, + { + "params": self.model.conv2.parameters(), + "lr": lr * self.layer2_mult, + }, + ], + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + ) epochs = 10 factor = 0.1 @@ -449,91 +406,83 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, patience = 2 # rule(less) and threshold_rule(rel) - rule, threshold_rule = 'less', 'rel' + rule, threshold_rule = "less", "rel" threshold = 0.01 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # rule(less) and threshold_rule(abs) - rule, threshold_rule = 'less', 'abs' + rule, threshold_rule = "less", "abs" threshold = 0.9 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # rule(greater) and threshold_rule(rel) - rule, threshold_rule = 'greater', 'rel' + rule, threshold_rule = "greater", "rel" threshold = 0.01 - monitor = 'bbox_mAP' - metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + monitor = "bbox_mAP" + metric_values = [1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # rule(greater) and threshold_rule(abs) - rule, threshold_rule = 'greater', 'abs' + rule, threshold_rule = "greater", "abs" threshold = 0.9 - monitor = 'bbox_mAP' - metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + monitor = "bbox_mAP" + metric_values = [1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # change min_value min_value = 0.01 - rule, threshold_rule = 'less', 'rel' + rule, threshold_rule = "less", "rel" threshold = 0.01 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets_1 = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, - min_value - ] + single_targets_1 = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, min_value] single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05] targets = [single_targets_1, single_targets_2] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, min_value) - - def _check_scheduler_state_dict(self, - construct, - construct2, - epochs=10, - step_kwargs=None): + _test_value( + epochs, + targets, + metrics_list, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ) + + def _check_scheduler_state_dict(self, construct, construct2, epochs=10, step_kwargs=None): if step_kwargs is None: step_kwargs = [{} for _ in range(epochs)] else: # step_kwargs is not None @@ -545,65 +494,57 @@ def _check_scheduler_state_dict(self, scheduler_copy = construct2() scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): - if key != 'optimizer': - self.assertEqual(scheduler.__dict__[key], - scheduler_copy.__dict__[key]) - self.assertEqual(scheduler.get_last_value(), - scheduler_copy.get_last_value()) + if key != "optimizer": + self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) + self.assertEqual(scheduler.get_last_value(), scheduler_copy.get_last_value()) def test_step_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: StepLR(self.optimizer, gamma=0.1, step_size=3), - lambda: StepLR(self.optimizer, gamma=0.01 / 2, step_size=1)) + lambda: StepLR(self.optimizer, gamma=0.01 / 2, step_size=1), + ) def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: MultiStepLR( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]), - lambda: MultiStepLR( - self.optimizer, gamma=0.01, milestones=[1, 4, 6])) + lambda: MultiStepLR(self.optimizer, gamma=0.1, milestones=[2, 5, 9]), + lambda: MultiStepLR(self.optimizer, gamma=0.01, milestones=[1, 4, 6]), + ) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: ExponentialLR(self.optimizer, gamma=0.1), - lambda: ExponentialLR(self.optimizer, gamma=0.01)) + lambda: ExponentialLR(self.optimizer, gamma=0.1), lambda: ExponentialLR(self.optimizer, gamma=0.01) + ) def test_cosine_scheduler_state_dict(self): epochs = 10 eta_min = 1e-10 self._check_scheduler_state_dict( - lambda: CosineAnnealingLR( - self.optimizer, T_max=epochs, eta_min=eta_min), - lambda: CosineAnnealingLR( - self.optimizer, T_max=epochs // 2, eta_min=eta_min / 2), - epochs=epochs) + lambda: CosineAnnealingLR(self.optimizer, T_max=epochs, eta_min=eta_min), + lambda: CosineAnnealingLR(self.optimizer, T_max=epochs // 2, eta_min=eta_min / 2), + epochs=epochs, + ) def test_linear_scheduler_state_dict(self): epochs = 10 self._check_scheduler_state_dict( lambda: LinearLR(self.optimizer, start_factor=1 / 3), lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3), - epochs=epochs) + epochs=epochs, + ) def test_poly_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001), lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002), - epochs=10) + epochs=10, + ) def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartLR( - self.optimizer, - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), - epochs=10) + lambda: CosineRestartLR(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0), + lambda: CosineRestartLR(self.optimizer, periods=[4, 6], restart_weights=[1, 0.5], eta_min=0), + epochs=10, + ) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 @@ -611,47 +552,45 @@ def test_reduce_on_plateau_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: ReduceOnPlateauLR( self.optimizer, - monitor='loss', - rule='less', + monitor="loss", + rule="less", factor=0.01, patience=5, threshold=1e-4, - threshold_rule='rel', + threshold_rule="rel", cooldown=0, min_value=0.0, - eps=1e-8), + eps=1e-8, + ), lambda: ReduceOnPlateauLR( self.optimizer, - monitor='loss_foo', - rule='greater', + monitor="loss_foo", + rule="greater", factor=0.05, patience=10, threshold=1e-5, - threshold_rule='abs', + threshold_rule="abs", cooldown=5, min_value=0.1, - eps=1e-9), + eps=1e-9, + ), epochs=epochs, - step_kwargs=metrics_list) + step_kwargs=metrics_list, + ) def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError): - scheduler = StepLR.build_iter_from_epoch( - self.optimizer, gamma=0.1, step_size=2, epoch_length=-1) + scheduler = StepLR.build_iter_from_epoch(self.optimizer, gamma=0.1, step_size=2, epoch_length=-1) # lr = 0.05 if epoch < 2 # lr = 0.005 if 2 <= epoch < 4 epochs = 4 epoch_length = 7 single_targets = [0.05] * 2 * epoch_length + [0.005] * 2 * epoch_length - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = StepLR.build_iter_from_epoch( - self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length) - self._test_scheduler_value( - scheduler, targets, epochs * epoch_length, param_name='lr') + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = StepLR.build_iter_from_epoch(self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length, param_name="lr") def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.05 if epoch < 2 @@ -660,18 +599,16 @@ def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.00005 if epoch >= 9 epochs = 10 epoch_length = 7 - single_targets = [0.05 - ] * 2 * epoch_length + [0.005] * 3 * epoch_length + [ - 0.0005 - ] * 4 * epoch_length + [0.00005] * 3 * epoch_length - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + single_targets = ( + [0.05] * 2 * epoch_length + + [0.005] * 3 * epoch_length + + [0.0005] * 4 * epoch_length + + [0.00005] * 3 * epoch_length + ) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = MultiStepLR.build_iter_from_epoch( - self.optimizer, - gamma=0.1, - milestones=[2, 5, 9], - epoch_length=epoch_length) + self.optimizer, gamma=0.1, milestones=[2, 5, 9], epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_constant_scheduler_convert_iterbased(self): @@ -679,13 +616,9 @@ def test_constant_scheduler_convert_iterbased(self): # lr = 0.005 if 5 <= epoch epochs = 10 epoch_length = 7 - single_targets = [0.025] * (5 * epoch_length - - 1) + [0.05] * (5 * epoch_length + 1) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = ConstantLR.build_iter_from_epoch( - self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length) + single_targets = [0.025] * (5 * epoch_length - 1) + [0.05] * (5 * epoch_length + 1) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = ConstantLR.build_iter_from_epoch(self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_linear_scheduler_convert_iterbased(self): @@ -695,33 +628,21 @@ def test_linear_scheduler_convert_iterbased(self): epoch_length = 11 iters = end * epoch_length - 1 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs * epoch_length - iters) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs * epoch_length - iters) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = LinearLR.build_iter_from_epoch( - self.optimizer, - start_factor=start_factor, - end=end, - epoch_length=epoch_length) + self.optimizer, start_factor=start_factor, end=end, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler_convert_iterbased(self): epochs = 10 epoch_length = 7 - single_targets = [ - 0.05 * (0.9**x) for x in range(epochs * epoch_length) - ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = ExponentialLR.build_iter_from_epoch( - self.optimizer, gamma=0.9, epoch_length=epoch_length) + single_targets = [0.05 * (0.9**x) for x in range(epochs * epoch_length)] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = ExponentialLR.build_iter_from_epoch(self.optimizer, gamma=0.9, epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_cos_anneal_scheduler_convert_iterbased(self): @@ -730,18 +651,13 @@ def test_cos_anneal_scheduler_convert_iterbased(self): eta_min = 1e-10 epoch_length = 11 single_targets = [ - eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / t / epoch_length)) / 2 + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t / epoch_length)) / 2 for x in range(epochs * epoch_length) ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = CosineAnnealingLR.build_iter_from_epoch( - self.optimizer, - T_max=t, - eta_min=eta_min, - epoch_length=epoch_length) + self.optimizer, T_max=t, eta_min=eta_min, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs) def test_poly_scheduler_convert_iterbased(self): @@ -752,37 +668,25 @@ def test_poly_scheduler_convert_iterbased(self): epoch_length = 11 iters = end * epoch_length - 1 - targets_layer1 = [ - min_lr + (0.05 - min_lr) * (1 - i / iters)**power - for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + targets_layer1 = [min_lr + (0.05 - min_lr) * (1 - i / iters) ** power for i in range(iters)] + [min_lr] * ( + epochs - iters + ) targets_layer2 = [ - min_lr + (0.05 * self.layer2_mult - min_lr) * - (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters) ** power for i in range(iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] scheduler = PolyLR.build_iter_from_epoch( - self.optimizer, - power=power, - eta_min=min_lr, - end=end, - epoch_length=epoch_length) + self.optimizer, power=power, eta_min=min_lr, end=end, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs=10) def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 - single_targets = [0.025, 0.03125, 0.0375, 0.04375 - ] + [0.05] * 4 + [0.005] * 3 + [0.0005] * 1 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler1 = LinearLR( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) + single_targets = [0.025, 0.03125, 0.0375, 0.04375] + [0.05] * 4 + [0.005] * 3 + [0.0005] * 1 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler1 = LinearLR(self.optimizer, start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepLR(self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): @@ -793,30 +697,21 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): eta_min = 1e-10 single_targets2 = [ - eta_min + (single_targets1[-1] - eta_min) * - (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) + eta_min + (single_targets1[-1] - eta_min) * (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) ] single_targets = single_targets1 + single_targets2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler2 = CosineAnnealingLR( - self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler2 = CosineAnnealingLR(self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_overlap(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 - single_targets = [0.025, 0.03125, 0.0375, 0.004375 - ] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 1 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler1 = LinearLR( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[3, 6, 9]) + single_targets = [0.025, 0.03125, 0.0375, 0.004375] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 1 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler1 = LinearLR(self.optimizer, start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepLR(self.optimizer, gamma=0.1, milestones=[3, 6, 9]) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): @@ -828,40 +723,31 @@ def test_multi_scheduler_with_gap(self): eta_min = 1e-10 single_targets2 = [ - eta_min + (single_targets1[-1] - eta_min) * - (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) + eta_min + (single_targets1[-1] - eta_min) * (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) ] - single_targets = single_targets1 + [single_targets1[-1] - ] * 5 + single_targets2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler2 = CosineAnnealingLR( - self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) + single_targets = single_targets1 + [single_targets1[-1]] * 5 + single_targets2 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler2 = CosineAnnealingLR(self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_onecycle_lr(self): # test linear annealing - target = [1., 13., 25., 21.5, 18., 14.5, 11., 7.5, 4., 0.5] - scheduler = OneCycleLR( - self.optimizer, - eta_max=25, - final_div_factor=2, - total_steps=10, - anneal_strategy='linear') + target = [1.0, 13.0, 25.0, 21.5, 18.0, 14.5, 11.0, 7.5, 4.0, 0.5] + scheduler = OneCycleLR(self.optimizer, eta_max=25, final_div_factor=2, total_steps=10, anneal_strategy="linear") self._test_scheduler_value(scheduler, [target], 10) # test linear annealing three phase - target = [1., 9., 17., 25., 17., 9., 1., 0.75, 0.5, 0.25] + target = [1.0, 9.0, 17.0, 25.0, 17.0, 9.0, 1.0, 0.75, 0.5, 0.25] scheduler = OneCycleLR( self.optimizer, eta_max=25, div_factor=25, total_steps=10, - anneal_strategy='linear', + anneal_strategy="linear", pct_start=0.4, final_div_factor=4, - three_phase=True) + three_phase=True, + ) self._test_scheduler_value(scheduler, [target], 10) # test cosine annealing @@ -870,14 +756,16 @@ def annealing_cos(start, end, pct): return end + (start - end) / 2.0 * cos_out target = [ - 1., 13., 25., + 1.0, + 13.0, + 25.0, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0), annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0), - annealing_cos(25, 0.5, 6 / 7.0), 0.5 + annealing_cos(25, 0.5, 6 / 7.0), + 0.5, ] - scheduler = OneCycleLR( - self.optimizer, eta_max=25, final_div_factor=2, total_steps=10) + scheduler = OneCycleLR(self.optimizer, eta_max=25, final_div_factor=2, total_steps=10) self._test_scheduler_value(scheduler, [target], 10) diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 60a9713ee2..63d4f1014b 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -7,19 +7,24 @@ import torch.optim as optim # yapf: disable -from mmengine.optim.scheduler import (ConstantMomentum, - CosineAnnealingMomentum, - CosineRestartMomentum, - ExponentialMomentum, LinearMomentum, - MultiStepMomentum, PolyMomentum, - ReduceOnPlateauMomentum, StepMomentum, - _ParamScheduler) +from mmengine.optim.scheduler import ( + ConstantMomentum, + CosineAnnealingMomentum, + CosineRestartMomentum, + ExponentialMomentum, + LinearMomentum, + MultiStepMomentum, + PolyMomentum, + ReduceOnPlateauMomentum, + StepMomentum, + _ParamScheduler, +) + # yapf: enable from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): - def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) @@ -30,7 +35,6 @@ def forward(self, x): class TestMomentumScheduler(TestCase): - def setUp(self): """Setup the model and optimizer which are used in every test method. @@ -40,31 +44,27 @@ def setUp(self): self.model = ToyModel() momentum = 0.05 self.layer2_mult = 10 - self.optimizer = optim.SGD([{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'momentum': momentum * self.layer2_mult - }], - lr=0.01, - momentum=momentum, - weight_decay=5e-4) + self.optimizer = optim.SGD( + [ + {"params": self.model.conv1.parameters()}, + {"params": self.model.conv2.parameters(), "momentum": momentum * self.layer2_mult}, + ], + lr=0.01, + momentum=momentum, + weight_decay=5e-4, + ) self.optimizer_with_betas = optim.Adam( - [{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'betas': (momentum * self.layer2_mult, 0.999) - }], + [ + {"params": self.model.conv1.parameters()}, + {"params": self.model.conv2.parameters(), "betas": (momentum * self.layer2_mult, 0.999)}, + ], lr=0.01, betas=(momentum, 0.999), - weight_decay=5e-4) + weight_decay=5e-4, + ) def test_invalid_optimizer(self): - with self.assertRaisesRegex( - ValueError, - 'optimizer must support momentum when using momentum scheduler' - ): + with self.assertRaisesRegex(ValueError, "optimizer must support momentum when using momentum scheduler"): optimizer = optim.ASGD( self.model.parameters(), lr=0.01, @@ -80,13 +80,11 @@ def overwrite_fun(): self.optimizer.step = overwrite_fun self.optimizer.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - scheduler.step) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", scheduler.step) def test_resume(self): # test invalid case: optimizer and scheduler are not both resumed - with self.assertRaisesRegex( - KeyError, "param 'initial_momentum' is not specified"): + with self.assertRaisesRegex(KeyError, "param 'initial_momentum' is not specified"): StepMomentum(self.optimizer, gamma=0.1, step_size=3, last_step=10) # test manually resume with ``last_step`` instead of load_state_dict @@ -96,7 +94,7 @@ def test_resume(self): results = [] for epoch in range(5): - results.append(self.optimizer.param_groups[0]['momentum']) + results.append(self.optimizer.param_groups[0]["momentum"]) # The order should be # train_epoch() -> save_checkpoint() -> scheduler.step(). # Break at here to simulate the checkpoint is saved before @@ -104,20 +102,19 @@ def test_resume(self): if epoch == 4: break scheduler.step() - scheduler2 = ExponentialMomentum( - self.optimizer, gamma=0.9, last_step=4) + scheduler2 = ExponentialMomentum(self.optimizer, gamma=0.9, last_step=4) for epoch in range(6): - results.append(self.optimizer.param_groups[0]['momentum']) + results.append(self.optimizer.param_groups[0]["momentum"]) scheduler2.step() for epoch in range(epochs): assert_allclose( targets[epoch], results[epoch], - msg='momentum is wrong in epoch {}: expected {}, got {}'. - format(epoch, targets[epoch], results[epoch]), + msg=f"momentum is wrong in epoch {epoch}: expected {targets[epoch]}, got {results[epoch]}", atol=1e-5, - rtol=0) + rtol=0, + ) def test_scheduler_before_optim_warning(self): """Warns if scheduler is used before optimizer.""" @@ -128,43 +125,38 @@ def call_sch_before_optim(): self.optimizer.step() # check warning doc link - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - call_sch_before_optim) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", call_sch_before_optim) # check warning when resume - for i, group in enumerate(self.optimizer.param_groups): - group['initial_momentum'] = 0.01 + for _i, group in enumerate(self.optimizer.param_groups): + group["initial_momentum"] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepMomentum( - self.optimizer, gamma=0.1, step_size=3, last_step=10) + scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3, last_step=10) scheduler.step() self.optimizer.step() # check warning doc link - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - call_sch_before_optim_resume) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", call_sch_before_optim_resume) def test_get_last_value(self): epochs = 10 single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] scheduler = StepMomentum(self.optimizer, 3, gamma=0.1) for epoch in range(epochs): result = scheduler.get_last_value() self.optimizer.step() scheduler.step() target = [t[epoch] for t in targets] - for t, r in zip(target, result): + for t, r in zip(target, result, strict=False): assert_allclose( target, result, - msg='momentum is wrong in epoch {}: expected {}, got {}'. - format(epoch, t, r), + msg=f"momentum is wrong in epoch {epoch}: expected {t}, got {r}", atol=1e-5, - rtol=0) + rtol=0, + ) def test_scheduler_step_count(self): iteration = 10 @@ -172,7 +164,7 @@ def test_scheduler_step_count(self): self.assertEqual(scheduler.last_step, 0) target = [i + 1 for i in range(iteration)] step_counts = [] - for i in range(iteration): + for _i in range(iteration): self.optimizer.step() scheduler.step() step_counts.append(scheduler.last_step) @@ -180,10 +172,8 @@ def test_scheduler_step_count(self): def test_effective_interval(self): # check invalid begin end - with self.assertRaisesRegex(ValueError, - 'end should be larger than begin'): - StepMomentum( - self.optimizer, gamma=0.1, step_size=3, begin=10, end=5) + with self.assertRaisesRegex(ValueError, "end should be larger than begin"): + StepMomentum(self.optimizer, gamma=0.1, step_size=3, begin=10, end=5) # momentum = 0.05 if epoch == 0 # momentum = 0.025 if epoch == 1 @@ -195,29 +185,13 @@ def test_effective_interval(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = LinearMomentum( - self.optimizer, - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [0.05] * begin + [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters - begin) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = LinearMomentum(self.optimizer, start_factor=start_factor, begin=begin, end=begin + iters + 1) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - def _test_scheduler_value(self, - optimizer, - schedulers, - targets, - epochs=10, - param_name='momentum', - step_kwargs=None): + def _test_scheduler_value(self, optimizer, schedulers, targets, epochs=10, param_name="momentum", step_kwargs=None): if isinstance(schedulers, _ParamScheduler): schedulers = [schedulers] if step_kwargs is None: @@ -227,28 +201,25 @@ def _test_scheduler_value(self, assert len(step_kwargs) == epochs assert len(step_kwargs[0]) == len(schedulers) for epoch in range(epochs): - for param_group, target in zip(optimizer.param_groups, targets): + for param_group, target in zip(optimizer.param_groups, targets, strict=False): assert_allclose( target[epoch], param_group[param_name], - msg='{} is wrong in epoch {}: expected {}, got {}'.format( - param_name, epoch, target[epoch], - param_group[param_name]), + msg=f"{param_name} is wrong in epoch {epoch}: expected {target[epoch]}, got {param_group[param_name]}", atol=1e-5, - rtol=0) - if 'betas' in optimizer.defaults: + rtol=0, + ) + if "betas" in optimizer.defaults: assert_allclose( target[epoch], - param_group['betas'][0], - msg='{} is wrong in epoch {}: expected {}, got {}'. - format('betas_0', epoch, target[epoch], - param_group['betas'][0]), + param_group["betas"][0], + msg="{} is wrong in epoch {}: expected {}, got {}".format( + "betas_0", epoch, target[epoch], param_group["betas"][0] + ), atol=1e-5, - rtol=0) - [ - scheduler.step(**step_kwargs[epoch][i]) - for i, scheduler in enumerate(schedulers) - ] + rtol=0, + ) + [scheduler.step(**step_kwargs[epoch][i]) for i, scheduler in enumerate(schedulers)] def test_step_scheduler(self): # momentum = 0.05 if epoch < 3 @@ -256,19 +227,13 @@ def test_step_scheduler(self): # momentum = 0.0005 if 6 <= epoch < 9 # momentum = 0.00005 if epoch >=9 epochs = 10 - single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005 - ] * 3 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = StepMomentum( - self.optimizer, gamma=0.1, step_size=3, verbose=True) + single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3, verbose=True) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = StepMomentum( - self.optimizer_with_betas, gamma=0.1, step_size=3, verbose=True) - self._test_scheduler_value(self.optimizer_with_betas, scheduler, - targets, epochs) + scheduler = StepMomentum(self.optimizer_with_betas, gamma=0.1, step_size=3, verbose=True) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) def test_multi_step_scheduler(self): # momentum = 0.05 if epoch < 2 @@ -276,19 +241,13 @@ def test_multi_step_scheduler(self): # momentum = 0.0005 if 5 <= epoch < 9 # momentum = 0.00005 if epoch >= 9 epochs = 10 - single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005 - ] * 3 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]) + single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = MultiStepMomentum(self.optimizer, gamma=0.1, milestones=[2, 5, 9]) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = MultiStepMomentum( - self.optimizer_with_betas, gamma=0.1, milestones=[2, 5, 9]) - self._test_scheduler_value(self.optimizer_with_betas, scheduler, - targets, epochs) + scheduler = MultiStepMomentum(self.optimizer_with_betas, gamma=0.1, milestones=[2, 5, 9]) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) def test_constant_scheduler(self): # factor should between 0~1 @@ -299,16 +258,12 @@ def test_constant_scheduler(self): # momentum = 0.005 if 5 <= epoch epochs = 10 single_targets = [0.025] * 4 + [0.05] * 6 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = ConstantMomentum(self.optimizer, factor=1.0 / 2, end=5) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = ConstantMomentum( - self.optimizer_with_betas, factor=1.0 / 2, end=5) - self._test_scheduler_value(self.optimizer_with_betas, scheduler, - targets, epochs) + scheduler = ConstantMomentum(self.optimizer_with_betas, factor=1.0 / 2, end=5) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) def test_linear_scheduler(self): with self.assertRaises(ValueError): @@ -327,61 +282,39 @@ def test_linear_scheduler(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = LinearMomentum( - self.optimizer, start_factor=start_factor, end=iters + 1) + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = LinearMomentum(self.optimizer, start_factor=start_factor, end=iters + 1) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = LinearMomentum( - self.optimizer_with_betas, - start_factor=start_factor, - end=iters + 1) - self._test_scheduler_value(self.optimizer_with_betas, scheduler, - targets, epochs) + scheduler = LinearMomentum(self.optimizer_with_betas, start_factor=start_factor, end=iters + 1) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) def test_exp_scheduler(self): epochs = 10 single_targets = [0.05 * (0.9**x) for x in range(epochs)] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = ExponentialMomentum(self.optimizer, gamma=0.9) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) scheduler = ExponentialMomentum(self.optimizer_with_betas, gamma=0.9) - self._test_scheduler_value(self.optimizer_with_betas, scheduler, - targets, epochs) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) def test_cos_anneal_scheduler(self): epochs = 12 t = 10 eta_min = 1e-10 - single_targets = [ - eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 - for x in range(epochs) - ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = CosineAnnealingMomentum( - self.optimizer, T_max=t, eta_min=eta_min) + single_targets = [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs)] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = CosineAnnealingMomentum(self.optimizer, T_max=t, eta_min=eta_min) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = CosineAnnealingMomentum( - self.optimizer_with_betas, T_max=t, eta_min=eta_min) - self._test_scheduler_value(self.optimizer_with_betas, scheduler, - targets, epochs) + scheduler = CosineAnnealingMomentum(self.optimizer_with_betas, T_max=t, eta_min=eta_min) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingMomentum( - self.optimizer, begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer, begin=5, end=100, eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -389,66 +322,46 @@ def test_poly_scheduler(self): power = 0.9 min_lr = 0.001 iters = 4 - layer1_targets = [ - min_lr + (0.05 - min_lr) * (1 - i / iters)**power - for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + layer1_targets = [min_lr + (0.05 - min_lr) * (1 - i / iters) ** power for i in range(iters)] + [min_lr] * ( + epochs - iters + ) layer2_targets = [ - min_lr + (0.05 * self.layer2_mult - min_lr) * - (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters) ** power for i in range(iters) + ] + [min_lr] * (epochs - iters) targets = [layer1_targets, layer2_targets] - scheduler = PolyMomentum( - self.optimizer, power=power, eta_min=min_lr, end=iters + 1) - self._test_scheduler_value( - self.optimizer, scheduler, targets, epochs=10) + scheduler = PolyMomentum(self.optimizer, power=power, eta_min=min_lr, end=iters + 1) + self._test_scheduler_value(self.optimizer, scheduler, targets, epochs=10) - scheduler = PolyMomentum( - self.optimizer_with_betas, - power=power, - eta_min=min_lr, - end=iters + 1) - self._test_scheduler_value( - self.optimizer_with_betas, scheduler, targets, epochs=10) + scheduler = PolyMomentum(self.optimizer_with_betas, power=power, eta_min=min_lr, end=iters + 1) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + self.optimizer, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0, eta_min_ratio=0.1 + ) with self.assertRaises(AssertionError): - CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartMomentum(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5, 0.0], eta_min=0) single_targets = [ - 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, - 0.0086372, 0.0023872, 0.0023872 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] + 0.05, + 0.0426776, + 0.025, + 0.00732233, + 0.025, + 0.022612712, + 0.01636271, + 0.0086372, + 0.0023872, + 0.0023872, ] - scheduler = CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) - self._test_scheduler_value( - self.optimizer, scheduler, targets, epochs=10) + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] + scheduler = CosineRestartMomentum(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0) + self._test_scheduler_value(self.optimizer, scheduler, targets, epochs=10) scheduler = CosineRestartMomentum( - self.optimizer_with_betas, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) - self._test_scheduler_value( - self.optimizer_with_betas, scheduler, targets, epochs=10) + self.optimizer_with_betas, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0 + ) + self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): # inherit _ParamScheduler but not call super().__init__(), @@ -474,30 +387,40 @@ def test_reduce_on_plateau_scheduler(self): ReduceOnPlateauMomentum(self.optimizer, factor=2.0) ReduceOnPlateauMomentum(self.optimizer, min_value=[0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauMomentum( - self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1]) + ReduceOnPlateauMomentum(self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1]) with self.assertRaises(ValueError): ReduceOnPlateauMomentum(self.optimizer, threshold=-1.0) with self.assertRaises(ValueError): - ReduceOnPlateauMomentum(self.optimizer, rule='foo') + ReduceOnPlateauMomentum(self.optimizer, rule="foo") with self.assertRaises(ValueError): - ReduceOnPlateauMomentum(self.optimizer, threshold_rule='foo') + ReduceOnPlateauMomentum(self.optimizer, threshold_rule="foo") # Test error in step method - scheduler = ReduceOnPlateauMomentum(self.optimizer, monitor='loss') + scheduler = ReduceOnPlateauMomentum(self.optimizer, monitor="loss") assert scheduler.step() is None with self.assertRaises(TypeError): - scheduler.step(('foo', 1.0)) + scheduler.step(("foo", 1.0)) metrics = dict(loss_foo=1.0) with self.assertRaises(KeyError): scheduler.step(metrics) # Test scheduler value - def _test_value(epochs, targets, metrics_list, optimizer, monitor, - rule, factor, patience, threshold, threshold_rule, - cooldown, min_value): + def _test_value( + epochs, + targets, + metrics_list, + optimizer, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ): lr = 0.01 momentum = 0.05 weight_decay = 5e-4 @@ -512,33 +435,27 @@ def _test_value(epochs, targets, metrics_list, optimizer, monitor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - optimizer, - scheduler, - targets, - epochs=epochs, - step_kwargs=metrics_list) + self._test_scheduler_value(optimizer, scheduler, targets, epochs=epochs, step_kwargs=metrics_list) # reset the state of optimizers - self.optimizer = optim.SGD([{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'momentum': momentum * self.layer2_mult - }], - lr=lr, - momentum=momentum, - weight_decay=weight_decay) + self.optimizer = optim.SGD( + [ + {"params": self.model.conv1.parameters()}, + {"params": self.model.conv2.parameters(), "momentum": momentum * self.layer2_mult}, + ], + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + ) self.optimizer_with_betas = optim.Adam( - [{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'betas': (momentum * self.layer2_mult, 0.999) - }], + [ + {"params": self.model.conv1.parameters()}, + {"params": self.model.conv2.parameters(), "betas": (momentum * self.layer2_mult, 0.999)}, + ], lr=lr, betas=(momentum, 0.999), - weight_decay=weight_decay) + weight_decay=weight_decay, + ) epochs = 10 factor = 0.1 @@ -546,100 +463,143 @@ def _test_value(epochs, targets, metrics_list, optimizer, monitor, patience = 2 # rule(less) and threshold_rule(rel) - rule, threshold_rule = 'less', 'rel' + rule, threshold_rule = "less", "rel" threshold = 0.01 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, self.optimizer, monitor, - rule, factor, patience, threshold, threshold_rule, - cooldown, 0.0) + _test_value( + epochs, + targets, + metrics_list, + self.optimizer, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + 0.0, + ) # rule(less) and threshold_rule(abs) - rule, threshold_rule = 'less', 'abs' + rule, threshold_rule = "less", "abs" threshold = 0.9 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, self.optimizer, monitor, - rule, factor, patience, threshold, threshold_rule, - cooldown, 0.0) + _test_value( + epochs, + targets, + metrics_list, + self.optimizer, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + 0.0, + ) # rule(greater) and threshold_rule(rel) - rule, threshold_rule = 'greater', 'rel' + rule, threshold_rule = "greater", "rel" threshold = 0.01 - monitor = 'bbox_mAP' - metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + monitor = "bbox_mAP" + metric_values = [1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, self.optimizer, monitor, - rule, factor, patience, threshold, threshold_rule, - cooldown, 0.0) + _test_value( + epochs, + targets, + metrics_list, + self.optimizer, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + 0.0, + ) # rule(greater) and threshold_rule(abs) - rule, threshold_rule = 'greater', 'abs' + rule, threshold_rule = "greater", "abs" threshold = 0.9 - monitor = 'bbox_mAP' - metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + monitor = "bbox_mAP" + metric_values = [1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, self.optimizer, monitor, - rule, factor, patience, threshold, threshold_rule, - cooldown, 0.0) + _test_value( + epochs, + targets, + metrics_list, + self.optimizer, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + 0.0, + ) # change min_value min_value = 0.01 - rule, threshold_rule = 'less', 'rel' + rule, threshold_rule = "less", "rel" threshold = 0.01 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets_1 = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, - min_value - ] + single_targets_1 = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, min_value] single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05] targets = [single_targets_1, single_targets_2] - _test_value(epochs, targets, metrics_list, self.optimizer, monitor, - rule, factor, patience, threshold, threshold_rule, - cooldown, min_value) - - _test_value(epochs, targets, metrics_list, self.optimizer_with_betas, - monitor, rule, factor, patience, threshold, threshold_rule, - cooldown, min_value) - - def _check_scheduler_state_dict(self, - construct, - construct2, - epochs=10, - step_kwargs=None): + _test_value( + epochs, + targets, + metrics_list, + self.optimizer, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ) + + _test_value( + epochs, + targets, + metrics_list, + self.optimizer_with_betas, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ) + + def _check_scheduler_state_dict(self, construct, construct2, epochs=10, step_kwargs=None): if step_kwargs is None: step_kwargs = [{} for _ in range(epochs)] else: # step_kwargs is not None @@ -651,66 +611,58 @@ def _check_scheduler_state_dict(self, scheduler_copy = construct2() scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): - if key != 'optimizer': - self.assertEqual(scheduler.__dict__[key], - scheduler_copy.__dict__[key]) - self.assertEqual(scheduler.get_last_value(), - scheduler_copy.get_last_value()) + if key != "optimizer": + self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) + self.assertEqual(scheduler.get_last_value(), scheduler_copy.get_last_value()) def test_step_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: StepMomentum(self.optimizer, gamma=0.1, step_size=3), - lambda: StepMomentum(self.optimizer, gamma=0.01 / 2, step_size=1)) + lambda: StepMomentum(self.optimizer, gamma=0.01 / 2, step_size=1), + ) def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]), - lambda: MultiStepMomentum( - self.optimizer, gamma=0.01, milestones=[1, 4, 6])) + lambda: MultiStepMomentum(self.optimizer, gamma=0.1, milestones=[2, 5, 9]), + lambda: MultiStepMomentum(self.optimizer, gamma=0.01, milestones=[1, 4, 6]), + ) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: ExponentialMomentum(self.optimizer, gamma=0.1), - lambda: ExponentialMomentum(self.optimizer, gamma=0.01)) + lambda: ExponentialMomentum(self.optimizer, gamma=0.01), + ) def test_cosine_scheduler_state_dict(self): epochs = 10 eta_min = 1e-10 self._check_scheduler_state_dict( - lambda: CosineAnnealingMomentum( - self.optimizer, T_max=epochs, eta_min=eta_min), - lambda: CosineAnnealingMomentum( - self.optimizer, T_max=epochs // 2, eta_min=eta_min / 2), - epochs=epochs) + lambda: CosineAnnealingMomentum(self.optimizer, T_max=epochs, eta_min=eta_min), + lambda: CosineAnnealingMomentum(self.optimizer, T_max=epochs // 2, eta_min=eta_min / 2), + epochs=epochs, + ) def test_linear_scheduler_state_dict(self): epochs = 10 self._check_scheduler_state_dict( lambda: LinearMomentum(self.optimizer, start_factor=1 / 3), - lambda: LinearMomentum( - self.optimizer, start_factor=0, end_factor=0.3), - epochs=epochs) + lambda: LinearMomentum(self.optimizer, start_factor=0, end_factor=0.3), + epochs=epochs, + ) def test_poly_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001), lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002), - epochs=10) + epochs=10, + ) def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartMomentum( - self.optimizer, - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), - epochs=10) + lambda: CosineRestartMomentum(self.optimizer, periods=[4, 5], restart_weights=[1, 0.5], eta_min=0), + lambda: CosineRestartMomentum(self.optimizer, periods=[4, 6], restart_weights=[1, 0.5], eta_min=0), + epochs=10, + ) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 @@ -718,101 +670,79 @@ def test_reduce_on_plateau_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: ReduceOnPlateauMomentum( self.optimizer, - monitor='loss', - rule='less', + monitor="loss", + rule="less", factor=0.01, patience=5, threshold=1e-4, - threshold_rule='rel', + threshold_rule="rel", cooldown=0, min_value=0.0, - eps=1e-8), + eps=1e-8, + ), lambda: ReduceOnPlateauMomentum( self.optimizer, - monitor='loss_foo', - rule='greater', + monitor="loss_foo", + rule="greater", factor=0.05, patience=10, threshold=1e-5, - threshold_rule='abs', + threshold_rule="abs", cooldown=5, min_value=0.1, - eps=1e-9), + eps=1e-9, + ), epochs=epochs, - step_kwargs=metrics_list) + step_kwargs=metrics_list, + ) def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 - single_targets = [0.025, 0.03125, 0.0375, 0.04375 - ] + [0.05] * 4 + [0.005] * 3 + [0.0005] * 1 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler1 = LinearMomentum( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) - self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], - targets, epochs) + single_targets = [0.025, 0.03125, 0.0375, 0.04375] + [0.05] * 4 + [0.005] * 3 + [0.0005] * 1 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler1 = LinearMomentum(self.optimizer, start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepMomentum(self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) + self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialMomentum( - self.optimizer, gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialMomentum(self.optimizer, gamma=0.9, begin=0, end=5) eta_min = 1e-10 single_targets2 = [ - eta_min + (single_targets1[-1] - eta_min) * - (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) + eta_min + (single_targets1[-1] - eta_min) * (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) ] single_targets = single_targets1 + single_targets2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler2 = CosineAnnealingMomentum( - self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler2 = CosineAnnealingMomentum(self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) - self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], - targets, epochs) + self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_overlap(self): # use Linear at first 5 epochs together with MultiStep epochs = 10 - single_targets = [0.025, 0.03125, 0.0375, 0.004375 - ] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 1 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler1 = LinearMomentum( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[3, 6, 9]) - self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], - targets, epochs) + single_targets = [0.025, 0.03125, 0.0375, 0.004375] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 1 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler1 = LinearMomentum(self.optimizer, start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepMomentum(self.optimizer, gamma=0.1, milestones=[3, 6, 9]) + self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): # use Exp in the first 5 epochs and the last 5 epochs use Cosine # no scheduler in the middle 5 epochs epochs = 15 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialMomentum( - self.optimizer, gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialMomentum(self.optimizer, gamma=0.9, begin=0, end=5) eta_min = 1e-10 single_targets2 = [ - eta_min + (single_targets1[-1] - eta_min) * - (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) - ] - single_targets = single_targets1 + [single_targets1[-1] - ] * 5 + single_targets2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] + eta_min + (single_targets1[-1] - eta_min) * (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) ] - scheduler2 = CosineAnnealingMomentum( - self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) + single_targets = single_targets1 + [single_targets1[-1]] * 5 + single_targets2 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler2 = CosineAnnealingMomentum(self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) - self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], - targets, epochs) + self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index a13072dc6e..d2be713743 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -9,23 +9,27 @@ import torch.optim as optim from mmengine.optim import OptimWrapper + # yapf: disable -from mmengine.optim.scheduler import (ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, - LinearParamScheduler, - MultiStepParamScheduler, - OneCycleParamScheduler, - PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler, _ParamScheduler) +from mmengine.optim.scheduler import ( + ConstantParamScheduler, + CosineAnnealingParamScheduler, + CosineRestartParamScheduler, + ExponentialParamScheduler, + LinearParamScheduler, + MultiStepParamScheduler, + OneCycleParamScheduler, + PolyParamScheduler, + ReduceOnPlateauParamScheduler, + StepParamScheduler, + _ParamScheduler, +) + # yapf: enable from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): - def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) @@ -36,7 +40,6 @@ def forward(self, x): class TestParameterScheduler(TestCase): - def setUp(self): """Setup the model and optimizer which are used in every test method. @@ -49,61 +52,53 @@ def setUp(self): momentum = 0.01 weight_decay = 5e-4 self.optimizer = optim.SGD( - [{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'lr': lr * self.layer2_mult, - 'momentum': momentum * self.layer2_mult, - 'weight_decay': weight_decay * self.layer2_mult - }], + [ + {"params": self.model.conv1.parameters()}, + { + "params": self.model.conv2.parameters(), + "lr": lr * self.layer2_mult, + "momentum": momentum * self.layer2_mult, + "weight_decay": weight_decay * self.layer2_mult, + }, + ], lr=lr, momentum=momentum, - weight_decay=weight_decay) + weight_decay=weight_decay, + ) self.temp_dir = tempfile.TemporaryDirectory() def test_base_scheduler_step(self): with self.assertRaises(NotImplementedError): - _ParamScheduler(self.optimizer, param_name='lr') + _ParamScheduler(self.optimizer, param_name="lr") def test_invalid_optimizer(self): - with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): - StepParamScheduler( - 'invalid_optimizer', step_size=1, param_name='lr') + with self.assertRaisesRegex(TypeError, "should be an Optimizer"): + StepParamScheduler("invalid_optimizer", step_size=1, param_name="lr") def test_overwrite_optimzer_step(self): # raise warning if the counter in optimizer.step() is overwritten - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.9) def overwrite_fun(): pass self.optimizer.step = overwrite_fun self.optimizer.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - scheduler.step) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", scheduler.step) def test_resume(self): # test invalid case: optimizer and scheduler are not both resumed - with self.assertRaisesRegex(KeyError, - "param 'initial_lr' is not specified"): - StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - last_step=10) + with self.assertRaisesRegex(KeyError, "param 'initial_lr' is not specified"): + StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3, last_step=10) # test manually resume with ``last_step`` instead of load_state_dict epochs = 10 targets = [0.05 * (0.9**x) for x in range(epochs)] - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.9) results = [] for epoch in range(5): - results.append(self.optimizer.param_groups[0]['lr']) + results.append(self.optimizer.param_groups[0]["lr"]) # The order should be # train_epoch() -> save_checkpoint() -> scheduler.step(). # Break at here to simulate the checkpoint is saved before @@ -111,85 +106,71 @@ def test_resume(self): if epoch == 4: break scheduler.step() - scheduler2 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, last_step=4) + scheduler2 = ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.9, last_step=4) for epoch in range(6): - results.append(self.optimizer.param_groups[0]['lr']) + results.append(self.optimizer.param_groups[0]["lr"]) scheduler2.step() for epoch in range(epochs): assert_allclose( targets[epoch], results[epoch], - msg='lr is wrong in epoch {}: expected {}, got {}'.format( - epoch, targets[epoch], results[epoch]), + msg=f"lr is wrong in epoch {epoch}: expected {targets[epoch]}, got {results[epoch]}", atol=1e-5, - rtol=0) + rtol=0, + ) def test_scheduler_before_optim_warning(self): """Warns if scheduler is used before optimizer.""" def call_sch_before_optim(): - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3) + scheduler = StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3) scheduler.step() self.optimizer.step() # check warning doc link - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - call_sch_before_optim) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", call_sch_before_optim) # check warning when resume - for i, group in enumerate(self.optimizer.param_groups): - group['initial_lr'] = 0.01 + for _i, group in enumerate(self.optimizer.param_groups): + group["initial_lr"] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - last_step=10) + scheduler = StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3, last_step=10) scheduler.step() self.optimizer.step() # check warning doc link - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', - call_sch_before_optim_resume) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", call_sch_before_optim_resume) def test_get_last_value(self): epochs = 10 single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', step_size=3, gamma=0.1) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = StepParamScheduler(self.optimizer, param_name="lr", step_size=3, gamma=0.1) for epoch in range(epochs): result = scheduler.get_last_value() - if isinstance(scheduler.optimizer, OptimWrapper) \ - and scheduler.optimizer.base_param_settings is not None: + if isinstance(scheduler.optimizer, OptimWrapper) and scheduler.optimizer.base_param_settings is not None: result.pop() self.optimizer.step() scheduler.step() target = [t[epoch] for t in targets] - for t, r in zip(target, result): + for t, r in zip(target, result, strict=False): assert_allclose( target, result, - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, t, r), + msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}", atol=1e-5, - rtol=0) + rtol=0, + ) def test_scheduler_step_count(self): iteration = 10 - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3) + scheduler = StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3) self.assertEqual(scheduler.last_step, 0) target = [i + 1 for i in range(iteration)] step_counts = [] - for i in range(iteration): + for _i in range(iteration): self.optimizer.step() scheduler.step() step_counts.append(scheduler.last_step) @@ -197,15 +178,8 @@ def test_scheduler_step_count(self): def test_effective_interval(self): # check invalid begin end - with self.assertRaisesRegex(ValueError, - 'end should be larger than begin'): - StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - begin=10, - end=5) + with self.assertRaisesRegex(ValueError, "end should be larger than begin"): + StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3, begin=10, end=5) # lr = 0.05 if epoch == 0 # lr = 0.025 if epoch == 1 @@ -217,34 +191,19 @@ def test_effective_interval(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [0.05] * begin + [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters - begin) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + self.optimizer, param_name="lr", start_factor=start_factor, begin=begin, end=begin + iters + 1 + ) self._test_scheduler_value(scheduler, targets, epochs) def test_param_name(self): with self.assertRaises(KeyError): - StepParamScheduler( - self.optimizer, param_name='invalid_name', step_size=10) - - def _test_scheduler_value(self, - schedulers, - targets, - epochs=10, - param_name='lr', - step_kwargs=None): + StepParamScheduler(self.optimizer, param_name="invalid_name", step_size=10) + + def _test_scheduler_value(self, schedulers, targets, epochs=10, param_name="lr", step_kwargs=None): if isinstance(schedulers, _ParamScheduler): schedulers = [schedulers] if step_kwargs is None: @@ -254,20 +213,15 @@ def _test_scheduler_value(self, assert len(step_kwargs) == epochs assert len(step_kwargs[0]) == len(schedulers) for epoch in range(epochs): - for param_group, target in zip(self.optimizer.param_groups, - targets): + for param_group, target in zip(self.optimizer.param_groups, targets, strict=False): assert_allclose( target[epoch], param_group[param_name], - msg='{} is wrong in epoch {}: expected {}, got {}'.format( - param_name, epoch, target[epoch], - param_group[param_name]), + msg=f"{param_name} is wrong in epoch {epoch}: expected {target[epoch]}, got {param_group[param_name]}", atol=1e-5, - rtol=0) - [ - scheduler.step(**step_kwargs[epoch][i]) - for i, scheduler in enumerate(schedulers) - ] + rtol=0, + ) + [scheduler.step(**step_kwargs[epoch][i]) for i, scheduler in enumerate(schedulers)] def test_step_scheduler(self): # lr = 0.05 if epoch < 3 @@ -275,30 +229,18 @@ def test_step_scheduler(self): # lr = 0.0005 if 6 <= epoch < 9 # lr = 0.00005 if epoch >=9 epochs = 10 - single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005 - ] * 3 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - verbose=True) + single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3, verbose=True) self._test_scheduler_value(scheduler, targets, epochs) # momentum = 0.01 if epoch < 2 # momentum = 0.001 if 2 <= epoch < 4 epochs = 4 single_targets = [0.01] * 2 + [0.001] * 2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = StepParamScheduler( - self.optimizer, param_name='momentum', gamma=0.1, step_size=2) - self._test_scheduler_value( - scheduler, targets, epochs, param_name='momentum') + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = StepParamScheduler(self.optimizer, param_name="momentum", gamma=0.1, step_size=2) + self._test_scheduler_value(scheduler, targets, epochs, param_name="momentum") def test_multi_step_scheduler(self): # lr = 0.05 if epoch < 2 @@ -306,44 +248,33 @@ def test_multi_step_scheduler(self): # lr = 0.0005 if 5 <= epoch < 9 # lr = 0.00005 if epoch >= 9 epochs = 10 - single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005 - ] * 3 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = MultiStepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, milestones=[2, 5, 9]) + single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = MultiStepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, milestones=[2, 5, 9]) self._test_scheduler_value(scheduler, targets, epochs) def test_constant_scheduler(self): # factor should between 0~1 with self.assertRaises(ValueError): - ConstantParamScheduler(self.optimizer, param_name='lr', factor=99) + ConstantParamScheduler(self.optimizer, param_name="lr", factor=99) # lr = 0.025 if epoch < 5 # lr = 0.005 if 5 <= epoch epochs = 10 single_targets = [0.025] * 4 + [0.05] * 6 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = ConstantParamScheduler( - self.optimizer, param_name='lr', factor=1.0 / 2, end=5) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = ConstantParamScheduler(self.optimizer, param_name="lr", factor=1.0 / 2, end=5) self._test_scheduler_value(scheduler, targets, epochs) def test_linear_scheduler(self): with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=10, end=900) + LinearParamScheduler(self.optimizer, param_name="lr", start_factor=10, end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=-1, end=900) + LinearParamScheduler(self.optimizer, param_name="lr", start_factor=-1, end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', end_factor=1.001, end=900) + LinearParamScheduler(self.optimizer, param_name="lr", end_factor=1.001, end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', end_factor=-0.00001, end=900) + LinearParamScheduler(self.optimizer, param_name="lr", end_factor=-0.00001, end=900) # lr = 0.025 if epoch == 0 # lr = 0.03125 if epoch == 1 # lr = 0.0375 if epoch == 2 @@ -352,77 +283,48 @@ def test_linear_scheduler(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=start_factor, - end=iters + 1) + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = LinearParamScheduler(self.optimizer, param_name="lr", start_factor=start_factor, end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler(self): epochs = 10 single_targets = [0.05 * (0.9**x) for x in range(epochs)] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler = ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.9) self._test_scheduler_value(scheduler, targets, epochs) def test_cos_anneal_scheduler(self): with self.assertRaises(AssertionError): - CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=10, - eta_min=0, - eta_min_ratio=0.1) + CosineAnnealingParamScheduler(self.optimizer, param_name="lr", T_max=10, eta_min=0, eta_min_ratio=0.1) epochs = 12 t = 10 eta_min = 5e-3 - targets1 = [ - eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 - for x in range(epochs) - ] - targets2 = [ - eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 - for x in range(epochs) - ] + targets1 = [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs)] + targets2 = [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs)] targets = [targets1, targets2] - scheduler = CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingParamScheduler(self.optimizer, param_name="lr", T_max=t, eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) # Test `eta_min_ratio` self.setUp() eta_min_ratio = 1e-3 targets1 = [ - 0.05 * eta_min_ratio + (0.05 - 0.05 * eta_min_ratio) * - (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) + 0.05 * eta_min_ratio + (0.05 - 0.05 * eta_min_ratio) * (1 + math.cos(math.pi * x / t)) / 2 + for x in range(epochs) ] targets2 = [ - 0.5 * eta_min_ratio + (0.5 - 0.5 * eta_min_ratio) * - (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) + 0.5 * eta_min_ratio + (0.5 - 0.5 * eta_min_ratio) * (1 + math.cos(math.pi * x / t)) / 2 + for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=t, - eta_min_ratio=eta_min_ratio) + scheduler = CosineAnnealingParamScheduler(self.optimizer, param_name="lr", T_max=t, eta_min_ratio=eta_min_ratio) self._test_scheduler_value(scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingParamScheduler(self.optimizer, param_name="lr", begin=5, end=100, eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -430,76 +332,54 @@ def test_poly_scheduler(self): power = 0.9 min_lr = 0.001 iters = 4 - targets_layer1 = [ - min_lr + (0.05 - min_lr) * (1 - i / iters)**power - for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + targets_layer1 = [min_lr + (0.05 - min_lr) * (1 - i / iters) ** power for i in range(iters)] + [min_lr] * ( + epochs - iters + ) targets_layer2 = [ - min_lr + (0.05 * self.layer2_mult - min_lr) * - (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters) ** power for i in range(iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyParamScheduler( - self.optimizer, - param_name='lr', - power=power, - eta_min=min_lr, - end=iters + 1) + scheduler = PolyParamScheduler(self.optimizer, param_name="lr", power=power, eta_min=min_lr, end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + self.optimizer, param_name="lr", periods=[4, 5], restart_weights=[1, 0.5], eta_min=0, eta_min_ratio=0.1 + ) with self.assertRaises(AssertionError): CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + self.optimizer, param_name="lr", periods=[4, 5], restart_weights=[1, 0.5, 0.0], eta_min=0 + ) single_targets = [ - 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, - 0.0086372, 0.0023872, 0.0023872 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + 0.05, + 0.0426776, + 0.025, + 0.00732233, + 0.025, + 0.022612712, + 0.01636271, + 0.0086372, + 0.0023872, + 0.0023872, + ] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] # Test with non-zero eta-min. scheduler = CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) + self.optimizer, param_name="lr", periods=[4, 5], restart_weights=[1, 0.5], eta_min=0 + ) self._test_scheduler_value(scheduler, targets, epochs=10) epochs = 10 t = 10 eta_min = 5e-3 - targets1 = [ - eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 - for x in range(epochs) - ] - targets2 = [ - eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 - for x in range(epochs) - ] + targets1 = [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs)] + targets2 = [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs)] targets = [targets1, targets2] scheduler = CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[t], - restart_weights=[1], - eta_min=eta_min) + self.optimizer, param_name="lr", periods=[t], restart_weights=[1], eta_min=eta_min + ) self._test_scheduler_value(scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): @@ -508,55 +388,59 @@ def test_reduce_on_plateau_scheduler(self): # Test error in __init__ method with self.assertRaises(TypeError): - ReduceOnPlateauParamScheduler('invalid_optimizer', param_name='lr') + ReduceOnPlateauParamScheduler("invalid_optimizer", param_name="lr") with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', begin=10, end=5) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", begin=10, end=5) with self.assertRaises(AssertionError): - ReduceOnPlateauParamScheduler(self.optimizer, 'lr', by_epoch=False) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", by_epoch=False) for last_step in (1.5, -2): with self.assertRaises(AssertionError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', last_step=last_step) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", last_step=last_step) with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler(self.optimizer, 'lr', factor=2.0) - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', min_value=[0.1, 0.1]) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", factor=2.0) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", min_value=[0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', min_value=[0.1, 0.1, 0.1, 0.1]) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", min_value=[0.1, 0.1, 0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler(self.optimizer, 'lr', threshold=-1.0) + ReduceOnPlateauParamScheduler(self.optimizer, "lr", threshold=-1.0) with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler(self.optimizer, 'lr', rule='foo') + ReduceOnPlateauParamScheduler(self.optimizer, "lr", rule="foo") with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', threshold_rule='foo') + ReduceOnPlateauParamScheduler(self.optimizer, "lr", threshold_rule="foo") # Test error in step method - scheduler = ReduceOnPlateauParamScheduler( - self.optimizer, param_name='lr', monitor='loss') + scheduler = ReduceOnPlateauParamScheduler(self.optimizer, param_name="lr", monitor="loss") assert scheduler.step() is None with self.assertRaises(TypeError): - scheduler.step(('foo', 1.0)) + scheduler.step(("foo", 1.0)) metrics = dict(loss_foo=1.0) with self.assertRaises(KeyError): scheduler.step(metrics) # Test scheduler value - def _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, - min_value): + def _test_value( + epochs, + targets, + metrics_list, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ): lr = 0.05 momentum = 0.01 weight_decay = 5e-4 scheduler = ReduceOnPlateauParamScheduler( self.optimizer, - param_name='lr', + param_name="lr", monitor=monitor, rule=rule, factor=factor, @@ -566,22 +450,23 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + self._test_scheduler_value(scheduler, targets, epochs=epochs, step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD( - [{ - 'params': self.model.conv1.parameters() - }, { - 'params': self.model.conv2.parameters(), - 'lr': lr * self.layer2_mult, - 'momentum': momentum * self.layer2_mult, - 'weight_decay': weight_decay * self.layer2_mult - }], + [ + {"params": self.model.conv1.parameters()}, + { + "params": self.model.conv2.parameters(), + "lr": lr * self.layer2_mult, + "momentum": momentum * self.layer2_mult, + "weight_decay": weight_decay * self.layer2_mult, + }, + ], lr=lr, momentum=momentum, - weight_decay=weight_decay) + weight_decay=weight_decay, + ) epochs = 10 factor = 0.1 @@ -589,91 +474,83 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, patience = 2 # rule(less) and threshold_rule(rel) - rule, threshold_rule = 'less', 'rel' + rule, threshold_rule = "less", "rel" threshold = 0.01 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # rule(less) and threshold_rule(abs) - rule, threshold_rule = 'less', 'abs' + rule, threshold_rule = "less", "abs" threshold = 0.9 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # rule(greater) and threshold_rule(rel) - rule, threshold_rule = 'greater', 'rel' + rule, threshold_rule = "greater", "rel" threshold = 0.01 - monitor = 'bbox_mAP' - metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + monitor = "bbox_mAP" + metric_values = [1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # rule(greater) and threshold_rule(abs) - rule, threshold_rule = 'greater', 'abs' + rule, threshold_rule = "greater", "abs" threshold = 0.9 - monitor = 'bbox_mAP' - metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + monitor = "bbox_mAP" + metric_values = [1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 - ] - targets = [ - single_targets, [t * self.layer2_mult for t in single_targets] - ] + single_targets = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005] + targets = [single_targets, [t * self.layer2_mult for t in single_targets]] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, 0.0) + _test_value( + epochs, targets, metrics_list, monitor, rule, factor, patience, threshold, threshold_rule, cooldown, 0.0 + ) # change min_value min_value = 0.01 - rule, threshold_rule = 'less', 'rel' + rule, threshold_rule = "less", "rel" threshold = 0.01 - monitor = 'loss' - metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + monitor = "loss" + metric_values = [10.0, 9.0, 8.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0] metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] - single_targets_1 = [ - 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, - min_value - ] + single_targets_1 = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, min_value] single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05] targets = [single_targets_1, single_targets_2] - _test_value(epochs, targets, metrics_list, monitor, rule, factor, - patience, threshold, threshold_rule, cooldown, min_value) - - def _check_scheduler_state_dict(self, - construct, - construct2, - epochs=10, - step_kwargs=None): + _test_value( + epochs, + targets, + metrics_list, + monitor, + rule, + factor, + patience, + threshold, + threshold_rule, + cooldown, + min_value, + ) + + def _check_scheduler_state_dict(self, construct, construct2, epochs=10, step_kwargs=None): if step_kwargs is None: step_kwargs = [{} for _ in range(epochs)] else: # step_kwargs is not None @@ -683,92 +560,68 @@ def _check_scheduler_state_dict(self, scheduler.optimizer.step() scheduler.step(**step_kwargs[epoch]) scheduler_copy = construct2() - torch.save(scheduler.state_dict(), - osp.join(self.temp_dir.name, 'tmp.pth')) - state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth')) + torch.save(scheduler.state_dict(), osp.join(self.temp_dir.name, "tmp.pth")) + state_dict = torch.load(osp.join(self.temp_dir.name, "tmp.pth")) scheduler_copy.load_state_dict(state_dict) for key in scheduler.__dict__.keys(): - if key != 'optimizer': - self.assertEqual(scheduler.__dict__[key], - scheduler_copy.__dict__[key]) - self.assertEqual(scheduler.get_last_value(), - scheduler_copy.get_last_value()) + if key != "optimizer": + self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) + self.assertEqual(scheduler.get_last_value(), scheduler_copy.get_last_value()) def test_step_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3), - lambda: StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.01 / 2, step_size=1)) + lambda: StepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, step_size=3), + lambda: StepParamScheduler(self.optimizer, param_name="lr", gamma=0.01 / 2, step_size=1), + ) def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[2, 5, 9]), lambda: MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.01, - milestones=[1, 4, 6])) + lambda: MultiStepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, milestones=[2, 5, 9]), + lambda: MultiStepParamScheduler(self.optimizer, param_name="lr", gamma=0.01, milestones=[1, 4, 6]), + ) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.1), - lambda: ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.01)) + lambda: ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.1), + lambda: ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.01), + ) def test_cosine_scheduler_state_dict(self): epochs = 10 eta_min = 1e-10 self._check_scheduler_state_dict( + lambda: CosineAnnealingParamScheduler(self.optimizer, param_name="lr", T_max=epochs, eta_min=eta_min), lambda: CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', T_max=epochs, eta_min=eta_min + self.optimizer, param_name="lr", T_max=epochs // 2, eta_min=eta_min / 2 ), - lambda: CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=epochs // 2, - eta_min=eta_min / 2), - epochs=epochs) + epochs=epochs, + ) def test_linear_scheduler_state_dict(self): epochs = 10 self._check_scheduler_state_dict( - lambda: LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=1 / 3), - lambda: LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=0, - end_factor=0.3), - epochs=epochs) + lambda: LinearParamScheduler(self.optimizer, param_name="lr", start_factor=1 / 3), + lambda: LinearParamScheduler(self.optimizer, param_name="lr", start_factor=0, end_factor=0.3), + epochs=epochs, + ) def test_poly_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: PolyParamScheduler( - self.optimizer, param_name='lr', power=0.5, eta_min=0.001), - lambda: PolyParamScheduler( - self.optimizer, param_name='lr', power=0.8, eta_min=0.002), - epochs=10) + lambda: PolyParamScheduler(self.optimizer, param_name="lr", power=0.5, eta_min=0.001), + lambda: PolyParamScheduler(self.optimizer, param_name="lr", power=0.8, eta_min=0.002), + epochs=10, + ) def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), + self.optimizer, param_name="lr", periods=[4, 5], restart_weights=[1, 0.5], eta_min=0 + ), lambda: CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), - epochs=10) + self.optimizer, param_name="lr", periods=[4, 6], restart_weights=[1, 0.5], eta_min=0 + ), + epochs=10, + ) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 @@ -776,57 +629,51 @@ def test_reduce_on_plateau_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: ReduceOnPlateauParamScheduler( self.optimizer, - param_name='lr', - monitor='loss', - rule='less', + param_name="lr", + monitor="loss", + rule="less", factor=0.01, patience=5, threshold=1e-4, - threshold_rule='rel', + threshold_rule="rel", cooldown=0, min_value=0.0, - eps=1e-8), + eps=1e-8, + ), lambda: ReduceOnPlateauParamScheduler( self.optimizer, - param_name='lr', - monitor='loss_foo', - rule='greater', + param_name="lr", + monitor="loss_foo", + rule="greater", factor=0.05, patience=10, threshold=1e-5, - threshold_rule='abs', + threshold_rule="abs", cooldown=5, min_value=0.1, - eps=1e-9), + eps=1e-9, + ), epochs=epochs, - step_kwargs=metrics_list) + step_kwargs=metrics_list, + ) def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError): scheduler = StepParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='momentum', - gamma=0.1, - step_size=2, - epoch_length=-1) + self.optimizer, param_name="momentum", gamma=0.1, step_size=2, epoch_length=-1 + ) # momentum = 0.01 if epoch < 2 # momentum = 0.001 if 2 <= epoch < 4 epochs = 4 epoch_length = 7 single_targets = [0.01] * 2 * epoch_length + [0.001] * 2 * epoch_length - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = StepParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='momentum', - gamma=0.1, - step_size=2, - epoch_length=epoch_length) - self._test_scheduler_value( - scheduler, targets, epochs * epoch_length, param_name='momentum') + self.optimizer, param_name="momentum", gamma=0.1, step_size=2, epoch_length=epoch_length + ) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length, param_name="momentum") def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.05 if epoch < 2 @@ -835,19 +682,16 @@ def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.00005 if epoch >= 9 epochs = 10 epoch_length = 7 - single_targets = [0.05 - ] * 2 * epoch_length + [0.005] * 3 * epoch_length + [ - 0.0005 - ] * 4 * epoch_length + [0.00005] * 3 * epoch_length - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + single_targets = ( + [0.05] * 2 * epoch_length + + [0.005] * 3 * epoch_length + + [0.0005] * 4 * epoch_length + + [0.00005] * 3 * epoch_length + ) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = MultiStepParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[2, 5, 9], - epoch_length=epoch_length) + self.optimizer, param_name="lr", gamma=0.1, milestones=[2, 5, 9], epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_constant_scheduler_convert_iterbased(self): @@ -855,17 +699,11 @@ def test_constant_scheduler_convert_iterbased(self): # lr = 0.005 if 5 <= epoch epochs = 10 epoch_length = 7 - single_targets = [0.025] * (5 * epoch_length - - 1) + [0.05] * (5 * epoch_length + 1) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + single_targets = [0.025] * (5 * epoch_length - 1) + [0.05] * (5 * epoch_length + 1) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = ConstantParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='lr', - factor=1.0 / 2, - end=5, - epoch_length=epoch_length) + self.optimizer, param_name="lr", factor=1.0 / 2, end=5, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_linear_scheduler_convert_iterbased(self): @@ -875,37 +713,23 @@ def test_linear_scheduler_convert_iterbased(self): epoch_length = 11 iters = end * epoch_length - 1 - interpolation = [ - start_factor + i * (1 - start_factor) / iters for i in range(iters) - ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs * epoch_length - iters) - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs * epoch_length - iters) + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = LinearParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='lr', - start_factor=start_factor, - end=end, - epoch_length=epoch_length) + self.optimizer, param_name="lr", start_factor=start_factor, end=end, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler_convert_iterbased(self): epochs = 10 epoch_length = 7 - single_targets = [ - 0.05 * (0.9**x) for x in range(epochs * epoch_length) - ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + single_targets = [0.05 * (0.9**x) for x in range(epochs * epoch_length)] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = ExponentialParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='lr', - gamma=0.9, - epoch_length=epoch_length) + self.optimizer, param_name="lr", gamma=0.9, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_cos_anneal_scheduler_convert_iterbased(self): @@ -914,19 +738,13 @@ def test_cos_anneal_scheduler_convert_iterbased(self): eta_min = 1e-10 epoch_length = 11 single_targets = [ - eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / t / epoch_length)) / 2 + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t / epoch_length)) / 2 for x in range(epochs * epoch_length) ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler = CosineAnnealingParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='lr', - T_max=t, - eta_min=eta_min, - epoch_length=epoch_length) + self.optimizer, param_name="lr", T_max=t, eta_min=eta_min, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs) def test_poly_scheduler_convert_iterbased(self): @@ -937,91 +755,54 @@ def test_poly_scheduler_convert_iterbased(self): epoch_length = 11 iters = end * epoch_length - 1 - targets_layer1 = [ - min_lr + (0.05 - min_lr) * (1 - i / iters)**power - for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + targets_layer1 = [min_lr + (0.05 - min_lr) * (1 - i / iters) ** power for i in range(iters)] + [min_lr] * ( + epochs - iters + ) targets_layer2 = [ - min_lr + (0.05 * self.layer2_mult - min_lr) * - (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters) ** power for i in range(iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] scheduler = PolyParamScheduler.build_iter_from_epoch( - self.optimizer, - param_name='lr', - power=power, - eta_min=min_lr, - end=end, - epoch_length=epoch_length) + self.optimizer, param_name="lr", power=power, eta_min=min_lr, end=end, epoch_length=epoch_length + ) self._test_scheduler_value(scheduler, targets, epochs=10) def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 - single_targets = [0.025, 0.03125, 0.0375, 0.04375 - ] + [0.05] * 4 + [0.005] * 3 + [0.0005] * 1 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler1 = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=1 / 2, - begin=0, - end=5) + single_targets = [0.025, 0.03125, 0.0375, 0.04375] + [0.05] * 4 + [0.005] * 3 + [0.0005] * 1 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler1 = LinearParamScheduler(self.optimizer, param_name="lr", start_factor=1 / 2, begin=0, end=5) scheduler2 = MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[3, 6], - begin=5, - end=12) + self.optimizer, param_name="lr", gamma=0.1, milestones=[3, 6], begin=5, end=12 + ) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.9, begin=0, end=5) eta_min = 1e-10 single_targets2 = [ - eta_min + (single_targets1[-1] - eta_min) * - (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) + eta_min + (single_targets1[-1] - eta_min) * (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) ] single_targets = single_targets1 + single_targets2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler2 = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=5, - eta_min=eta_min, - begin=5, - end=10) + self.optimizer, param_name="lr", T_max=5, eta_min=eta_min, begin=5, end=10 + ) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_overlap(self): # use Linear at first 5 epochs together with MultiStep epochs = 10 - single_targets = [0.025, 0.03125, 0.0375, 0.004375 - ] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 1 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] - ] - scheduler1 = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=1 / 2, - begin=0, - end=5) - scheduler2 = MultiStepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, milestones=[3, 6, 9]) + single_targets = [0.025, 0.03125, 0.0375, 0.004375] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 1 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] + scheduler1 = LinearParamScheduler(self.optimizer, param_name="lr", start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepParamScheduler(self.optimizer, param_name="lr", gamma=0.1, milestones=[3, 6, 9]) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): @@ -1029,49 +810,33 @@ def test_multi_scheduler_with_gap(self): # no scheduler in the middle 5 epochs epochs = 15 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialParamScheduler(self.optimizer, param_name="lr", gamma=0.9, begin=0, end=5) eta_min = 1e-10 single_targets2 = [ - eta_min + (single_targets1[-1] - eta_min) * - (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) - ] - single_targets = single_targets1 + [single_targets1[-1] - ] * 5 + single_targets2 - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] + eta_min + (single_targets1[-1] - eta_min) * (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) ] + single_targets = single_targets1 + [single_targets1[-1]] * 5 + single_targets2 + targets = [single_targets, [x * self.layer2_mult for x in single_targets]] scheduler2 = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=5, - eta_min=eta_min, - begin=10, - end=15) + self.optimizer, param_name="lr", T_max=5, eta_min=eta_min, begin=10, end=15 + ) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_onecycle_scheduler(self): # test invalid total steps with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, param_name='lr', total_steps=-1) + OneCycleParamScheduler(self.optimizer, param_name="lr", total_steps=-1) # test invalid pct_start with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, param_name='lr', total_steps=10, pct_start=-1) + OneCycleParamScheduler(self.optimizer, param_name="lr", total_steps=10, pct_start=-1) # test invalid anneal_strategy with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, - param_name='lr', - total_steps=10, - anneal_strategy='a') + OneCycleParamScheduler(self.optimizer, param_name="lr", total_steps=10, anneal_strategy="a") class TestParameterSchedulerOptimWrapper(TestParameterScheduler): - def setUp(self): super().setUp() self.optimizer = OptimWrapper(optimizer=self.optimizer) diff --git a/tests/test_registry/test_build_functions.py b/tests/test_registry/test_build_functions.py index 80094ae107..34433f3c5c 100644 --- a/tests/test_registry/test_build_functions.py +++ b/tests/test_registry/test_build_functions.py @@ -1,40 +1,34 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest -from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin, - Registry, build_from_cfg, build_model_from_cfg) +from mmengine import PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin, Registry, build_from_cfg, build_model_from_cfg from mmengine.utils import is_installed -@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) +@pytest.mark.parametrize("cfg_type", [dict, ConfigDict, Config]) def test_build_from_cfg(cfg_type): - BACKBONES = Registry('backbone') + BACKBONES = Registry("backbone") @BACKBONES.register_module() class ResNet: - def __init__(self, depth, stages=4): self.depth = depth self.stages = stages @BACKBONES.register_module() class ResNeXt: - def __init__(self, depth, stages=4): self.depth = depth self.stages = stages # test `cfg` parameter # `cfg` should be a dict, ConfigDict or Config object - with pytest.raises( - TypeError, - match=('cfg should be a dict, ConfigDict or Config, but got ' - "")): - cfg = 'ResNet' + with pytest.raises(TypeError, match=("cfg should be a dict, ConfigDict or Config, but got ")): + cfg = "ResNet" model = build_from_cfg(cfg, BACKBONES) # `cfg` is a dict, ConfigDict or Config object - cfg = cfg_type(dict(type='ResNet', depth=50)) + cfg = cfg_type(dict(type="ResNet", depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 @@ -46,14 +40,12 @@ def __init__(self, depth, stages=4): model = build_from_cfg(cfg, BACKBONES) # cfg['type'] should be a str or class - with pytest.raises( - TypeError, - match="type must be a str or valid type, but got "): + with pytest.raises(TypeError, match="type must be a str or valid type, but got "): cfg = dict(type=1000) cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) - cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) + cfg = cfg_type(dict(type="ResNeXt", depth=50, stages=3)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 @@ -65,80 +57,72 @@ def __init__(self, depth, stages=4): # non-registered class with pytest.raises( - KeyError, - match='VGG is not in the test_build_functions::backbone registry', + KeyError, + match="VGG is not in the test_build_functions::backbone registry", ): - cfg = cfg_type(dict(type='VGG')) + cfg = cfg_type(dict(type="VGG")) model = build_from_cfg(cfg, BACKBONES) # `cfg` contains unexpected arguments with pytest.raises(TypeError): - cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) + cfg = cfg_type(dict(type="ResNet", non_existing_arg=50)) model = build_from_cfg(cfg, BACKBONES) # test `default_args` parameter - cfg = cfg_type(dict(type='ResNet', depth=50)) + cfg = cfg_type(dict(type="ResNet", depth=50)) model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 3 # default_args must be a dict or None with pytest.raises(TypeError): - cfg = cfg_type(dict(type='ResNet', depth=50)) + cfg = cfg_type(dict(type="ResNet", depth=50)) model = build_from_cfg(cfg, BACKBONES, default_args=1) # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + model = build_from_cfg(cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + model = build_from_cfg(cfg, BACKBONES, default_args=cfg_type(dict(type="ResNet"))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + model = build_from_cfg(cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # test `registry` parameter # incorrect registry type - with pytest.raises( - TypeError, - match=('registry must be a mmengine.Registry object, but got ' - "")): - cfg = cfg_type(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, 'BACKBONES') + with pytest.raises(TypeError, match=("registry must be a mmengine.Registry object, but got ")): + cfg = cfg_type(dict(type="ResNet", depth=50)) + model = build_from_cfg(cfg, "BACKBONES") - VISUALIZER = Registry('visualizer') + VISUALIZER = Registry("visualizer") @VISUALIZER.register_module() class Visualizer(ManagerMixin): - def __init__(self, name): super().__init__(name) with pytest.raises(RuntimeError): Visualizer.get_current_instance() - cfg = dict(type='Visualizer', name='visualizer') + cfg = dict(type="Visualizer", name="visualizer") build_from_cfg(cfg, VISUALIZER) Visualizer.get_current_instance() -@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') +@pytest.mark.skipif(not is_installed("torch"), reason="tests requires torch") def test_build_model_from_cfg(): import torch.nn as nn - BACKBONES = Registry('backbone', build_func=build_model_from_cfg) + BACKBONES = Registry("backbone", build_func=build_model_from_cfg) @BACKBONES.register_module() class ResNet(nn.Module): - def __init__(self, depth, stages=4): super().__init__() self.depth = depth @@ -149,7 +133,6 @@ def forward(self, x): @BACKBONES.register_module() class ResNeXt(nn.Module): - def __init__(self, depth, stages=4): super().__init__() self.depth = depth @@ -158,20 +141,17 @@ def __init__(self, depth, stages=4): def forward(self, x): return x - cfg = dict(type='ResNet', depth=50) + cfg = dict(type="ResNet", depth=50) model = BACKBONES.build(cfg) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 - cfg = dict(type='ResNeXt', depth=50, stages=3) + cfg = dict(type="ResNeXt", depth=50, stages=3) model = BACKBONES.build(cfg) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 - cfg = [ - dict(type='ResNet', depth=50), - dict(type='ResNeXt', depth=50, stages=3) - ] + cfg = [dict(type="ResNet", depth=50), dict(type="ResNeXt", depth=50, stages=3)] model = BACKBONES.build(cfg) assert isinstance(model, nn.Sequential) assert isinstance(model[0], ResNet) @@ -180,41 +160,38 @@ def forward(self, x): assert model[1].depth == 50 and model[1].stages == 3 # test inherit `build_func` from parent - NEW_MODELS = Registry('models', parent=BACKBONES, scope='new') + NEW_MODELS = Registry("models", parent=BACKBONES, scope="new") assert NEW_MODELS.build_func is build_model_from_cfg # test specify `build_func` def pseudo_build(cfg): return cfg - NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build) + NEW_MODELS = Registry("models", parent=BACKBONES, build_func=pseudo_build) assert NEW_MODELS.build_func is pseudo_build -@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') +@pytest.mark.skipif(not is_installed("torch"), reason="tests requires torch") def test_build_scheduler_from_cfg(): import torch.nn as nn from torch.optim import SGD + model = nn.Conv2d(1, 1, 1) optimizer = SGD(model.parameters(), lr=0.1) - cfg = dict( - type='LinearParamScheduler', - optimizer=optimizer, - param_name='lr', - begin=0, - end=100) + cfg = dict(type="LinearParamScheduler", optimizer=optimizer, param_name="lr", begin=0, end=100) scheduler = PARAM_SCHEDULERS.build(cfg) assert scheduler.begin == 0 assert scheduler.end == 100 cfg = dict( - type='LinearParamScheduler', + type="LinearParamScheduler", convert_to_iter_based=True, optimizer=optimizer, - param_name='lr', + param_name="lr", begin=0, end=100, - epoch_length=10) + epoch_length=10, + ) scheduler = PARAM_SCHEDULERS.build(cfg) assert scheduler.begin == 0 diff --git a/tests/test_registry/test_default_scope.py b/tests/test_registry/test_default_scope.py index 0798f4a2c7..993f677b7b 100644 --- a/tests/test_registry/test_default_scope.py +++ b/tests/test_registry/test_default_scope.py @@ -7,42 +7,36 @@ class TestDefaultScope: - def test_scope(self): - default_scope = DefaultScope.get_instance('name1', scope_name='mmdet') - assert default_scope.scope_name == 'mmdet' + default_scope = DefaultScope.get_instance("name1", scope_name="mmdet") + assert default_scope.scope_name == "mmdet" # `DefaultScope.get_instance` must have `scope_name` argument. with pytest.raises(TypeError): - DefaultScope.get_instance('name2') + DefaultScope.get_instance("name2") def test_get_current_instance(self): DefaultScope._instance_dict = OrderedDict() assert DefaultScope.get_current_instance() is None - DefaultScope.get_instance('instance_name', scope_name='mmengine') + DefaultScope.get_instance("instance_name", scope_name="mmengine") default_scope = DefaultScope.get_current_instance() - assert default_scope.scope_name == 'mmengine' + assert default_scope.scope_name == "mmengine" def test_overwrite_default_scope(self): - origin_scope = DefaultScope.get_instance( - 'test_overwrite_default_scope', scope_name='origin_scope') + origin_scope = DefaultScope.get_instance("test_overwrite_default_scope", scope_name="origin_scope") with DefaultScope.overwrite_default_scope(scope_name=None): - assert DefaultScope.get_current_instance( - ).scope_name == 'origin_scope' - with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'): - assert DefaultScope.get_current_instance( - ).scope_name == 'test_overwrite' - assert DefaultScope.get_current_instance( - ).scope_name == origin_scope.scope_name == 'origin_scope' + assert DefaultScope.get_current_instance().scope_name == "origin_scope" + with DefaultScope.overwrite_default_scope(scope_name="test_overwrite"): + assert DefaultScope.get_current_instance().scope_name == "test_overwrite" + assert DefaultScope.get_current_instance().scope_name == origin_scope.scope_name == "origin_scope" # Test overwrite default scope immediately. # Test sequentially overwrite. - with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'): + with DefaultScope.overwrite_default_scope(scope_name="test_overwrite"): pass - with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'): + with DefaultScope.overwrite_default_scope(scope_name="test_overwrite"): pass # Test nested overwrite. - with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'): - with DefaultScope.overwrite_default_scope( - scope_name='test_overwrite'): + with DefaultScope.overwrite_default_scope(scope_name="test_overwrite"): + with DefaultScope.overwrite_default_scope(scope_name="test_overwrite"): pass diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index eb99b3dc8e..bca41d152a 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -5,16 +5,14 @@ import pytest from mmengine.config import Config, ConfigDict # type: ignore -from mmengine.registry import (DefaultScope, Registry, build_from_cfg, - build_model_from_cfg) +from mmengine.registry import DefaultScope, Registry, build_from_cfg, build_model_from_cfg from mmengine.utils import ManagerMixin, is_installed class TestRegistry: - def test_init(self): - CATS = Registry('cat') - assert CATS.name == 'cat' + CATS = Registry("cat") + assert CATS.name == "cat" assert CATS.module_dict == {} assert CATS.build_func is build_from_cfg assert len(CATS) == 0 @@ -23,48 +21,48 @@ def test_init(self): def build_func(cfg, registry, default_args): pass - CATS = Registry('cat', build_func=build_func) + CATS = Registry("cat", build_func=build_func) assert CATS.build_func is build_func # test `parent` parameter # `parent` is either None or a `Registry` instance with pytest.raises(AssertionError): - CATS = Registry('little_cat', parent='cat', scope='little_cat') + CATS = Registry("little_cat", parent="cat", scope="little_cat") - LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat') + LITTLECATS = Registry("little_cat", parent=CATS, scope="little_cat") assert LITTLECATS.parent is CATS - assert CATS._children.get('little_cat') is LITTLECATS + assert CATS._children.get("little_cat") is LITTLECATS # test `scope` parameter # `scope` is either None or a string with pytest.raises(AssertionError): - CATS = Registry('cat', scope=1) + CATS = Registry("cat", scope=1) - CATS = Registry('cat') - assert CATS.scope == 'test_registry' + CATS = Registry("cat") + assert CATS.scope == "test_registry" - CATS = Registry('cat', scope='cat') - assert CATS.scope == 'cat' + CATS = Registry("cat", scope="cat") + assert CATS.scope == "cat" def test_split_scope_key(self): - DOGS = Registry('dogs') + DOGS = Registry("dogs") - scope, key = DOGS.split_scope_key('BloodHound') - assert scope is None and key == 'BloodHound' - scope, key = DOGS.split_scope_key('hound.BloodHound') - assert scope == 'hound' and key == 'BloodHound' - scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund') - assert scope == 'hound' and key == 'little_hound.Dachshund' + scope, key = DOGS.split_scope_key("BloodHound") + assert scope is None and key == "BloodHound" + scope, key = DOGS.split_scope_key("hound.BloodHound") + assert scope == "hound" and key == "BloodHound" + scope, key = DOGS.split_scope_key("hound.little_hound.Dachshund") + assert scope == "hound" and key == "little_hound.Dachshund" def test_register_module(self): - CATS = Registry('cat') + CATS = Registry("cat") @CATS.register_module() def muchkin(size): pass - assert CATS.get('muchkin') is muchkin - assert 'muchkin' in CATS + assert CATS.get("muchkin") is muchkin + assert "muchkin" in CATS # test `name` parameter which must be either of None, a string or a # sequence of string @@ -74,31 +72,30 @@ class BritishShorthair: pass assert len(CATS) == 2 - assert CATS.get('BritishShorthair') is BritishShorthair + assert CATS.get("BritishShorthair") is BritishShorthair # `name` is a string - @CATS.register_module(name='Munchkin') + @CATS.register_module(name="Munchkin") class Munchkin: pass assert len(CATS) == 3 - assert CATS.get('Munchkin') is Munchkin - assert 'Munchkin' in CATS + assert CATS.get("Munchkin") is Munchkin + assert "Munchkin" in CATS # `name` is a sequence of string - @CATS.register_module(name=['Siamese', 'Siamese2']) + @CATS.register_module(name=["Siamese", "Siamese2"]) class SiameseCat: pass - assert CATS.get('Siamese') is SiameseCat - assert CATS.get('Siamese2') is SiameseCat + assert CATS.get("Siamese") is SiameseCat + assert CATS.get("Siamese2") is SiameseCat assert len(CATS) == 5 # `name` is an invalid type with pytest.raises( - TypeError, - match=('name must be None, an instance of str, or a sequence ' - "of str, but got ")): + TypeError, match=("name must be None, an instance of str, or a sequence of str, but got ") + ): @CATS.register_module(name=7474741) class SiameseCat: @@ -106,19 +103,14 @@ class SiameseCat: # test `force` parameter, which must be a boolean # force is not a boolean - with pytest.raises( - TypeError, - match="force must be a boolean, but got "): + with pytest.raises(TypeError, match="force must be a boolean, but got "): @CATS.register_module(force=1) class BritishShorthair: pass # force=False - with pytest.raises( - KeyError, - match='BritishShorthair is already registered in cat ' - 'at test_registry'): + with pytest.raises(KeyError, match="BritishShorthair is already registered in cat at test_registry"): @CATS.register_module() class BritishShorthair: @@ -134,37 +126,34 @@ class BritishShorthair: # test `module` parameter, which is either None or a class # when the `register_module`` is called as a method rather than a # decorator, which must be a class - with pytest.raises( - TypeError, - match='module must be Callable,' - " but got "): - CATS.register_module(module='string') + with pytest.raises(TypeError, match="module must be Callable, but got "): + CATS.register_module(module="string") class SphynxCat: pass CATS.register_module(module=SphynxCat) - assert CATS.get('SphynxCat') is SphynxCat + assert CATS.get("SphynxCat") is SphynxCat assert len(CATS) == 6 - CATS.register_module(name='Sphynx1', module=SphynxCat) - assert CATS.get('Sphynx1') is SphynxCat + CATS.register_module(name="Sphynx1", module=SphynxCat) + assert CATS.get("Sphynx1") is SphynxCat assert len(CATS) == 7 - CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat) - assert CATS.get('Sphynx2') is SphynxCat - assert CATS.get('Sphynx3') is SphynxCat + CATS.register_module(name=["Sphynx2", "Sphynx3"], module=SphynxCat) + assert CATS.get("Sphynx2") is SphynxCat + assert CATS.get("Sphynx3") is SphynxCat assert len(CATS) == 9 # partial functions can be registered muchkin0 = functools.partial(muchkin, size=0) - CATS.register_module('muchkin0', False, muchkin0) + CATS.register_module("muchkin0", False, muchkin0) # lambda functions can be registered - CATS.register_module(name='unknown cat', module=lambda: 'unknown') + CATS.register_module(name="unknown cat", module=lambda: "unknown") - assert CATS.get('muchkin0') is muchkin0 - assert 'unknown cat' in CATS - assert 'muchkin0' in CATS + assert CATS.get("muchkin0") is muchkin0 + assert "unknown cat" in CATS + assert "muchkin0" in CATS assert len(CATS) == 11 def _build_registry(self): @@ -179,19 +168,17 @@ def _build_registry(self): # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = [] - DOGS = Registry('dogs') + DOGS = Registry("dogs") registries.append(DOGS) - HOUNDS = Registry('hounds', parent=DOGS, scope='hound') + HOUNDS = Registry("hounds", parent=DOGS, scope="hound") registries.append(HOUNDS) - LITTLE_HOUNDS = Registry( - 'little hounds', parent=HOUNDS, scope='little_hound') + LITTLE_HOUNDS = Registry("little hounds", parent=HOUNDS, scope="little_hound") registries.append(LITTLE_HOUNDS) - MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound') + MID_HOUNDS = Registry("mid hounds", parent=HOUNDS, scope="mid_hound") registries.append(MID_HOUNDS) - SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed') + SAMOYEDS = Registry("samoyeds", parent=DOGS, scope="samoyed") registries.append(SAMOYEDS) - LITTLE_SAMOYEDS = Registry( - 'little samoyeds', parent=SAMOYEDS, scope='little_samoyed') + LITTLE_SAMOYEDS = Registry("little samoyeds", parent=SAMOYEDS, scope="little_samoyed") registries.append(LITTLE_SAMOYEDS) return registries @@ -236,17 +223,17 @@ def test_get(self): def bark(word, times): return [word] * times - dog_bark = functools.partial(bark, 'woof') - DOGS.register_module('dog_bark', False, dog_bark) + dog_bark = functools.partial(bark, "woof") + DOGS.register_module("dog_bark", False, dog_bark) @DOGS.register_module() class GoldenRetriever: pass assert len(DOGS) == 3 - assert DOGS.get('GoldenRetriever') is GoldenRetriever - assert DOGS.get('bark') is bark - assert DOGS.get('dog_bark') is dog_bark + assert DOGS.get("GoldenRetriever") is GoldenRetriever + assert DOGS.get("bark") is bark + assert DOGS.get("dog_bark") is dog_bark @HOUNDS.register_module() class BloodHound: @@ -254,17 +241,17 @@ class BloodHound: assert len(HOUNDS) == 1 # get key from current registry - assert HOUNDS.get('BloodHound') is BloodHound + assert HOUNDS.get("BloodHound") is BloodHound # get key from its children - assert DOGS.get('hound.BloodHound') is BloodHound + assert DOGS.get("hound.BloodHound") is BloodHound # get key from current registry - assert HOUNDS.get('hound.BloodHound') is BloodHound + assert HOUNDS.get("hound.BloodHound") is BloodHound # If the key is not found in the current registry, then look for its # parent - assert HOUNDS.get('GoldenRetriever') is GoldenRetriever - assert HOUNDS.get('bark') is bark - assert HOUNDS.get('dog_bark') is dog_bark + assert HOUNDS.get("GoldenRetriever") is GoldenRetriever + assert HOUNDS.get("bark") is bark + assert HOUNDS.get("dog_bark") is dog_bark @LITTLE_HOUNDS.register_module() class Dachshund: @@ -272,25 +259,25 @@ class Dachshund: assert len(LITTLE_HOUNDS) == 1 # get key from current registry - assert LITTLE_HOUNDS.get('Dachshund') is Dachshund + assert LITTLE_HOUNDS.get("Dachshund") is Dachshund # get key from its parent - assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound + assert LITTLE_HOUNDS.get("hound.BloodHound") is BloodHound # get key from its children - assert HOUNDS.get('little_hound.Dachshund') is Dachshund + assert HOUNDS.get("little_hound.Dachshund") is Dachshund # get key from its descendants - assert DOGS.get('hound.little_hound.Dachshund') is Dachshund + assert DOGS.get("hound.little_hound.Dachshund") is Dachshund # If the key is not found in the current registry, then look for its # parent - assert LITTLE_HOUNDS.get('BloodHound') is BloodHound - assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever + assert LITTLE_HOUNDS.get("BloodHound") is BloodHound + assert LITTLE_HOUNDS.get("GoldenRetriever") is GoldenRetriever @MID_HOUNDS.register_module() class Beagle: pass # get key from its sibling registries - assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle + assert LITTLE_HOUNDS.get("hound.mid_hound.Beagle") is Beagle @SAMOYEDS.register_module() class PedigreeSamoyed: @@ -298,36 +285,35 @@ class PedigreeSamoyed: assert len(SAMOYEDS) == 1 # get key from its uncle - assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed + assert LITTLE_HOUNDS.get("samoyed.PedigreeSamoyed") is PedigreeSamoyed @LITTLE_SAMOYEDS.register_module() class LittlePedigreeSamoyed: pass # get key from its cousin - assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed' - ) is LittlePedigreeSamoyed + assert LITTLE_HOUNDS.get("samoyed.little_samoyed.LittlePedigreeSamoyed") is LittlePedigreeSamoyed # get key from its nephews - assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed' - ) is LittlePedigreeSamoyed + assert HOUNDS.get("samoyed.little_samoyed.LittlePedigreeSamoyed") is LittlePedigreeSamoyed # invalid keys # GoldenRetrieverererer can not be found at LITTLE_HOUNDS modules - assert LITTLE_HOUNDS.get('GoldenRetrieverererer') is None + assert LITTLE_HOUNDS.get("GoldenRetrieverererer") is None # samoyedddd is not a child of DOGS - assert DOGS.get('samoyedddd.PedigreeSamoyed') is None + assert DOGS.get("samoyedddd.PedigreeSamoyed") is None # samoyed is a child of DOGS but LittlePedigreeSamoyed can not be found # at SAMOYEDS modules - assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None - assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None + assert DOGS.get("samoyed.LittlePedigreeSamoyed") is None + assert LITTLE_HOUNDS.get("mid_hound.PedigreeSamoyedddddd") is None # Get mmengine.utils by string - utils = LITTLE_HOUNDS.get('mmengine.utils') + utils = LITTLE_HOUNDS.get("mmengine.utils") import mmengine.utils + assert utils is mmengine.utils - unknown = LITTLE_HOUNDS.get('mmengine.unknown') + unknown = LITTLE_HOUNDS.get("mmengine.unknown") assert unknown is None def test__search_child(self): @@ -343,13 +329,13 @@ def test__search_child(self): registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] - assert DOGS._search_child('hound') is HOUNDS - assert DOGS._search_child('not a child') is None - assert DOGS._search_child('little_hound') is LITTLE_HOUNDS - assert LITTLE_HOUNDS._search_child('hound') is None - assert LITTLE_HOUNDS._search_child('mid_hound') is None + assert DOGS._search_child("hound") is HOUNDS + assert DOGS._search_child("not a child") is None + assert DOGS._search_child("little_hound") is LITTLE_HOUNDS + assert LITTLE_HOUNDS._search_child("hound") is None + assert LITTLE_HOUNDS._search_child("mid_hound") is None - @pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) + @pytest.mark.parametrize("cfg_type", [dict, ConfigDict, Config]) def test_build(self, cfg_type): # Hierarchical Registry # DOGS @@ -365,64 +351,61 @@ def test_build(self, cfg_type): @DOGS.register_module() def bark(word, times): - return ' '.join([word] * times) + return " ".join([word] * times) - dog_bark = functools.partial(bark, word='woof') - DOGS.register_module('dog_bark', False, dog_bark) + dog_bark = functools.partial(bark, word="woof") + DOGS.register_module("dog_bark", False, dog_bark) - bark_cfg = cfg_type(dict(type='bark', word='meow', times=3)) - dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3)) + bark_cfg = cfg_type(dict(type="bark", word="meow", times=3)) + dog_bark_cfg = cfg_type(dict(type="dog_bark", times=3)) @DOGS.register_module() class GoldenRetriever: pass - gr_cfg = cfg_type(dict(type='GoldenRetriever')) + gr_cfg = cfg_type(dict(type="GoldenRetriever")) assert isinstance(DOGS.build(gr_cfg), GoldenRetriever) - assert DOGS.build(bark_cfg) == 'meow meow meow' - assert DOGS.build(dog_bark_cfg) == 'woof woof woof' + assert DOGS.build(bark_cfg) == "meow meow meow" + assert DOGS.build(dog_bark_cfg) == "woof woof woof" @HOUNDS.register_module() class BloodHound: pass - bh_cfg = cfg_type(dict(type='BloodHound')) + bh_cfg = cfg_type(dict(type="BloodHound")) assert isinstance(HOUNDS.build(bh_cfg), BloodHound) assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever) - assert HOUNDS.build(bark_cfg) == 'meow meow meow' - assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof' + assert HOUNDS.build(bark_cfg) == "meow meow meow" + assert HOUNDS.build(dog_bark_cfg) == "woof woof woof" @LITTLE_HOUNDS.register_module() class Dachshund: pass - d_cfg = cfg_type(dict(type='Dachshund')) + d_cfg = cfg_type(dict(type="Dachshund")) assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund) @MID_HOUNDS.register_module() class Beagle: pass - b_cfg = cfg_type(dict(type='Beagle')) + b_cfg = cfg_type(dict(type="Beagle")) assert isinstance(MID_HOUNDS.build(b_cfg), Beagle) # test `default_scope` # switch the current registry to another registry - DefaultScope.get_instance( - f'test-{time.time()}', scope_name='mid_hound') + DefaultScope.get_instance(f"test-{time.time()}", scope_name="mid_hound") dog = LITTLE_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) # `default_scope` can not be found - DefaultScope.get_instance( - f'test2-{time.time()}', scope_name='scope-not-found') + DefaultScope.get_instance(f"test2-{time.time()}", scope_name="scope-not-found") dog = MID_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) # test overwrite default scope with `_scope_` @SAMOYEDS.register_module() class MySamoyed: - def __init__(self, friend): self.friend = DOGS.build(friend) @@ -430,43 +413,33 @@ def __init__(self, friend): class YourSamoyed: pass - s_cfg = cfg_type( - dict( - _scope_='samoyed', - type='MySamoyed', - friend=dict(type='hound.BloodHound'))) + s_cfg = cfg_type(dict(_scope_="samoyed", type="MySamoyed", friend=dict(type="hound.BloodHound"))) dog = DOGS.build(s_cfg) assert isinstance(dog, MySamoyed) assert isinstance(dog.friend, BloodHound) - assert DefaultScope.get_current_instance().scope_name != 'samoyed' + assert DefaultScope.get_current_instance().scope_name != "samoyed" - s_cfg = cfg_type( - dict( - _scope_='samoyed', - type='MySamoyed', - friend=dict(type='YourSamoyed'))) + s_cfg = cfg_type(dict(_scope_="samoyed", type="MySamoyed", friend=dict(type="YourSamoyed"))) dog = DOGS.build(s_cfg) assert isinstance(dog, MySamoyed) assert isinstance(dog.friend, YourSamoyed) - assert DefaultScope.get_current_instance().scope_name != 'samoyed' + assert DefaultScope.get_current_instance().scope_name != "samoyed" # build an instance by lambda or partial function. lambda_dog = lambda name: name # noqa: E731 - DOGS.register_module(name='lambda_dog', module=lambda_dog) - lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown')) - assert DOGS.build(lambda_cfg) == 'unknown' + DOGS.register_module(name="lambda_dog", module=lambda_dog) + lambda_cfg = cfg_type(dict(type="lambda_dog", name="unknown")) + assert DOGS.build(lambda_cfg) == "unknown" - DOGS.register_module( - name='patial dog', - module=functools.partial(lambda_dog, name='patial')) - unknown_cfg = cfg_type(dict(type='patial dog')) - assert DOGS.build(unknown_cfg) == 'patial' + DOGS.register_module(name="patial dog", module=functools.partial(lambda_dog, name="patial")) + unknown_cfg = cfg_type(dict(type="patial dog")) + assert DOGS.build(unknown_cfg) == "patial" def test_switch_scope_and_registry(self): - DOGS = Registry('dogs') - HOUNDS = Registry('hounds', scope='hound', parent=DOGS) - SAMOYEDS = Registry('samoyeds', scope='samoyed', parent=DOGS) - CHIHUAHUA = Registry('chihuahuas', scope='chihuahua', parent=DOGS) + DOGS = Registry("dogs") + HOUNDS = Registry("hounds", scope="hound", parent=DOGS) + SAMOYEDS = Registry("samoyeds", scope="samoyed", parent=DOGS) + CHIHUAHUA = Registry("chihuahuas", scope="chihuahua", parent=DOGS) # Hierarchical Registry # DOGS @@ -474,39 +447,33 @@ def test_switch_scope_and_registry(self): # | | | # HOUNDS (hound) SAMOYEDS (samoyed) CHIHUAHUA (chihuahua) - DefaultScope.get_instance( - f'scope_{time.time()}', scope_name='chihuahua') - assert DefaultScope.get_current_instance().scope_name == 'chihuahua' + DefaultScope.get_instance(f"scope_{time.time()}", scope_name="chihuahua") + assert DefaultScope.get_current_instance().scope_name == "chihuahua" # Test switch scope and get target registry. - with CHIHUAHUA.switch_scope_and_registry(scope='hound') as \ - registry: - assert DefaultScope.get_current_instance().scope_name == 'hound' + with CHIHUAHUA.switch_scope_and_registry(scope="hound") as registry: + assert DefaultScope.get_current_instance().scope_name == "hound" assert id(registry) == id(HOUNDS) # Test nested-ly switch scope. - with CHIHUAHUA.switch_scope_and_registry(scope='samoyed') as \ - samoyed_registry: - assert DefaultScope.get_current_instance().scope_name == 'samoyed' + with CHIHUAHUA.switch_scope_and_registry(scope="samoyed") as samoyed_registry: + assert DefaultScope.get_current_instance().scope_name == "samoyed" assert id(samoyed_registry) == id(SAMOYEDS) - with CHIHUAHUA.switch_scope_and_registry(scope='hound') as \ - hound_registry: - assert DefaultScope.get_current_instance().scope_name == \ - 'hound' + with CHIHUAHUA.switch_scope_and_registry(scope="hound") as hound_registry: + assert DefaultScope.get_current_instance().scope_name == "hound" assert id(hound_registry) == id(HOUNDS) # Test switch to original scope - assert DefaultScope.get_current_instance().scope_name == 'chihuahua' + assert DefaultScope.get_current_instance().scope_name == "chihuahua" # Test get an unknown registry. - with CHIHUAHUA.switch_scope_and_registry(scope='unknown') as \ - registry: + with CHIHUAHUA.switch_scope_and_registry(scope="unknown") as registry: assert id(registry) == id(CHIHUAHUA) - assert DefaultScope.get_current_instance().scope_name == 'unknown' + assert DefaultScope.get_current_instance().scope_name == "unknown" def test_repr(self): - CATS = Registry('cat') + CATS = Registry("cat") @CATS.register_module() class BritishShorthair: @@ -516,40 +483,35 @@ class BritishShorthair: class Munchkin: pass - assert 'Registry of cat' in repr(CATS) - assert 'BritishShorthair' in repr(CATS) - assert 'Munchkin' in repr(CATS) + assert "Registry of cat" in repr(CATS) + assert "BritishShorthair" in repr(CATS) + assert "Munchkin" in repr(CATS) -@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) +@pytest.mark.parametrize("cfg_type", [dict, ConfigDict, Config]) def test_build_from_cfg(cfg_type): - BACKBONES = Registry('backbone') + BACKBONES = Registry("backbone") @BACKBONES.register_module() class ResNet: - def __init__(self, depth, stages=4): self.depth = depth self.stages = stages @BACKBONES.register_module() class ResNeXt: - def __init__(self, depth, stages=4): self.depth = depth self.stages = stages # test `cfg` parameter # `cfg` should be a dict, ConfigDict or Config object - with pytest.raises( - TypeError, - match=('cfg should be a dict, ConfigDict or Config, but got ' - "")): - cfg = 'ResNet' + with pytest.raises(TypeError, match=("cfg should be a dict, ConfigDict or Config, but got ")): + cfg = "ResNet" model = build_from_cfg(cfg, BACKBONES) # `cfg` is a dict, ConfigDict or Config object - cfg = cfg_type(dict(type='ResNet', depth=50)) + cfg = cfg_type(dict(type="ResNet", depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 @@ -561,14 +523,12 @@ def __init__(self, depth, stages=4): model = build_from_cfg(cfg, BACKBONES) # cfg['type'] should be a str or class - with pytest.raises( - TypeError, - match="type must be a str or valid type, but got "): + with pytest.raises(TypeError, match="type must be a str or valid type, but got "): cfg = dict(type=1000) cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) - cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) + cfg = cfg_type(dict(type="ResNeXt", depth=50, stages=3)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 @@ -580,72 +540,64 @@ def __init__(self, depth, stages=4): # `cfg` contains unexpected arguments with pytest.raises(TypeError): - cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) + cfg = cfg_type(dict(type="ResNet", non_existing_arg=50)) model = build_from_cfg(cfg, BACKBONES) # test `default_args` parameter - cfg = cfg_type(dict(type='ResNet', depth=50)) + cfg = cfg_type(dict(type="ResNet", depth=50)) model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 3 # default_args must be a dict or None with pytest.raises(TypeError): - cfg = cfg_type(dict(type='ResNet', depth=50)) + cfg = cfg_type(dict(type="ResNet", depth=50)) model = build_from_cfg(cfg, BACKBONES, default_args=1) # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + model = build_from_cfg(cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + model = build_from_cfg(cfg, BACKBONES, default_args=cfg_type(dict(type="ResNet"))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + model = build_from_cfg(cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # test `registry` parameter # incorrect registry type - with pytest.raises( - TypeError, - match=('registry must be a mmengine.Registry object, but got ' - "")): - cfg = cfg_type(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, 'BACKBONES') + with pytest.raises(TypeError, match=("registry must be a mmengine.Registry object, but got ")): + cfg = cfg_type(dict(type="ResNet", depth=50)) + model = build_from_cfg(cfg, "BACKBONES") - VISUALIZER = Registry('visualizer') + VISUALIZER = Registry("visualizer") @VISUALIZER.register_module() class Visualizer(ManagerMixin): - def __init__(self, name): super().__init__(name) with pytest.raises(RuntimeError): Visualizer.get_current_instance() - cfg = dict(type='Visualizer', name='visualizer') + cfg = dict(type="Visualizer", name="visualizer") build_from_cfg(cfg, VISUALIZER) Visualizer.get_current_instance() -@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') +@pytest.mark.skipif(not is_installed("torch"), reason="tests requires torch") def test_build_model_from_cfg(): import torch.nn as nn - BACKBONES = Registry('backbone', build_func=build_model_from_cfg) + BACKBONES = Registry("backbone", build_func=build_model_from_cfg) @BACKBONES.register_module() class ResNet(nn.Module): - def __init__(self, depth, stages=4): super().__init__() self.depth = depth @@ -656,7 +608,6 @@ def forward(self, x): @BACKBONES.register_module() class ResNeXt(nn.Module): - def __init__(self, depth, stages=4): super().__init__() self.depth = depth @@ -665,20 +616,17 @@ def __init__(self, depth, stages=4): def forward(self, x): return x - cfg = dict(type='ResNet', depth=50) + cfg = dict(type="ResNet", depth=50) model = BACKBONES.build(cfg) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 - cfg = dict(type='ResNeXt', depth=50, stages=3) + cfg = dict(type="ResNeXt", depth=50, stages=3) model = BACKBONES.build(cfg) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 - cfg = [ - dict(type='ResNet', depth=50), - dict(type='ResNeXt', depth=50, stages=3) - ] + cfg = [dict(type="ResNet", depth=50), dict(type="ResNeXt", depth=50, stages=3)] model = BACKBONES.build(cfg) assert isinstance(model, nn.Sequential) assert isinstance(model[0], ResNet) @@ -687,12 +635,12 @@ def forward(self, x): assert model[1].depth == 50 and model[1].stages == 3 # test inherit `build_func` from parent - NEW_MODELS = Registry('models', parent=BACKBONES, scope='new') + NEW_MODELS = Registry("models", parent=BACKBONES, scope="new") assert NEW_MODELS.build_func is build_model_from_cfg # test specify `build_func` def pseudo_build(cfg): return cfg - NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build) + NEW_MODELS = Registry("models", parent=BACKBONES, build_func=pseudo_build) assert NEW_MODELS.build_func is pseudo_build diff --git a/tests/test_registry/test_registry_utils.py b/tests/test_registry/test_registry_utils.py index ebd53dc03b..cc3d72107b 100644 --- a/tests/test_registry/test_registry_utils.py +++ b/tests/test_registry/test_registry_utils.py @@ -5,14 +5,18 @@ from unittest import TestCase, skipIf from mmengine.logging import MMLogger -from mmengine.registry import (DefaultScope, Registry, - count_registered_modules, init_default_scope, - root, traverse_registry_tree) +from mmengine.registry import ( + DefaultScope, + Registry, + count_registered_modules, + init_default_scope, + root, + traverse_registry_tree, +) from mmengine.utils import is_installed class TestUtils(TestCase): - def test_traverse_registry_tree(self): # Hierarchical Registry # DOGS @@ -23,14 +27,16 @@ def test_traverse_registry_tree(self): # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) - DOGS = Registry('dogs') - HOUNDS = Registry('dogs', parent=DOGS, scope='hound') + DOGS = Registry("dogs") + HOUNDS = Registry("dogs", parent=DOGS, scope="hound") LITTLE_HOUNDS = Registry( # noqa - 'dogs', parent=HOUNDS, scope='little_hound') - MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound') - SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed') + "dogs", parent=HOUNDS, scope="little_hound" + ) + MID_HOUNDS = Registry("dogs", parent=HOUNDS, scope="mid_hound") + SAMOYEDS = Registry("dogs", parent=DOGS, scope="samoyed") LITTLE_SAMOYEDS = Registry( # noqa - 'dogs', parent=SAMOYEDS, scope='little_samoyed') + "dogs", parent=SAMOYEDS, scope="little_samoyed" + ) @DOGS.register_module() class GoldenRetriever: @@ -38,7 +44,7 @@ class GoldenRetriever: # traversing the tree from the root result = traverse_registry_tree(DOGS) - self.assertEqual(result[0]['num_modules'], 1) + self.assertEqual(result[0]["num_modules"], 1) self.assertEqual(len(result), 6) # traversing the tree from leaf node @@ -46,37 +52,31 @@ class GoldenRetriever: # result from any node should be the same self.assertEqual(result, result_leaf) - @skipIf(not is_installed('torch'), 'tests requires torch') + @skipIf(not is_installed("torch"), "tests requires torch") def test_count_all_registered_modules(self): temp_dir = TemporaryDirectory() results = count_registered_modules(temp_dir.name, verbose=True) - self.assertTrue( - osp.exists( - osp.join(temp_dir.name, 'modules_statistic_results.json'))) - registries_info = results['registries'] + self.assertTrue(osp.exists(osp.join(temp_dir.name, "modules_statistic_results.json"))) + registries_info = results["registries"] for registry in registries_info: self.assertTrue(hasattr(root, registry)) - self.assertEqual(registries_info[registry][0]['num_modules'], - len(getattr(root, registry).module_dict)) + self.assertEqual(registries_info[registry][0]["num_modules"], len(getattr(root, registry).module_dict)) temp_dir.cleanup() # test not saving results count_registered_modules(save_path=None, verbose=False) - self.assertFalse( - osp.exists( - osp.join(temp_dir.name, 'modules_statistic_results.json'))) + self.assertFalse(osp.exists(osp.join(temp_dir.name, "modules_statistic_results.json"))) - @skipIf(not is_installed('torch'), 'tests requires torch') + @skipIf(not is_installed("torch"), "tests requires torch") def test_init_default_scope(self): # init default scope - init_default_scope('mmdet') - self.assertEqual(DefaultScope.get_current_instance().scope_name, - 'mmdet') + init_default_scope("mmdet") + self.assertEqual(DefaultScope.get_current_instance().scope_name, "mmdet") # init default scope when another scope is init - name = f'test-{datetime.datetime.now()}' - DefaultScope.get_instance(name, scope_name='test') + name = f"test-{datetime.datetime.now()}" + DefaultScope.get_instance(name, scope_name="test") # Warning should be raised since the current # default scope is not 'mmdet' - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): - init_default_scope('mmdet') + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): + init_default_scope("mmdet") diff --git a/tests/test_runner/test_activation_checkpointing.py b/tests/test_runner/test_activation_checkpointing.py index d48c027cdb..c3b04504f3 100644 --- a/tests/test_runner/test_activation_checkpointing.py +++ b/tests/test_runner/test_activation_checkpointing.py @@ -5,13 +5,11 @@ import torch.nn.functional as F from torch import nn -from mmengine.runner.activation_checkpointing import \ - turn_on_activation_checkpointing +from mmengine.runner.activation_checkpointing import turn_on_activation_checkpointing from mmengine.testing import assert_allclose class Model(nn.Module): - def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) @@ -37,7 +35,6 @@ def forward(self, x): class TestActivationCheckpointing(TestCase): - def test_activation_checkpointing(self): model = Model() input = torch.randn(16, 3, 224, 224) @@ -46,7 +43,7 @@ def test_activation_checkpointing(self): output.sum().backward() grad = input.grad.clone() - turn_on_activation_checkpointing(model, ['conv1', 'conv2', 'conv3']) + turn_on_activation_checkpointing(model, ["conv1", "conv2", "conv3"]) output2 = model(input) output2.sum().backward() grad2 = input.grad.clone() diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index 7208e25079..75df5b2184 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -5,18 +5,16 @@ import torch.nn as nn import mmengine -from mmengine.device import (get_device, is_mlu_available, is_musa_available, - is_npu_available) +from mmengine.device import get_device, is_mlu_available, is_musa_available, is_npu_available from mmengine.runner import autocast from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION class TestAmp(unittest.TestCase): - def test_autocast(self): if is_npu_available(): - device = 'npu' + device = "npu" with autocast(device_type=device): # torch.autocast support npu mode. layer = nn.Conv2d(1, 1, 1).to(device) @@ -31,7 +29,7 @@ def test_autocast(self): res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32) elif is_mlu_available(): - device = 'mlu' + device = "mlu" with autocast(device_type=device): # torch.autocast support mlu mode. layer = nn.Conv2d(1, 1, 1).to(device) @@ -46,7 +44,7 @@ def test_autocast(self): res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32) elif is_musa_available(): - device = 'musa' + device = "musa" with autocast(device_type=device): # torch.autocast support mlu mode. layer = nn.Conv2d(1, 1, 1).to(device) @@ -61,12 +59,11 @@ def test_autocast(self): res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32) elif not torch.cuda.is_available(): - if digit_version(TORCH_VERSION) < digit_version('1.10.0'): + if digit_version(TORCH_VERSION) < digit_version("1.10.0"): # `torch.cuda.amp.autocast` is only support in gpu mode, if # cuda is not available, it will return an empty context and # should not accept any arguments. - with self.assertRaisesRegex(RuntimeError, - 'If pytorch versions is '): + with self.assertRaisesRegex(RuntimeError, "If pytorch versions is "): with autocast(): pass @@ -76,7 +73,7 @@ def test_autocast(self): self.assertEqual(res.dtype, torch.float32) else: - with autocast(device_type='cpu'): + with autocast(device_type="cpu"): # torch.autocast support cpu mode. layer = nn.Conv2d(1, 1, 1) res = layer(torch.randn(1, 1, 1, 1)) @@ -86,10 +83,10 @@ def test_autocast(self): self.assertEqual(res.dtype, torch.float32) else: - if digit_version(TORCH_VERSION) < digit_version('1.10.0'): - devices = ['cuda'] + if digit_version(TORCH_VERSION) < digit_version("1.10.0"): + devices = ["cuda"] else: - devices = ['cpu', 'cuda'] + devices = ["cpu", "cuda"] for device in devices: with autocast(device_type=device): # torch.autocast support cpu and cuda mode. @@ -106,21 +103,20 @@ def test_autocast(self): self.assertEqual(res.dtype, torch.float32) # Test mps - if digit_version(TORCH_VERSION) >= digit_version('1.12.0'): - mmengine.runner.amp.get_device = lambda: 'mps' + if digit_version(TORCH_VERSION) >= digit_version("1.12.0"): + mmengine.runner.amp.get_device = lambda: "mps" with autocast(enabled=False): layer = nn.Conv2d(1, 1, 1) res = layer(torch.randn(1, 1, 1, 1)) self.assertEqual(res.dtype, torch.float32) - with self.assertRaisesRegex(ValueError, - 'User specified autocast device_type'): + with self.assertRaisesRegex(ValueError, "User specified autocast device_type"): with autocast(enabled=True): pass - # Native pytorch does not support mlu, here we simply test autocast - # will call `torch.autocast`, which will be overridden by mlu version - # pytorch - mmengine.runner.amp.get_device = lambda: 'mlu' + # Native pytorch does not support mlu, here we simply test autocast + # will call `torch.autocast`, which will be overridden by mlu version + # pytorch + mmengine.runner.amp.get_device = lambda: "mlu" with self.assertRaises(RuntimeError): with autocast(enabled=False): pass diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 4655a4c5da..d5eddb3133 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -14,22 +14,25 @@ from mmengine.fileio.file_client import PetrelBackend from mmengine.registry import MODEL_WRAPPERS -from mmengine.runner.checkpoint import (CheckpointLoader, - _load_checkpoint_with_prefix, - get_state_dict, load_checkpoint, - load_from_local, load_from_pavi, - load_state_dict, save_checkpoint) +from mmengine.runner.checkpoint import ( + CheckpointLoader, + _load_checkpoint_with_prefix, + get_state_dict, + load_checkpoint, + load_from_local, + load_from_pavi, + load_state_dict, + save_checkpoint, +) @MODEL_WRAPPERS.register_module() class DDPWrapper: - def __init__(self, module): self.module = module class Block(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) @@ -37,7 +40,6 @@ def __init__(self): class Model(nn.Module): - def __init__(self): super().__init__() self.block = Block() @@ -45,8 +47,7 @@ def __init__(self): class Mockpavimodel: - - def __init__(self, name='fakename'): + def __init__(self, name="fakename"): self.name = name def download(self, file): @@ -58,18 +59,28 @@ def assert_tensor_equal(tensor_a, tensor_b): def test_get_state_dict(): - if torch.__version__ == 'parrots': + if torch.__version__ == "parrots": state_dict_keys = { - 'block.conv.weight', 'block.conv.bias', 'block.norm.weight', - 'block.norm.bias', 'block.norm.running_mean', - 'block.norm.running_var', 'conv.weight', 'conv.bias' + "block.conv.weight", + "block.conv.bias", + "block.norm.weight", + "block.norm.bias", + "block.norm.running_mean", + "block.norm.running_var", + "conv.weight", + "conv.bias", } else: state_dict_keys = { - 'block.conv.weight', 'block.conv.bias', 'block.norm.weight', - 'block.norm.bias', 'block.norm.running_mean', - 'block.norm.running_var', 'block.norm.num_batches_tracked', - 'conv.weight', 'conv.bias' + "block.conv.weight", + "block.conv.bias", + "block.norm.weight", + "block.norm.bias", + "block.norm.running_mean", + "block.norm.running_var", + "block.norm.num_batches_tracked", + "conv.weight", + "conv.bias", } model = Model() @@ -77,46 +88,33 @@ def test_get_state_dict(): assert isinstance(state_dict, OrderedDict) assert set(state_dict.keys()) == state_dict_keys - assert_tensor_equal(state_dict['block.conv.weight'], - model.block.conv.weight) - assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias) - assert_tensor_equal(state_dict['block.norm.weight'], - model.block.norm.weight) - assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias) - assert_tensor_equal(state_dict['block.norm.running_mean'], - model.block.norm.running_mean) - assert_tensor_equal(state_dict['block.norm.running_var'], - model.block.norm.running_var) - if torch.__version__ != 'parrots': - assert_tensor_equal(state_dict['block.norm.num_batches_tracked'], - model.block.norm.num_batches_tracked) - assert_tensor_equal(state_dict['conv.weight'], model.conv.weight) - assert_tensor_equal(state_dict['conv.bias'], model.conv.bias) + assert_tensor_equal(state_dict["block.conv.weight"], model.block.conv.weight) + assert_tensor_equal(state_dict["block.conv.bias"], model.block.conv.bias) + assert_tensor_equal(state_dict["block.norm.weight"], model.block.norm.weight) + assert_tensor_equal(state_dict["block.norm.bias"], model.block.norm.bias) + assert_tensor_equal(state_dict["block.norm.running_mean"], model.block.norm.running_mean) + assert_tensor_equal(state_dict["block.norm.running_var"], model.block.norm.running_var) + if torch.__version__ != "parrots": + assert_tensor_equal(state_dict["block.norm.num_batches_tracked"], model.block.norm.num_batches_tracked) + assert_tensor_equal(state_dict["conv.weight"], model.conv.weight) + assert_tensor_equal(state_dict["conv.bias"], model.conv.bias) wrapped_model = DDPWrapper(model) state_dict = get_state_dict(wrapped_model) assert isinstance(state_dict, OrderedDict) assert set(state_dict.keys()) == state_dict_keys - assert_tensor_equal(state_dict['block.conv.weight'], - wrapped_model.module.block.conv.weight) - assert_tensor_equal(state_dict['block.conv.bias'], - wrapped_model.module.block.conv.bias) - assert_tensor_equal(state_dict['block.norm.weight'], - wrapped_model.module.block.norm.weight) - assert_tensor_equal(state_dict['block.norm.bias'], - wrapped_model.module.block.norm.bias) - assert_tensor_equal(state_dict['block.norm.running_mean'], - wrapped_model.module.block.norm.running_mean) - assert_tensor_equal(state_dict['block.norm.running_var'], - wrapped_model.module.block.norm.running_var) - if torch.__version__ != 'parrots': + assert_tensor_equal(state_dict["block.conv.weight"], wrapped_model.module.block.conv.weight) + assert_tensor_equal(state_dict["block.conv.bias"], wrapped_model.module.block.conv.bias) + assert_tensor_equal(state_dict["block.norm.weight"], wrapped_model.module.block.norm.weight) + assert_tensor_equal(state_dict["block.norm.bias"], wrapped_model.module.block.norm.bias) + assert_tensor_equal(state_dict["block.norm.running_mean"], wrapped_model.module.block.norm.running_mean) + assert_tensor_equal(state_dict["block.norm.running_var"], wrapped_model.module.block.norm.running_var) + if torch.__version__ != "parrots": assert_tensor_equal( - state_dict['block.norm.num_batches_tracked'], - wrapped_model.module.block.norm.num_batches_tracked) - assert_tensor_equal(state_dict['conv.weight'], - wrapped_model.module.conv.weight) - assert_tensor_equal(state_dict['conv.bias'], - wrapped_model.module.conv.bias) + state_dict["block.norm.num_batches_tracked"], wrapped_model.module.block.norm.num_batches_tracked + ) + assert_tensor_equal(state_dict["conv.weight"], wrapped_model.module.conv.weight) + assert_tensor_equal(state_dict["conv.bias"], wrapped_model.module.conv.bias) # wrapped inner module for name, module in wrapped_model.module._modules.items(): @@ -125,46 +123,37 @@ def test_get_state_dict(): state_dict = get_state_dict(wrapped_model) assert isinstance(state_dict, OrderedDict) assert set(state_dict.keys()) == state_dict_keys - assert_tensor_equal(state_dict['block.conv.weight'], - wrapped_model.module.block.module.conv.weight) - assert_tensor_equal(state_dict['block.conv.bias'], - wrapped_model.module.block.module.conv.bias) - assert_tensor_equal(state_dict['block.norm.weight'], - wrapped_model.module.block.module.norm.weight) - assert_tensor_equal(state_dict['block.norm.bias'], - wrapped_model.module.block.module.norm.bias) - assert_tensor_equal(state_dict['block.norm.running_mean'], - wrapped_model.module.block.module.norm.running_mean) - assert_tensor_equal(state_dict['block.norm.running_var'], - wrapped_model.module.block.module.norm.running_var) - if torch.__version__ != 'parrots': + assert_tensor_equal(state_dict["block.conv.weight"], wrapped_model.module.block.module.conv.weight) + assert_tensor_equal(state_dict["block.conv.bias"], wrapped_model.module.block.module.conv.bias) + assert_tensor_equal(state_dict["block.norm.weight"], wrapped_model.module.block.module.norm.weight) + assert_tensor_equal(state_dict["block.norm.bias"], wrapped_model.module.block.module.norm.bias) + assert_tensor_equal(state_dict["block.norm.running_mean"], wrapped_model.module.block.module.norm.running_mean) + assert_tensor_equal(state_dict["block.norm.running_var"], wrapped_model.module.block.module.norm.running_var) + if torch.__version__ != "parrots": assert_tensor_equal( - state_dict['block.norm.num_batches_tracked'], - wrapped_model.module.block.module.norm.num_batches_tracked) - assert_tensor_equal(state_dict['conv.weight'], - wrapped_model.module.conv.module.weight) - assert_tensor_equal(state_dict['conv.bias'], - wrapped_model.module.conv.module.bias) + state_dict["block.norm.num_batches_tracked"], wrapped_model.module.block.module.norm.num_batches_tracked + ) + assert_tensor_equal(state_dict["conv.weight"], wrapped_model.module.conv.module.weight) + assert_tensor_equal(state_dict["conv.bias"], wrapped_model.module.conv.module.bias) -@patch.dict(sys.modules, {'pavi': MagicMock()}) +@patch.dict(sys.modules, {"pavi": MagicMock()}) def test_load_pavimodel_dist(): pavimodel = Mockpavimodel() import pavi + pavi.modelcloud.get = MagicMock(return_value=pavimodel) with pytest.raises(AssertionError): # test pavi prefix - _ = load_from_pavi('MyPaviFolder/checkpoint.pth') + _ = load_from_pavi("MyPaviFolder/checkpoint.pth") with pytest.raises(FileNotFoundError): # there is not such checkpoint for us to load - _ = load_from_pavi('pavi://checkpoint.pth') + _ = load_from_pavi("pavi://checkpoint.pth") def test_load_checkpoint_with_prefix(): - class FooModule(nn.Module): - def __init__(self): super().__init__() self.linear = nn.Linear(1, 2) @@ -180,18 +169,16 @@ def __init__(self): nn.init.constant_(model.conv2d_2.bias, 6) with TemporaryDirectory(): - torch.save(model.state_dict(), 'model.pth') - prefix = 'conv2d' - state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') - assert torch.equal(model.conv2d.state_dict()['weight'], - state_dict['weight']) - assert torch.equal(model.conv2d.state_dict()['bias'], - state_dict['bias']) + torch.save(model.state_dict(), "model.pth") + prefix = "conv2d" + state_dict = _load_checkpoint_with_prefix(prefix, "model.pth") + assert torch.equal(model.conv2d.state_dict()["weight"], state_dict["weight"]) + assert torch.equal(model.conv2d.state_dict()["bias"], state_dict["bias"]) # test whether prefix is in pretrained model with pytest.raises(AssertionError): - prefix = 'back' - _load_checkpoint_with_prefix(prefix, 'model.pth') + prefix = "back" + _load_checkpoint_with_prefix(prefix, "model.pth") def test_load_checkpoint(): @@ -200,36 +187,31 @@ def test_load_checkpoint(): import tempfile class PrefixModel(nn.Module): - def __init__(self): super().__init__() self.backbone = Model() pmodel = PrefixModel() model = Model() - checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') + checkpoint_path = os.path.join(tempfile.gettempdir(), "checkpoint.pth") # add prefix torch.save(model.state_dict(), checkpoint_path) - state_dict = load_checkpoint( - pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')]) + state_dict = load_checkpoint(pmodel, checkpoint_path, revise_keys=[(r"^", "backbone.")]) for key in pmodel.backbone.state_dict().keys(): assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key]) # strip prefix torch.save(pmodel.state_dict(), checkpoint_path) - state_dict = load_checkpoint( - model, checkpoint_path, revise_keys=[(r'^backbone\.', '')]) + state_dict = load_checkpoint(model, checkpoint_path, revise_keys=[(r"^backbone\.", "")]) for key in state_dict.keys(): - key_stripped = re.sub(r'^backbone\.', '', key) + key_stripped = re.sub(r"^backbone\.", "", key) assert torch.equal(model.state_dict()[key_stripped], state_dict[key]) os.remove(checkpoint_path) def test_load_checkpoint_metadata(): - class ModelV1(nn.Module): - def __init__(self): super().__init__() self.block = Block() @@ -249,15 +231,14 @@ def __init__(self): nn.init.normal_(self.conv0.weight) nn.init.normal_(self.conv1.weight) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - *args, **kwargs): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, **kwargs): """Load checkpoints.""" # Names of some parameters in has been changed. - version = local_metadata.get('version', None) + version = local_metadata.get("version", None) if version is None or version < 2: state_dict_keys = list(state_dict.keys()) - convert_map = {'conv1': 'conv0', 'conv2': 'conv1'} + convert_map = {"conv1": "conv0", "conv2": "conv1"} for k in state_dict_keys: for ori_str, new_str in convert_map.items(): if k.startswith(prefix + ori_str): @@ -265,8 +246,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, state_dict[new_key] = state_dict[k] del state_dict[k] - super()._load_from_state_dict(state_dict, prefix, local_metadata, - *args, **kwargs) + super()._load_from_state_dict(state_dict, prefix, local_metadata, *args, **kwargs) model_v1 = ModelV1() model_v1_conv0_weight = model_v1.conv1.weight.detach() @@ -274,8 +254,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, model_v2 = ModelV2() model_v2_conv0_weight = model_v2.conv0.weight.detach() model_v2_conv1_weight = model_v2.conv1.weight.detach() - ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth') - ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth') + ckpt_v1_path = os.path.join(tempfile.gettempdir(), "checkpoint_v1.pth") + ckpt_v2_path = os.path.join(tempfile.gettempdir(), "checkpoint_v2.pth") # Save checkpoint save_checkpoint(model_v1.state_dict(), ckpt_v1_path) @@ -292,66 +272,84 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight) -@patch.dict(sys.modules, {'petrel_client': MagicMock()}) +@patch.dict(sys.modules, {"petrel_client": MagicMock()}) def test_checkpoint_loader(): filenames = [ - 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', - 'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth', - 'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth', - 'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth', - 'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth', - 'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth', - 'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path' + "http://xx.xx/xx.pth", + "https://xx.xx/xx.pth", + "modelzoo://xx.xx/xx.pth", + "torchvision://xx.xx/xx.pth", + "open-mmlab://xx.xx/xx.pth", + "openmmlab://xx.xx/xx.pth", + "mmcls://xx.xx/xx.pth", + "pavi://xx.xx/xx.pth", + "s3://xx.xx/xx.pth", + "ss3://xx.xx/xx.pth", + " s3://xx.xx/xx.pth", + "open-mmlab:s3://xx.xx/xx.pth", + "openmmlab:s3://xx.xx/xx.pth", + "openmmlabs3://xx.xx/xx.pth", + ":s3://xx.xx/xx.path", ] fn_names = [ - 'load_from_http', 'load_from_http', 'load_from_torchvision', - 'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab', - 'load_from_mmcls', 'load_from_pavi', 'load_from_ceph', - 'load_from_local', 'load_from_local', 'load_from_ceph', - 'load_from_ceph', 'load_from_local', 'load_from_local' + "load_from_http", + "load_from_http", + "load_from_torchvision", + "load_from_torchvision", + "load_from_openmmlab", + "load_from_openmmlab", + "load_from_mmcls", + "load_from_pavi", + "load_from_ceph", + "load_from_local", + "load_from_local", + "load_from_ceph", + "load_from_ceph", + "load_from_local", + "load_from_local", ] - for filename, fn_name in zip(filenames, fn_names): + for filename, fn_name in zip(filenames, fn_names, strict=False): loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == fn_name - @CheckpointLoader.register_scheme(prefixes='ftp://') + @CheckpointLoader.register_scheme(prefixes="ftp://") def load_from_ftp(filename, map_location): return dict(filename=filename) # test register_loader - filename = 'ftp://xx.xx/xx.pth' + filename = "ftp://xx.xx/xx.pth" loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == 'load_from_ftp' + assert loader.__name__ == "load_from_ftp" def load_from_ftp1(filename, map_location): return dict(filename=filename) # test duplicate registered error with pytest.raises(KeyError): - CheckpointLoader.register_scheme('ftp://', load_from_ftp1) + CheckpointLoader.register_scheme("ftp://", load_from_ftp1) # test force param - CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True) + CheckpointLoader.register_scheme("ftp://", load_from_ftp1, force=True) checkpoint = CheckpointLoader.load_checkpoint(filename) - assert checkpoint['filename'] == filename + assert checkpoint["filename"] == filename # test print function name loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == 'load_from_ftp1' + assert loader.__name__ == "load_from_ftp1" # test sort - @CheckpointLoader.register_scheme(prefixes='a/b') + @CheckpointLoader.register_scheme(prefixes="a/b") def load_from_ab(filename, map_location): return dict(filename=filename) - @CheckpointLoader.register_scheme(prefixes='a/b/c') + @CheckpointLoader.register_scheme(prefixes="a/b/c") def load_from_abc(filename, map_location): return dict(filename=filename) - filename = 'a/b/c/d' + filename = "a/b/c/d" loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == 'load_from_abc' + assert loader.__name__ == "load_from_abc" def test_save_checkpoint(tmp_path): @@ -359,52 +357,43 @@ def test_save_checkpoint(tmp_path): optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) # meta is not a dict with pytest.raises(TypeError): - save_checkpoint(model, '/path/of/your/filename', meta='invalid type') + save_checkpoint(model, "/path/of/your/filename", meta="invalid type") # 1. save to disk - filename = str(tmp_path / 'checkpoint1.pth') + filename = str(tmp_path / "checkpoint1.pth") save_checkpoint(model.state_dict(), filename) - filename = str(tmp_path / 'checkpoint2.pth') - checkpoint = dict( - model=model.state_dict(), optimizer=optimizer.state_dict()) + filename = str(tmp_path / "checkpoint2.pth") + checkpoint = dict(model=model.state_dict(), optimizer=optimizer.state_dict()) save_checkpoint(checkpoint, filename) - filename = str(tmp_path / 'checkpoint3.pth') - save_checkpoint( - model.state_dict(), filename, backend_args={'backend': 'local'}) + filename = str(tmp_path / "checkpoint3.pth") + save_checkpoint(model.state_dict(), filename, backend_args={"backend": "local"}) - filename = str(tmp_path / 'checkpoint4.pth') - save_checkpoint( - model.state_dict(), filename, file_client_args={'backend': 'disk'}) + filename = str(tmp_path / "checkpoint4.pth") + save_checkpoint(model.state_dict(), filename, file_client_args={"backend": "disk"}) # 2. save to petrel oss - with patch.object(PetrelBackend, 'put') as mock_method: - filename = 's3://path/of/your/checkpoint1.pth' + with patch.object(PetrelBackend, "put") as mock_method: + filename = "s3://path/of/your/checkpoint1.pth" save_checkpoint(model.state_dict(), filename) mock_method.assert_called() - with patch.object(PetrelBackend, 'put') as mock_method: - filename = 's3://path//of/your/checkpoint2.pth' - save_checkpoint( - model.state_dict(), - filename, - file_client_args={'backend': 'petrel'}) + with patch.object(PetrelBackend, "put") as mock_method: + filename = "s3://path//of/your/checkpoint2.pth" + save_checkpoint(model.state_dict(), filename, file_client_args={"backend": "petrel"}) mock_method.assert_called() def test_load_from_local(): import os - home_path = os.path.expanduser('~') - checkpoint_path = os.path.join( - home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth') + + home_path = os.path.expanduser("~") + checkpoint_path = os.path.join(home_path, "dummy_checkpoint_used_to_test_load_from_local.pth") model = Model() save_checkpoint(model.state_dict(), checkpoint_path) - checkpoint = load_from_local( - '~/dummy_checkpoint_used_to_test_load_from_local.pth', - map_location=None) - assert_tensor_equal(checkpoint['block.conv.weight'], - model.block.conv.weight) + checkpoint = load_from_local("~/dummy_checkpoint_used_to_test_load_from_local.pth", map_location=None) + assert_tensor_equal(checkpoint["block.conv.weight"], model.block.conv.weight) os.remove(checkpoint_path) @@ -412,24 +401,24 @@ def test_load_state_dict_post_hooks(): module = Block() state_dict = { - 'conv.weight': torch.empty((3, 3, 1, 1), dtype=torch.float32), - 'conv.bias': torch.empty((3, ), dtype=torch.float32), - 'norm.weight': torch.empty([3], dtype=torch.float32), - 'norm.bias': torch.empty([3], dtype=torch.float32), - 'norm.running_mean': torch.empty([3], dtype=torch.float32), - 'norm.running_var': torch.empty([3], dtype=torch.float32), + "conv.weight": torch.empty((3, 3, 1, 1), dtype=torch.float32), + "conv.bias": torch.empty((3,), dtype=torch.float32), + "norm.weight": torch.empty([3], dtype=torch.float32), + "norm.bias": torch.empty([3], dtype=torch.float32), + "norm.running_mean": torch.empty([3], dtype=torch.float32), + "norm.running_var": torch.empty([3], dtype=torch.float32), } - state_dict.pop('norm.running_var') + state_dict.pop("norm.running_var") - with patch('mmengine.runner.checkpoint.print_log') as mock: + with patch("mmengine.runner.checkpoint.print_log") as mock: load_state_dict(module, state_dict, strict=False) mock.assert_called_once() def post_hook(_, incompatible_keys): - incompatible_keys.missing_keys.remove('norm.running_var') + incompatible_keys.missing_keys.remove("norm.running_var") module._load_state_dict_post_hooks = {0: post_hook} - with patch('mmengine.runner.checkpoint.print_log') as mock: + with patch("mmengine.runner.checkpoint.print_log") as mock: load_state_dict(module, state_dict, strict=False) mock.assert_not_called() diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index d7fae5722a..fb88c0930c 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -14,10 +14,8 @@ class TestLogProcessor(RunnerTestCase): - def test_init(self): - log_processor = LogProcessor( - window_size=10, by_epoch=True, custom_cfg=None) + log_processor = LogProcessor(window_size=10, by_epoch=True, custom_cfg=None) assert log_processor.by_epoch assert log_processor.window_size == 10 assert log_processor.custom_cfg == [] @@ -25,50 +23,44 @@ def test_init(self): def test_check_custom_cfg(self): # ``by_epoch==False`` and `window_size='epoch'` in log config will # raise AssertionError. - custom_cfg = [dict(data_src='loss', window_size='epoch')] + custom_cfg = [dict(data_src="loss", window_size="epoch")] with pytest.raises(AssertionError): LogProcessor(by_epoch=False, custom_cfg=custom_cfg) # Duplicate log_name will raise AssertionError. - custom_cfg = [ - dict(data_src='loss', log_name='loss_1'), - dict(data_src='loss', log_name='loss_1') - ] + custom_cfg = [dict(data_src="loss", log_name="loss_1"), dict(data_src="loss", log_name="loss_1")] with pytest.raises(AssertionError): LogProcessor(custom_cfg=custom_cfg) # Overwrite loss item twice will raise AssertionError. - custom_cfg = [dict(data_src='loss'), dict(data_src='loss')] + custom_cfg = [dict(data_src="loss"), dict(data_src="loss")] with pytest.raises(AssertionError): LogProcessor(custom_cfg=custom_cfg) custom_cfg = [ - dict(data_src='loss_cls', window_size=100, method_name='min'), - dict(data_src='loss', log_name='loss_min', method_name='max'), - dict(data_src='loss', log_name='loss_max', method_name='max') + dict(data_src="loss_cls", window_size=100, method_name="min"), + dict(data_src="loss", log_name="loss_min", method_name="max"), + dict(data_src="loss", log_name="loss_max", method_name="max"), ] LogProcessor(custom_cfg=custom_cfg) def test_parse_windows_size(self): log_processor = LogProcessor() # Test parse 'epoch' window_size. - custom_cfg = [dict(data_src='loss_cls', window_size='epoch')] - custom_cfg = log_processor._parse_windows_size(self.runner, 1, - custom_cfg) - assert custom_cfg[0]['window_size'] == 2 + custom_cfg = [dict(data_src="loss_cls", window_size="epoch")] + custom_cfg = log_processor._parse_windows_size(self.runner, 1, custom_cfg) + assert custom_cfg[0]["window_size"] == 2 # Test parse 'global' window_size. - custom_cfg = [dict(data_src='loss_cls', window_size='global')] - custom_cfg = log_processor._parse_windows_size(self.runner, 1, - custom_cfg) - assert custom_cfg[0]['window_size'] == 11 + custom_cfg = [dict(data_src="loss_cls", window_size="global")] + custom_cfg = log_processor._parse_windows_size(self.runner, 1, custom_cfg) + assert custom_cfg[0]["window_size"] == 11 # Test parse int window_size - custom_cfg = [dict(data_src='loss_cls', window_size=100)] - custom_cfg = log_processor._parse_windows_size(self.runner, 1, - custom_cfg) - assert custom_cfg[0]['window_size'] == 100 + custom_cfg = [dict(data_src="loss_cls", window_size=100)] + custom_cfg = log_processor._parse_windows_size(self.runner, 1, custom_cfg) + assert custom_cfg[0]["window_size"] == 100 # Invalid type window_size will raise TypeError. - custom_cfg = [dict(data_src='loss_cls', window_size=[])] + custom_cfg = [dict(data_src="loss_cls", window_size=[])] with pytest.raises(TypeError): log_processor._parse_windows_size(self.runner, 1, custom_cfg) diff --git a/tests/test_runner/test_priority.py b/tests/test_runner/test_priority.py index 2065897887..821edd538d 100644 --- a/tests/test_runner/test_priority.py +++ b/tests/test_runner/test_priority.py @@ -9,9 +9,9 @@ def test_get_priority(): # `priority` is an integer assert get_priority(10) == 10 # `priority` is an integer but it exceeds the valid ranges - with pytest.raises(ValueError, match='priority must be between 0 and 100'): + with pytest.raises(ValueError, match="priority must be between 0 and 100"): get_priority(-1) - with pytest.raises(ValueError, match='priority must be between 0 and 100'): + with pytest.raises(ValueError, match="priority must be between 0 and 100"): get_priority(101) # `priority` is a Priority enum value @@ -19,11 +19,9 @@ def test_get_priority(): assert get_priority(Priority.LOWEST) == 100 # `priority` is a string - assert get_priority('HIGHEST') == 0 - assert get_priority('LOWEST') == 100 + assert get_priority("HIGHEST") == 0 + assert get_priority("LOWEST") == 100 # `priority` is an invalid type - with pytest.raises( - TypeError, - match='priority must be an integer or Priority enum value'): + with pytest.raises(TypeError, match="priority must be an integer or Priority enum value"): get_priority([10]) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e7668054bb..979c85af2f 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -19,20 +19,35 @@ from mmengine.config import Config from mmengine.dataset import DefaultSampler, pseudo_collate from mmengine.evaluator import BaseMetric, Evaluator -from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, - IterTimerHook, LoggerHook, ParamSchedulerHook, - RuntimeInfoHook) +from mmengine.hooks import ( + CheckpointHook, + DistSamplerSeedHook, + Hook, + IterTimerHook, + LoggerHook, + ParamSchedulerHook, + RuntimeInfoHook, +) from mmengine.logging import HistoryBuffer, MessageHub, MMLogger from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, - OptimWrapper, OptimWrapperDict, StepLR) -from mmengine.registry import (DATASETS, EVALUATOR, FUNCTIONS, HOOKS, - LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, - MODELS, OPTIM_WRAPPER_CONSTRUCTORS, - OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS, - Registry) -from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, - LogProcessor, Runner, TestLoop, ValLoop) +from mmengine.optim import DefaultOptimWrapperConstructor, MultiStepLR, OptimWrapper, OptimWrapperDict, StepLR +from mmengine.registry import ( + DATASETS, + EVALUATOR, + FUNCTIONS, + HOOKS, + LOG_PROCESSORS, + LOOPS, + METRICS, + MODEL_WRAPPERS, + MODELS, + OPTIM_WRAPPER_CONSTRUCTORS, + OPTIM_WRAPPERS, + PARAM_SCHEDULERS, + RUNNERS, + Registry, +) +from mmengine.runner import BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, LogProcessor, Runner, TestLoop, ValLoop from mmengine.runner.loops import _InfiniteDataloaderIterator from mmengine.runner.priority import Priority, get_priority from mmengine.utils import digit_version, is_list_of @@ -41,7 +56,7 @@ def skip_test_comile(): - if digit_version(torch.__version__) < digit_version('2.0.0'): + if digit_version(torch.__version__) < digit_version("2.0.0"): return True # The default compiling backend for PyTorch 2.0, inductor, does not support # Nvidia graphics cards older than Volta architecture. @@ -63,13 +78,12 @@ def skip_test_comile(): class ToyModel(BaseModel): - def __init__(self, data_preprocessor=None): super().__init__(data_preprocessor=data_preprocessor) self.linear1 = nn.Linear(2, 2) self.linear2 = nn.Linear(2, 1) - def forward(self, inputs, data_sample, mode='tensor'): + def forward(self, inputs, data_sample, mode="tensor"): if isinstance(inputs, list): inputs = torch.stack(inputs) if isinstance(data_sample, list): @@ -77,103 +91,95 @@ def forward(self, inputs, data_sample, mode='tensor'): outputs = self.linear1(inputs) outputs = self.linear2(outputs) - if mode == 'tensor': + if mode == "tensor": return outputs - elif mode == 'loss': + elif mode == "loss": loss = (data_sample - outputs).sum() outputs = dict(loss=loss) return outputs - elif mode == 'predict': + elif mode == "predict": return outputs class ToyModel1(ToyModel): - def __init__(self): super().__init__() class ToySyncBNModel(BaseModel): - def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 8, 2) self.bn = nn.SyncBatchNorm(8) - def forward(self, inputs, data_sample, mode='tensor'): + def forward(self, inputs, data_sample, mode="tensor"): data_sample = torch.stack(data_sample) inputs = torch.stack(inputs) outputs = self.conv(inputs) outputs = self.bn(outputs) - if mode == 'tensor': + if mode == "tensor": return outputs - elif mode == 'loss': + elif mode == "loss": loss = (data_sample - outputs).sum() outputs = dict(loss=loss) return outputs - elif mode == 'predict': + elif mode == "predict": outputs = dict(log_vars=dict(a=1, b=0.5)) return outputs class ToyGANModel(BaseModel): - def __init__(self): super().__init__() self.linear1 = nn.Linear(2, 1) self.linear2 = nn.Linear(2, 1) - def forward(self, inputs, data_sample, mode='tensor'): + def forward(self, inputs, data_sample, mode="tensor"): data_sample = torch.stack(data_sample) inputs = torch.stack(inputs) output1 = self.linear1(inputs) output2 = self.linear2(inputs) - if mode == 'tensor': + if mode == "tensor": return output1, output2 - elif mode == 'loss': + elif mode == "loss": loss1 = (data_sample - output1).sum() loss2 = (data_sample - output2).sum() outputs = dict(linear1=loss1, linear2=loss2) return outputs - elif mode == 'predict': + elif mode == "predict": return output1, output2 def train_step(self, data, optim_wrapper): data = self.data_preprocessor(data) - loss = self(**data, mode='loss') - optim_wrapper['linear1'].update_params(loss['linear1']) - optim_wrapper['linear2'].update_params(loss['linear2']) + loss = self(**data, mode="loss") + optim_wrapper["linear1"].update_params(loss["linear1"]) + optim_wrapper["linear2"].update_params(loss["linear2"]) return loss class CustomModelWrapper(nn.Module): - def __init__(self, module): super().__init__() self.model = module class ToyMultipleOptimizerConstructor: - def __init__(self, optim_wrapper_cfg, paramwise_cfg=None): if not isinstance(optim_wrapper_cfg, dict): - raise TypeError('optimizer_cfg should be a dict', - f'but got {type(optim_wrapper_cfg)}') - assert paramwise_cfg is None, ( - 'parawise_cfg should be set in each optimizer separately') + raise TypeError("optimizer_cfg should be a dict", f"but got {type(optim_wrapper_cfg)}") + assert paramwise_cfg is None, "parawise_cfg should be set in each optimizer separately" self.optim_wrapper_cfg = optim_wrapper_cfg self.constructors = {} for key, cfg in self.optim_wrapper_cfg.items(): _cfg = cfg.copy() - paramwise_cfg_ = _cfg.pop('paramwise_cfg', None) - self.constructors[key] = DefaultOptimWrapperConstructor( - _cfg, paramwise_cfg_) + paramwise_cfg_ = _cfg.pop("paramwise_cfg", None) + self.constructors[key] = DefaultOptimWrapperConstructor(_cfg, paramwise_cfg_) def __call__(self, model: nn.Module) -> OptimWrapperDict: optimizers = {} - while hasattr(model, 'module'): + while hasattr(model, "module"): model = model.module for key, constructor in self.constructors.items(): @@ -210,13 +216,12 @@ def __getitem__(self, index): class ToyMetric1(BaseMetric): - - def __init__(self, collect_device='cpu', dummy_metrics=None): + def __init__(self, collect_device="cpu", dummy_metrics=None): super().__init__(collect_device=collect_device) self.dummy_metrics = dummy_metrics def process(self, data_batch, predictions): - result = {'acc': 1} + result = {"acc": 1} self.results.append(result) def compute_metrics(self, results): @@ -224,39 +229,36 @@ def compute_metrics(self, results): class ToyMetric2(BaseMetric): - - def __init__(self, collect_device='cpu', dummy_metrics=None): + def __init__(self, collect_device="cpu", dummy_metrics=None): super().__init__(collect_device=collect_device) self.dummy_metrics = dummy_metrics def process(self, data_batch, predictions): - result = {'acc': 1} + result = {"acc": 1} self.results.append(result) def compute_metrics(self, results): return dict(acc=1) -class ToyOptimWrapper(OptimWrapper): - ... +class ToyOptimWrapper(OptimWrapper): ... class ToyHook(Hook): - priority = 'Lowest' + priority = "Lowest" def before_train_epoch(self, runner): pass class ToyHook2(Hook): - priority = 'Lowest' + priority = "Lowest" def after_train_epoch(self, runner): pass class CustomTrainLoop(BaseLoop): - def __init__(self, runner, dataloader, max_epochs): super().__init__(runner, dataloader) self._max_epochs = max_epochs @@ -266,7 +268,6 @@ def run(self) -> None: class CustomValLoop(BaseLoop): - def __init__(self, runner, dataloader, evaluator): super().__init__(runner, dataloader) self._runner = runner @@ -281,7 +282,6 @@ def run(self) -> None: class CustomTestLoop(BaseLoop): - def __init__(self, runner, dataloader, evaluator): super().__init__(runner, dataloader) self._runner = runner @@ -296,7 +296,6 @@ def run(self) -> None: class CustomLogProcessor(LogProcessor): - def __init__(self, window_size=10, by_epoch=True, custom_cfg=None): self.window_size = window_size self.by_epoch = by_epoch @@ -305,35 +304,36 @@ def __init__(self, window_size=10, by_epoch=True, custom_cfg=None): class CustomRunner(Runner): - - def __init__(self, - model, - work_dir, - train_dataloader=None, - val_dataloader=None, - test_dataloader=None, - train_cfg=None, - val_cfg=None, - test_cfg=None, - auto_scale_lr=None, - optim_wrapper=None, - param_scheduler=None, - val_evaluator=None, - test_evaluator=None, - default_hooks=None, - custom_hooks=None, - data_preprocessor=None, - load_from=None, - resume=False, - launcher='none', - env_cfg=dict(dist_cfg=dict(backend='nccl')), - log_processor=None, - log_level='INFO', - visualizer=None, - default_scope=None, - randomness=dict(seed=None), - experiment_name=None, - cfg=None): + def __init__( + self, + model, + work_dir, + train_dataloader=None, + val_dataloader=None, + test_dataloader=None, + train_cfg=None, + val_cfg=None, + test_cfg=None, + auto_scale_lr=None, + optim_wrapper=None, + param_scheduler=None, + val_evaluator=None, + test_evaluator=None, + default_hooks=None, + custom_hooks=None, + data_preprocessor=None, + load_from=None, + resume=False, + launcher="none", + env_cfg=None, + log_processor=None, + log_level="INFO", + visualizer=None, + default_scope=None, + randomness=None, + experiment_name=None, + cfg=None, + ): pass def setup_env(self, env_cfg): @@ -341,7 +341,6 @@ def setup_env(self, env_cfg): class ToyEvaluator(Evaluator): - def __init__(self, metrics): super().__init__(metrics) @@ -360,15 +359,13 @@ def custom_worker_init(worker_id): class TestRunner(TestCase): - def setUp(self): MODELS.register_module(module=ToyModel, force=True) MODELS.register_module(module=ToyModel1, force=True) MODELS.register_module(module=ToySyncBNModel, force=True) MODELS.register_module(module=ToyGANModel, force=True) MODEL_WRAPPERS.register_module(module=CustomModelWrapper, force=True) - OPTIM_WRAPPER_CONSTRUCTORS.register_module( - module=ToyMultipleOptimizerConstructor, force=True) + OPTIM_WRAPPER_CONSTRUCTORS.register_module(module=ToyMultipleOptimizerConstructor, force=True) DATASETS.register_module(module=ToyDataset, force=True) DATASETS.register_module(module=ToyDatasetNoMeta, force=True) METRICS.register_module(module=ToyMetric1, force=True) @@ -387,87 +384,89 @@ def setUp(self): self.temp_dir = tempfile.mkdtemp() epoch_based_cfg = dict( - model=dict(type='ToyModel'), + model=dict(type="ToyModel"), work_dir=self.temp_dir, train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), batch_size=3, - num_workers=0), + num_workers=0, + ), val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=False), batch_size=3, - num_workers=0), + num_workers=0, + ), test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=False), batch_size=3, - num_workers=0), + num_workers=0, + ), auto_scale_lr=dict(base_batch_size=16, enable=False), - optim_wrapper=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - val_evaluator=dict(type='ToyMetric1'), - test_evaluator=dict(type='ToyMetric1'), - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), + optim_wrapper=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + param_scheduler=dict(type="MultiStepLR", milestones=[1, 2]), + val_evaluator=dict(type="ToyMetric1"), + test_evaluator=dict(type="ToyMetric1"), + train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), val_cfg=dict(), test_cfg=dict(), custom_hooks=[], default_hooks=dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict( - type='CheckpointHook', interval=1, by_epoch=True), - sampler_seed=dict(type='DistSamplerSeedHook')), + runtime_info=dict(type="RuntimeInfoHook"), + timer=dict(type="IterTimerHook"), + logger=dict(type="LoggerHook"), + param_scheduler=dict(type="ParamSchedulerHook"), + checkpoint=dict(type="CheckpointHook", interval=1, by_epoch=True), + sampler_seed=dict(type="DistSamplerSeedHook"), + ), data_preprocessor=None, - launcher='none', - env_cfg=dict(dist_cfg=dict(backend='nccl')), + launcher="none", + env_cfg=dict(dist_cfg=dict(backend="nccl")), ) self.epoch_based_cfg = Config(epoch_based_cfg) self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg) self.iter_based_cfg.train_dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict(type="ToyDataset"), + sampler=dict(type="InfiniteSampler", shuffle=True), batch_size=3, - num_workers=0) + num_workers=0, + ) self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) self.iter_based_cfg.default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False), - sampler_seed=dict(type='DistSamplerSeedHook')) + runtime_info=dict(type="RuntimeInfoHook"), + timer=dict(type="IterTimerHook"), + logger=dict(type="LoggerHook"), + param_scheduler=dict(type="ParamSchedulerHook"), + checkpoint=dict(type="CheckpointHook", interval=1, by_epoch=False), + sampler_seed=dict(type="DistSamplerSeedHook"), + ) def tearDown(self): # `FileHandler` should be closed in Windows, otherwise we cannot # delete the temporary directory - MODELS.module_dict.pop('ToyModel') - MODELS.module_dict.pop('ToyModel1') - MODELS.module_dict.pop('ToySyncBNModel') - MODELS.module_dict.pop('ToyGANModel') - MODEL_WRAPPERS.module_dict.pop('CustomModelWrapper') - OPTIM_WRAPPER_CONSTRUCTORS.module_dict.pop( - 'ToyMultipleOptimizerConstructor') - OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') - DATASETS.module_dict.pop('ToyDataset') - DATASETS.module_dict.pop('ToyDatasetNoMeta') - METRICS.module_dict.pop('ToyMetric1') - METRICS.module_dict.pop('ToyMetric2') - HOOKS.module_dict.pop('ToyHook') - HOOKS.module_dict.pop('ToyHook2') - LOOPS.module_dict.pop('CustomTrainLoop') - LOOPS.module_dict.pop('CustomValLoop') - LOOPS.module_dict.pop('CustomTestLoop') - LOG_PROCESSORS.module_dict.pop('CustomLogProcessor') - RUNNERS.module_dict.pop('CustomRunner') - EVALUATOR.module_dict.pop('ToyEvaluator') - FUNCTIONS.module_dict.pop('custom_collate') - FUNCTIONS.module_dict.pop('custom_worker_init') + MODELS.module_dict.pop("ToyModel") + MODELS.module_dict.pop("ToyModel1") + MODELS.module_dict.pop("ToySyncBNModel") + MODELS.module_dict.pop("ToyGANModel") + MODEL_WRAPPERS.module_dict.pop("CustomModelWrapper") + OPTIM_WRAPPER_CONSTRUCTORS.module_dict.pop("ToyMultipleOptimizerConstructor") + OPTIM_WRAPPERS.module_dict.pop("ToyOptimWrapper") + DATASETS.module_dict.pop("ToyDataset") + DATASETS.module_dict.pop("ToyDatasetNoMeta") + METRICS.module_dict.pop("ToyMetric1") + METRICS.module_dict.pop("ToyMetric2") + HOOKS.module_dict.pop("ToyHook") + HOOKS.module_dict.pop("ToyHook2") + LOOPS.module_dict.pop("CustomTrainLoop") + LOOPS.module_dict.pop("CustomValLoop") + LOOPS.module_dict.pop("CustomTestLoop") + LOG_PROCESSORS.module_dict.pop("CustomLogProcessor") + RUNNERS.module_dict.pop("CustomRunner") + EVALUATOR.module_dict.pop("ToyEvaluator") + FUNCTIONS.module_dict.pop("custom_collate") + FUNCTIONS.module_dict.pop("custom_worker_init") logging.shutdown() MMLogger._instance_dict.clear() @@ -477,78 +476,78 @@ def test_init(self): # 1. test arguments # 1.1 train_dataloader, train_cfg, optimizer and param_scheduler cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init1' - cfg.pop('train_cfg') - with self.assertRaisesRegex(ValueError, 'either all None or not None'): + cfg.experiment_name = "test_init1" + cfg.pop("train_cfg") + with self.assertRaisesRegex(ValueError, "either all None or not None"): Runner(**cfg) # all of training related configs are None and param_scheduler should # also be None - cfg.experiment_name = 'test_init2' - cfg.pop('train_dataloader') - cfg.pop('optim_wrapper') - cfg.pop('param_scheduler') + cfg.experiment_name = "test_init2" + cfg.pop("train_dataloader") + cfg.pop("optim_wrapper") + cfg.pop("param_scheduler") runner = Runner(**cfg) self.assertIsInstance(runner, Runner) # all of training related configs are not None cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init3' + cfg.experiment_name = "test_init3" runner = Runner(**cfg) self.assertIsInstance(runner, Runner) # all of training related configs are not None and param_scheduler # can be None cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init4' - cfg.pop('param_scheduler') + cfg.experiment_name = "test_init4" + cfg.pop("param_scheduler") runner = Runner(**cfg) self.assertIsInstance(runner, Runner) self.assertEqual(runner.param_schedulers, None) # param_scheduler should be None when optimizer is None cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init5' - cfg.pop('train_cfg') - cfg.pop('train_dataloader') - cfg.pop('optim_wrapper') - with self.assertRaisesRegex(ValueError, 'should be None'): + cfg.experiment_name = "test_init5" + cfg.pop("train_cfg") + cfg.pop("train_dataloader") + cfg.pop("optim_wrapper") + with self.assertRaisesRegex(ValueError, "should be None"): runner = Runner(**cfg) # 1.2 val_dataloader, val_evaluator, val_cfg cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init6' - cfg.pop('val_cfg') - with self.assertRaisesRegex(ValueError, 'either all None or not None'): + cfg.experiment_name = "test_init6" + cfg.pop("val_cfg") + with self.assertRaisesRegex(ValueError, "either all None or not None"): Runner(**cfg) - cfg.experiment_name = 'test_init7' - cfg.pop('val_dataloader') - cfg.pop('val_evaluator') + cfg.experiment_name = "test_init7" + cfg.pop("val_dataloader") + cfg.pop("val_evaluator") runner = Runner(**cfg) self.assertIsInstance(runner, Runner) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init8' + cfg.experiment_name = "test_init8" runner = Runner(**cfg) self.assertIsInstance(runner, Runner) # 1.3 test_dataloader, test_evaluator and test_cfg cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init9' - cfg.pop('test_cfg') - with self.assertRaisesRegex(ValueError, 'either all None or not None'): + cfg.experiment_name = "test_init9" + cfg.pop("test_cfg") + with self.assertRaisesRegex(ValueError, "either all None or not None"): runner = Runner(**cfg) - cfg.experiment_name = 'test_init10' - cfg.pop('test_dataloader') - cfg.pop('test_evaluator') + cfg.experiment_name = "test_init10" + cfg.pop("test_dataloader") + cfg.pop("test_evaluator") runner = Runner(**cfg) self.assertIsInstance(runner, Runner) # 1.4 test env params cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init11' + cfg.experiment_name = "test_init11" runner = Runner(**cfg) self.assertFalse(runner.distributed) self.assertFalse(runner.deterministic) @@ -556,7 +555,7 @@ def test_init(self): # 1.5 message_hub, logger and visualizer # they are all not specified cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init12' + cfg.experiment_name = "test_init12" runner = Runner(**cfg) self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.message_hub, MessageHub) @@ -564,8 +563,8 @@ def test_init(self): # they are all specified cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init13' - cfg.log_level = 'INFO' + cfg.experiment_name = "test_init13" + cfg.log_level = "INFO" cfg.visualizer = None runner = Runner(**cfg) self.assertIsInstance(runner.logger, MMLogger) @@ -577,9 +576,8 @@ def test_init(self): assert runner.work_dir == self.temp_dir # 2 model should be initialized - self.assertIsInstance(runner.model, - (nn.Module, DistributedDataParallel)) - self.assertEqual(runner.model_name, 'ToyModel') + self.assertIsInstance(runner.model, (nn.Module, DistributedDataParallel)) + self.assertEqual(runner.model_name, "ToyModel") # 3. test lazy initialization self.assertIsInstance(runner._train_dataloader, dict) @@ -610,10 +608,12 @@ def test_init(self): # 4. initialize runner with objects rather than config model = ToyModel() - optim_wrapper = OptimWrapper(SGD( - model.parameters(), - lr=0.01, - )) + optim_wrapper = OptimWrapper( + SGD( + model.parameters(), + lr=0.01, + ) + ) toy_hook = ToyHook() toy_hook2 = ToyHook2() @@ -623,8 +623,7 @@ def test_init(self): runner = Runner( model=model, work_dir=self.temp_dir, - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), + train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), @@ -636,7 +635,8 @@ def test_init(self): test_evaluator=[ToyMetric1()], default_hooks=dict(param_scheduler=toy_hook), custom_hooks=[toy_hook2], - experiment_name='test_init14') + experiment_name="test_init14", + ) runner.train() runner.test() @@ -644,117 +644,106 @@ def test_init(self): # available, and this test will be skipped. if torch.cuda.is_available() and torch.distributed.is_nccl_available(): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init15' - cfg.launcher = 'pytorch' - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29600' - os.environ['RANK'] = '0' - os.environ['WORLD_SIZE'] = '1' - os.environ['LOCAL_RANK'] = '0' + cfg.experiment_name = "test_init15" + cfg.launcher = "pytorch" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29600" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_RANK"] = "0" Runner(**cfg) - cfg.experiment_name = 'test_init16' + cfg.experiment_name = "test_init16" Runner(**cfg) # 6.1 Test initializing with empty scheduler. cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_init17' + cfg.experiment_name = "test_init17" cfg.param_scheduler = None runner = Runner(**cfg) self.assertIsNone(runner.param_schedulers) # 6.2 Test initializing single scheduler. - cfg.experiment_name = 'test_init18' - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) + cfg.experiment_name = "test_init18" + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2]) Runner(**cfg) # 6.3 Test initializing list of scheduler. - cfg.param_scheduler = [ - dict(type='MultiStepLR', milestones=[1, 2]), - dict(type='MultiStepLR', milestones=[2, 3]) - ] - cfg.experiment_name = 'test_init19' + cfg.param_scheduler = [dict(type="MultiStepLR", milestones=[1, 2]), dict(type="MultiStepLR", milestones=[2, 3])] + cfg.experiment_name = "test_init19" Runner(**cfg) # 6.4 Test initializing 2 schedulers for 2 optimizers. cfg.param_scheduler = dict( - linear1=dict(type='MultiStepLR', milestones=[1, 2]), - linear2=dict(type='MultiStepLR', milestones=[1, 2]), + linear1=dict(type="MultiStepLR", milestones=[1, 2]), + linear2=dict(type="MultiStepLR", milestones=[1, 2]), ) - cfg.experiment_name = 'test_init20' + cfg.experiment_name = "test_init20" Runner(**cfg) # 6.5 Test initializing 2 schedulers for 2 optimizers. cfg.param_scheduler = dict( - linear1=[dict(type='MultiStepLR', milestones=[1, 2])], - linear2=[dict(type='MultiStepLR', milestones=[1, 2])], + linear1=[dict(type="MultiStepLR", milestones=[1, 2])], + linear2=[dict(type="MultiStepLR", milestones=[1, 2])], ) - cfg.experiment_name = 'test_init21' + cfg.experiment_name = "test_init21" Runner(**cfg) # 6.6 Test initializing with `_ParameterScheduler`. optimizer = SGD(nn.Linear(1, 1).parameters(), lr=0.1) - cfg.param_scheduler = MultiStepLR( - milestones=[1, 2], optimizer=optimizer) - cfg.experiment_name = 'test_init22' + cfg.param_scheduler = MultiStepLR(milestones=[1, 2], optimizer=optimizer) + cfg.experiment_name = "test_init22" Runner(**cfg) # 6.7 Test initializing with list of `_ParameterScheduler`. - cfg.param_scheduler = [ - MultiStepLR(milestones=[1, 2], optimizer=optimizer) - ] - cfg.experiment_name = 'test_init23' + cfg.param_scheduler = [MultiStepLR(milestones=[1, 2], optimizer=optimizer)] + cfg.experiment_name = "test_init23" Runner(**cfg) # 6.8 Test initializing with 2 `_ParameterScheduler` for 2 optimizers. cfg.param_scheduler = dict( linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer), - linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer)) - cfg.experiment_name = 'test_init24' + linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer), + ) + cfg.experiment_name = "test_init24" Runner(**cfg) # 6.9 Test initializing with 2 list of `_ParameterScheduler` for 2 # optimizers. cfg.param_scheduler = dict( linear1=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)], - linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)]) - cfg.experiment_name = 'test_init25' + linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)], + ) + cfg.experiment_name = "test_init25" Runner(**cfg) # 6.10 Test initializing with error type scheduler. - cfg.param_scheduler = dict(linear1='error_type') - cfg.experiment_name = 'test_init26' - with self.assertRaisesRegex(AssertionError, 'Each value of'): + cfg.param_scheduler = dict(linear1="error_type") + cfg.experiment_name = "test_init26" + with self.assertRaisesRegex(AssertionError, "Each value of"): Runner(**cfg) - cfg.param_scheduler = 'error_type' - cfg.experiment_name = 'test_init27' - with self.assertRaisesRegex(TypeError, - '`param_scheduler` should be a'): + cfg.param_scheduler = "error_type" + cfg.experiment_name = "test_init27" + with self.assertRaisesRegex(TypeError, "`param_scheduler` should be a"): Runner(**cfg) def test_dump_config(self): # dump config from dict. cfg = copy.deepcopy(self.epoch_based_cfg) for idx, cfg in enumerate((cfg, cfg._cfg_dict)): - cfg.experiment_name = f'test_dump{idx}' + cfg.experiment_name = f"test_dump{idx}" runner = Runner.from_cfg(cfg=cfg) - assert osp.exists( - osp.join(runner.work_dir, f'{runner.timestamp}.py')) + assert osp.exists(osp.join(runner.work_dir, f"{runner.timestamp}.py")) # dump config from file. with tempfile.TemporaryDirectory() as temp_config_dir: # Set `delete=Flase` and close the file to make it # work in Windows. - temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix='.py', delete=False) + temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py", delete=False) temp_config_file.close() - file_cfg = Config( - self.epoch_based_cfg._cfg_dict, - filename=temp_config_file.name) - file_cfg.experiment_name = f'test_dump2{idx}' + file_cfg = Config(self.epoch_based_cfg._cfg_dict, filename=temp_config_file.name) + file_cfg.experiment_name = f"test_dump2{idx}" runner = Runner.from_cfg(cfg=file_cfg) - assert osp.exists( - osp.join(runner.work_dir, - osp.basename(temp_config_file.name))) + assert osp.exists(osp.join(runner.work_dir, osp.basename(temp_config_file.name))) def test_from_cfg(self): runner = Runner.from_cfg(cfg=self.epoch_based_cfg) @@ -765,115 +754,106 @@ def test_setup_env(self): pass def test_build_logger(self): - self.epoch_based_cfg.experiment_name = 'test_build_logger1' + self.epoch_based_cfg.experiment_name = "test_build_logger1" runner = Runner.from_cfg(self.epoch_based_cfg) self.assertIsInstance(runner.logger, MMLogger) self.assertEqual(runner.experiment_name, runner.logger.instance_name) # input is a dict - logger = runner.build_logger(name='test_build_logger2') + logger = runner.build_logger(name="test_build_logger2") self.assertIsInstance(logger, MMLogger) - self.assertEqual(logger.instance_name, 'test_build_logger2') + self.assertEqual(logger.instance_name, "test_build_logger2") # input is a dict but does not contain name key - runner._experiment_name = 'test_build_logger3' + runner._experiment_name = "test_build_logger3" logger = runner.build_logger() self.assertIsInstance(logger, MMLogger) - self.assertEqual(logger.instance_name, 'test_build_logger3') + self.assertEqual(logger.instance_name, "test_build_logger3") def test_build_message_hub(self): - self.epoch_based_cfg.experiment_name = 'test_build_message_hub1' + self.epoch_based_cfg.experiment_name = "test_build_message_hub1" runner = Runner.from_cfg(self.epoch_based_cfg) self.assertIsInstance(runner.message_hub, MessageHub) - self.assertEqual(runner.message_hub.instance_name, - runner.experiment_name) + self.assertEqual(runner.message_hub.instance_name, runner.experiment_name) # input is a dict - message_hub_cfg = dict(name='test_build_message_hub2') + message_hub_cfg = dict(name="test_build_message_hub2") message_hub = runner.build_message_hub(message_hub_cfg) self.assertIsInstance(message_hub, MessageHub) - self.assertEqual(message_hub.instance_name, 'test_build_message_hub2') + self.assertEqual(message_hub.instance_name, "test_build_message_hub2") # input is a dict but does not contain name key - runner._experiment_name = 'test_build_message_hub3' + runner._experiment_name = "test_build_message_hub3" message_hub_cfg = dict() message_hub = runner.build_message_hub(message_hub_cfg) self.assertIsInstance(message_hub, MessageHub) - self.assertEqual(message_hub.instance_name, 'test_build_message_hub3') + self.assertEqual(message_hub.instance_name, "test_build_message_hub3") # input is not a valid type - with self.assertRaisesRegex(TypeError, 'message_hub should be'): - runner.build_message_hub('invalid-type') + with self.assertRaisesRegex(TypeError, "message_hub should be"): + runner.build_message_hub("invalid-type") def test_build_visualizer(self): - self.epoch_based_cfg.experiment_name = 'test_build_visualizer1' + self.epoch_based_cfg.experiment_name = "test_build_visualizer1" runner = Runner.from_cfg(self.epoch_based_cfg) self.assertIsInstance(runner.visualizer, Visualizer) - self.assertEqual(runner.experiment_name, - runner.visualizer.instance_name) + self.assertEqual(runner.experiment_name, runner.visualizer.instance_name) # input is a Visualizer object - self.assertEqual( - id(runner.build_visualizer(runner.visualizer)), - id(runner.visualizer)) + self.assertEqual(id(runner.build_visualizer(runner.visualizer)), id(runner.visualizer)) # input is a dict - visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2') + visualizer_cfg = dict(type="Visualizer", name="test_build_visualizer2") visualizer = runner.build_visualizer(visualizer_cfg) self.assertIsInstance(visualizer, Visualizer) - self.assertEqual(visualizer.instance_name, 'test_build_visualizer2') + self.assertEqual(visualizer.instance_name, "test_build_visualizer2") # input is a dict but does not contain name key - runner._experiment_name = 'test_build_visualizer3' + runner._experiment_name = "test_build_visualizer3" visualizer_cfg = None visualizer = runner.build_visualizer(visualizer_cfg) self.assertIsInstance(visualizer, Visualizer) - self.assertEqual(visualizer.instance_name, 'test_build_visualizer3') + self.assertEqual(visualizer.instance_name, "test_build_visualizer3") # input is not a valid type - with self.assertRaisesRegex(TypeError, 'visualizer should be'): - runner.build_visualizer('invalid-type') + with self.assertRaisesRegex(TypeError, "visualizer should be"): + runner.build_visualizer("invalid-type") def test_default_scope(self): - TOY_SCHEDULERS = Registry( - 'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy') + TOY_SCHEDULERS = Registry("parameter scheduler", parent=PARAM_SCHEDULERS, scope="toy") @TOY_SCHEDULERS.register_module(force=True) class ToyScheduler(MultiStepLR): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.epoch_based_cfg.param_scheduler = dict( - type='ToyScheduler', milestones=[1, 2]) - self.epoch_based_cfg.default_scope = 'toy' + self.epoch_based_cfg.param_scheduler = dict(type="ToyScheduler", milestones=[1, 2]) + self.epoch_based_cfg.default_scope = "toy" cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_default_scope' + cfg.experiment_name = "test_default_scope" runner = Runner.from_cfg(cfg) runner.train() self.assertIsInstance(runner.param_schedulers[0], ToyScheduler) def test_build_model(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_model1' + cfg.experiment_name = "test_build_model1" runner = Runner.from_cfg(cfg) self.assertIsInstance(runner.model, ToyModel) - self.assertIsInstance(runner.model.data_preprocessor, - BaseDataPreprocessor) + self.assertIsInstance(runner.model.data_preprocessor, BaseDataPreprocessor) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_data_preprocessor' - cfg.data_preprocessor = dict(type='ImgDataPreprocessor') + cfg.experiment_name = "test_data_preprocessor" + cfg.data_preprocessor = dict(type="ImgDataPreprocessor") runner = Runner.from_cfg(cfg) # data_preprocessor is passed to used if no `data_preprocessor` # in model config. - self.assertIsInstance(runner.model.data_preprocessor, - ImgDataPreprocessor) + self.assertIsInstance(runner.model.data_preprocessor, ImgDataPreprocessor) # input should be a nn.Module object or dict - with self.assertRaisesRegex(TypeError, 'model should be'): - runner.build_model('invalid-type') + with self.assertRaisesRegex(TypeError, "model should be"): + runner.build_model("invalid-type") # input is a nn.Module object _model = ToyModel1() @@ -881,47 +861,46 @@ def test_build_model(self): self.assertEqual(id(model), id(_model)) # input is a dict - model = runner.build_model(dict(type='ToyModel1')) + model = runner.build_model(dict(type="ToyModel1")) self.assertIsInstance(model, ToyModel1) def test_wrap_model(self): # revert sync batchnorm cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_revert_syncbn' - cfg.model = dict(type='ToySyncBNModel') + cfg.experiment_name = "test_revert_syncbn" + cfg.model = dict(type="ToySyncBNModel") runner = Runner.from_cfg(cfg) self.assertIsInstance(runner.model, BaseModel) assert not isinstance(runner.model.bn, nn.SyncBatchNorm) # custom model wrapper cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_wrap_model' - cfg.model_wrapper_cfg = dict(type='CustomModelWrapper') + cfg.experiment_name = "test_wrap_model" + cfg.model_wrapper_cfg = dict(type="CustomModelWrapper") runner = Runner.from_cfg(cfg) self.assertIsInstance(runner.model, BaseModel) # Test with ddp wrapper if torch.cuda.is_available() and torch.distributed.is_nccl_available(): - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29515' - os.environ['RANK'] = str(0) - os.environ['WORLD_SIZE'] = str(1) - cfg.launcher = 'pytorch' - cfg.experiment_name = 'test_wrap_model1' + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29515" + os.environ["RANK"] = str(0) + os.environ["WORLD_SIZE"] = str(1) + cfg.launcher = "pytorch" + cfg.experiment_name = "test_wrap_model1" runner = Runner.from_cfg(cfg) self.assertIsInstance(runner.model, CustomModelWrapper) # Test cfg.sync_bn = 'torch', when model does not have BN layer cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.launcher = 'pytorch' - cfg.experiment_name = 'test_wrap_model2' - cfg.sync_bn = 'torch' - cfg.model_wrapper_cfg = dict(type='CustomModelWrapper') + cfg.launcher = "pytorch" + cfg.experiment_name = "test_wrap_model2" + cfg.sync_bn = "torch" + cfg.model_wrapper_cfg = dict(type="CustomModelWrapper") runner.from_cfg(cfg) @MODELS.register_module(force=True) class ToyBN(BaseModel): - def __init__(self): super().__init__() self.bn = nn.BatchNorm2d(2) @@ -929,20 +908,19 @@ def __init__(self): def forward(self, *args, **kwargs): pass - cfg.model = dict(type='ToyBN') - cfg.experiment_name = 'test_data_preprocessor2' + cfg.model = dict(type="ToyBN") + cfg.experiment_name = "test_data_preprocessor2" runner = Runner.from_cfg(cfg) - self.assertIsInstance(runner.model.model.bn, - torch.nn.SyncBatchNorm) + self.assertIsInstance(runner.model.model.bn, torch.nn.SyncBatchNorm) - cfg.sync_bn = 'unknown' - cfg.experiment_name = 'test_data_preprocessor3' + cfg.sync_bn = "unknown" + cfg.experiment_name = "test_data_preprocessor3" with self.assertRaises(ValueError): Runner.from_cfg(cfg) def test_scale_lr(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_scale_lr' + cfg.experiment_name = "test_scale_lr" runner = Runner.from_cfg(cfg) # When no base_batch_size in auto_scale_lr, an @@ -957,34 +935,33 @@ def test_scale_lr(self): auto_scale_lr = dict(base_batch_size=16, enable=False) optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01)) runner.scale_lr(optim_wrapper) - self.assertEqual(optim_wrapper.optimizer.param_groups[0]['lr'], 0.01) + self.assertEqual(optim_wrapper.optimizer.param_groups[0]["lr"], 0.01) runner.scale_lr(optim_wrapper, auto_scale_lr) - self.assertEqual(optim_wrapper.optimizer.param_groups[0]['lr'], 0.01) + self.assertEqual(optim_wrapper.optimizer.param_groups[0]["lr"], 0.01) # When auto_scale_lr is correct and enable is True, the lr will # be linearly scaled. auto_scale_lr = dict(base_batch_size=16, enable=True) - real_bs = runner.world_size * cfg.train_dataloader['batch_size'] + real_bs = runner.world_size * cfg.train_dataloader["batch_size"] optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01)) runner.scale_lr(optim_wrapper, auto_scale_lr) - self.assertEqual(optim_wrapper.optimizer.param_groups[0]['lr'], - 0.01 * (real_bs / 16)) + self.assertEqual(optim_wrapper.optimizer.param_groups[0]["lr"], 0.01 * (real_bs / 16)) # Test when optim_wrapper is an OptimWrapperDict optim_wrapper = OptimWrapper(SGD(runner.model.parameters(), lr=0.01)) wrapper_dict = OptimWrapperDict(wrapper=optim_wrapper) runner.scale_lr(wrapper_dict, auto_scale_lr) - scaled_lr = wrapper_dict['wrapper'].optimizer.param_groups[0]['lr'] + scaled_lr = wrapper_dict["wrapper"].optimizer.param_groups[0]["lr"] self.assertEqual(scaled_lr, 0.01 * (real_bs / 16)) def test_build_optim_wrapper(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_optim_wrapper' + cfg.experiment_name = "test_build_optim_wrapper" runner = Runner.from_cfg(cfg) # input should be an Optimizer object or dict - with self.assertRaisesRegex(TypeError, 'optimizer wrapper should be'): - runner.build_optim_wrapper('invalid-type') + with self.assertRaisesRegex(TypeError, "optimizer wrapper should be"): + runner.build_optim_wrapper("invalid-type") # 1. test one optimizer # 1.1 input is an Optimizer object @@ -994,13 +971,11 @@ def test_build_optim_wrapper(self): self.assertEqual(id(optimizer), id(optim_wrapper.optimizer)) # 1.2 input is a dict - optim_wrapper = runner.build_optim_wrapper( - dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))) + optim_wrapper = runner.build_optim_wrapper(dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01))) self.assertIsInstance(optim_wrapper, OptimWrapper) # 1.3 use default OptimWrapper type. - optim_wrapper = runner.build_optim_wrapper( - dict(optimizer=dict(type='SGD', lr=0.01))) + optim_wrapper = runner.build_optim_wrapper(dict(optimizer=dict(type="SGD", lr=0.01))) self.assertIsInstance(optim_wrapper, OptimWrapper) # 2. test multiple optmizers @@ -1012,31 +987,28 @@ def test_build_optim_wrapper(self): optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapperDict) - self.assertIsInstance(optim_wrapper['key1'].optimizer, SGD) - self.assertIsInstance(optim_wrapper['key2'].optimizer, Adam) + self.assertIsInstance(optim_wrapper["key1"].optimizer, SGD) + self.assertIsInstance(optim_wrapper["key2"].optimizer, Adam) # 2.2 each item mush be an optimizer object when "type" and # "constructor" are not in optimizer optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01) optim_wrapper1 = OptimWrapper(optimizer1) - optim_wrapper2 = dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01)) + optim_wrapper2 = dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.01)) optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) - with self.assertRaisesRegex(ValueError, - 'each item mush be an optimizer object'): + with self.assertRaisesRegex(ValueError, "each item mush be an optimizer object"): runner.build_optim_wrapper(optim_cfg) # 2.3 input is a dict which contains multiple configs optim_wrapper_cfg = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + linear1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + linear2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.02)), + constructor="ToyMultipleOptimizerConstructor", + ) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapperDict) - self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD) - self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam) + self.assertIsInstance(optim_wrapper["linear1"].optimizer, SGD) + self.assertIsInstance(optim_wrapper["linear2"].optimizer, Adam) # 2.4 input is a dict which contains optimizer instance. model = nn.Linear(1, 1) @@ -1049,8 +1021,7 @@ def test_build_optim_wrapper(self): # Specify the type of optimizer wrapper model = nn.Linear(1, 1) optimizer = SGD(model.parameters(), lr=0.1) - optim_wrapper_cfg = dict( - optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2) + optim_wrapper_cfg = dict(optimizer=optimizer, type="ToyOptimWrapper", accumulative_counts=2) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, ToyOptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) @@ -1058,23 +1029,22 @@ def test_build_optim_wrapper(self): def test_build_param_scheduler(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_param_scheduler' + cfg.experiment_name = "test_build_param_scheduler" runner = Runner.from_cfg(cfg) # `build_optim_wrapper` should be called before # `build_param_scheduler` - cfg = dict(type='MultiStepLR', milestones=[1, 2]) + cfg = dict(type="MultiStepLR", milestones=[1, 2]) runner.optim_wrapper = dict( - key1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - key2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), + key1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + key2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.02)), ) - with self.assertRaisesRegex(AssertionError, 'should be called before'): + with self.assertRaisesRegex(AssertionError, "should be called before"): runner.build_param_scheduler(cfg) runner.optim_wrapper = runner.build_optim_wrapper( - dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))) + dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)) + ) param_schedulers = runner.build_param_scheduler(cfg) self.assertIsInstance(param_schedulers, list) self.assertEqual(len(param_schedulers), 1) @@ -1094,17 +1064,14 @@ def test_build_param_scheduler(self): # 2. test one optimizer and list of parameter schedulers # 2.1 input is a list of dict - cfg = [ - dict(type='MultiStepLR', milestones=[1, 2]), - dict(type='StepLR', step_size=1) - ] + cfg = [dict(type="MultiStepLR", milestones=[1, 2]), dict(type="StepLR", step_size=1)] param_schedulers = runner.build_param_scheduler(cfg) self.assertEqual(len(param_schedulers), 2) self.assertIsInstance(param_schedulers[0], MultiStepLR) self.assertIsInstance(param_schedulers[1], StepLR) # 2.2 input is a list and some items are ParamScheduler objects - cfg = [param_scheduler, dict(type='StepLR', step_size=1)] + cfg = [param_scheduler, dict(type="StepLR", step_size=1)] param_schedulers = runner.build_param_scheduler(cfg) self.assertEqual(len(param_schedulers), 2) self.assertIsInstance(param_schedulers[0], MultiStepLR) @@ -1117,45 +1084,35 @@ def test_build_param_scheduler(self): optim_wrapper2 = OptimWrapper(optimizer2) optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) runner.optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - cfg = [ - dict(type='MultiStepLR', milestones=[1, 2]), - dict(type='StepLR', step_size=1) - ] + cfg = [dict(type="MultiStepLR", milestones=[1, 2]), dict(type="StepLR", step_size=1)] param_schedulers = runner.build_param_scheduler(cfg) print(param_schedulers) self.assertIsInstance(param_schedulers, dict) self.assertEqual(len(param_schedulers), 2) - self.assertEqual(len(param_schedulers['key1']), 2) - self.assertEqual(len(param_schedulers['key2']), 2) + self.assertEqual(len(param_schedulers["key1"]), 2) + self.assertEqual(len(param_schedulers["key2"]), 2) # 4. test multiple optimizers and multiple parameter shceduers cfg = dict( - key1=dict(type='MultiStepLR', milestones=[1, 2]), - key2=[ - dict(type='MultiStepLR', milestones=[1, 2]), - dict(type='StepLR', step_size=1) - ]) + key1=dict(type="MultiStepLR", milestones=[1, 2]), + key2=[dict(type="MultiStepLR", milestones=[1, 2]), dict(type="StepLR", step_size=1)], + ) param_schedulers = runner.build_param_scheduler(cfg) self.assertIsInstance(param_schedulers, dict) self.assertEqual(len(param_schedulers), 2) - self.assertEqual(len(param_schedulers['key1']), 1) - self.assertEqual(len(param_schedulers['key2']), 2) + self.assertEqual(len(param_schedulers["key1"]), 1) + self.assertEqual(len(param_schedulers["key2"]), 2) # 5. test converting epoch-based scheduler to iter-based runner.optim_wrapper = runner.build_optim_wrapper( - dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))) + dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)) + ) # 5.1 train loop should be built before converting scheduler - cfg = dict( - type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True) + cfg = dict(type="MultiStepLR", milestones=[1, 2], convert_to_iter_based=True) # 5.2 convert epoch-based to iter-based scheduler - cfg = dict( - type='MultiStepLR', - milestones=[1, 2], - begin=1, - end=7, - convert_to_iter_based=True) + cfg = dict(type="MultiStepLR", milestones=[1, 2], begin=1, end=7, convert_to_iter_based=True) runner._train_loop = runner.build_train_loop(runner.train_loop) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) @@ -1163,18 +1120,14 @@ def test_build_param_scheduler(self): self.assertEqual(param_schedulers[0].end, 28) # 6. test set default end of schedulers - cfg = dict(type='MultiStepLR', milestones=[1, 2], begin=1) + cfg = dict(type="MultiStepLR", milestones=[1, 2], begin=1) param_schedulers = runner.build_param_scheduler(cfg) self.assertTrue(param_schedulers[0].by_epoch) self.assertEqual(param_schedulers[0].begin, 1) # runner.max_epochs = 3 self.assertEqual(param_schedulers[0].end, 3) - cfg = dict( - type='MultiStepLR', - milestones=[1, 2], - begin=1, - convert_to_iter_based=True) + cfg = dict(type="MultiStepLR", milestones=[1, 2], begin=1, convert_to_iter_based=True) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) self.assertEqual(param_schedulers[0].begin, 4) @@ -1183,7 +1136,7 @@ def test_build_param_scheduler(self): def test_build_evaluator(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_evaluator' + cfg.experiment_name = "test_build_evaluator" runner = Runner.from_cfg(cfg) # input is a BaseEvaluator or ComposedEvaluator object @@ -1194,11 +1147,11 @@ def test_build_evaluator(self): self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator)) # input is a dict - evaluator = dict(type='ToyMetric1') + evaluator = dict(type="ToyMetric1") self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator) # input is a list of dict - evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')] + evaluator = [dict(type="ToyMetric1"), dict(type="ToyMetric2")] self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator) # input is a list of built metric. @@ -1208,40 +1161,36 @@ def test_build_evaluator(self): self.assertIs(_evaluator.metrics[1], metric[1]) # test collect device - evaluator = [ - dict(type='ToyMetric1', collect_device='cpu'), - dict(type='ToyMetric2', collect_device='gpu') - ] + evaluator = [dict(type="ToyMetric1", collect_device="cpu"), dict(type="ToyMetric2", collect_device="gpu")] _evaluator = runner.build_evaluator(evaluator) - self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu') - self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu') + self.assertEqual(_evaluator.metrics[0].collect_device, "cpu") + self.assertEqual(_evaluator.metrics[1].collect_device, "gpu") # test build a customize evaluator evaluator = dict( - type='ToyEvaluator', - metrics=[ - dict(type='ToyMetric1', collect_device='cpu'), - dict(type='ToyMetric2', collect_device='gpu') - ]) + type="ToyEvaluator", + metrics=[dict(type="ToyMetric1", collect_device="cpu"), dict(type="ToyMetric2", collect_device="gpu")], + ) _evaluator = runner.build_evaluator(evaluator) self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator) - self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu') - self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu') + self.assertEqual(_evaluator.metrics[0].collect_device, "cpu") + self.assertEqual(_evaluator.metrics[1].collect_device, "gpu") # test evaluator must be a Evaluator instance - with self.assertRaisesRegex(TypeError, 'evaluator should be'): + with self.assertRaisesRegex(TypeError, "evaluator should be"): _evaluator = runner.build_evaluator(ToyMetric1()) def test_build_dataloader(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_dataloader' + cfg.experiment_name = "test_build_dataloader" runner = Runner.from_cfg(cfg) cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), batch_size=1, - num_workers=0) + num_workers=0, + ) seed = np.random.randint(2**31) dataloader = runner.build_dataloader(cfg, seed=seed) self.assertIsInstance(dataloader, DataLoader) @@ -1250,28 +1199,29 @@ def test_build_dataloader(self): self.assertEqual(dataloader.sampler.seed, seed) # diff_rank_seed is True - dataloader = runner.build_dataloader( - cfg, seed=seed, diff_rank_seed=True) + dataloader = runner.build_dataloader(cfg, seed=seed, diff_rank_seed=True) self.assertNotEqual(dataloader.sampler.seed, seed) # custom worker_init_fn cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), + worker_init_fn=dict(type="custom_worker_init"), batch_size=1, - num_workers=2) + num_workers=2, + ) dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.worker_init_fn.func, custom_worker_init) # collate_fn is a dict cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), + worker_init_fn=dict(type="custom_worker_init"), batch_size=1, num_workers=2, - collate_fn=dict(type='pseudo_collate')) + collate_fn=dict(type="pseudo_collate"), + ) dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) @@ -1280,58 +1230,60 @@ def custom_collate(data_batch): return data_batch cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), + worker_init_fn=dict(type="custom_worker_init"), batch_size=1, num_workers=2, - collate_fn=custom_collate) + collate_fn=custom_collate, + ) dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.collate_fn, custom_collate) # collate_fn is a invalid value - with self.assertRaisesRegex( - TypeError, 'collate_fn should be a dict or callable object'): + with self.assertRaisesRegex(TypeError, "collate_fn should be a dict or callable object"): cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), + worker_init_fn=dict(type="custom_worker_init"), batch_size=1, num_workers=2, - collate_fn='collate_fn') + collate_fn="collate_fn", + ) dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) # num_batch_per_epoch is not None cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate'), + dataset=dict(type="ToyDataset"), + sampler=dict(type="DefaultSampler", shuffle=True), + collate_fn=dict(type="default_collate"), batch_size=3, num_workers=2, - num_batch_per_epoch=2) + num_batch_per_epoch=2, + ) dataloader = runner.build_dataloader(cfg) self.assertEqual(len(dataloader.dataset), 6) def test_build_train_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_train_loop' + cfg.experiment_name = "test_build_train_loop" runner = Runner.from_cfg(cfg) # input should be a Loop object or dict - with self.assertRaisesRegex(TypeError, 'should be'): - runner.build_train_loop('invalid-type') + with self.assertRaisesRegex(TypeError, "should be"): + runner.build_train_loop("invalid-type") # Only one of type or by_epoch can exist in cfg - cfg = dict(type='EpochBasedTrainLoop', by_epoch=True, max_epochs=3) - with self.assertRaisesRegex(RuntimeError, 'Only one'): + cfg = dict(type="EpochBasedTrainLoop", by_epoch=True, max_epochs=3) + with self.assertRaisesRegex(RuntimeError, "Only one"): runner.build_train_loop(cfg) # input is a dict and contains type key - cfg = dict(type='EpochBasedTrainLoop', max_epochs=3) + cfg = dict(type="EpochBasedTrainLoop", max_epochs=3) loop = runner.build_train_loop(cfg) self.assertIsInstance(loop, EpochBasedTrainLoop) - cfg = dict(type='IterBasedTrainLoop', max_iters=3) + cfg = dict(type="IterBasedTrainLoop", max_iters=3) loop = runner.build_train_loop(cfg) self.assertIsInstance(loop, IterBasedTrainLoop) @@ -1348,27 +1300,27 @@ def test_build_train_loop(self): self.assertEqual(id(runner.build_train_loop(loop)), id(loop)) # param_schedulers can be None - cfg = dict(type='EpochBasedTrainLoop', max_epochs=3) + cfg = dict(type="EpochBasedTrainLoop", max_epochs=3) runner.param_schedulers = None loop = runner.build_train_loop(cfg) self.assertIsInstance(loop, EpochBasedTrainLoop) # test custom training loop - cfg = dict(type='CustomTrainLoop', max_epochs=3) + cfg = dict(type="CustomTrainLoop", max_epochs=3) loop = runner.build_train_loop(cfg) self.assertIsInstance(loop, CustomTrainLoop) def test_build_val_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_val_loop' + cfg.experiment_name = "test_build_val_loop" runner = Runner.from_cfg(cfg) # input should be a Loop object or dict - with self.assertRaisesRegex(TypeError, 'should be'): - runner.build_test_loop('invalid-type') + with self.assertRaisesRegex(TypeError, "should be"): + runner.build_test_loop("invalid-type") # input is a dict and contains type key - cfg = dict(type='ValLoop') + cfg = dict(type="ValLoop") loop = runner.build_test_loop(cfg) self.assertIsInstance(loop, ValLoop) @@ -1381,21 +1333,21 @@ def test_build_val_loop(self): self.assertEqual(id(runner.build_val_loop(loop)), id(loop)) # test custom validation loop - cfg = dict(type='CustomValLoop') + cfg = dict(type="CustomValLoop") loop = runner.build_val_loop(cfg) self.assertIsInstance(loop, CustomValLoop) def test_build_test_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_test_loop' + cfg.experiment_name = "test_build_test_loop" runner = Runner.from_cfg(cfg) # input should be a Loop object or dict - with self.assertRaisesRegex(TypeError, 'should be'): - runner.build_test_loop('invalid-type') + with self.assertRaisesRegex(TypeError, "should be"): + runner.build_test_loop("invalid-type") # input is a dict and contains type key - cfg = dict(type='TestLoop') + cfg = dict(type="TestLoop") loop = runner.build_test_loop(cfg) self.assertIsInstance(loop, TestLoop) @@ -1408,21 +1360,21 @@ def test_build_test_loop(self): self.assertEqual(id(runner.build_test_loop(loop)), id(loop)) # test custom validation loop - cfg = dict(type='CustomTestLoop') + cfg = dict(type="CustomTestLoop") loop = runner.build_val_loop(cfg) self.assertIsInstance(loop, CustomTestLoop) def test_build_log_processor(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_log_processor' + cfg.experiment_name = "test_build_log_processor" runner = Runner.from_cfg(cfg) # input should be a LogProcessor object or dict - with self.assertRaisesRegex(TypeError, 'should be'): - runner.build_log_processor('invalid-type') + with self.assertRaisesRegex(TypeError, "should be"): + runner.build_log_processor("invalid-type") # input is a dict and contains type key - cfg = dict(type='LogProcessor') + cfg = dict(type="LogProcessor") log_processor = runner.build_log_processor(cfg) self.assertIsInstance(log_processor, LogProcessor) @@ -1432,40 +1384,38 @@ def test_build_log_processor(self): self.assertIsInstance(log_processor, LogProcessor) # input is a LogProcessor object - self.assertEqual( - id(runner.build_log_processor(log_processor)), id(log_processor)) + self.assertEqual(id(runner.build_log_processor(log_processor)), id(log_processor)) # test custom validation log_processor - cfg = dict(type='CustomLogProcessor') + cfg = dict(type="CustomLogProcessor") log_processor = runner.build_log_processor(cfg) self.assertIsInstance(log_processor, CustomLogProcessor) def test_train(self): # 1. test `self.train_loop` is None cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train1' - cfg.pop('train_dataloader') - cfg.pop('train_cfg') - cfg.pop('optim_wrapper') - cfg.pop('param_scheduler') + cfg.experiment_name = "test_train1" + cfg.pop("train_dataloader") + cfg.pop("train_cfg") + cfg.pop("optim_wrapper") + cfg.pop("param_scheduler") runner = Runner.from_cfg(cfg) - with self.assertRaisesRegex(RuntimeError, 'should not be None'): + with self.assertRaisesRegex(RuntimeError, "should not be None"): runner.train() # 2. test iter and epoch counter of EpochBasedTrainLoop and timing of # running ValLoop epoch_results = [] - epoch_targets = [i for i in range(3)] + epoch_targets = list(range(3)) iter_results = [] - iter_targets = [i for i in range(4 * 3)] + iter_targets = list(range(4 * 3)) batch_idx_results = [] - batch_idx_targets = [i for i in range(4)] * 3 # train and val + batch_idx_targets = list(range(4)) * 3 # train and val val_epoch_results = [] - val_epoch_targets = [i for i in range(2, 4)] + val_epoch_targets = list(range(2, 4)) @HOOKS.register_module(force=True) class TestEpochHook(Hook): - def before_train_epoch(self, runner): epoch_results.append(runner.epoch) @@ -1477,8 +1427,8 @@ def before_val_epoch(self, runner): val_epoch_results.append(runner.epoch) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train2' - cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)] + cfg.experiment_name = "test_train2" + cfg.custom_hooks = [dict(type="TestEpochHook", priority=50)] cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2) runner = Runner.from_cfg(cfg) runner.train() @@ -1487,13 +1437,25 @@ def before_val_epoch(self, runner): assert isinstance(runner.train_loop, EpochBasedTrainLoop) - for result, target, in zip(epoch_results, epoch_targets): + for ( + result, + target, + ) in zip(epoch_results, epoch_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(iter_results, iter_targets): + for ( + result, + target, + ) in zip(iter_results, iter_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(batch_idx_results, batch_idx_targets): + for ( + result, + target, + ) in zip(batch_idx_results, batch_idx_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_epoch_results, val_epoch_targets): + for ( + result, + target, + ) in zip(val_epoch_results, val_epoch_targets, strict=False): self.assertEqual(result, target) # 3. test iter and epoch counter of IterBasedTrainLoop and timing of @@ -1503,14 +1465,13 @@ def before_val_epoch(self, runner): batch_idx_results = [] val_iter_results = [] val_batch_idx_results = [] - iter_targets = [i for i in range(12)] - batch_idx_targets = [i for i in range(12)] - val_iter_targets = [i for i in range(4, 12)] - val_batch_idx_targets = [i for i in range(4)] * 2 + iter_targets = list(range(12)) + batch_idx_targets = list(range(12)) + val_iter_targets = list(range(4, 12)) + val_batch_idx_targets = list(range(4)) * 2 @HOOKS.register_module(force=True) class TestIterHook(Hook): - def before_train_epoch(self, runner): epoch_results.append(runner.epoch) @@ -1523,10 +1484,9 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): val_batch_idx_results.append(batch_idx) cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train3' - cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.train_cfg = dict( - by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + cfg.experiment_name = "test_train3" + cfg.custom_hooks = [dict(type="TestIterHook", priority=50)] + cfg.train_cfg = dict(by_epoch=False, max_iters=12, val_interval=4, val_begin=4) runner = Runner.from_cfg(cfg) runner.train() @@ -1538,14 +1498,25 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): self.assertEqual(epoch_results[0], 0) self.assertEqual(runner.val_interval, 4) self.assertEqual(runner.val_begin, 4) - for result, target, in zip(iter_results, iter_targets): + for ( + result, + target, + ) in zip(iter_results, iter_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(batch_idx_results, batch_idx_targets): + for ( + result, + target, + ) in zip(batch_idx_results, batch_idx_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_iter_results, val_iter_targets): + for ( + result, + target, + ) in zip(val_iter_results, val_iter_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_batch_idx_results, - val_batch_idx_targets): + for ( + result, + target, + ) in zip(val_batch_idx_results, val_batch_idx_targets, strict=False): self.assertEqual(result, target) # 4. test iter and epoch counter of IterBasedTrainLoop and timing of @@ -1555,39 +1526,47 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): batch_idx_results = [] val_iter_results = [] val_batch_idx_results = [] - iter_targets = [i for i in range(12)] - batch_idx_targets = [i for i in range(12)] - val_iter_targets = [i for i in range(4, 12)] - val_batch_idx_targets = [i for i in range(4)] * 2 + iter_targets = list(range(12)) + batch_idx_targets = list(range(12)) + val_iter_targets = list(range(4, 12)) + val_batch_idx_targets = list(range(4)) * 2 cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train4' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) - cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.train_cfg = dict( - by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + cfg.experiment_name = "test_train4" + cfg.train_dataloader.sampler = dict(type="DefaultSampler", shuffle=True) + cfg.custom_hooks = [dict(type="TestIterHook", priority=50)] + cfg.train_cfg = dict(by_epoch=False, max_iters=12, val_interval=4, val_begin=4) runner = Runner.from_cfg(cfg) # Warning should be raised since the sampler is not InfiniteSampler. - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): runner.train() assert isinstance(runner.train_loop, IterBasedTrainLoop) - assert isinstance(runner.train_loop.dataloader_iterator, - _InfiniteDataloaderIterator) + assert isinstance(runner.train_loop.dataloader_iterator, _InfiniteDataloaderIterator) self.assertEqual(len(epoch_results), 1) self.assertEqual(epoch_results[0], 0) self.assertEqual(runner.val_interval, 4) self.assertEqual(runner.val_begin, 4) - for result, target, in zip(iter_results, iter_targets): + for ( + result, + target, + ) in zip(iter_results, iter_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(batch_idx_results, batch_idx_targets): + for ( + result, + target, + ) in zip(batch_idx_results, batch_idx_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_iter_results, val_iter_targets): + for ( + result, + target, + ) in zip(val_iter_results, val_iter_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_batch_idx_results, - val_batch_idx_targets): + for ( + result, + target, + ) in zip(val_batch_idx_results, val_batch_idx_targets, strict=False): self.assertEqual(result, target) # 5.1 test dynamic interval in IterBasedTrainLoop @@ -1601,7 +1580,6 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): @HOOKS.register_module(force=True) class TestIterDynamicIntervalHook(Hook): - def before_val(self, runner): iter_results.append(runner.iter) @@ -1609,22 +1587,23 @@ def before_train_iter(self, runner, batch_idx, data_batch=None): val_interval_results.append(runner.train_loop.val_interval) cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train5' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) - cfg.custom_hooks = [ - dict(type='TestIterDynamicIntervalHook', priority=50) - ] + cfg.experiment_name = "test_train5" + cfg.train_dataloader.sampler = dict(type="DefaultSampler", shuffle=True) + cfg.custom_hooks = [dict(type="TestIterDynamicIntervalHook", priority=50)] cfg.train_cfg = dict( - by_epoch=False, - max_iters=max_iters, - val_interval=interval, - dynamic_intervals=dynamic_intervals) + by_epoch=False, max_iters=max_iters, val_interval=interval, dynamic_intervals=dynamic_intervals + ) runner = Runner.from_cfg(cfg) runner.train() - for result, target, in zip(iter_results, iter_targets): + for ( + result, + target, + ) in zip(iter_results, iter_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_interval_results, val_interval_targets): + for ( + result, + target, + ) in zip(val_interval_results, val_interval_targets, strict=False): self.assertEqual(result, target) # 5.2 test dynamic interval in EpochBasedTrainLoop @@ -1638,7 +1617,6 @@ def before_train_iter(self, runner, batch_idx, data_batch=None): @HOOKS.register_module(force=True) class TestEpochDynamicIntervalHook(Hook): - def before_val_epoch(self, runner): epoch_results.append(runner.epoch) @@ -1646,28 +1624,28 @@ def before_train_epoch(self, runner): val_interval_results.append(runner.train_loop.val_interval) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train6' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) - cfg.custom_hooks = [ - dict(type='TestEpochDynamicIntervalHook', priority=50) - ] + cfg.experiment_name = "test_train6" + cfg.train_dataloader.sampler = dict(type="DefaultSampler", shuffle=True) + cfg.custom_hooks = [dict(type="TestEpochDynamicIntervalHook", priority=50)] cfg.train_cfg = dict( - by_epoch=True, - max_epochs=max_epochs, - val_interval=interval, - dynamic_intervals=dynamic_intervals) + by_epoch=True, max_epochs=max_epochs, val_interval=interval, dynamic_intervals=dynamic_intervals + ) runner = Runner.from_cfg(cfg) runner.train() - for result, target, in zip(epoch_results, epoch_targets): + for ( + result, + target, + ) in zip(epoch_results, epoch_targets, strict=False): self.assertEqual(result, target) - for result, target, in zip(val_interval_results, val_interval_targets): + for ( + result, + target, + ) in zip(val_interval_results, val_interval_targets, strict=False): self.assertEqual(result, target) # 7. test init weights @MODELS.register_module(force=True) class ToyModel2(ToyModel): - def __init__(self): super().__init__() self.initiailzed = False @@ -1676,7 +1654,7 @@ def init_weights(self): self.initiailzed = True cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train7' + cfg.experiment_name = "test_train7" runner = Runner.from_cfg(cfg) model = ToyModel2() runner.model = model @@ -1685,82 +1663,78 @@ def init_weights(self): # 8.1 test train with multiple optimizer and single list of schedulers. cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train8' - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) + cfg.experiment_name = "test_train8" + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2]) cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') - cfg.model = dict(type='ToyGANModel') + linear1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + linear2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.02)), + constructor="ToyMultipleOptimizerConstructor", + ) + cfg.model = dict(type="ToyGANModel") runner = runner.from_cfg(cfg) runner.train() # 8.1 Test train with multiple optimizer and single schedulers. cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train8.1.1' - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) + cfg.experiment_name = "test_train8.1.1" + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2]) cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') - cfg.model = dict(type='ToyGANModel') + linear1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + linear2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.02)), + constructor="ToyMultipleOptimizerConstructor", + ) + cfg.model = dict(type="ToyGANModel") runner = runner.from_cfg(cfg) runner.train() # Test list like single scheduler. - cfg.experiment_name = 'test_train8.1.2' - cfg.param_scheduler = [dict(type='MultiStepLR', milestones=[1, 2])] + cfg.experiment_name = "test_train8.1.2" + cfg.param_scheduler = [dict(type="MultiStepLR", milestones=[1, 2])] runner = runner.from_cfg(cfg) runner.train() # 8.2 Test train with multiple optimizer and multiple schedulers. - cfg.experiment_name = 'test_train8.2.1' + cfg.experiment_name = "test_train8.2.1" cfg.param_scheduler = dict( - linear1=dict(type='MultiStepLR', milestones=[1, 2]), - linear2=dict(type='MultiStepLR', milestones=[1, 2]), + linear1=dict(type="MultiStepLR", milestones=[1, 2]), + linear2=dict(type="MultiStepLR", milestones=[1, 2]), ) runner = runner.from_cfg(cfg) runner.train() - cfg.experiment_name = 'test_train8.2.2' + cfg.experiment_name = "test_train8.2.2" cfg.param_scheduler = dict( - linear1=[dict(type='MultiStepLR', milestones=[1, 2])], - linear2=[dict(type='MultiStepLR', milestones=[1, 2])], + linear1=[dict(type="MultiStepLR", milestones=[1, 2])], + linear2=[dict(type="MultiStepLR", milestones=[1, 2])], ) runner = runner.from_cfg(cfg) runner.train() # 9 Test training with a dataset without metainfo - cfg.experiment_name = 'test_train9' + cfg.experiment_name = "test_train9" cfg = copy.deepcopy(cfg) - cfg.train_dataloader.dataset = dict(type='ToyDatasetNoMeta') + cfg.train_dataloader.dataset = dict(type="ToyDatasetNoMeta") runner = runner.from_cfg(cfg) runner.train() # 10.1 Test build dataloader with default collate function cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train10.1' - cfg.train_dataloader.update(collate_fn=dict(type='default_collate')) + cfg.experiment_name = "test_train10.1" + cfg.train_dataloader.update(collate_fn=dict(type="default_collate")) runner = Runner.from_cfg(cfg) runner.train() # 10.2 Test build dataloader with custom collate function cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train10.2' - cfg.train_dataloader.update( - collate_fn=dict(type='custom_collate', pad_value=100)) + cfg.experiment_name = "test_train10.2" + cfg.train_dataloader.update(collate_fn=dict(type="custom_collate", pad_value=100)) runner = Runner.from_cfg(cfg) runner.train() # 10.3 Test build dataloader with custom worker_init function cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train10.3' - cfg.train_dataloader.update( - worker_init_fn=dict(type='custom_worker_init')) + cfg.experiment_name = "test_train10.3" + cfg.train_dataloader.update(worker_init_fn=dict(type="custom_worker_init")) runner = Runner.from_cfg(cfg) runner.train() @@ -1768,15 +1742,14 @@ def init_weights(self): # function. with self.assertRaises(TypeError): cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_train11' - cfg.train_dataloader.update(collate_fn=dict(type='custom_collate')) + cfg.experiment_name = "test_train11" + cfg.train_dataloader.update(collate_fn=dict(type="custom_collate")) runner = Runner.from_cfg(cfg) runner.train() # 12.1 Test train with model, which does not inherit from BaseModel @MODELS.register_module(force=True) class ToyModel3(nn.Module): - def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) @@ -1785,38 +1758,38 @@ def train_step(self, *args, **kwargs): return dict(loss=torch.tensor(1)) cfg = copy.deepcopy(self.iter_based_cfg) - cfg.pop('val_cfg') - cfg.pop('val_dataloader') - cfg.pop('val_evaluator') - cfg.model = dict(type='ToyModel3') - cfg.experiment_name = 'test_train12.1' + cfg.pop("val_cfg") + cfg.pop("val_dataloader") + cfg.pop("val_evaluator") + cfg.model = dict(type="ToyModel3") + cfg.experiment_name = "test_train12.1" runner = Runner.from_cfg(cfg) runner.train() # 12.2 Test val_step should be implemented if val_cfg is not None cfg = copy.deepcopy(self.iter_based_cfg) - cfg.model = dict(type='ToyModel3') - cfg.experiment_name = 'test_train12.2' + cfg.model = dict(type="ToyModel3") + cfg.experiment_name = "test_train12.2" runner = Runner.from_cfg(cfg) - with self.assertRaisesRegex(AssertionError, 'If you want to validate'): + with self.assertRaisesRegex(AssertionError, "If you want to validate"): runner.train() # 13 Test the logs will be printed when the length of # train_dataloader is smaller than the interval set in LoggerHook cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train13' - cfg.default_hooks = dict(logger=dict(type='LoggerHook', interval=5)) + cfg.experiment_name = "test_train13" + cfg.default_hooks = dict(logger=dict(type="LoggerHook", interval=5)) runner = Runner.from_cfg(cfg) runner.train() with open(runner.logger._log_file) as f: log = f.read() - self.assertIn('Epoch(train) [1][4/4]', log) + self.assertIn("Epoch(train) [1][4/4]", log) # 14. test_loop will not be built for cfg in (self.epoch_based_cfg, self.iter_based_cfg): cfg = copy.deepcopy(cfg) - cfg.experiment_name = 'test_train14' + cfg.experiment_name = "test_train14" runner = Runner.from_cfg(cfg) runner.train() self.assertIsInstance(runner._train_loop, BaseLoop) @@ -1825,8 +1798,8 @@ def train_step(self, *args, **kwargs): # 15. test num_batch_per_epoch cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train15' - cfg.train_dataloader['num_batch_per_epoch'] = 2 + cfg.experiment_name = "test_train15" + cfg.train_dataloader["num_batch_per_epoch"] = 2 cfg.train_cfg = dict( by_epoch=True, max_epochs=3, @@ -1835,57 +1808,54 @@ def train_step(self, *args, **kwargs): runner.train() self.assertEqual(runner.iter, 3 * 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, reason="torch.compile is not valid, please install PyTorch>=2.0.0") def test_train_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train_compile_simple' + cfg.experiment_name = "test_train_compile_simple" cfg.compile = True runner = Runner.from_cfg(cfg) runner.train() # 2. test with advanced configuration cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_train_compile_advanced' - cfg.compile = dict(backend='inductor', mode='default') + cfg.experiment_name = "test_train_compile_advanced" + cfg.compile = dict(backend="inductor", mode="default") runner = Runner.from_cfg(cfg) runner.train() - runner._maybe_compile('train_step') + runner._maybe_compile("train_step") # PyTorch 2.0.0 could close the FileHandler after calling of # ``torch.compile``. So we need to test our file handler still works. - with open(osp.join(f'{runner.log_dir}', - f'{runner.timestamp}.log')) as f: + with open(osp.join(f"{runner.log_dir}", f"{runner.timestamp}.log")) as f: last_line = f.readlines()[-1] - self.assertTrue(last_line.endswith('please be patient.\n')) + self.assertTrue(last_line.endswith("please be patient.\n")) def test_val(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_val1' - cfg.pop('val_dataloader') - cfg.pop('val_cfg') - cfg.pop('val_evaluator') + cfg.experiment_name = "test_val1" + cfg.pop("val_dataloader") + cfg.pop("val_cfg") + cfg.pop("val_evaluator") runner = Runner.from_cfg(cfg) - with self.assertRaisesRegex(RuntimeError, 'should not be None'): + with self.assertRaisesRegex(RuntimeError, "should not be None"): runner.val() cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_val2' + cfg.experiment_name = "test_val2" runner = Runner.from_cfg(cfg) runner.val() # test run val without train and test components cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_individually_val' - cfg.pop('train_dataloader') - cfg.pop('train_cfg') - cfg.pop('optim_wrapper') - cfg.pop('param_scheduler') - cfg.pop('test_dataloader') - cfg.pop('test_cfg') - cfg.pop('test_evaluator') + cfg.experiment_name = "test_individually_val" + cfg.pop("train_dataloader") + cfg.pop("train_cfg") + cfg.pop("optim_wrapper") + cfg.pop("param_scheduler") + cfg.pop("test_dataloader") + cfg.pop("test_cfg") + cfg.pop("test_evaluator") runner = Runner.from_cfg(cfg) # Test default fp32 `autocast` context. @@ -1900,23 +1870,21 @@ def get_outputs_callback(module, inputs, outputs): predictions.clear() # Test fp16 `autocast` context. - cfg.experiment_name = 'test_val3' + cfg.experiment_name = "test_val3" cfg.val_cfg = dict(fp16=True) runner = Runner.from_cfg(cfg) runner.model.register_forward_hook(get_outputs_callback) - if (digit_version(TORCH_VERSION) < digit_version('1.10.0') - and not torch.cuda.is_available()): - with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'): + if digit_version(TORCH_VERSION) < digit_version("1.10.0") and not torch.cuda.is_available(): + with self.assertRaisesRegex(RuntimeError, "If pytorch versions"): runner.val() else: runner.val() - self.assertIn(predictions[0].dtype, - (torch.float16, torch.bfloat16)) + self.assertIn(predictions[0].dtype, (torch.float16, torch.bfloat16)) # train_loop and test_loop will not be built for cfg in (self.epoch_based_cfg, self.iter_based_cfg): cfg = copy.deepcopy(cfg) - cfg.experiment_name = 'test_val4' + cfg.experiment_name = "test_val4" runner = Runner.from_cfg(cfg) runner.val() self.assertIsInstance(runner._train_loop, dict) @@ -1927,56 +1895,49 @@ def get_outputs_callback(module, inputs, outputs): @HOOKS.register_module(force=True) class TestIterHook(Hook): - def __init__(self): self.val_iter = 0 - def after_val_iter(self, - runner, - batch_idx, - data_batch=None, - outputs=None): + def after_val_iter(self, runner, batch_idx, data_batch=None, outputs=None): self.val_iter += 1 nonlocal val_result val_result = self.val_iter cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.val_dataloader['num_batch_per_epoch'] = 2 + cfg.custom_hooks = [dict(type="TestIterHook", priority=50)] + cfg.val_dataloader["num_batch_per_epoch"] = 2 runner = Runner.from_cfg(cfg) runner.val() self.assertEqual(val_result, 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, reason="torch.compile is not valid, please install PyTorch>=2.0.0") def test_val_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_val_compile_simple' + cfg.experiment_name = "test_val_compile_simple" cfg.compile = True runner = Runner.from_cfg(cfg) runner.val() # 2. test with advanced configuration cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_val_compile_advanced' - cfg.compile = dict(backend='inductor', mode='default') + cfg.experiment_name = "test_val_compile_advanced" + cfg.compile = dict(backend="inductor", mode="default") runner = Runner.from_cfg(cfg) runner.val() def test_test(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_test1' - cfg.pop('test_dataloader') - cfg.pop('test_cfg') - cfg.pop('test_evaluator') + cfg.experiment_name = "test_test1" + cfg.pop("test_dataloader") + cfg.pop("test_cfg") + cfg.pop("test_evaluator") runner = Runner.from_cfg(cfg) - with self.assertRaisesRegex(RuntimeError, 'should not be None'): + with self.assertRaisesRegex(RuntimeError, "should not be None"): runner.test() cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_test2' + cfg.experiment_name = "test_test2" runner = Runner.from_cfg(cfg) runner.test() # Test run test without building train loop. @@ -1984,14 +1945,14 @@ def test_test(self): # test run test without train and test components cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_individually_test' - cfg.pop('train_dataloader') - cfg.pop('train_cfg') - cfg.pop('optim_wrapper') - cfg.pop('param_scheduler') - cfg.pop('val_dataloader') - cfg.pop('val_cfg') - cfg.pop('val_evaluator') + cfg.experiment_name = "test_individually_test" + cfg.pop("train_dataloader") + cfg.pop("train_cfg") + cfg.pop("optim_wrapper") + cfg.pop("param_scheduler") + cfg.pop("val_dataloader") + cfg.pop("val_cfg") + cfg.pop("val_evaluator") runner = Runner.from_cfg(cfg) # Test default fp32 `autocast` context. @@ -2006,22 +1967,20 @@ def get_outputs_callback(module, inputs, outputs): predictions.clear() # Test fp16 `autocast` context. - cfg.experiment_name = 'test_test3' + cfg.experiment_name = "test_test3" cfg.test_cfg = dict(fp16=True) runner = Runner.from_cfg(cfg) runner.model.register_forward_hook(get_outputs_callback) - if (digit_version(TORCH_VERSION) < digit_version('1.10.0') - and not torch.cuda.is_available()): - with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'): + if digit_version(TORCH_VERSION) < digit_version("1.10.0") and not torch.cuda.is_available(): + with self.assertRaisesRegex(RuntimeError, "If pytorch versions"): runner.test() else: runner.test() - self.assertIn(predictions[0].dtype, - (torch.float16, torch.bfloat16)) + self.assertIn(predictions[0].dtype, (torch.float16, torch.bfloat16)) # train_loop and val_loop will not be built for cfg in (self.epoch_based_cfg, self.iter_based_cfg): cfg = copy.deepcopy(cfg) - cfg.experiment_name = 'test_test4' + cfg.experiment_name = "test_test4" runner = Runner.from_cfg(cfg) runner.test() self.assertIsInstance(runner._train_loop, dict) @@ -2032,75 +1991,64 @@ def get_outputs_callback(module, inputs, outputs): @HOOKS.register_module(force=True) class TestIterHook(Hook): - def __init__(self): self.test_iter = 0 - def after_test_iter(self, - runner, - batch_idx, - data_batch=None, - outputs=None): + def after_test_iter(self, runner, batch_idx, data_batch=None, outputs=None): self.test_iter += 1 nonlocal test_result test_result = self.test_iter cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.test_dataloader['num_batch_per_epoch'] = 2 + cfg.custom_hooks = [dict(type="TestIterHook", priority=50)] + cfg.test_dataloader["num_batch_per_epoch"] = 2 runner = Runner.from_cfg(cfg) runner.test() self.assertEqual(test_result, 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, reason="torch.compile is not valid, please install PyTorch>=2.0.0") def test_test_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_test_compile_simple' + cfg.experiment_name = "test_test_compile_simple" cfg.compile = True runner = Runner.from_cfg(cfg) runner.test() # 2. test with advanced configuration cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_test_compile_advanced' - cfg.compile = dict(backend='inductor', mode='default') + cfg.experiment_name = "test_test_compile_advanced" + cfg.compile = dict(backend="inductor", mode="default") runner = Runner.from_cfg(cfg) runner.test() def test_register_hook(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_register_hook' + cfg.experiment_name = "test_register_hook" runner = Runner.from_cfg(cfg) runner._hooks = [] # 1. test `hook` parameter # 1.1 `hook` should be either a Hook object or dict - with self.assertRaisesRegex( - TypeError, 'hook should be an instance of Hook or dict'): - runner.register_hook(['string']) + with self.assertRaisesRegex(TypeError, "hook should be an instance of Hook or dict"): + runner.register_hook(["string"]) # 1.2 `hook` is a dict - timer_cfg = dict(type='IterTimerHook') + timer_cfg = dict(type="IterTimerHook") runner.register_hook(timer_cfg) self.assertEqual(len(runner._hooks), 1) self.assertTrue(isinstance(runner._hooks[0], IterTimerHook)) # default priority of `IterTimerHook` is 'NORMAL' - self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('NORMAL')) + self.assertEqual(get_priority(runner._hooks[0].priority), get_priority("NORMAL")) runner._hooks = [] # 1.2.1 `hook` is a dict and contains `priority` field # set the priority of `IterTimerHook` as 'BELOW_NORMAL' - timer_cfg = dict(type='IterTimerHook', priority='BELOW_NORMAL') + timer_cfg = dict(type="IterTimerHook", priority="BELOW_NORMAL") runner.register_hook(timer_cfg) self.assertEqual(len(runner._hooks), 1) self.assertTrue(isinstance(runner._hooks[0], IterTimerHook)) - self.assertEqual( - get_priority(runner._hooks[0].priority), - get_priority('BELOW_NORMAL')) + self.assertEqual(get_priority(runner._hooks[0].priority), get_priority("BELOW_NORMAL")) # 1.3 `hook` is a hook object runtime_info_hook = RuntimeInfoHook() @@ -2110,30 +2058,27 @@ def test_register_hook(self): # `IterTimerHook`, so the first item of `_hooks` should be # `runtime_info_hook` self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook)) - self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH')) + self.assertEqual(get_priority(runner._hooks[0].priority), get_priority("VERY_HIGH")) # 2. test `priority` parameter # `priority` argument is not None and it will be set as priority of # hook - param_scheduler_cfg = dict(type='ParamSchedulerHook', priority='LOW') - runner.register_hook(param_scheduler_cfg, priority='VERY_LOW') + param_scheduler_cfg = dict(type="ParamSchedulerHook", priority="LOW") + runner.register_hook(param_scheduler_cfg, priority="VERY_LOW") self.assertEqual(len(runner._hooks), 3) self.assertTrue(isinstance(runner._hooks[2], ParamSchedulerHook)) - self.assertEqual( - get_priority(runner._hooks[2].priority), get_priority('VERY_LOW')) + self.assertEqual(get_priority(runner._hooks[2].priority), get_priority("VERY_LOW")) # `priority` is Priority - logger_cfg = dict(type='LoggerHook', priority='BELOW_NORMAL') + logger_cfg = dict(type="LoggerHook", priority="BELOW_NORMAL") runner.register_hook(logger_cfg, priority=Priority.VERY_LOW) self.assertEqual(len(runner._hooks), 4) self.assertTrue(isinstance(runner._hooks[3], LoggerHook)) - self.assertEqual( - get_priority(runner._hooks[3].priority), get_priority('VERY_LOW')) + self.assertEqual(get_priority(runner._hooks[3].priority), get_priority("VERY_LOW")) def test_default_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_default_hooks' + cfg.experiment_name = "test_default_hooks" runner = Runner.from_cfg(cfg) runner._hooks = [] @@ -2154,28 +2099,28 @@ def test_default_hooks(self): # add a new default hook runner._hooks = [] - runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook'))) + runner.register_default_hooks(hooks=dict(ToyHook=dict(type="ToyHook"))) self.assertEqual(len(runner._hooks), 7) self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_custom_hooks' + cfg.experiment_name = "test_custom_hooks" runner = Runner.from_cfg(cfg) self.assertEqual(len(runner._hooks), 6) - custom_hooks = [dict(type='ToyHook')] + custom_hooks = [dict(type="ToyHook")] runner.register_custom_hooks(custom_hooks) self.assertEqual(len(runner._hooks), 7) self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_register_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_register_hooks' + cfg.experiment_name = "test_register_hooks" runner = Runner.from_cfg(cfg) runner._hooks = [] - custom_hooks = [dict(type='ToyHook')] + custom_hooks = [dict(type="ToyHook")] runner.register_hooks(custom_hooks=custom_hooks) # six default hooks + custom hook (ToyHook) self.assertEqual(len(runner._hooks), 7) @@ -2187,16 +2132,13 @@ def test_custom_loop(self): class CustomTrainLoop2(IterBasedTrainLoop): """Custom train loop with additional warmup stage.""" - def __init__(self, runner, dataloader, max_iters, warmup_loader, - max_warmup_iters): - super().__init__( - runner=runner, dataloader=dataloader, max_iters=max_iters) - self.warmup_loader = self.runner.build_dataloader( - warmup_loader) + def __init__(self, runner, dataloader, max_iters, warmup_loader, max_warmup_iters): + super().__init__(runner=runner, dataloader=dataloader, max_iters=max_iters) + self.warmup_loader = self.runner.build_dataloader(warmup_loader) self.max_warmup_iters = max_warmup_iters def run(self): - self.runner.call_hook('before_train') + self.runner.call_hook("before_train") self.runner.cur_dataloader = self.warmup_loader for idx, data_batch in enumerate(self.warmup_loader, 1): self.warmup_iter(data_batch) @@ -2204,22 +2146,19 @@ def run(self): break self.runner.cur_dataloader = self.warmup_loader - self.runner.call_hook('before_train_epoch') + self.runner.call_hook("before_train_epoch") while self.runner.iter < self._max_iters: data_batch = next(self.dataloader_iterator) self.run_iter(data_batch) - self.runner.call_hook('after_train_epoch') + self.runner.call_hook("after_train_epoch") - self.runner.call_hook('after_train') + self.runner.call_hook("after_train") def warmup_iter(self, data_batch): - self.runner.call_hook( - 'before_warmup_iter', data_batch=data_batch) - train_logs = self.runner.model.train_step( - data_batch, self.runner.optim_wrapper) - self.runner.message_hub.update_info('train_logs', train_logs) - self.runner.call_hook( - 'after_warmup_iter', data_batch=data_batch) + self.runner.call_hook("before_warmup_iter", data_batch=data_batch) + train_logs = self.runner.model.train_step(data_batch, self.runner.optim_wrapper) + self.runner.message_hub.update_info("train_logs", train_logs) + self.runner.call_hook("after_warmup_iter", data_batch=data_batch) before_warmup_iter_results = [] after_warmup_iter_results = [] @@ -2229,24 +2168,24 @@ class TestWarmupHook(Hook): """Test custom train loop.""" def before_warmup_iter(self, runner, data_batch=None): - before_warmup_iter_results.append('before') + before_warmup_iter_results.append("before") def after_warmup_iter(self, runner, data_batch=None, outputs=None): - after_warmup_iter_results.append('after') + after_warmup_iter_results.append("after") self.iter_based_cfg.train_cfg = dict( - type='CustomTrainLoop2', + type="CustomTrainLoop2", max_iters=10, warmup_loader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict(type="ToyDataset"), + sampler=dict(type="InfiniteSampler", shuffle=True), batch_size=1, - num_workers=0), - max_warmup_iters=5) - self.iter_based_cfg.custom_hooks = [ - dict(type='TestWarmupHook', priority=50) - ] - self.iter_based_cfg.experiment_name = 'test_custom_loop' + num_workers=0, + ), + max_warmup_iters=5, + ) + self.iter_based_cfg.custom_hooks = [dict(type="TestWarmupHook", priority=50)] + self.iter_based_cfg.experiment_name = "test_custom_loop" runner = Runner.from_cfg(self.iter_based_cfg) runner.train() @@ -2255,42 +2194,40 @@ def after_warmup_iter(self, runner, data_batch=None, outputs=None): # test custom hook triggered as expected self.assertEqual(len(before_warmup_iter_results), 5) self.assertEqual(len(after_warmup_iter_results), 5) - for before, after in zip(before_warmup_iter_results, - after_warmup_iter_results): - self.assertEqual(before, 'before') - self.assertEqual(after, 'after') + for before, after in zip(before_warmup_iter_results, after_warmup_iter_results, strict=False): + self.assertEqual(before, "before") + self.assertEqual(after, "after") def test_checkpoint(self): # 1. test epoch based cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint1' + cfg.experiment_name = "test_checkpoint1" runner = Runner.from_cfg(cfg) runner.train() # 1.1 test `save_checkpoint` which is called by `CheckpointHook` - path = osp.join(self.temp_dir, 'epoch_3.pth') + path = osp.join(self.temp_dir, "epoch_3.pth") self.assertTrue(osp.exists(path)) - self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) + self.assertFalse(osp.exists(osp.join(self.temp_dir, "epoch_4.pth"))) ckpt = torch.load(path) - self.assertEqual(ckpt['meta']['epoch'], 3) - self.assertEqual(ckpt['meta']['iter'], 12) - self.assertEqual(ckpt['meta']['experiment_name'], - runner.experiment_name) - self.assertEqual(ckpt['meta']['seed'], runner.seed) - assert isinstance(ckpt['optimizer'], dict) - assert isinstance(ckpt['param_schedulers'], list) - self.assertIsInstance(ckpt['message_hub'], dict) - message_hub = MessageHub.get_instance('test_ckpt') - message_hub.load_state_dict(ckpt['message_hub']) - self.assertEqual(message_hub.get_info('epoch'), 2) - self.assertEqual(message_hub.get_info('iter'), 11) + self.assertEqual(ckpt["meta"]["epoch"], 3) + self.assertEqual(ckpt["meta"]["iter"], 12) + self.assertEqual(ckpt["meta"]["experiment_name"], runner.experiment_name) + self.assertEqual(ckpt["meta"]["seed"], runner.seed) + assert isinstance(ckpt["optimizer"], dict) + assert isinstance(ckpt["param_schedulers"], list) + self.assertIsInstance(ckpt["message_hub"], dict) + message_hub = MessageHub.get_instance("test_ckpt") + message_hub.load_state_dict(ckpt["message_hub"]) + self.assertEqual(message_hub.get_info("epoch"), 2) + self.assertEqual(message_hub.get_info("iter"), 11) # 1.2 test `load_checkpoint` cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint2' - cfg.optim_wrapper = dict(type='SGD', lr=0.2) - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) + cfg.experiment_name = "test_checkpoint2" + cfg.optim_wrapper = dict(type="SGD", lr=0.2) + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.load_checkpoint(path) self.assertEqual(runner.epoch, 0) @@ -2303,10 +2240,9 @@ def test_checkpoint(self): # 1.3.1 test `resume` cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint3' - cfg.optim_wrapper = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)) - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) + cfg.experiment_name = "test_checkpoint3" + cfg.optim_wrapper = dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.2)) + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.resume(path) self.assertEqual(runner.epoch, 3) @@ -2314,47 +2250,45 @@ def test_checkpoint(self): self.assertTrue(runner._has_loaded) self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) - self.assertEqual(runner.optim_wrapper.param_groups[0]['lr'], 0.0001) + self.assertEqual(runner.optim_wrapper.param_groups[0]["lr"], 0.0001) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1}) self.assertIsInstance(runner.message_hub, MessageHub) - self.assertEqual(runner.message_hub.get_info('epoch'), 2) - self.assertEqual(runner.message_hub.get_info('iter'), 11) - self.assertEqual(MessageHub.get_current_instance().get_info('epoch'), - 2) - self.assertEqual(MessageHub.get_current_instance().get_info('iter'), - 11) + self.assertEqual(runner.message_hub.get_info("epoch"), 2) + self.assertEqual(runner.message_hub.get_info("iter"), 11) + self.assertEqual(MessageHub.get_current_instance().get_info("epoch"), 2) + self.assertEqual(MessageHub.get_current_instance().get_info("iter"), 11) # 1.3.2 test resume with unmatched dataset_meta ckpt_modified = copy.deepcopy(ckpt) - ckpt_modified['meta']['dataset_meta'] = {'CLASSES': ['cat', 'dog']} + ckpt_modified["meta"]["dataset_meta"] = {"CLASSES": ["cat", "dog"]} # ckpt_modified['meta']['seed'] = 123 - path_modified = osp.join(self.temp_dir, 'modified.pth') + path_modified = osp.join(self.temp_dir, "modified.pth") torch.save(ckpt_modified, path_modified) # Warning should be raised since dataset_meta is not matched - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): runner.resume(path_modified) # 1.3.3 test resume with unmatched seed ckpt_modified = copy.deepcopy(ckpt) - ckpt_modified['meta']['seed'] = 123 - path_modified = osp.join(self.temp_dir, 'modified.pth') + ckpt_modified["meta"]["seed"] = 123 + path_modified = osp.join(self.temp_dir, "modified.pth") torch.save(ckpt_modified, path_modified) # Warning should be raised since seed is not matched - with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): + with self.assertLogs(MMLogger.get_current_instance(), level="WARNING"): runner.resume(path_modified) # 1.3.3 test resume with no seed and dataset meta ckpt_modified = copy.deepcopy(ckpt) - ckpt_modified['meta'].pop('seed') - ckpt_modified['meta'].pop('dataset_meta') - path_modified = osp.join(self.temp_dir, 'modified.pth') + ckpt_modified["meta"].pop("seed") + ckpt_modified["meta"].pop("dataset_meta") + path_modified = osp.join(self.temp_dir, "modified.pth") torch.save(ckpt_modified, path_modified) runner.resume(path_modified) # 1.4 test auto resume cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint4' + cfg.experiment_name = "test_checkpoint4" cfg.resume = True runner = Runner.from_cfg(cfg) runner.load_or_resume() @@ -2366,9 +2300,9 @@ def test_checkpoint(self): # 1.5 test resume from a specified checkpoint cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint5' + cfg.experiment_name = "test_checkpoint5" cfg.resume = True - cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth') + cfg.load_from = osp.join(self.temp_dir, "epoch_1.pth") runner = Runner.from_cfg(cfg) runner.load_or_resume() self.assertEqual(runner.epoch, 1) @@ -2379,88 +2313,74 @@ def test_checkpoint(self): # 1.6 multiple optimizers cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint6' + cfg.experiment_name = "test_checkpoint6" cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') - cfg.model = dict(type='ToyGANModel') + linear1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + linear2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.02)), + constructor="ToyMultipleOptimizerConstructor", + ) + cfg.model = dict(type="ToyGANModel") # disable OptimizerHook because it only works with one optimizer runner = Runner.from_cfg(cfg) runner.train() - path = osp.join(self.temp_dir, 'epoch_3.pth') + path = osp.join(self.temp_dir, "epoch_3.pth") self.assertTrue(osp.exists(path)) - self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'], - 0.0001) - self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam) - self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'], - 0.0002) + self.assertEqual(runner.optim_wrapper["linear1"].param_groups[0]["lr"], 0.0001) + self.assertIsInstance(runner.optim_wrapper["linear2"].optimizer, Adam) + self.assertEqual(runner.optim_wrapper["linear2"].param_groups[0]["lr"], 0.0002) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint7' + cfg.experiment_name = "test_checkpoint7" cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)), - constructor='ToyMultipleOptimizerConstructor') - cfg.model = dict(type='ToyGANModel') - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) + linear1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.2)), + linear2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.03)), + constructor="ToyMultipleOptimizerConstructor", + ) + cfg.model = dict(type="ToyGANModel") + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.resume(path) self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict) - self.assertIsInstance(runner.optim_wrapper['linear1'].optimizer, SGD) - self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'], - 0.0001) - self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam) - self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'], - 0.0002) + self.assertIsInstance(runner.optim_wrapper["linear1"].optimizer, SGD) + self.assertEqual(runner.optim_wrapper["linear1"].param_groups[0]["lr"], 0.0001) + self.assertIsInstance(runner.optim_wrapper["linear2"].optimizer, Adam) + self.assertEqual(runner.optim_wrapper["linear2"].param_groups[0]["lr"], 0.0002) self.assertIsInstance(runner.param_schedulers, dict) - self.assertEqual(len(runner.param_schedulers['linear1']), 1) - self.assertIsInstance(runner.param_schedulers['linear1'][0], - MultiStepLR) - self.assertEqual(runner.param_schedulers['linear1'][0].milestones, { - 1: 1, - 2: 1 - }) - self.assertEqual(len(runner.param_schedulers['linear2']), 1) - self.assertIsInstance(runner.param_schedulers['linear2'][0], - MultiStepLR) - self.assertEqual(runner.param_schedulers['linear2'][0].milestones, { - 1: 1, - 2: 1 - }) + self.assertEqual(len(runner.param_schedulers["linear1"]), 1) + self.assertIsInstance(runner.param_schedulers["linear1"][0], MultiStepLR) + self.assertEqual(runner.param_schedulers["linear1"][0].milestones, {1: 1, 2: 1}) + self.assertEqual(len(runner.param_schedulers["linear2"]), 1) + self.assertIsInstance(runner.param_schedulers["linear2"][0], MultiStepLR) + self.assertEqual(runner.param_schedulers["linear2"][0].milestones, {1: 1, 2: 1}) # 2. test iter based cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint8' + cfg.experiment_name = "test_checkpoint8" runner = Runner.from_cfg(cfg) runner.train() # 2.1.1 test `save_checkpoint` which is called by `CheckpointHook` - path = osp.join(self.temp_dir, 'iter_12.pth') + path = osp.join(self.temp_dir, "iter_12.pth") self.assertTrue(osp.exists(path)) - self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) + self.assertFalse(osp.exists(osp.join(self.temp_dir, "epoch_13.pth"))) ckpt = torch.load(path) - self.assertEqual(ckpt['meta']['epoch'], 0) - self.assertEqual(ckpt['meta']['iter'], 12) - assert isinstance(ckpt['optimizer'], dict) - assert isinstance(ckpt['param_schedulers'], list) - self.assertIsInstance(ckpt['message_hub'], dict) - message_hub.load_state_dict(ckpt['message_hub']) - self.assertEqual(message_hub.get_info('epoch'), 0) - self.assertEqual(message_hub.get_info('iter'), 11) + self.assertEqual(ckpt["meta"]["epoch"], 0) + self.assertEqual(ckpt["meta"]["iter"], 12) + assert isinstance(ckpt["optimizer"], dict) + assert isinstance(ckpt["param_schedulers"], list) + self.assertIsInstance(ckpt["message_hub"], dict) + message_hub.load_state_dict(ckpt["message_hub"]) + self.assertEqual(message_hub.get_info("epoch"), 0) + self.assertEqual(message_hub.get_info("iter"), 11) # 2.1.2 check class attribute _statistic_methods can be saved HistoryBuffer._statistics_methods.clear() ckpt = torch.load(path) - self.assertIn('min', HistoryBuffer._statistics_methods) + self.assertIn("min", HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint` cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint9' + cfg.experiment_name = "test_checkpoint9" runner = Runner.from_cfg(cfg) runner.load_checkpoint(path) self.assertEqual(runner.epoch, 0) @@ -2469,7 +2389,7 @@ def test_checkpoint(self): # 2.3 test `resume` cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint10' + cfg.experiment_name = "test_checkpoint10" runner = Runner.from_cfg(cfg) runner.resume(path) self.assertEqual(runner.epoch, 0) @@ -2477,12 +2397,12 @@ def test_checkpoint(self): self.assertTrue(runner._has_loaded) self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) - self.assertEqual(runner.message_hub.get_info('epoch'), 0) - self.assertEqual(runner.message_hub.get_info('iter'), 11) + self.assertEqual(runner.message_hub.get_info("epoch"), 0) + self.assertEqual(runner.message_hub.get_info("iter"), 11) # 2.4 test auto resume cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint11' + cfg.experiment_name = "test_checkpoint11" cfg.resume = True runner = Runner.from_cfg(cfg) runner.load_or_resume() @@ -2494,9 +2414,9 @@ def test_checkpoint(self): # 2.5 test resume from a specified checkpoint cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint12' + cfg.experiment_name = "test_checkpoint12" cfg.resume = True - cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth') + cfg.load_from = osp.join(self.temp_dir, "iter_3.pth") runner = Runner.from_cfg(cfg) runner.load_or_resume() self.assertEqual(runner.epoch, 0) @@ -2507,71 +2427,66 @@ def test_checkpoint(self): # 2.6 test resumed message_hub has the history value. cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint13' + cfg.experiment_name = "test_checkpoint13" cfg.resume = True - cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth') + cfg.load_from = osp.join(self.temp_dir, "iter_3.pth") runner = Runner.from_cfg(cfg) runner.load_or_resume() - assert len(runner.message_hub.log_scalars['train/lr'].data[1]) == 3 - assert len(MessageHub.get_current_instance().log_scalars['train/lr']. - data[1]) == 3 + assert len(runner.message_hub.log_scalars["train/lr"].data[1]) == 3 + assert len(MessageHub.get_current_instance().log_scalars["train/lr"].data[1]) == 3 # 2.7.1 test `resume` 2 optimizers and 1 scheduler list. - path = osp.join(self.temp_dir, 'epoch_3.pth') + path = osp.join(self.temp_dir, "epoch_3.pth") optim_cfg = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + linear1=dict(type="OptimWrapper", optimizer=dict(type="SGD", lr=0.01)), + linear2=dict(type="OptimWrapper", optimizer=dict(type="Adam", lr=0.02)), + constructor="ToyMultipleOptimizerConstructor", + ) cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint14' + cfg.experiment_name = "test_checkpoint14" cfg.optim_wrapper = optim_cfg - cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) - cfg.model = dict(type='ToyGANModel') + cfg.param_scheduler = dict(type="MultiStepLR", milestones=[1, 2, 3]) + cfg.model = dict(type="ToyGANModel") resumed_cfg = copy.deepcopy(cfg) runner = Runner.from_cfg(cfg) runner.train() - resumed_cfg.experiment_name = 'test_checkpoint15' + resumed_cfg.experiment_name = "test_checkpoint15" runner = Runner.from_cfg(resumed_cfg) runner.resume(path) - self.assertEqual(len(runner.param_schedulers['linear1']), 1) - self.assertEqual(len(runner.param_schedulers['linear2']), 1) - self.assertIsInstance(runner.param_schedulers['linear1'][0], - MultiStepLR) - self.assertIsInstance(runner.param_schedulers['linear2'][0], - MultiStepLR) + self.assertEqual(len(runner.param_schedulers["linear1"]), 1) + self.assertEqual(len(runner.param_schedulers["linear2"]), 1) + self.assertIsInstance(runner.param_schedulers["linear1"][0], MultiStepLR) + self.assertIsInstance(runner.param_schedulers["linear2"][0], MultiStepLR) # 2.7.2 test `resume` 2 optimizers and 2 scheduler list. cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint16' + cfg.experiment_name = "test_checkpoint16" cfg.optim_wrapper = optim_cfg cfg.param_scheduler = dict( - linear1=dict(type='MultiStepLR', milestones=[1, 2, 3]), - linear2=dict(type='StepLR', gamma=0.1, step_size=3)) - cfg.model = dict(type='ToyGANModel') + linear1=dict(type="MultiStepLR", milestones=[1, 2, 3]), linear2=dict(type="StepLR", gamma=0.1, step_size=3) + ) + cfg.model = dict(type="ToyGANModel") resumed_cfg = copy.deepcopy(cfg) runner = Runner.from_cfg(cfg) runner.train() - resumed_cfg.experiment_name = 'test_checkpoint17' + resumed_cfg.experiment_name = "test_checkpoint17" runner = Runner.from_cfg(resumed_cfg) runner.resume(path) - self.assertEqual(len(runner.param_schedulers['linear1']), 1) - self.assertEqual(len(runner.param_schedulers['linear2']), 1) - self.assertIsInstance(runner.param_schedulers['linear1'][0], - MultiStepLR) - self.assertIsInstance(runner.param_schedulers['linear2'][0], StepLR) + self.assertEqual(len(runner.param_schedulers["linear1"]), 1) + self.assertEqual(len(runner.param_schedulers["linear2"]), 1) + self.assertIsInstance(runner.param_schedulers["linear1"][0], MultiStepLR) + self.assertIsInstance(runner.param_schedulers["linear2"][0], StepLR) # 2.7.3 test `resume` 2 optimizers and 0 scheduler list. cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_checkpoint18' + cfg.experiment_name = "test_checkpoint18" cfg.optim_wrapper = optim_cfg - cfg.model = dict(type='ToyGANModel') + cfg.model = dict(type="ToyGANModel") cfg.param_scheduler = None resumed_cfg = copy.deepcopy(cfg) runner = Runner.from_cfg(cfg) runner.train() - resumed_cfg.experiment_name = 'test_checkpoint19' + resumed_cfg.experiment_name = "test_checkpoint19" runner = Runner.from_cfg(resumed_cfg) runner.resume(path) self.assertIsNone(runner.param_schedulers) @@ -2581,27 +2496,28 @@ def test_build_runner(self): # `test_build_from_cfg` # test custom runner cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_runner1' - cfg.runner_type = 'CustomRunner' + cfg.experiment_name = "test_build_runner1" + cfg.runner_type = "CustomRunner" assert isinstance(RUNNERS.build(cfg), CustomRunner) # test default runner cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_runner2' + cfg.experiment_name = "test_build_runner2" assert isinstance(RUNNERS.build(cfg), Runner) def test_get_hooks_info(self): # test get_hooks_info() function cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_get_hooks_info_from_test_runner_py' - cfg.runner_type = 'Runner' + cfg.experiment_name = "test_get_hooks_info_from_test_runner_py" + cfg.runner_type = "Runner" runner = RUNNERS.build(cfg) self.assertIsInstance(runner, Runner) - target_str = ('after_train_iter:\n' - '(VERY_HIGH ) RuntimeInfoHook \n' - '(NORMAL ) IterTimerHook \n' - '(BELOW_NORMAL) LoggerHook \n' - '(LOW ) ParamSchedulerHook \n' - '(VERY_LOW ) CheckpointHook \n') - self.assertIn(target_str, runner.get_hooks_info(), - 'target string is not in logged hooks information.') + target_str = ( + "after_train_iter:\n" + "(VERY_HIGH ) RuntimeInfoHook \n" + "(NORMAL ) IterTimerHook \n" + "(BELOW_NORMAL) LoggerHook \n" + "(LOW ) ParamSchedulerHook \n" + "(VERY_LOW ) CheckpointHook \n" + ) + self.assertIn(target_str, runner.get_hooks_info(), "target string is not in logged hooks information.") diff --git a/tests/test_strategies/test_fsdp.py b/tests/test_strategies/test_fsdp.py index 64b900d2f8..e0021403e3 100644 --- a/tests/test_strategies/test_fsdp.py +++ b/tests/test_strategies/test_fsdp.py @@ -7,12 +7,15 @@ import torch import torch.nn as nn + try: - from torch.distributed.fsdp import (FullStateDictConfig, - FullyShardedDataParallel, - LocalStateDictConfig, StateDictType) - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullOptimStateDictConfig, LocalOptimStateDictConfig) + from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel, + LocalStateDictConfig, + StateDictType, + ) + from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, LocalOptimStateDictConfig from mmengine._strategy import FSDPStrategy except: # noqa: E722 @@ -20,8 +23,7 @@ from torch.multiprocessing.spawn import start_processes from torch.optim import SGD -from mmengine.dist import (all_gather_object, broadcast_object_list, - is_main_process) +from mmengine.dist import all_gather_object, broadcast_object_list, is_main_process from mmengine.optim import LinearLR, OptimWrapper from mmengine.testing.runner_test_case import ToyModel from mmengine.utils import digit_version @@ -38,11 +40,10 @@ def linear_wrap_policy( @skipIf( - digit_version(torch.__version__) < digit_version('2.0.0') - or not torch.cuda.is_available(), - 'Only test FSDP with CUDA and PyTorch >= 2.0.0') + digit_version(torch.__version__) < digit_version("2.0.0") or not torch.cuda.is_available(), + "Only test FSDP with CUDA and PyTorch >= 2.0.0", +) class TestStrategy(TestCase): - def setUp(self): self.world_size = 2 self.temp_dir = TemporaryDirectory() @@ -53,15 +54,13 @@ def tearDown(self) -> None: def test_init(self): strategy = FSDPStrategy() self.assertFalse(strategy.skip_init_weights) - strategy = FSDPStrategy(state_dict_cfg='local') + strategy = FSDPStrategy(state_dict_cfg="local") self._assert_local(strategy) - strategy = FSDPStrategy(state_dict_cfg='full') + strategy = FSDPStrategy(state_dict_cfg="full") self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.LOCAL_STATE_DICT)) + strategy = FSDPStrategy(state_dict_cfg=dict(state_dict_type=StateDictType.LOCAL_STATE_DICT)) self._assert_local(strategy) strategy = FSDPStrategy( @@ -69,15 +68,17 @@ def test_init(self): state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig(), optim_state_dict_config=FullOptimStateDictConfig(), - )) + ) + ) self._assert_full(strategy) strategy = FSDPStrategy( state_dict_cfg=dict( - state_dict_type='FULL_STATE_DICT', - state_dict_config=dict(type='FullStateDictConfig'), - optim_state_dict_config=dict(type='FullOptimStateDictConfig'), - )) + state_dict_type="FULL_STATE_DICT", + state_dict_config=dict(type="FullStateDictConfig"), + optim_state_dict_config=dict(type="FullOptimStateDictConfig"), + ) + ) self._assert_full(strategy) strategy = FSDPStrategy( @@ -85,11 +86,12 @@ def test_init(self): state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=dict(type=FullStateDictConfig), optim_state_dict_config=dict(type=FullOptimStateDictConfig), - )) + ) + ) self._assert_full(strategy) with self.assertRaises(ValueError): - strategy = FSDPStrategy(state_dict_cfg='error-str') + strategy = FSDPStrategy(state_dict_cfg="error-str") # state_dict_cfg should be a str or a dict with self.assertRaises(TypeError): @@ -101,9 +103,9 @@ def test_init(self): state_dict_cfg=dict( state_dict_type=[], state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=dict( - type=FullOptimStateDictConfig), - )) + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + ) + ) # state_dict_config should be a dict or a subclass of StateDictConfig with self.assertRaises(TypeError): @@ -111,9 +113,9 @@ def test_init(self): state_dict_cfg=dict( state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=[], - optim_state_dict_config=dict( - type=FullOptimStateDictConfig), - )) + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + ) + ) # optim_state_dict_config should be a dict or a subclass of # OptimStateDictConfig @@ -123,37 +125,38 @@ def test_init(self): state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=dict(type=FullStateDictConfig), optim_state_dict_config=[], - )) + ) + ) def run_strategy(self): # Strategy can run with the built model, optimizer and schedulers. - for skip_init_weights, state_dict_cfg in [(True, 'local'), - (False, 'full')]: + for skip_init_weights, state_dict_cfg in [(True, "local"), (False, "full")]: strategy = FSDPStrategy( skip_init_weights=skip_init_weights, state_dict_cfg=state_dict_cfg, - model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) + model_wrapper=dict(auto_wrap_policy=linear_wrap_policy), + ) model = ToyModel() optim = OptimWrapper(SGD(model.parameters(), lr=0.1, momentum=0.9)) lr_scheduler = LinearLR(optimizer=optim) model, optim, lr_scheduler = strategy.prepare( - model=model, optim_wrapper=optim, param_scheduler=lr_scheduler) + model=model, optim_wrapper=optim, param_scheduler=lr_scheduler + ) self.assertIsInstance(model, FullyShardedDataParallel) self.assertIsInstance(model.linear1, FullyShardedDataParallel) self.assertIsInstance(model.linear2, FullyShardedDataParallel) data = torch.ones(2, 2).cuda() data_samples = torch.zeros(2, 2).cuda() - loss = model(data, data_samples=data_samples, mode='loss')['loss'] + loss = model(data, data_samples=data_samples, mode="loss")["loss"] loss.backward() optim.step() [scheduler.step() for scheduler in lr_scheduler] - ckpt_path = osp.join(self.temp_dir.name, - f'checkpoint_{state_dict_cfg}.pth') + ckpt_path = osp.join(self.temp_dir.name, f"checkpoint_{state_dict_cfg}.pth") strategy.save_checkpoint(ckpt_path) - if state_dict_cfg == 'full': + if state_dict_cfg == "full": if not is_main_process(): self.assertFalse(osp.exists(ckpt_path)) ckpt_path = [ckpt_path] @@ -161,29 +164,27 @@ def run_strategy(self): ckpt_path = ckpt_path[0] strategy.load_checkpoint(ckpt_path) - loss = model(data, data_samples=data_samples, mode='loss')['loss'] + loss = model(data, data_samples=data_samples, mode="loss")["loss"] loss.backward() optim.step() [scheduler.step() for scheduler in lr_scheduler] # optimizer with multiple param_groups can be reconstructed. model = ToyModel() - strategy = FSDPStrategy( - model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) + strategy = FSDPStrategy(model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) param_groups = [] for param in model.parameters(): param_groups.append(dict(params=[param], lr=0.1)) optim = SGD(param_groups, lr=0.1, momentum=0.9) lr_scheduler = LinearLR(optimizer=optim) - model, optim, lr_scheduler = strategy.prepare( - model=model, optim_wrapper=optim, param_scheduler=lr_scheduler) + model, optim, lr_scheduler = strategy.prepare(model=model, optim_wrapper=optim, param_scheduler=lr_scheduler) data = torch.ones(2, 2).cuda() data_samples = torch.zeros(2, 2).cuda() - loss = model(data, data_samples=data_samples, mode='loss')['loss'] + loss = model(data, data_samples=data_samples, mode="loss")["loss"] loss.backward() optim.step() [scheduler.step() for scheduler in lr_scheduler] - optim_state = optim.state_dict()['state'] + optim_state = optim.state_dict()["state"] optim_state = all_gather_object(optim_state) @classmethod @@ -193,21 +194,18 @@ def _worker(cls, rank, func): self.setUp() self.rank = rank - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(self.world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(12123) - torch.cuda.set_device(f'cuda:{rank}') + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(12123) + torch.cuda.set_device(f"cuda:{rank}") getattr(self, func)() self.tearDown() def test_run_strategy(self): - start_processes( - TestStrategy._worker, - args=('run_strategy', ), - nprocs=self.world_size) + start_processes(TestStrategy._worker, args=("run_strategy",), nprocs=self.world_size) def test_build_model(self): ... @@ -217,15 +215,11 @@ def test_build_model(self): # state_dict = dict() def _assert_local(self, strategy): - self.assertEqual(strategy.state_dict_type, - StateDictType.LOCAL_STATE_DICT) + self.assertEqual(strategy.state_dict_type, StateDictType.LOCAL_STATE_DICT) self.assertIsInstance(strategy.state_dict_config, LocalStateDictConfig) - self.assertIsInstance(strategy.optim_state_dict_config, - LocalOptimStateDictConfig) + self.assertIsInstance(strategy.optim_state_dict_config, LocalOptimStateDictConfig) def _assert_full(self, strategy): - self.assertEqual(strategy.state_dict_type, - StateDictType.FULL_STATE_DICT) + self.assertEqual(strategy.state_dict_type, StateDictType.FULL_STATE_DICT) self.assertIsInstance(strategy.state_dict_config, FullStateDictConfig) - self.assertIsInstance(strategy.optim_state_dict_config, - FullOptimStateDictConfig) + self.assertIsInstance(strategy.optim_state_dict_config, FullOptimStateDictConfig) diff --git a/tests/test_structures/test_data_element.py b/tests/test_structures/test_data_element.py index 1cb7cd1745..352a6b69b9 100644 --- a/tests/test_structures/test_data_element.py +++ b/tests/test_structures/test_data_element.py @@ -10,14 +10,13 @@ class DetDataSample(BaseDataElement): - @property def proposals(self): return self._proposals @proposals.setter def proposals(self, value): - self.set_field(value=value, name='_proposals', dtype=BaseDataElement) + self.set_field(value=value, name="_proposals", dtype=BaseDataElement) @proposals.deleter def proposals(self): @@ -29,8 +28,7 @@ def gt_instances(self): @gt_instances.setter def gt_instances(self, value): - self.set_field( - value=value, name='_gt_instances', dtype=BaseDataElement) + self.set_field(value=value, name="_gt_instances", dtype=BaseDataElement) @gt_instances.deleter def gt_instances(self): @@ -42,8 +40,7 @@ def pred_instances(self): @pred_instances.setter def pred_instances(self, value): - self.set_field( - value=value, name='_pred_instances', dtype=BaseDataElement) + self.set_field(value=value, name="_pred_instances", dtype=BaseDataElement) @pred_instances.deleter def pred_instances(self): @@ -51,24 +48,18 @@ def pred_instances(self): class TestBaseDataElement(TestCase): - def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) - gt_instances = BaseDataElement( - bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) - pred_instances = BaseDataElement( - bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + metainfo = dict(img_id=random.randint(0, 100), img_shape=(random.randint(400, 600), random.randint(400, 600))) + gt_instances = BaseDataElement(bboxes=torch.rand((5, 4)), labels=torch.rand((5,))) + pred_instances = BaseDataElement(bboxes=torch.rand((5, 4)), scores=torch.rand((5,))) data = dict(gt_instances=gt_instances, pred_instances=pred_instances) return metainfo, data def is_equal(self, x, y): assert type(x) is type(y) - if isinstance( - x, (int, float, str, list, tuple, dict, set, BaseDataElement)): + if isinstance(x, int | float | str | list | tuple | dict | set | BaseDataElement): return x == y - elif isinstance(x, (torch.Tensor, np.ndarray)): + elif isinstance(x, torch.Tensor | np.ndarray): return (x == y).all() def check_key_value(self, instances, metainfo=None, data=None): @@ -100,7 +91,7 @@ def check_data_device(self, instances, device): def check_data_dtype(self, instances, dtype): for v in instances.values(): - if isinstance(v, (torch.Tensor, np.ndarray)): + if isinstance(v, torch.Tensor | np.ndarray): assert isinstance(v, dtype) if isinstance(v, BaseDataElement): self.check_data_dtype(v, dtype) @@ -121,7 +112,7 @@ def test_init(self): assert instances.get(k, None) is None for k in data: assert k not in instances - assert instances.get(k, 'abc') == 'abc' + assert instances.get(k, "abc") == "abc" # initialization with kwargs metainfo, data = self.setup_data() @@ -147,8 +138,7 @@ def test_new(self): # element and will have new address _, data = self.setup_data() new_instances.set_data(data) - assert not self.is_equal(new_instances.gt_instances, - instances.gt_instances) + assert not self.is_equal(new_instances.gt_instances, instances.gt_instances) self.check_key_value(new_instances, metainfo, data) # test new() with arguments @@ -158,7 +148,7 @@ def test_new(self): assert id(new_instances.gt_instances) != id(instances.gt_instances) _, new_data = self.setup_data() new_instances.set_data(new_data) - assert id(new_instances.gt_instances) != id(data['gt_instances']) + assert id(new_instances.gt_instances) != id(data["gt_instances"]) self.check_key_value(new_instances, metainfo, new_data) metainfo, data = self.setup_data() @@ -196,8 +186,8 @@ def test_set_data(self): metainfo, data = self.setup_data() instances = BaseDataElement() - instances.gt_instances = data['gt_instances'] - instances.pred_instances = data['pred_instances'] + instances.gt_instances = data["gt_instances"] + instances.pred_instances = data["pred_instances"] self.check_key_value(instances, data=data) metainfo, data = self.setup_data() @@ -206,14 +196,14 @@ def test_set_data(self): self.check_key_value(instances, data=data) # a.xx only set data rather than metainfo - instances.img_shape = metainfo['img_shape'] - instances.img_id = metainfo['img_id'] + instances.img_shape = metainfo["img_shape"] + instances.img_id = metainfo["img_id"] self.check_key_value(instances, data=metainfo) metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) with self.assertRaises(AttributeError): - instances.img_shape = metainfo['img_shape'] + instances.img_shape = metainfo["img_shape"] # test set '_metainfo_fields' or '_data_fields' with self.assertRaises(AttributeError): @@ -232,12 +222,10 @@ def test_set_data(self): def test_update(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) - proposals = BaseDataElement( - bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + proposals = BaseDataElement(bboxes=torch.rand((5, 4)), scores=torch.rand((5,))) new_instances = BaseDataElement(proposals=proposals) instances.update(new_instances) - self.check_key_value(instances, metainfo, - data.update(dict(proposals=proposals))) + self.check_key_value(instances, metainfo, data.update(dict(proposals=proposals))) def test_delete_modify(self): random.seed(10) @@ -247,45 +235,42 @@ def test_delete_modify(self): new_metainfo, new_data = self.setup_data() # avoid generating same metainfo, data while True: - if new_metainfo['img_id'] == metainfo['img_id'] or new_metainfo[ - 'img_shape'] == metainfo['img_shape']: + if new_metainfo["img_id"] == metainfo["img_id"] or new_metainfo["img_shape"] == metainfo["img_shape"]: new_metainfo, new_data = self.setup_data() else: break - instances.gt_instances = new_data['gt_instances'] - instances.pred_instances = new_data['pred_instances'] + instances.gt_instances = new_data["gt_instances"] + instances.pred_instances = new_data["pred_instances"] # a.xx only set data rather than metainfo instances.set_metainfo(new_metainfo) self.check_key_value(instances, new_metainfo, new_data) - assert not self.is_equal(instances.gt_instances, data['gt_instances']) - assert not self.is_equal(instances.pred_instances, - data['pred_instances']) - assert not self.is_equal(instances.img_id, metainfo['img_id']) - assert not self.is_equal(instances.img_shape, metainfo['img_shape']) + assert not self.is_equal(instances.gt_instances, data["gt_instances"]) + assert not self.is_equal(instances.pred_instances, data["pred_instances"]) + assert not self.is_equal(instances.img_id, metainfo["img_id"]) + assert not self.is_equal(instances.img_shape, metainfo["img_shape"]) del instances.gt_instances del instances.img_id - assert not self.is_equal( - instances.pop('pred_instances', None), data['pred_instances']) + assert not self.is_equal(instances.pop("pred_instances", None), data["pred_instances"]) with self.assertRaises(AttributeError): del instances.pred_instances - assert 'gt_instances' not in instances - assert 'pred_instances' not in instances - assert 'img_id' not in instances - assert instances.pop('gt_instances', None) is None + assert "gt_instances" not in instances + assert "pred_instances" not in instances + assert "img_id" not in instances + assert instances.pop("gt_instances", None) is None # test pop not exist key without default with self.assertRaises(KeyError): - instances.pop('gt_instances') - assert instances.pop('pred_instances', 'abcdef') == 'abcdef' + instances.pop("gt_instances") + assert instances.pop("pred_instances", "abcdef") == "abcdef" - assert instances.pop('img_id', None) is None + assert instances.pop("img_id", None) is None # test pop not exist key without default with self.assertRaises(KeyError): - instances.pop('img_id') - assert instances.pop('img_shape') == new_metainfo['img_shape'] + instances.pop("img_id") + assert instances.pop("img_shape") == new_metainfo["img_shape"] # test del '_metainfo_fields' or '_data_fields' with self.assertRaises(AttributeError): @@ -293,32 +278,31 @@ def test_delete_modify(self): with self.assertRaises(AttributeError): del instances._data_fields - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is required!") def test_cuda(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) cuda_instances = instances.cuda() - self.check_data_device(cuda_instances, 'cuda:0') + self.check_data_device(cuda_instances, "cuda:0") # here we further test to convert from cuda to cpu cpu_instances = cuda_instances.cpu() - self.check_data_device(cpu_instances, 'cpu') + self.check_data_device(cpu_instances, "cpu") del cuda_instances - cuda_instances = instances.to('cuda:0') - self.check_data_device(cuda_instances, 'cuda:0') + cuda_instances = instances.to("cuda:0") + self.check_data_device(cuda_instances, "cuda:0") def test_cpu(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) - self.check_data_device(instances, 'cpu') + self.check_data_device(instances, "cpu") cpu_instances = instances.cpu() # assert cpu_instances.device == 'cpu' - assert cpu_instances.gt_instances.bboxes.device == torch.device('cpu') - assert cpu_instances.gt_instances.labels.device == torch.device('cpu') + assert cpu_instances.gt_instances.bboxes.device == torch.device("cpu") + assert cpu_instances.gt_instances.labels.device == torch.device("cpu") def test_numpy_tensor(self): metainfo, data = self.setup_data() @@ -338,23 +322,23 @@ def test_detach(self): def test_repr(self): metainfo = dict(img_shape=(800, 1196, 3)) - gt_instances = BaseDataElement( - metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + gt_instances = BaseDataElement(metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) sample = BaseDataElement(metainfo=metainfo, gt_instances=gt_instances) address = hex(id(sample)) address_gt_instances = hex(id(sample.gt_instances)) assert repr(sample) == ( - '\n' - f') at {address}>') + "\n" + f") at {address}>" + ) def test_set_fields(self): metainfo, data = self.setup_data() @@ -371,20 +355,19 @@ def test_set_fields(self): instances.set_field(name=key, value=value, dtype=torch.Tensor) def test_inheritance(self): - det_sample = DetDataSample() # test set proposals = BaseDataElement(bboxes=torch.rand((5, 4))) det_sample.proposals = proposals - assert 'proposals' in det_sample + assert "proposals" in det_sample # test get assert det_sample.proposals == proposals # test delete del det_sample.proposals - assert 'proposals' not in det_sample + assert "proposals" not in det_sample # test the data whether meet the requirements with self.assertRaises(AssertionError): @@ -396,8 +379,7 @@ def test_values(self): instances = BaseDataElement(metainfo=metainfo, **data) assert len(instances.metainfo_values()) == len(metainfo.values()) # test_all_values - assert len(instances.all_values()) == len(metainfo.values()) + len( - data.values()) + assert len(instances.all_values()) == len(metainfo.values()) + len(data.values()) # test_values assert len(instances.values()) == len(data.values()) @@ -409,8 +391,7 @@ def test_keys(self): assert len(instances.metainfo_keys()) == len(metainfo.keys()) # test_all_keys - assert len( - instances.all_keys()) == len(data.keys()) + len(metainfo.keys()) + assert len(instances.all_keys()) == len(data.keys()) + len(metainfo.keys()) # test_keys assert len(instances.keys()) == len(data.keys()) @@ -418,17 +399,15 @@ def test_keys(self): det_sample = DetDataSample() proposals = BaseDataElement(bboxes=torch.rand((5, 4))) det_sample.proposals = proposals - assert '_proposals' not in det_sample.keys() + assert "_proposals" not in det_sample.keys() def test_items(self): # test_metainfo_items metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) - assert len(dict(instances.metainfo_items())) == len( - dict(metainfo.items())) + assert len(dict(instances.metainfo_items())) == len(dict(metainfo.items())) # test_all_items - assert len(dict(instances.all_items())) == len(dict( - metainfo.items())) + len(dict(data.items())) + assert len(dict(instances.all_items())) == len(dict(metainfo.items())) + len(dict(data.items())) # test_items assert len(dict(instances.items())) == len(dict(data.items())) @@ -443,15 +422,15 @@ def test_to_dict(self): assert k in dict_instances assert isinstance(dict_instances, dict) # sub data element should also be converted to dict - assert isinstance(dict_instances['gt_instances'], dict) - assert isinstance(dict_instances['pred_instances'], dict) + assert isinstance(dict_instances["gt_instances"], dict) + assert isinstance(dict_instances["pred_instances"], dict) det_sample = DetDataSample() proposals = BaseDataElement(bboxes=torch.rand((5, 4))) det_sample.proposals = proposals dict_sample = det_sample.to_dict() - assert '_proposals' not in dict_sample - assert 'proposals' in dict_sample + assert "_proposals" not in dict_sample + assert "proposals" in dict_sample def test_metainfo(self): # test metainfo property diff --git a/tests/test_structures/test_instance_data.py b/tests/test_structures/test_instance_data.py index fe4a1b2603..9412c60b4d 100644 --- a/tests/test_structures/test_instance_data.py +++ b/tests/test_structures/test_instance_data.py @@ -11,7 +11,6 @@ class TmpObject: - def __init__(self, tmp) -> None: assert isinstance(tmp, list) if len(tmp) > 0: @@ -25,7 +24,7 @@ def __len__(self): def __getitem__(self, item): if isinstance(item, int): if item >= len(self) or item < -len(self): # type:ignore - raise IndexError(f'Index {item} out of range!') + raise IndexError(f"Index {item} out of range!") else: # keep the dimension item = slice(item, None, len(self)) @@ -46,7 +45,6 @@ def __repr__(self): class TmpObjectWithoutCat: - def __init__(self, tmp) -> None: assert isinstance(tmp, list) if len(tmp) > 0: @@ -60,7 +58,7 @@ def __len__(self): def __getitem__(self, item): if isinstance(item, int): if item >= len(self) or item < -len(self): # type:ignore - raise IndexError(f'Index {item} out of range!') + raise IndexError(f"Index {item} out of range!") else: # keep the dimension item = slice(item, None, len(self)) @@ -71,17 +69,14 @@ def __repr__(self): class TestInstanceData(TestCase): - def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) + metainfo = dict(img_id=random.randint(0, 100), img_shape=(random.randint(400, 600), random.randint(400, 600))) instances_infos = [1] * 5 bboxes = torch.rand((5, 4)) labels = np.random.rand(5) kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]] ids = (1, 2, 3, 4, 5) - name_ids = '12345' + name_ids = "12345" polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist()) instance_data = InstanceData( metainfo=metainfo, @@ -91,7 +86,8 @@ def setup_data(self): kps=kps, ids=ids, name_ids=name_ids, - instances_infos=instances_infos) + instances_infos=instances_infos, + ) return instance_data def test_set_data(self): @@ -108,7 +104,7 @@ def test_set_data(self): instance_data.keypoints = torch.rand((17, 2)) instance_data.keypoints = torch.rand((5, 2)) - assert 'keypoints' in instance_data + assert "keypoints" in instance_data def test_getitem(self): instance_data = InstanceData() @@ -138,7 +134,7 @@ def test_getitem(self): instance_data[item.bool()] # test LongTensor - long_tensor = torch.randint(5, (2, )) + long_tensor = torch.randint(5, (2,)) long_index_instance_data = instance_data[long_tensor] assert len(long_index_instance_data) == len(long_tensor) @@ -170,27 +166,27 @@ def test_getitem(self): assert len(bool_numpy_instance_data) == bool_numpy.sum() # without cat - instance_data.polygons = TmpObjectWithoutCat( - np.arange(25).reshape((5, -1)).tolist()) + instance_data.polygons = TmpObjectWithoutCat(np.arange(25).reshape((5, -1)).tolist()) bool_numpy = np.random.rand(5) > 0.5 with pytest.raises( - ValueError, - match=('The type of `polygons` is ' - f'`{type(instance_data.polygons)}`, ' - 'which has no attribute of `cat`, so it does not ' - f'support slice with `bool`')): + ValueError, + match=( + "The type of `polygons` is " + f"`{type(instance_data.polygons)}`, " + "which has no attribute of `cat`, so it does not " + f"support slice with `bool`" + ), + ): bool_numpy_instance_data = instance_data[bool_numpy] def test_cat(self): instance_data_1 = self.setup_data() instance_data_2 = self.setup_data() - cat_instance_data = InstanceData.cat( - [instance_data_1, instance_data_2]) + cat_instance_data = InstanceData.cat([instance_data_1, instance_data_2]) assert len(cat_instance_data) == 10 # All inputs must be InstanceData - instance_data_2 = BaseDataElement( - bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) + instance_data_2 = BaseDataElement(bboxes=torch.rand((5, 4)), labels=torch.rand((5,))) with self.assertRaises(AssertionError): InstanceData.cat([instance_data_1, instance_data_2]) @@ -199,22 +195,18 @@ def test_cat(self): InstanceData.cat([]) instance_data_2 = instance_data_1.clone() instance_data_2 = instance_data_2[torch.zeros(5) > 0.5] - cat_instance_data = InstanceData.cat( - [instance_data_1, instance_data_2]) + cat_instance_data = InstanceData.cat([instance_data_1, instance_data_2]) cat_instance_data = InstanceData.cat([instance_data_1]) assert len(cat_instance_data) == 5 # test custom data cat - instance_data_1.polygons = TmpObjectWithoutCat( - np.arange(25).reshape((5, -1)).tolist()) + instance_data_1.polygons = TmpObjectWithoutCat(np.arange(25).reshape((5, -1)).tolist()) instance_data_2 = instance_data_1.clone() with pytest.raises( - ValueError, - match=('The type of `polygons` is ' - f'`{type(instance_data_1.polygons)}` ' - 'which has no attribute of `cat`')): - cat_instance_data = InstanceData.cat( - [instance_data_1, instance_data_2]) + ValueError, + match=(f"The type of `polygons` is `{type(instance_data_1.polygons)}` which has no attribute of `cat`"), + ): + cat_instance_data = InstanceData.cat([instance_data_1, instance_data_2]) def test_len(self): instance_data = self.setup_data() diff --git a/tests/test_structures/test_label_data.py b/tests/test_structures/test_label_data.py index 8c73bca767..e0a7ea2ead 100644 --- a/tests/test_structures/test_label_data.py +++ b/tests/test_structures/test_label_data.py @@ -8,36 +8,29 @@ class TestLabelData(TestCase): - def test_label_to_onehot(self): item = torch.tensor([1], dtype=torch.int64) num_classes = 10 onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes) - assert tuple(onehot.shape) == (num_classes, ) + assert tuple(onehot.shape) == (num_classes,) assert onehot.device == item.device # item is not onehot with self.assertRaises(AssertionError): - LabelData.label_to_onehot(label='item', num_classes=num_classes) + LabelData.label_to_onehot(label="item", num_classes=num_classes) # item'max bigger than num_classes with self.assertRaises(AssertionError): - LabelData.label_to_onehot( - torch.tensor([11], dtype=torch.int64), num_classes) - onehot = LabelData.label_to_onehot( - label=torch.tensor([], dtype=torch.int64), num_classes=num_classes) - assert (onehot == torch.zeros((num_classes, ), - dtype=torch.int64)).all() + LabelData.label_to_onehot(torch.tensor([11], dtype=torch.int64), num_classes) + onehot = LabelData.label_to_onehot(label=torch.tensor([], dtype=torch.int64), num_classes=num_classes) + assert (onehot == torch.zeros((num_classes,), dtype=torch.int64)).all() def test_onehot_to_label(self): # item is not onehot - with self.assertRaisesRegex( - ValueError, - 'input is not one-hot and can not convert to label'): - LabelData.onehot_to_label( - onehot=torch.tensor([2], dtype=torch.int64)) + with self.assertRaisesRegex(ValueError, "input is not one-hot and can not convert to label"): + LabelData.onehot_to_label(onehot=torch.tensor([2], dtype=torch.int64)) with self.assertRaises(AssertionError): - LabelData.onehot_to_label(onehot='item') + LabelData.onehot_to_label(onehot="item") item = torch.arange(0, 9) onehot = LabelData.label_to_onehot(item, num_classes=10) @@ -50,8 +43,7 @@ def test_onehot_to_label(self): assert label == item assert label.device == item.device - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is required!") def test_cuda(self): item = torch.arange(0, 9).cuda() onehot = LabelData.label_to_onehot(item, num_classes=10) diff --git a/tests/test_structures/test_pixel_data.py b/tests/test_structures/test_pixel_data.py index 1ca80373af..27e183a886 100644 --- a/tests/test_structures/test_pixel_data.py +++ b/tests/test_structures/test_pixel_data.py @@ -10,11 +10,8 @@ class TestPixelData(TestCase): - def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) + metainfo = dict(img_id=random.randint(0, 100), img_shape=(random.randint(400, 600), random.randint(400, 600))) image = np.random.randint(0, 255, (4, 20, 40)) featmap = torch.randint(0, 255, (10, 20, 40)) pixel_data = PixelData(metainfo=metainfo, image=image, featmap=featmap) @@ -31,7 +28,7 @@ def test_set_data(self): # value only supports (torch.Tensor, np.ndarray) with self.assertRaises(AssertionError): - pixel_data.v = 'value' + pixel_data.v = "value" # The width and height must be the same with self.assertRaises(AssertionError): @@ -42,7 +39,7 @@ def test_set_data(self): pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40)) pixel_data.map2 = torch.randint(0, 255, (3, 20, 40)) - assert 'map2' in pixel_data + assert "map2" in pixel_data pixel_data.map3 = torch.randint(0, 255, (20, 40)) assert tuple(pixel_data.map3.shape) == (1, 20, 40) @@ -60,20 +57,15 @@ def test_getitem(self): # must be tuple item = torch.Tensor([1, 2, 3, 4]) - with pytest.raises( - TypeError, - match=f'Unsupported type {type(item)} for slicing PixelData'): + with pytest.raises(TypeError, match=f"Unsupported type {type(item)} for slicing PixelData"): pixel_data[item] item = 1 - with pytest.raises( - TypeError, - match=f'Unsupported type {type(item)} for slicing PixelData'): + with pytest.raises(TypeError, match=f"Unsupported type {type(item)} for slicing PixelData"): pixel_data[item] item = (5.5, 5) with pytest.raises( - TypeError, - match=('The type of element in input must be int or slice, ' - f'but got {type(item[0])}')): + TypeError, match=(f"The type of element in input must be int or slice, but got {type(item[0])}") + ): pixel_data[item] def test_shape(self): diff --git a/tests/test_testing/test_compare.py b/tests/test_testing/test_compare.py index cd4e79bc57..e1f9833ac4 100644 --- a/tests/test_testing/test_compare.py +++ b/tests/test_testing/test_compare.py @@ -4,6 +4,7 @@ import mmengine.testing as testing + try: import torch except ImportError: @@ -13,96 +14,64 @@ def test_assert_dict_contains_subset(): - dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)} + dict_obj = {"a": "test1", "b": 2, "c": (4, 6)} # case 1 - expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)} + expected_subset = {"a": "test1", "b": 2, "c": (4, 6)} assert testing.assert_dict_contains_subset(dict_obj, expected_subset) # case 2 - expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)} + expected_subset = {"a": "test1", "b": 2, "c": (6, 4)} assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) # case 3 - expected_subset = {'a': 'test1', 'b': 2, 'c': None} + expected_subset = {"a": "test1", "b": 2, "c": None} assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) # case 4 - expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)} + expected_subset = {"a": "test1", "b": 2, "d": (4, 6)} assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) # case 5 - dict_obj = { - 'a': 'test1', - 'b': 2, - 'c': (4, 6), - 'd': np.array([[5, 3, 5], [1, 2, 3]]) - } - expected_subset = { - 'a': 'test1', - 'b': 2, - 'c': (4, 6), - 'd': np.array([[5, 3, 5], [6, 2, 3]]) - } + dict_obj = {"a": "test1", "b": 2, "c": (4, 6), "d": np.array([[5, 3, 5], [1, 2, 3]])} + expected_subset = {"a": "test1", "b": 2, "c": (4, 6), "d": np.array([[5, 3, 5], [6, 2, 3]])} assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) # case 6 - dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} - expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} + dict_obj = {"a": "test1", "b": 2, "c": (4, 6), "d": np.array([[1]])} + expected_subset = {"a": "test1", "b": 2, "c": (4, 6), "d": np.array([[1]])} assert testing.assert_dict_contains_subset(dict_obj, expected_subset) if torch is not None: - dict_obj = { - 'a': 'test1', - 'b': 2, - 'c': (4, 6), - 'd': torch.tensor([5, 3, 5]) - } + dict_obj = {"a": "test1", "b": 2, "c": (4, 6), "d": torch.tensor([5, 3, 5])} # case 7 - expected_subset = {'d': torch.tensor([5, 5, 5])} - assert not testing.assert_dict_contains_subset(dict_obj, - expected_subset) + expected_subset = {"d": torch.tensor([5, 5, 5])} + assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) # case 8 - expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])} - assert not testing.assert_dict_contains_subset(dict_obj, - expected_subset) + expected_subset = {"d": torch.tensor([[5, 3, 5], [4, 1, 2]])} + assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) def test_assert_attrs_equal(): - class TestExample: - a, b, c = 1, ('wvi', 3), [4.5, 3.14] + a, b, c = 1, ("wvi", 3), [4.5, 3.14] def test_func(self): return self.b # case 1 - assert testing.assert_attrs_equal(TestExample, { - 'a': 1, - 'b': ('wvi', 3), - 'c': [4.5, 3.14] - }) + assert testing.assert_attrs_equal(TestExample, {"a": 1, "b": ("wvi", 3), "c": [4.5, 3.14]}) # case 2 - assert not testing.assert_attrs_equal(TestExample, { - 'a': 1, - 'b': ('wvi', 3), - 'c': [4.5, 3.14, 2] - }) + assert not testing.assert_attrs_equal(TestExample, {"a": 1, "b": ("wvi", 3), "c": [4.5, 3.14, 2]}) # case 3 - assert not testing.assert_attrs_equal(TestExample, { - 'bc': 54, - 'c': [4.5, 3.14] - }) + assert not testing.assert_attrs_equal(TestExample, {"bc": 54, "c": [4.5, 3.14]}) # case 4 - assert testing.assert_attrs_equal(TestExample, { - 'b': ('wvi', 3), - 'test_func': TestExample.test_func - }) + assert testing.assert_attrs_equal(TestExample, {"b": ("wvi", 3), "test_func": TestExample.test_func}) if torch is not None: @@ -110,48 +79,38 @@ class TestExample: a, b = torch.tensor([1]), torch.tensor([4, 5]) # case 5 - assert testing.assert_attrs_equal(TestExample, { - 'a': torch.tensor([1]), - 'b': torch.tensor([4, 5]) - }) + assert testing.assert_attrs_equal(TestExample, {"a": torch.tensor([1]), "b": torch.tensor([4, 5])}) # case 6 - assert not testing.assert_attrs_equal(TestExample, { - 'a': torch.tensor([1]), - 'b': torch.tensor([4, 6]) - }) + assert not testing.assert_attrs_equal(TestExample, {"a": torch.tensor([1]), "b": torch.tensor([4, 6])}) -assert_dict_has_keys_data_1 = [({ - 'res_layer': 1, - 'norm_layer': 2, - 'dense_layer': 3 -})] -assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True), - (['res_layer', 'conv_layer'], False)] +assert_dict_has_keys_data_1 = [({"res_layer": 1, "norm_layer": 2, "dense_layer": 3})] +assert_dict_has_keys_data_2 = [(["res_layer", "dense_layer"], True), (["res_layer", "conv_layer"], False)] -@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1) -@pytest.mark.parametrize('expected_keys, ret_value', - assert_dict_has_keys_data_2) +@pytest.mark.parametrize("obj", assert_dict_has_keys_data_1) +@pytest.mark.parametrize("expected_keys, ret_value", assert_dict_has_keys_data_2) def test_assert_dict_has_keys(obj, expected_keys, ret_value): assert testing.assert_dict_has_keys(obj, expected_keys) == ret_value -assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])] -assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True), - (['res_layer', 'dense_layer', 'norm_layer'], True), - (['res_layer', 'norm_layer'], False), - (['res_layer', 'conv_layer', 'norm_layer'], False)] +assert_keys_equal_data_1 = [(["res_layer", "norm_layer", "dense_layer"])] +assert_keys_equal_data_2 = [ + (["res_layer", "norm_layer", "dense_layer"], True), + (["res_layer", "dense_layer", "norm_layer"], True), + (["res_layer", "norm_layer"], False), + (["res_layer", "conv_layer", "norm_layer"], False), +] -@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1) -@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2) +@pytest.mark.parametrize("result_keys", assert_keys_equal_data_1) +@pytest.mark.parametrize("target_keys, ret_value", assert_keys_equal_data_2) def test_assert_keys_equal(result_keys, target_keys, ret_value): assert testing.assert_keys_equal(result_keys, target_keys) == ret_value -@pytest.mark.skipif(torch is None, reason='requires torch library') +@pytest.mark.skipif(torch is None, reason="requires torch library") def test_assert_is_norm_layer(): # case 1 assert not testing.assert_is_norm_layer(nn.Conv3d(3, 64, 3)) @@ -166,7 +125,7 @@ def test_assert_is_norm_layer(): assert not testing.assert_is_norm_layer(nn.Sigmoid()) -@pytest.mark.skipif(torch is None, reason='requires torch library') +@pytest.mark.skipif(torch is None, reason="requires torch library") def test_assert_params_all_zeros(): demo_module = nn.Conv2d(3, 64, 3) nn.init.constant_(demo_module.weight, 0) @@ -186,12 +145,12 @@ def test_assert_params_all_zeros(): def test_check_python_script(capsys): - testing.check_python_script('./tests/data/scripts/hello.py zz') + testing.check_python_script("./tests/data/scripts/hello.py zz") captured = capsys.readouterr().out - assert captured == 'hello zz!\n' - testing.check_python_script('./tests/data/scripts/hello.py agent') + assert captured == "hello zz!\n" + testing.check_python_script("./tests/data/scripts/hello.py agent") captured = capsys.readouterr().out - assert captured == 'hello agent!\n' + assert captured == "hello agent!\n" # Make sure that wrong cmd raises an error with pytest.raises(SystemExit): - testing.check_python_script('./tests/data/scripts/hello.py li zz') + testing.check_python_script("./tests/data/scripts/hello.py li zz") diff --git a/tests/test_testing/test_runner_test_case.py b/tests/test_testing/test_runner_test_case.py index 5d41c03531..2cb948fe24 100644 --- a/tests/test_testing/test_runner_test_case.py +++ b/tests/test_testing/test_runner_test_case.py @@ -9,15 +9,14 @@ class TestRunnerTestCase(RunnerTestCase): - def test_setup(self): self.assertIsInstance(self.epoch_based_cfg, Config) self.assertIsInstance(self.iter_based_cfg, Config) - self.assertIn('MASTER_ADDR', self.dist_cfg) - self.assertIn('MASTER_PORT', self.dist_cfg) - self.assertIn('RANK', self.dist_cfg) - self.assertIn('WORLD_SIZE', self.dist_cfg) - self.assertIn('LOCAL_RANK', self.dist_cfg) + self.assertIn("MASTER_ADDR", self.dist_cfg) + self.assertIn("MASTER_PORT", self.dist_cfg) + self.assertIn("RANK", self.dist_cfg) + self.assertIn("WORLD_SIZE", self.dist_cfg) + self.assertIn("LOCAL_RANK", self.dist_cfg) def test_tearDown(self): self.tearDown() @@ -46,13 +45,11 @@ def test_experiment_name(self): def test_init_dist(self): self.setup_dist_env() - self.assertEqual( - str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT']) - self.assertEqual(self.dist_cfg['MASTER_ADDR'], - os.environ['MASTER_ADDR']) - self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK']) - self.assertEqual(self.dist_cfg['LOCAL_RANK'], os.environ['LOCAL_RANK']) - self.assertEqual(self.dist_cfg['WORLD_SIZE'], os.environ['WORLD_SIZE']) - fisrt_port = os.environ['MASTER_ADDR'] + self.assertEqual(str(self.dist_cfg["MASTER_PORT"]), os.environ["MASTER_PORT"]) + self.assertEqual(self.dist_cfg["MASTER_ADDR"], os.environ["MASTER_ADDR"]) + self.assertEqual(self.dist_cfg["RANK"], os.environ["RANK"]) + self.assertEqual(self.dist_cfg["LOCAL_RANK"], os.environ["LOCAL_RANK"]) + self.assertEqual(self.dist_cfg["WORLD_SIZE"], os.environ["WORLD_SIZE"]) + fisrt_port = os.environ["MASTER_ADDR"] self.setup_dist_env() - self.assertNotEqual(fisrt_port, os.environ['MASTER_PORT']) + self.assertNotEqual(fisrt_port, os.environ["MASTER_PORT"]) diff --git a/tests/test_utils/test_dl_utils/test_get_env.py b/tests/test_utils/test_dl_utils/test_get_env.py index 38d258acbe..0891bf1972 100644 --- a/tests/test_utils/test_dl_utils/test_get_env.py +++ b/tests/test_utils/test_dl_utils/test_get_env.py @@ -7,20 +7,25 @@ class TestCollectEnv(TestCase): - def test_collect_env(self): env_info = collect_env() expected_keys = [ - 'sys.platform', 'Python', 'CUDA available', 'PyTorch', - 'PyTorch compiling details', 'OpenCV', 'MMEngine', 'GCC' + "sys.platform", + "Python", + "CUDA available", + "PyTorch", + "PyTorch compiling details", + "OpenCV", + "MMEngine", + "GCC", ] for key in expected_keys: assert key in env_info - if env_info['CUDA available']: - for key in ['CUDA_HOME', 'NVCC']: + if env_info["CUDA available"]: + for key in ["CUDA_HOME", "NVCC"]: assert key in env_info - assert env_info['sys.platform'] == sys.platform - assert env_info['Python'] == sys.version.replace('\n', '') - assert env_info['MMEngine'] == mmengine.__version__ + assert env_info["sys.platform"] == sys.platform + assert env_info["Python"] == sys.version.replace("\n", "") + assert env_info["MMEngine"] == mmengine.__version__ diff --git a/tests/test_utils/test_dl_utils/test_setup_env.py b/tests/test_utils/test_dl_utils/test_setup_env.py index 9ca98b4311..40b80b6a53 100644 --- a/tests/test_utils/test_dl_utils/test_setup_env.py +++ b/tests/test_utils/test_dl_utils/test_setup_env.py @@ -13,46 +13,45 @@ def test_setup_multi_processes(): sys_start_mehod = mp.get_start_method(allow_none=True) sys_cv_threads = cv2.getNumThreads() # pop and temp save system env vars - sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) - sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) + sys_omp_threads = os.environ.pop("OMP_NUM_THREADS", default=None) + sys_mkl_threads = os.environ.pop("MKL_NUM_THREADS", default=None) # test distributed set_multi_processing(distributed=True) - assert os.getenv('OMP_NUM_THREADS') == '1' - assert os.getenv('MKL_NUM_THREADS') == '1' + assert os.getenv("OMP_NUM_THREADS") == "1" + assert os.getenv("MKL_NUM_THREADS") == "1" # when set to 0, the num threads will be 1 assert cv2.getNumThreads() == 1 - if platform.system() != 'Windows': - assert mp.get_start_method() == 'fork' + if platform.system() != "Windows": + assert mp.get_start_method() == "fork" # test num workers <= 1 - os.environ.pop('OMP_NUM_THREADS') - os.environ.pop('MKL_NUM_THREADS') + os.environ.pop("OMP_NUM_THREADS") + os.environ.pop("MKL_NUM_THREADS") set_multi_processing(distributed=False) - assert 'OMP_NUM_THREADS' not in os.environ - assert 'MKL_NUM_THREADS' not in os.environ + assert "OMP_NUM_THREADS" not in os.environ + assert "MKL_NUM_THREADS" not in os.environ # test manually set env var - os.environ['OMP_NUM_THREADS'] = '4' + os.environ["OMP_NUM_THREADS"] = "4" set_multi_processing(distributed=False) - assert os.getenv('OMP_NUM_THREADS') == '4' + assert os.getenv("OMP_NUM_THREADS") == "4" # test manually set opencv threads and mp start method - config = dict( - mp_start_method='spawn', opencv_num_threads=4, distributed=True) + config = dict(mp_start_method="spawn", opencv_num_threads=4, distributed=True) set_multi_processing(**config) assert cv2.getNumThreads() == 4 - assert mp.get_start_method() == 'spawn' + assert mp.get_start_method() == "spawn" # revert setting to avoid affecting other programs if sys_start_mehod: mp.set_start_method(sys_start_mehod, force=True) cv2.setNumThreads(sys_cv_threads) if sys_omp_threads: - os.environ['OMP_NUM_THREADS'] = sys_omp_threads + os.environ["OMP_NUM_THREADS"] = sys_omp_threads else: - os.environ.pop('OMP_NUM_THREADS') + os.environ.pop("OMP_NUM_THREADS") if sys_mkl_threads: - os.environ['MKL_NUM_THREADS'] = sys_mkl_threads + os.environ["MKL_NUM_THREADS"] = sys_mkl_threads else: - os.environ.pop('MKL_NUM_THREADS') + os.environ.pop("MKL_NUM_THREADS") diff --git a/tests/test_utils/test_dl_utils/test_time_counter.py b/tests/test_utils/test_dl_utils/test_time_counter.py index 9c7d884ab9..a2de660366 100644 --- a/tests/test_utils/test_dl_utils/test_time_counter.py +++ b/tests/test_utils/test_dl_utils/test_time_counter.py @@ -6,9 +6,7 @@ class TestTimeCounter(unittest.TestCase): - def test_decorate_timer(self): - @TimeCounter() def demo_fun(): time.sleep(0.1) @@ -22,7 +20,7 @@ def demo_fun(): for _ in range(10): demo_fun() - @TimeCounter(log_interval=2, with_sync=False, tag='demo_fun1') + @TimeCounter(log_interval=2, with_sync=False, tag="demo_fun1") def demo_fun(): time.sleep(0.1) @@ -36,7 +34,6 @@ def demo_fun(): time.sleep(0.1) def test_context_timer(self): - # tag must be specified in context mode with self.assertRaises(AssertionError): with TimeCounter(): @@ -44,12 +41,12 @@ def test_context_timer(self): # warmup_interval must be greater than 0 with self.assertRaises(AssertionError): - with TimeCounter(warmup_interval=0, tag='func_1'): + with TimeCounter(warmup_interval=0, tag="func_1"): time.sleep(0.1) - with TimeCounter(tag='func_1'): + with TimeCounter(tag="func_1"): time.sleep(0.1) for _ in range(10): - with TimeCounter(log_interval=2, with_sync=False, tag='func_2'): + with TimeCounter(log_interval=2, with_sync=False, tag="func_2"): time.sleep(0.1) diff --git a/tests/test_utils/test_dl_utils/test_torch_ops.py b/tests/test_utils/test_dl_utils/test_torch_ops.py index 03e5decf0d..2b2cc5428e 100644 --- a/tests/test_utils/test_dl_utils/test_torch_ops.py +++ b/tests/test_utils/test_dl_utils/test_torch_ops.py @@ -9,7 +9,7 @@ def test_torch_meshgrid(): # torch_meshgrid should not throw warning with warnings.catch_warnings(): - warnings.simplefilter('error') + warnings.simplefilter("error") x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5, 6]) grid_x, grid_y = torch_meshgrid(x, y) diff --git a/tests/test_utils/test_dl_utils/test_trace.py b/tests/test_utils/test_dl_utils/test_trace.py index a4ff6fea87..4baa3b4f68 100644 --- a/tests/test_utils/test_dl_utils/test_trace.py +++ b/tests/test_utils/test_dl_utils/test_trace.py @@ -7,10 +7,10 @@ @pytest.mark.skipif( - digit_version(torch.__version__) < digit_version('1.6.0'), - reason='torch.jit.is_tracing is not available before 1.6.0') + digit_version(torch.__version__) < digit_version("1.6.0"), + reason="torch.jit.is_tracing is not available before 1.6.0", +) def test_is_jit_tracing(): - def foo(x): if is_jit_tracing(): return x @@ -22,5 +22,5 @@ def foo(x): assert isinstance(foo(x), list) # test with trace - traced_foo = torch.jit.trace(foo, (torch.rand(1), )) + traced_foo = torch.jit.trace(foo, (torch.rand(1),)) assert isinstance(traced_foo(x), torch.Tensor) diff --git a/tests/test_utils/test_manager.py b/tests/test_utils/test_manager.py index 913affb649..df7afe636d 100644 --- a/tests/test_utils/test_manager.py +++ b/tests/test_utils/test_manager.py @@ -5,67 +5,61 @@ class SubClassA(ManagerMixin): - - def __init__(self, name='', *args, **kwargs): + def __init__(self, name="", *args, **kwargs): super().__init__(name, *args, **kwargs) class SubClassB(ManagerMixin): - - def __init__(self, name='', *args, **kwargs): + def __init__(self, name="", *args, **kwargs): super().__init__(name, *args, **kwargs) class TestGlobalMeta: - def test_init(self): # Subclass's constructor does not contain name arguments will raise an # error. with pytest.raises(AssertionError): class SubClassNoName1(metaclass=ManagerMeta): - def __init__(self, a, *args, **kwargs): pass # Valid subclass. class GlobalAccessible1(metaclass=ManagerMeta): - def __init__(self, name): self.name = name class TestManagerMixin: - def test_init(self): # test create instance by name. - base_cls = ManagerMixin('name') - assert base_cls.instance_name == 'name' + base_cls = ManagerMixin("name") + assert base_cls.instance_name == "name" def test_get_instance(self): # SubClass should manage their own `_instance_dict`. with pytest.raises(RuntimeError): SubClassA.get_current_instance() - SubClassA.get_instance('instance_a') - SubClassB.get_instance('instance_b') + SubClassA.get_instance("instance_a") + SubClassB.get_instance("instance_b") assert SubClassB._instance_dict != SubClassA._instance_dict # Test `message_hub` can create by name. - message_hub = SubClassA.get_instance('name1') - assert message_hub.instance_name == 'name1' + message_hub = SubClassA.get_instance("name1") + assert message_hub.instance_name == "name1" # No arguments will raise an assertion error. - SubClassA.get_instance('name2') + SubClassA.get_instance("name2") message_hub = SubClassA.get_current_instance() message_hub.mark = -1 - assert message_hub.instance_name == 'name2' + assert message_hub.instance_name == "name2" # Test get latest `message_hub` repeatedly. - message_hub = SubClassA.get_instance('name3') - assert message_hub.instance_name == 'name3' + message_hub = SubClassA.get_instance("name3") + assert message_hub.instance_name == "name3" message_hub = SubClassA.get_current_instance() - assert message_hub.instance_name == 'name3' + assert message_hub.instance_name == "name3" # Test get name2 repeatedly. - message_hub = SubClassA.get_instance('name2') + message_hub = SubClassA.get_instance("name2") assert message_hub.mark == -1 # Non-string instance name will raise `AssertionError`. with pytest.raises(AssertionError): @@ -73,4 +67,4 @@ def test_get_instance(self): # `get_instance` should not accept other arguments if corresponding # instance has been created. with pytest.warns(UserWarning): - SubClassA.get_instance('name2', a=1, b=2) + SubClassA.get_instance("name2", a=1, b=2) diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 7c43d04853..c5918c56bc 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -7,50 +7,69 @@ from mmengine import MMLogger from mmengine.utils import is_installed + # yapf: disable -from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning, - deprecated_function, get_object_from_string, - has_method, import_modules_from_strings, - is_list_of, is_method_overridden, is_seq_of, - is_tuple_of, iter_cast, list_cast, - requires_executable, requires_package, - slice_list, to_1tuple, to_2tuple, to_3tuple, - to_4tuple, to_ntuple, tuple_cast) +from mmengine.utils.misc import ( + apply_to, + concat_list, + deprecated_api_warning, + deprecated_function, + get_object_from_string, + has_method, + import_modules_from_strings, + is_list_of, + is_method_overridden, + is_seq_of, + is_tuple_of, + iter_cast, + list_cast, + requires_executable, + requires_package, + slice_list, + to_1tuple, + to_2tuple, + to_3tuple, + to_4tuple, + to_ntuple, + tuple_cast, +) + # yapf: enable def test_to_ntuple(): single_number = 2 - assert to_1tuple(single_number) == (single_number, ) + assert to_1tuple(single_number) == (single_number,) assert to_2tuple(single_number) == (single_number, single_number) - assert to_3tuple(single_number) == (single_number, single_number, - single_number) - assert to_4tuple(single_number) == (single_number, single_number, - single_number, single_number) - assert to_ntuple(5)(single_number) == (single_number, single_number, - single_number, single_number, - single_number) - assert to_ntuple(6)(single_number) == (single_number, single_number, - single_number, single_number, - single_number, single_number) + assert to_3tuple(single_number) == (single_number, single_number, single_number) + assert to_4tuple(single_number) == (single_number, single_number, single_number, single_number) + assert to_ntuple(5)(single_number) == (single_number, single_number, single_number, single_number, single_number) + assert to_ntuple(6)(single_number) == ( + single_number, + single_number, + single_number, + single_number, + single_number, + single_number, + ) def test_iter_cast(): assert list_cast([1, 2, 3], int) == [1, 2, 3] - assert list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0] - assert list_cast([1, 2, 3], str) == ['1', '2', '3'] - assert tuple_cast((1, 2, 3), str) == ('1', '2', '3') - assert next(iter_cast([1, 2, 3], str)) == '1' + assert list_cast(["1.1", 2, "3"], float) == [1.1, 2.0, 3.0] + assert list_cast([1, 2, 3], str) == ["1", "2", "3"] + assert tuple_cast((1, 2, 3), str) == ("1", "2", "3") + assert next(iter_cast([1, 2, 3], str)) == "1" with pytest.raises(TypeError): - iter_cast([1, 2, 3], '') + iter_cast([1, 2, 3], "") with pytest.raises(TypeError): iter_cast(1, str) def test_is_seq_of(): assert is_seq_of([1.0, 2.0, 3.0], float) - assert is_seq_of([(1, ), (2, ), (3, )], tuple) + assert is_seq_of([(1,), (2,), (3,)], tuple) assert is_seq_of((1.0, 2.0, 3.0), float) assert is_list_of([1.0, 2.0, 3.0], float) assert not is_seq_of((1.0, 2.0, 3.0), float, seq_type=list) @@ -75,61 +94,53 @@ def test_concat_list(): def test_requires_package(capsys): - - @requires_package('nnn') + @requires_package("nnn") def func_a(): pass - @requires_package(['numpy', 'n1', 'n2']) + @requires_package(["numpy", "n1", "n2"]) def func_b(): pass - @requires_package('numpy') + @requires_package("numpy") def func_c(): return 1 with pytest.raises(RuntimeError): func_a() out, _ = capsys.readouterr() - assert out == ('Prerequisites "nnn" are required in method "func_a" but ' - 'not found, please install them first.\n') + assert out == ('Prerequisites "nnn" are required in method "func_a" but not found, please install them first.\n') with pytest.raises(RuntimeError): func_b() out, _ = capsys.readouterr() - assert out == ( - 'Prerequisites "n1, n2" are required in method "func_b" but not found,' - ' please install them first.\n') + assert out == ('Prerequisites "n1, n2" are required in method "func_b" but not found, please install them first.\n') assert func_c() == 1 def test_requires_executable(capsys): - - @requires_executable('nnn') + @requires_executable("nnn") def func_a(): pass - @requires_executable(['ls', 'n1', 'n2']) + @requires_executable(["ls", "n1", "n2"]) def func_b(): pass - @requires_executable('mv') + @requires_executable("mv") def func_c(): return 1 with pytest.raises(RuntimeError): func_a() out, _ = capsys.readouterr() - assert out == ('Prerequisites "nnn" are required in method "func_a" but ' - 'not found, please install them first.\n') + assert out == ('Prerequisites "nnn" are required in method "func_a" but not found, please install them first.\n') with pytest.raises(RuntimeError): func_b() out, _ = capsys.readouterr() - assert out == ( - 'Prerequisites "n1, n2" are required in method "func_b" but not found,' - ' please install them first.\n') + assert out == ('Prerequisites "n1, n2" are required in method "func_b" but not found, please install them first.\n') assert func_c() == 1 @@ -138,17 +149,18 @@ def test_import_modules_from_strings(): # multiple imports import os.path as osp_ import sys as sys_ - osp, sys = import_modules_from_strings(['os.path', 'sys']) + + osp, sys = import_modules_from_strings(["os.path", "sys"]) assert osp == osp_ assert sys == sys_ # single imports - osp = import_modules_from_strings('os.path') + osp = import_modules_from_strings("os.path") assert osp == osp_ # No imports assert import_modules_from_strings(None) is None assert import_modules_from_strings([]) is None - assert import_modules_from_strings('') is None + assert import_modules_from_strings("") is None # Unsupported types with pytest.raises(TypeError): import_modules_from_strings(1) @@ -156,22 +168,18 @@ def test_import_modules_from_strings(): import_modules_from_strings([1]) # Failed imports with pytest.raises(ImportError): - import_modules_from_strings('_not_implemented_module') + import_modules_from_strings("_not_implemented_module") with pytest.warns(UserWarning): - imported = import_modules_from_strings( - '_not_implemented_module', allow_failed_imports=True) + imported = import_modules_from_strings("_not_implemented_module", allow_failed_imports=True) assert imported is None with pytest.warns(UserWarning): - imported = import_modules_from_strings(['os.path', '_not_implemented'], - allow_failed_imports=True) + imported = import_modules_from_strings(["os.path", "_not_implemented"], allow_failed_imports=True) assert imported[0] == osp assert imported[1] is None def test_is_method_overridden(): - class Base: - def foo1(): pass @@ -179,43 +187,39 @@ def foo2(): pass class Sub(Base): - def foo1(): pass # test passing sub class directly - assert is_method_overridden('foo1', Base, Sub) - assert not is_method_overridden('foo2', Base, Sub) + assert is_method_overridden("foo1", Base, Sub) + assert not is_method_overridden("foo2", Base, Sub) # test passing instance of sub class sub_instance = Sub() - assert is_method_overridden('foo1', Base, sub_instance) - assert not is_method_overridden('foo2', Base, sub_instance) + assert is_method_overridden("foo1", Base, sub_instance) + assert not is_method_overridden("foo2", Base, sub_instance) # base_class should be a class, not instance base_instance = Base() with pytest.raises(AssertionError): - is_method_overridden('foo1', base_instance, sub_instance) + is_method_overridden("foo1", base_instance, sub_instance) def test_has_method(): - class Foo: - def __init__(self, name): self.name = name def print_name(self): print(self.name) - foo = Foo('foo') - assert not has_method(foo, 'name') - assert has_method(foo, 'print_name') + foo = Foo("foo") + assert not has_method(foo, "name") + assert has_method(foo, "print_name") def test_deprecated_api_warning(): - - @deprecated_api_warning(name_dict=dict(old_key='new_key')) + @deprecated_api_warning(name_dict=dict(old_key="new_key")) def dummy_func(new_key=1): return new_key @@ -230,8 +234,7 @@ def dummy_func(new_key=1): def test_deprecated_function(): - - @deprecated_function('0.2.0', '0.3.0', 'toy instruction') + @deprecated_function("0.2.0", "0.3.0", "toy instruction") def deprecated_demo(arg1: int, arg2: int) -> tuple: """This is a long summary. This is a long summary. This is a long summary. This is a long summary. @@ -248,14 +251,13 @@ def deprecated_demo(arg1: int, arg2: int) -> tuple: return arg1, arg2 - MMLogger.get_instance('test_deprecated_function') + MMLogger.get_instance("test_deprecated_function") deprecated_demo(1, 2) # out, _ = capsys.readouterr() # assert "'test_misc.deprecated_demo' is deprecated" in out assert (1, 2) == deprecated_demo(1, 2) - expected_docstring = \ - """.. deprecated:: 0.2.0 + expected_docstring = """.. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. @@ -272,25 +274,24 @@ def deprecated_demo(arg1: int, arg2: int) -> tuple: Long description without a line break. Long description without a line break. """ # noqa: E122 - assert expected_docstring.strip(' ') == deprecated_demo.__doc__ + assert expected_docstring.strip(" ") == deprecated_demo.__doc__ MMLogger._instance_dict.clear() # Test with short summary without args. - @deprecated_function('0.2.0', '0.3.0', 'toy instruction') + @deprecated_function("0.2.0", "0.3.0", "toy instruction") def deprecated_demo1(): """Short summary.""" - expected_docstring = \ - """.. deprecated:: 0.2.0 + expected_docstring = """.. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. Short summary.""" # noqa: E122 - assert expected_docstring.strip(' ') == deprecated_demo1.__doc__ + assert expected_docstring.strip(" ") == deprecated_demo1.__doc__ -@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') +@pytest.mark.skipif(not is_installed("torch"), reason="tests requires torch") def test_apply_to(): import torch @@ -306,38 +307,35 @@ def test_apply_to(): # Tensor to numpy data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2)) - result = apply_to(data, lambda x: isinstance(x, torch.Tensor), - lambda x: x.numpy()) - assert isinstance(result['b'], np.ndarray) - assert isinstance(result['a'][0]['c'], np.ndarray) + result = apply_to(data, lambda x: isinstance(x, torch.Tensor), lambda x: x.numpy()) + assert isinstance(result["b"], np.ndarray) + assert isinstance(result["a"][0]["c"], np.ndarray) # Tuple and convert string - data = (1, dict(a=[dict(b=2.0)]), 'test') + data = (1, dict(a=[dict(b=2.0)]), "test") result = apply_to( - data, lambda x: isinstance(x, int) or x == 'test', - lambda x: torch.Tensor(x) if isinstance(x, int) else 'train') + data, lambda x: isinstance(x, int) or x == "test", lambda x: torch.Tensor(x) if isinstance(x, int) else "train" + ) assert isinstance(result, tuple) assert isinstance(result[0], torch.Tensor) - assert isinstance(result[1]['a'][0]['b'], float) - assert result[2] == 'train' + assert isinstance(result[1]["a"][0]["b"], float) + assert result[2] == "train" # Named Tuple - dataclass = namedtuple('Data', ['a', 'b']) - data = dataclass('test', dict(a=[dict(c=1)], b=2.0)) + dataclass = namedtuple("Data", ["a", "b"]) + data = dataclass("test", dict(a=[dict(c=1)], b=2.0)) result = apply_to( - data, lambda x: isinstance(x, int) or x == 'test', - lambda x: torch.Tensor(x) if isinstance(x, int) else 'train') + data, lambda x: isinstance(x, int) or x == "test", lambda x: torch.Tensor(x) if isinstance(x, int) else "train" + ) assert isinstance(result, dataclass) - assert result[0] == 'train' - assert isinstance(result.b['a'][0]['c'], torch.Tensor) - assert isinstance(result.b['b'], float) + assert result[0] == "train" + assert isinstance(result.b["a"][0]["c"], torch.Tensor) + assert isinstance(result.b["b"], float) def test_locate(): - assert get_object_from_string('a.b.c') is None - config_module = import_module('mmengine.config') - assert get_object_from_string('mmengine.config') is config_module - assert get_object_from_string( - 'mmengine.config.Config') is config_module.Config - assert get_object_from_string('mmengine.config.Config.fromfile') is \ - config_module.Config.fromfile + assert get_object_from_string("a.b.c") is None + config_module = import_module("mmengine.config") + assert get_object_from_string("mmengine.config") is config_module + assert get_object_from_string("mmengine.config.Config") is config_module.Config + assert get_object_from_string("mmengine.config.Config.fromfile") is config_module.Config.fromfile diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index bed91b6c18..8d1772eab2 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -10,28 +10,26 @@ def test_is_installed(): # TODO: Windows CI may failed in unknown reason. Skip check the value - is_installed('mmengine') + is_installed("mmengine") # If there is `__init__.py` in the directory which is added into # `sys.path`, the directory will be recognized as a package. - PYTHONPATH = osp.abspath( - osp.join(osp.dirname(__file__), '..', '..', 'mmengine')) + PYTHONPATH = osp.abspath(osp.join(osp.dirname(__file__), "..", "..", "mmengine")) sys.path.append(PYTHONPATH) - assert is_installed('optim') + assert is_installed("optim") sys.path.pop() def test_get_install_path(): # TODO: Windows CI may failed in unknown reason. Skip check the value - get_installed_path('mmengine') + get_installed_path("mmengine") # get path for package "installed" by setting PYTHONPATH - PYTHONPATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) - PYTHONPATH = osp.abspath( - osp.join(osp.dirname(__file__), '..', '..', 'mmengine')) + PYTHONPATH = osp.abspath(osp.join(osp.dirname(__file__), "..")) + PYTHONPATH = osp.abspath(osp.join(osp.dirname(__file__), "..", "..", "mmengine")) sys.path.append(PYTHONPATH) - assert get_installed_path('optim') == osp.join(PYTHONPATH, 'optim') + assert get_installed_path("optim") == osp.join(PYTHONPATH, "optim") sys.path.pop() with pytest.raises(pkg_resources.DistributionNotFound): - get_installed_path('unknown') + get_installed_path("unknown") diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py index 0636e25e1d..be432abc07 100644 --- a/tests/test_utils/test_progressbar.py +++ b/tests/test_utils/test_progressbar.py @@ -15,35 +15,30 @@ def reset_string_io(io): class TestProgressBar: - def test_start(self): out = StringIO() bar_width = 20 # without total task num prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out) - assert out.getvalue() == 'completed: 0, elapsed: 0s' + assert out.getvalue() == "completed: 0, elapsed: 0s" reset_string_io(out) - prog_bar = mmengine.ProgressBar( - bar_width=bar_width, start=False, file=out) - assert out.getvalue() == '' + prog_bar = mmengine.ProgressBar(bar_width=bar_width, start=False, file=out) + assert out.getvalue() == "" reset_string_io(out) prog_bar.start() - assert out.getvalue() == 'completed: 0, elapsed: 0s' + assert out.getvalue() == "completed: 0, elapsed: 0s" # with total task num reset_string_io(out) prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) - assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' + assert out.getvalue() == f"[{' ' * bar_width}] 0/10, elapsed: 0s, ETA:" reset_string_io(out) - prog_bar = mmengine.ProgressBar( - 10, bar_width=bar_width, start=False, file=out) - assert out.getvalue() == '' + prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, start=False, file=out) + assert out.getvalue() == "" reset_string_io(out) prog_bar.start() - assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' + assert out.getvalue() == f"[{' ' * bar_width}] 0/10, elapsed: 0s, ETA:" - @skipIf( - platform.system() != 'Linux', - reason='Only test `TestProgressBar.test_update` in Linux') + @skipIf(platform.system() != "Linux", reason="Only test `TestProgressBar.test_update` in Linux") def test_update(self): out = StringIO() bar_width = 20 @@ -52,21 +47,18 @@ def test_update(self): time.sleep(1) reset_string_io(out) prog_bar.update() - assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s' + assert out.getvalue() == "completed: 1, elapsed: 1s, 1.0 tasks/s" reset_string_io(out) # with total task num prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() - assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ - 'task/s, elapsed: 1s, ETA: 9s' + assert out.getvalue() == f"\r[{'>' * 2 + ' ' * 18}] 1/10, 1.0 task/s, elapsed: 1s, ETA: 9s" - @skipIf( - platform.system() != 'Linux', - reason='Only test `TestProgressBar.test_adaptive_length` in Linux') + @skipIf(platform.system() != "Linux", reason="Only test `TestProgressBar.test_adaptive_length` in Linux") def test_adaptive_length(self): - with patch.dict('os.environ', {'COLUMNS': '80'}): + with patch.dict("os.environ", {"COLUMNS": "80"}): out = StringIO() bar_width = 20 prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) @@ -75,12 +67,12 @@ def test_adaptive_length(self): prog_bar.update() assert len(out.getvalue()) == 66 - os.environ['COLUMNS'] = '30' + os.environ["COLUMNS"] = "30" reset_string_io(out) prog_bar.update() assert len(out.getvalue()) == 48 - os.environ['COLUMNS'] = '60' + os.environ["COLUMNS"] = "60" reset_string_io(out) prog_bar.update() assert len(out.getvalue()) == 60 @@ -99,22 +91,21 @@ def test_track_progress(): # tasks is a list out = StringIO() ret = mmengine.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) - if platform == 'Linux': + if platform == "Linux": assert out.getvalue() == ( - '[ ] 0/3, elapsed: 0s, ETA:' - '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' - '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' - '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') + "[ ] 0/3, elapsed: 0s, ETA:" + "\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s" + "\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s" + "\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n" + ) assert ret == [1, 2, 3] # tasks is an iterable object - ret = mmengine.track_progress( - return_itself, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) + ret = mmengine.track_progress(return_itself, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) assert ret == [1, 2, 3] # tasks is a range object - ret = mmengine.track_progress( - return_itself, range(1, 4), bar_width=3, file=out) + ret = mmengine.track_progress(return_itself, range(1, 4), bar_width=3, file=out) assert ret == [1, 2, 3] @@ -128,8 +119,7 @@ def test_track_iter_progress(): ret = [] count = [] - for i, num in enumerate( - mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)): + for i, num in enumerate(mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)): ret.append(num) count.append(i) assert ret == [1, 2, 3] @@ -143,19 +133,13 @@ def test_track_iter_progress(): def test_track_parallel_progress(): # tasks is a list out = StringIO() - ret = mmengine.track_parallel_progress( - return_itself, [1, 2, 3, 4], 2, bar_width=4, file=out) + ret = mmengine.track_parallel_progress(return_itself, [1, 2, 3, 4], 2, bar_width=4, file=out) assert ret == [1, 2, 3, 4] # tasks is an iterable object - ret = mmengine.track_parallel_progress( - return_itself, ((i for i in [1, 2, 3, 4]), 4), - 2, - bar_width=4, - file=out) + ret = mmengine.track_parallel_progress(return_itself, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out) assert ret == [1, 2, 3, 4] # tasks is a range object - ret = mmengine.track_parallel_progress( - return_itself, range(1, 5), 2, bar_width=4, file=out) + ret = mmengine.track_parallel_progress(return_itself, range(1, 5), 2, bar_width=4, file=out) assert ret == [1, 2, 3, 4] diff --git a/tests/test_utils/test_progressbar_rich.py b/tests/test_utils/test_progressbar_rich.py index 9c507bf629..02513220dd 100644 --- a/tests/test_utils/test_progressbar_rich.py +++ b/tests/test_utils/test_progressbar_rich.py @@ -35,14 +35,14 @@ def test_progressbar_rich_exception(): track_progress_rich(foo1, nproc=0) -@pytest.mark.parametrize('nproc', [1, 2]) +@pytest.mark.parametrize("nproc", [1, 2]) def test_progressbar_rich(nproc): # empty tasks results = track_progress_rich(foo, nproc=nproc, task_num=10) assert results == [1] * 10 # Ordered results # foo1 - tasks_ = [i for i in range(10)] + tasks_ = list(range(10)) for tasks in (tasks_, iter(tasks_)): results = track_progress_rich(foo1, tasks, nproc=nproc) assert results == tasks_ diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py index 570f7ea380..80fee4157d 100644 --- a/tests/test_utils/test_timer.py +++ b/tests/test_utils/test_timer.py @@ -7,8 +7,7 @@ import mmengine -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != "Linux", reason="Only test `Timer` in linux!") def test_timer_init(): timer = mmengine.Timer(start=False) assert not timer.is_running @@ -18,8 +17,7 @@ def test_timer_init(): assert timer.is_running -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != "Linux", reason="Only test `Timer` in linux!") def test_timer_run(): timer = mmengine.Timer() time.sleep(1) @@ -36,8 +34,7 @@ def test_timer_run(): timer.since_last_check() -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != "Linux", reason="Only test `Timer` in linux!") def test_timer_context(capsys): with mmengine.Timer(): time.sleep(1) @@ -45,7 +42,7 @@ def test_timer_context(capsys): # In Windows, the error could be larger than 20ms. More details in # https://stackoverflow.com/questions/11657734/sleep-for-exact-time-in-python. # noqa: E501 assert abs(float(out) - 1) < 3e-2 - with mmengine.Timer(print_tmpl='time: {:.1f}s'): + with mmengine.Timer(print_tmpl="time: {:.1f}s"): time.sleep(1) out, _ = capsys.readouterr() - assert out == 'time: 1.0s\n' + assert out == "time: 1.0s\n" diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py index c991462ef9..3e004c2a13 100644 --- a/tests/test_visualizer/test_vis_backend.py +++ b/tests/test_visualizer/test_vis_backend.py @@ -14,365 +14,346 @@ from mmengine.fileio import load from mmengine.registry import VISBACKENDS from mmengine.utils import digit_version, is_installed -from mmengine.visualization import (AimVisBackend, ClearMLVisBackend, - DVCLiveVisBackend, LocalVisBackend, - MLflowVisBackend, NeptuneVisBackend, - TensorboardVisBackend, WandbVisBackend) +from mmengine.visualization import ( + AimVisBackend, + ClearMLVisBackend, + DVCLiveVisBackend, + LocalVisBackend, + MLflowVisBackend, + NeptuneVisBackend, + TensorboardVisBackend, + WandbVisBackend, +) class TestLocalVisBackend: - def test_init(self): # 'config_save_file' format must be py with pytest.raises(AssertionError): - LocalVisBackend('temp_dir', config_save_file='a.txt') + LocalVisBackend("temp_dir", config_save_file="a.txt") # 'scalar_save_file' format must be json with pytest.raises(AssertionError): - LocalVisBackend('temp_dir', scalar_save_file='a.yaml') + LocalVisBackend("temp_dir", scalar_save_file="a.yaml") - local_vis_backend = VISBACKENDS.build( - dict(type='LocalVisBackend', save_dir='temp_dir')) + local_vis_backend = VISBACKENDS.build(dict(type="LocalVisBackend", save_dir="temp_dir")) assert isinstance(local_vis_backend, LocalVisBackend) def test_experiment(self): - local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend = LocalVisBackend("temp_dir") assert local_vis_backend.experiment == local_vis_backend def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend = LocalVisBackend("temp_dir") local_vis_backend.add_config(cfg) assert os.path.exists(local_vis_backend._config_save_file) - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)) - local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend = LocalVisBackend("temp_dir") # image must be in np.uint8 format with pytest.raises(AssertionError): - local_vis_backend.add_image('img', image) + local_vis_backend.add_image("img", image) - local_vis_backend.add_image('img', image.astype(np.uint8)) - assert os.path.exists( - os.path.join(local_vis_backend._img_save_dir, 'img_0.png')) - local_vis_backend.add_image('img', image.astype(np.uint8), step=2) - assert os.path.exists( - os.path.join(local_vis_backend._img_save_dir, 'img_2.png')) - shutil.rmtree('temp_dir') + local_vis_backend.add_image("img", image.astype(np.uint8)) + assert os.path.exists(os.path.join(local_vis_backend._img_save_dir, "img_0.png")) + local_vis_backend.add_image("img", image.astype(np.uint8), step=2) + assert os.path.exists(os.path.join(local_vis_backend._img_save_dir, "img_2.png")) + shutil.rmtree("temp_dir") def test_add_scalar(self): - local_vis_backend = LocalVisBackend('temp_dir') - local_vis_backend.add_scalar('map', 0.9) - out_dict = load(local_vis_backend._scalar_save_file, 'json') - assert out_dict == {'map': 0.9, 'step': 0} - shutil.rmtree('temp_dir') + local_vis_backend = LocalVisBackend("temp_dir") + local_vis_backend.add_scalar("map", 0.9) + out_dict = load(local_vis_backend._scalar_save_file, "json") + assert out_dict == {"map": 0.9, "step": 0} + shutil.rmtree("temp_dir") # test append mode - local_vis_backend = LocalVisBackend('temp_dir') - local_vis_backend.add_scalar('map', 1, step=0) - local_vis_backend.add_scalar('map', 0.95, step=1) + local_vis_backend = LocalVisBackend("temp_dir") + local_vis_backend.add_scalar("map", 1, step=0) + local_vis_backend.add_scalar("map", 0.95, step=1) # local_vis_backend.add_scalar('map', torch.IntTensor(1), step=2) - local_vis_backend.add_scalar('map', np.array(0.9), step=2) + local_vis_backend.add_scalar("map", np.array(0.9), step=2) with open(local_vis_backend._scalar_save_file) as f: out_dict = f.read() - assert out_dict == '{"map": 1, "step": 0}\n' \ - '{"map": 0.95, "step": 1}\n' \ - '{"map": 0.9, "step": 2}\n' - shutil.rmtree('temp_dir') + assert out_dict == '{"map": 1, "step": 0}\n{"map": 0.95, "step": 1}\n{"map": 0.9, "step": 2}\n' + shutil.rmtree("temp_dir") - local_vis_backend = LocalVisBackend('temp_dir') - local_vis_backend.add_scalar('map', torch.tensor(1.)) + local_vis_backend = LocalVisBackend("temp_dir") + local_vis_backend.add_scalar("map", torch.tensor(1.0)) assert os.path.exists(local_vis_backend._scalar_save_file) - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_scalars(self): - input_dict = {'map': 0.7, 'acc': 0.9} - local_vis_backend = LocalVisBackend('temp_dir') + input_dict = {"map": 0.7, "acc": 0.9} + local_vis_backend = LocalVisBackend("temp_dir") local_vis_backend.add_scalars(input_dict) - assert input_dict == {'map': 0.7, 'acc': 0.9} - out_dict = load(local_vis_backend._scalar_save_file, 'json') - assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0} + assert input_dict == {"map": 0.7, "acc": 0.9} + out_dict = load(local_vis_backend._scalar_save_file, "json") + assert out_dict == {"map": 0.7, "acc": 0.9, "step": 0} # test append mode - local_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + local_vis_backend.add_scalars({"map": 0.8, "acc": 0.8}, step=1) with open(local_vis_backend._scalar_save_file) as f: out_dict = f.read() - assert out_dict == '{"map": 0.7, "acc": 0.9, ' \ - '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' + assert out_dict == '{"map": 0.7, "acc": 0.9, "step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' # test file_path - local_vis_backend.add_scalars(input_dict, file_path='temp.json') + local_vis_backend.add_scalars(input_dict, file_path="temp.json") assert os.path.exists(local_vis_backend._scalar_save_file) - assert os.path.exists( - os.path.join(local_vis_backend._save_dir, 'temp.json')) + assert os.path.exists(os.path.join(local_vis_backend._save_dir, "temp.json")) # file_path and scalar_save_file cannot be the same with pytest.raises(AssertionError): - local_vis_backend.add_scalars(input_dict, file_path='scalars.json') + local_vis_backend.add_scalars(input_dict, file_path="scalars.json") - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") class TestTensorboardVisBackend: - sys.modules['torch.utils.tensorboard'] = MagicMock() - sys.modules['tensorboardX'] = MagicMock() + sys.modules["torch.utils.tensorboard"] = MagicMock() + sys.modules["tensorboardX"] = MagicMock() def test_init(self): - TensorboardVisBackend('temp_dir') - VISBACKENDS.build( - dict(type='TensorboardVisBackend', save_dir='temp_dir')) + TensorboardVisBackend("temp_dir") + VISBACKENDS.build(dict(type="TensorboardVisBackend", save_dir="temp_dir")) def test_experiment(self): - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') - assert (tensorboard_vis_backend.experiment == - tensorboard_vis_backend._tensorboard) - shutil.rmtree('temp_dir') + tensorboard_vis_backend = TensorboardVisBackend("temp_dir") + assert tensorboard_vis_backend.experiment == tensorboard_vis_backend._tensorboard + shutil.rmtree("temp_dir") def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend = TensorboardVisBackend("temp_dir") tensorboard_vis_backend.add_config(cfg) - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') - tensorboard_vis_backend.add_image('img', image) - tensorboard_vis_backend.add_image('img', image, step=2) - shutil.rmtree('temp_dir') + tensorboard_vis_backend = TensorboardVisBackend("temp_dir") + tensorboard_vis_backend.add_image("img", image) + tensorboard_vis_backend.add_image("img", image, step=2) + shutil.rmtree("temp_dir") def test_add_scalar(self): - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') - tensorboard_vis_backend.add_scalar('map', 0.9) + tensorboard_vis_backend = TensorboardVisBackend("temp_dir") + tensorboard_vis_backend.add_scalar("map", 0.9) # test append mode - tensorboard_vis_backend.add_scalar('map', 0.9, step=0) - tensorboard_vis_backend.add_scalar('map', 0.95, step=1) + tensorboard_vis_backend.add_scalar("map", 0.9, step=0) + tensorboard_vis_backend.add_scalar("map", 0.95, step=1) # test with numpy with warnings.catch_warnings(record=True) as record: - tensorboard_vis_backend.add_scalar('map', np.array(0.9), step=0) - tensorboard_vis_backend.add_scalar('map', np.array(0.95), step=1) - tensorboard_vis_backend.add_scalar('map', np.array(9), step=0) - tensorboard_vis_backend.add_scalar('map', np.array(95), step=1) - tensorboard_vis_backend.add_scalar('map', np.array([9])[0], step=0) - tensorboard_vis_backend.add_scalar( - 'map', np.array([95])[0], step=1) + tensorboard_vis_backend.add_scalar("map", np.array(0.9), step=0) + tensorboard_vis_backend.add_scalar("map", np.array(0.95), step=1) + tensorboard_vis_backend.add_scalar("map", np.array(9), step=0) + tensorboard_vis_backend.add_scalar("map", np.array(95), step=1) + tensorboard_vis_backend.add_scalar("map", np.array([9])[0], step=0) + tensorboard_vis_backend.add_scalar("map", np.array([95])[0], step=1) assert len(record) == 0 # test with tensor - tensorboard_vis_backend.add_scalar('map', torch.tensor(0.9), step=0) - tensorboard_vis_backend.add_scalar('map', torch.tensor(0.95), step=1) + tensorboard_vis_backend.add_scalar("map", torch.tensor(0.9), step=0) + tensorboard_vis_backend.add_scalar("map", torch.tensor(0.95), step=1) # Unprocessable data will output a warning message with pytest.warns(Warning): - tensorboard_vis_backend.add_scalar('map', [0.95]) - shutil.rmtree('temp_dir') + tensorboard_vis_backend.add_scalar("map", [0.95]) + shutil.rmtree("temp_dir") def test_add_scalars(self): - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend = TensorboardVisBackend("temp_dir") # The step value must be passed through the parameter with pytest.raises(AssertionError): - tensorboard_vis_backend.add_scalars({ - 'map': 0.7, - 'acc': 0.9, - 'step': 1 - }) + tensorboard_vis_backend.add_scalars({"map": 0.7, "acc": 0.9, "step": 1}) # Unprocessable data will output a warning message with pytest.warns(Warning): - tensorboard_vis_backend.add_scalars({ - 'map': [1, 2], - }) + tensorboard_vis_backend.add_scalars( + { + "map": [1, 2], + } + ) - input_dict = {'map': 0.7, 'acc': 0.9} + input_dict = {"map": 0.7, "acc": 0.9} tensorboard_vis_backend.add_scalars(input_dict) # test append mode - tensorboard_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - shutil.rmtree('temp_dir') + tensorboard_vis_backend.add_scalars({"map": 0.8, "acc": 0.8}, step=1) + shutil.rmtree("temp_dir") def test_close(self): - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend = TensorboardVisBackend("temp_dir") tensorboard_vis_backend._init_env() tensorboard_vis_backend.close() - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") class TestWandbVisBackend: - sys.modules['wandb'] = MagicMock() - sys.modules['wandb.run'] = MagicMock() + sys.modules["wandb"] = MagicMock() + sys.modules["wandb.run"] = MagicMock() def test_init(self): - WandbVisBackend('temp_dir') - VISBACKENDS.build(dict(type='WandbVisBackend', save_dir='temp_dir')) + WandbVisBackend("temp_dir") + VISBACKENDS.build(dict(type="WandbVisBackend", save_dir="temp_dir")) def test_experiment(self): - wandb_vis_backend = WandbVisBackend('temp_dir') + wandb_vis_backend = WandbVisBackend("temp_dir") assert wandb_vis_backend.experiment == wandb_vis_backend._wandb - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - wandb_vis_backend = WandbVisBackend('temp_dir', log_code_name='code') + wandb_vis_backend = WandbVisBackend("temp_dir", log_code_name="code") _wandb = wandb_vis_backend.experiment - _wandb.run.dir = 'temp_dir' + _wandb.run.dir = "temp_dir" wandb_vis_backend.add_config(cfg) - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - wandb_vis_backend = WandbVisBackend('temp_dir') - wandb_vis_backend.add_image('img', image) - wandb_vis_backend.add_image('img', image) - shutil.rmtree('temp_dir') + wandb_vis_backend = WandbVisBackend("temp_dir") + wandb_vis_backend.add_image("img", image) + wandb_vis_backend.add_image("img", image) + shutil.rmtree("temp_dir") def test_add_scalar(self): - wandb_vis_backend = WandbVisBackend('temp_dir') - wandb_vis_backend.add_scalar('map', 0.9) + wandb_vis_backend = WandbVisBackend("temp_dir") + wandb_vis_backend.add_scalar("map", 0.9) # test append mode - wandb_vis_backend.add_scalar('map', 0.9) - wandb_vis_backend.add_scalar('map', 0.95) - shutil.rmtree('temp_dir') + wandb_vis_backend.add_scalar("map", 0.9) + wandb_vis_backend.add_scalar("map", 0.95) + shutil.rmtree("temp_dir") def test_add_scalars(self): - wandb_vis_backend = WandbVisBackend('temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} + wandb_vis_backend = WandbVisBackend("temp_dir") + input_dict = {"map": 0.7, "acc": 0.9} wandb_vis_backend.add_scalars(input_dict) # test append mode - wandb_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}) - wandb_vis_backend.add_scalars({'map': [0.8], 'acc': 0.8}) - shutil.rmtree('temp_dir') + wandb_vis_backend.add_scalars({"map": 0.8, "acc": 0.8}) + wandb_vis_backend.add_scalars({"map": [0.8], "acc": 0.8}) + shutil.rmtree("temp_dir") def test_close(self): - wandb_vis_backend = WandbVisBackend('temp_dir') + wandb_vis_backend = WandbVisBackend("temp_dir") wandb_vis_backend._init_env() wandb_vis_backend.close() - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_define_metric_cfg(self): # list of dict define_metric_cfg = [ - dict(name='test1', step_metric='iter'), - dict(name='test2', step_metric='epoch'), + dict(name="test1", step_metric="iter"), + dict(name="test2", step_metric="epoch"), ] - wandb_vis_backend = WandbVisBackend( - 'temp_dir', define_metric_cfg=define_metric_cfg) + wandb_vis_backend = WandbVisBackend("temp_dir", define_metric_cfg=define_metric_cfg) wandb_vis_backend._init_env() - wandb_vis_backend._wandb.define_metric.assert_any_call( - name='test1', step_metric='iter') - wandb_vis_backend._wandb.define_metric.assert_any_call( - name='test2', step_metric='epoch') + wandb_vis_backend._wandb.define_metric.assert_any_call(name="test1", step_metric="iter") + wandb_vis_backend._wandb.define_metric.assert_any_call(name="test2", step_metric="epoch") # dict - define_metric_cfg = dict(test3='max') - wandb_vis_backend = WandbVisBackend( - 'temp_dir', define_metric_cfg=define_metric_cfg) + define_metric_cfg = dict(test3="max") + wandb_vis_backend = WandbVisBackend("temp_dir", define_metric_cfg=define_metric_cfg) wandb_vis_backend._init_env() - wandb_vis_backend._wandb.define_metric.assert_any_call( - 'test3', summary='max') + wandb_vis_backend._wandb.define_metric.assert_any_call("test3", summary="max") - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") class TestMLflowVisBackend: - def test_init(self): - MLflowVisBackend('temp_dir') - VISBACKENDS.build(dict(type='MLflowVisBackend', save_dir='temp_dir')) + MLflowVisBackend("temp_dir") + VISBACKENDS.build(dict(type="MLflowVisBackend", save_dir="temp_dir")) def test_experiment(self): - mlflow_vis_backend = MLflowVisBackend('temp_dir') + mlflow_vis_backend = MLflowVisBackend("temp_dir") assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow def test_create_experiment(self): - with patch('mlflow.create_experiment') as mock_create_experiment: - MLflowVisBackend( - 'temp_dir', exp_name='test', - artifact_location='foo')._init_env() - mock_create_experiment.assert_any_call( - 'test', artifact_location='foo') + with patch("mlflow.create_experiment") as mock_create_experiment: + MLflowVisBackend("temp_dir", exp_name="test", artifact_location="foo")._init_env() + mock_create_experiment.assert_any_call("test", artifact_location="foo") def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - mlflow_vis_backend = MLflowVisBackend('temp_dir') + mlflow_vis_backend = MLflowVisBackend("temp_dir") mlflow_vis_backend.add_config(cfg) def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - mlflow_vis_backend = MLflowVisBackend('temp_dir') - mlflow_vis_backend.add_image('img.png', image) + mlflow_vis_backend = MLflowVisBackend("temp_dir") + mlflow_vis_backend.add_image("img.png", image) def test_add_scalar(self): - mlflow_vis_backend = MLflowVisBackend('temp_dir') - mlflow_vis_backend.add_scalar('map', 0.9) + mlflow_vis_backend = MLflowVisBackend("temp_dir") + mlflow_vis_backend.add_scalar("map", 0.9) # test append mode - mlflow_vis_backend.add_scalar('map', 0.9) - mlflow_vis_backend.add_scalar('map', 0.95) + mlflow_vis_backend.add_scalar("map", 0.9) + mlflow_vis_backend.add_scalar("map", 0.95) def test_add_scalars(self): - mlflow_vis_backend = MLflowVisBackend('temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} + mlflow_vis_backend = MLflowVisBackend("temp_dir") + input_dict = {"map": 0.7, "acc": 0.9} mlflow_vis_backend.add_scalars(input_dict) # test append mode - mlflow_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}) + mlflow_vis_backend.add_scalars({"map": 0.8, "acc": 0.8}) def test_close(self): - cfg = Config(dict(work_dir='temp_dir')) - mlflow_vis_backend = MLflowVisBackend('temp_dir') + cfg = Config(dict(work_dir="temp_dir")) + mlflow_vis_backend = MLflowVisBackend("temp_dir") mlflow_vis_backend._init_env() mlflow_vis_backend.add_config(cfg) mlflow_vis_backend.close() - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") -@patch.dict(sys.modules, {'clearml': MagicMock()}) +@patch.dict(sys.modules, {"clearml": MagicMock()}) class TestClearMLVisBackend: - def test_init(self): - ClearMLVisBackend('temp_dir') - VISBACKENDS.build(dict(type='ClearMLVisBackend', save_dir='temp_dir')) + ClearMLVisBackend("temp_dir") + VISBACKENDS.build(dict(type="ClearMLVisBackend", save_dir="temp_dir")) def test_experiment(self): - clearml_vis_backend = ClearMLVisBackend('temp_dir') + clearml_vis_backend = ClearMLVisBackend("temp_dir") assert clearml_vis_backend.experiment == clearml_vis_backend._clearml def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - clearml_vis_backend = ClearMLVisBackend('temp_dir') + clearml_vis_backend = ClearMLVisBackend("temp_dir") clearml_vis_backend.add_config(cfg) def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - clearml_vis_backend = ClearMLVisBackend('temp_dir') - clearml_vis_backend.add_image('img.png', image) + clearml_vis_backend = ClearMLVisBackend("temp_dir") + clearml_vis_backend.add_image("img.png", image) def test_add_scalar(self): - clearml_vis_backend = ClearMLVisBackend('temp_dir') - clearml_vis_backend.add_scalar('map', 0.9) + clearml_vis_backend = ClearMLVisBackend("temp_dir") + clearml_vis_backend.add_scalar("map", 0.9) # test append mode - clearml_vis_backend.add_scalar('map', 0.9) - clearml_vis_backend.add_scalar('map', 0.95) + clearml_vis_backend.add_scalar("map", 0.9) + clearml_vis_backend.add_scalar("map", 0.95) def test_add_scalars(self): - clearml_vis_backend = ClearMLVisBackend('temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} + clearml_vis_backend = ClearMLVisBackend("temp_dir") + input_dict = {"map": 0.7, "acc": 0.9} clearml_vis_backend.add_scalars(input_dict) # test append mode - clearml_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}) + clearml_vis_backend.add_scalars({"map": 0.8, "acc": 0.8}) def test_close(self): - cfg = Config(dict(work_dir='temp_dir')) - clearml_vis_backend = ClearMLVisBackend('temp_dir') + cfg = Config(dict(work_dir="temp_dir")) + clearml_vis_backend = ClearMLVisBackend("temp_dir") clearml_vis_backend._init_env() clearml_vis_backend.add_config(cfg) clearml_vis_backend.close() -@pytest.mark.skipif( - not is_installed('neptune'), reason='Neptune is not installed.') +@pytest.mark.skipif(not is_installed("neptune"), reason="Neptune is not installed.") class TestNeptuneVisBackend: - def test_init(self): NeptuneVisBackend() - VISBACKENDS.build(dict(type='NeptuneVisBackend')) + VISBACKENDS.build(dict(type="NeptuneVisBackend")) def test_experiment(self): neptune_vis_backend = NeptuneVisBackend() @@ -386,18 +367,18 @@ def test_add_config(self): def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) neptune_vis_backend = NeptuneVisBackend() - neptune_vis_backend.add_image('img', image) - neptune_vis_backend.add_image('img', image, step=1) + neptune_vis_backend.add_image("img", image) + neptune_vis_backend.add_image("img", image, step=1) def test_add_scalar(self): neptune_vis_backend = NeptuneVisBackend() - neptune_vis_backend.add_scalar('map', 0.9) - neptune_vis_backend.add_scalar('map', 0.9, step=1) - neptune_vis_backend.add_scalar('map', 0.95, step=2) + neptune_vis_backend.add_scalar("map", 0.9) + neptune_vis_backend.add_scalar("map", 0.9, step=1) + neptune_vis_backend.add_scalar("map", 0.95, step=2) def test_add_scalars(self): neptune_vis_backend = NeptuneVisBackend() - input_dict = {'map': 0.7, 'acc': 0.9} + input_dict = {"map": 0.7, "acc": 0.9} neptune_vis_backend.add_scalars(input_dict) def test_close(self): @@ -407,64 +388,63 @@ def test_close(self): @pytest.mark.skipif( - digit_version(platform.python_version()) < digit_version('3.8'), - reason='DVCLiveVisBackend does not support python version < 3.8') + digit_version(platform.python_version()) < digit_version("3.8"), + reason="DVCLiveVisBackend does not support python version < 3.8", +) +@pytest.mark.skipif(not is_installed("dvclive"), reason="DVCLive is not installed.") class TestDVCLiveVisBackend: - def test_init(self): - DVCLiveVisBackend('temp_dir') - VISBACKENDS.build(dict(type='DVCLiveVisBackend', save_dir='temp_dir')) + DVCLiveVisBackend("temp_dir") + VISBACKENDS.build(dict(type="DVCLiveVisBackend", save_dir="temp_dir")) def test_experiment(self): - dvclive_vis_backend = DVCLiveVisBackend('temp_dir') + dvclive_vis_backend = DVCLiveVisBackend("temp_dir") assert dvclive_vis_backend.experiment == dvclive_vis_backend._dvclive - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - dvclive_vis_backend = DVCLiveVisBackend('temp_dir') + dvclive_vis_backend = DVCLiveVisBackend("temp_dir") dvclive_vis_backend.add_config(cfg) - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") def test_add_image(self): img = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - dvclive_vis_backend = DVCLiveVisBackend('temp_dir') - dvclive_vis_backend.add_image('img', img) - shutil.rmtree('temp_dir') + dvclive_vis_backend = DVCLiveVisBackend("temp_dir") + dvclive_vis_backend.add_image("img", img) + shutil.rmtree("temp_dir") def test_add_scalar(self): - dvclive_vis_backend = DVCLiveVisBackend('temp_dir') - dvclive_vis_backend.add_scalar('mAP', 0.9) + dvclive_vis_backend = DVCLiveVisBackend("temp_dir") + dvclive_vis_backend.add_scalar("mAP", 0.9) # test append mode - dvclive_vis_backend.add_scalar('mAP', 0.9) - dvclive_vis_backend.add_scalar('mAP', 0.95) - shutil.rmtree('temp_dir') + dvclive_vis_backend.add_scalar("mAP", 0.9) + dvclive_vis_backend.add_scalar("mAP", 0.95) + shutil.rmtree("temp_dir") def test_add_scalars(self): - dvclive_vis_backend = DVCLiveVisBackend('temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} + dvclive_vis_backend = DVCLiveVisBackend("temp_dir") + input_dict = {"map": 0.7, "acc": 0.9} dvclive_vis_backend.add_scalars(input_dict) # test append mode - dvclive_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}) - shutil.rmtree('temp_dir') + dvclive_vis_backend.add_scalars({"map": 0.8, "acc": 0.8}) + shutil.rmtree("temp_dir") def test_close(self): - cfg = Config(dict(work_dir='temp_dir')) - dvclive_vis_backend = DVCLiveVisBackend('temp_dir') + cfg = Config(dict(work_dir="temp_dir")) + dvclive_vis_backend = DVCLiveVisBackend("temp_dir") dvclive_vis_backend._init_env() dvclive_vis_backend.add_config(cfg) dvclive_vis_backend.close() - shutil.rmtree('temp_dir') + shutil.rmtree("temp_dir") -@pytest.mark.skipif( - platform.system() == 'Windows', - reason='Aim does not support Windows for now.') +@pytest.mark.skipif(platform.system() == "Windows", reason="Aim does not support Windows for now.") +@pytest.mark.skipif(not is_installed("aim"), reason="Aim is not installed.") class TestAimVisBackend: - def test_init(self): AimVisBackend() - VISBACKENDS.build(dict(type='AimVisBackend')) + VISBACKENDS.build(dict(type="AimVisBackend")) def test_experiment(self): aim_vis_backend = AimVisBackend() @@ -478,18 +458,18 @@ def test_add_config(self): def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) aim_vis_backend = AimVisBackend() - aim_vis_backend.add_image('img', image) - aim_vis_backend.add_image('img', image, step=1) + aim_vis_backend.add_image("img", image) + aim_vis_backend.add_image("img", image, step=1) def test_add_scalar(self): aim_vis_backend = AimVisBackend() - aim_vis_backend.add_scalar('map', 0.9) - aim_vis_backend.add_scalar('map', 0.9, step=1) - aim_vis_backend.add_scalar('map', 0.95, step=2) + aim_vis_backend.add_scalar("map", 0.9) + aim_vis_backend.add_scalar("map", 0.9, step=1) + aim_vis_backend.add_scalar("map", 0.95, step=2) def test_add_scalars(self): aim_vis_backend = AimVisBackend() - input_dict = {'map': 0.7, 'acc': 0.9} + input_dict = {"map": 0.7, "acc": 0.9} aim_vis_backend.add_scalars(input_dict) def test_close(self): diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index e4ababc637..701bff561b 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -16,8 +16,7 @@ @VISBACKENDS.register_module() class MockVisBackend: - - def __init__(self, save_dir: str = 'none'): + def __init__(self, save_dir: str = "none"): self._save_dir = save_dir self._close = False @@ -37,11 +36,7 @@ def add_image(self, name, image, step=0, **kwargs) -> None: def add_scalar(self, name, value, step=0, **kwargs) -> None: self._add_scalar = True - def add_scalars(self, - scalar_dict, - step=0, - file_path=None, - **kwargs) -> None: + def add_scalars(self, scalar_dict, step=0, file_path=None, **kwargs) -> None: self._add_scalars = True def close(self) -> None: @@ -50,57 +45,41 @@ def close(self) -> None: class TestVisualizer(TestCase): - def setUp(self): """Setup the demo image in every test method. TestCase calls functions in this order: setUp() -> testMethod() -> tearDown() -> cleanUp() """ - self.image = np.random.randint( - 0, 256, size=(10, 10, 3)).astype('uint8') - self.vis_backend_cfg = [ - dict(type='MockVisBackend', name='mock1'), - dict(type='MockVisBackend', name='mock2') - ] + self.image = np.random.randint(0, 256, size=(10, 10, 3)).astype("uint8") + self.vis_backend_cfg = [dict(type="MockVisBackend", name="mock1"), dict(type="MockVisBackend", name="mock2")] def test_init(self): visualizer = Visualizer(image=self.image) visualizer.get_image() # build visualizer without `save_dir` - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg)) + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg)) - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') - assert isinstance(visualizer.get_backend('mock1'), MockVisBackend) + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") + assert isinstance(visualizer.get_backend("mock1"), MockVisBackend) assert len(visualizer._vis_backends) == 2 # The name fields cannot be the same with pytest.raises(RuntimeError): - Visualizer( - vis_backends=[ - dict(type='MockVisBackend'), - dict(type='MockVisBackend') - ], - save_dir='temp_dir') + Visualizer(vis_backends=[dict(type="MockVisBackend"), dict(type="MockVisBackend")], save_dir="temp_dir") with pytest.raises(RuntimeError): Visualizer( - vis_backends=[ - dict(type='MockVisBackend', name='mock1'), - dict(type='MockVisBackend', name='mock1') - ], - save_dir='temp_dir') + vis_backends=[dict(type="MockVisBackend", name="mock1"), dict(type="MockVisBackend", name="mock1")], + save_dir="temp_dir", + ) # test global init - instance_name = 'visualizer' + str(time.time()) + instance_name = "visualizer" + str(time.time()) visualizer = Visualizer.get_instance( - instance_name, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + instance_name, vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir" + ) assert len(visualizer._vis_backends) == 2 visualizer_any = Visualizer.get_instance(instance_name) assert visualizer_any == visualizer @@ -108,27 +87,22 @@ def test_init(self): # local backend will not be built without `save_dir` argument @VISBACKENDS.register_module() class CustomLocalVisBackend: - def __init__(self, save_dir: str) -> None: self._save_dir = save_dir with pytest.warns(UserWarning): - visualizer = Visualizer.get_instance( - 'test_save_dir', - vis_backends=[dict(type='CustomLocalVisBackend')]) + visualizer = Visualizer.get_instance("test_save_dir", vis_backends=[dict(type="CustomLocalVisBackend")]) assert not visualizer._vis_backends - VISBACKENDS.module_dict.pop('CustomLocalVisBackend') + VISBACKENDS.module_dict.pop("CustomLocalVisBackend") visualizer = Visualizer.get_instance( - 'test_save_dir', - vis_backends=dict(type='CustomLocalVisBackend', save_dir='tmp')) + "test_save_dir", vis_backends=dict(type="CustomLocalVisBackend", save_dir="tmp") + ) - visualizer = Visualizer.get_instance( - 'test_save_dir', vis_backends=[CustomLocalVisBackend('tmp')]) + visualizer = Visualizer.get_instance("test_save_dir", vis_backends=[CustomLocalVisBackend("tmp")]) - visualizer = Visualizer.get_instance( - 'test_save_dir', vis_backends=CustomLocalVisBackend('tmp')) + visualizer = Visualizer.get_instance("test_save_dir", vis_backends=CustomLocalVisBackend("tmp")) def test_set_image(self): visualizer = Visualizer() @@ -148,8 +122,7 @@ def test_draw_bboxes(self): # valid bbox visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2])) bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]]) - visualizer.draw_bboxes( - bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-') + visualizer.draw_bboxes(bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles="-") bboxes = bboxes.numpy() visualizer.draw_bboxes(bboxes) @@ -160,9 +133,8 @@ def test_draw_bboxes(self): # test out of bounds with pytest.warns( - UserWarning, - match='Warning: The bbox is out of bounds,' - ' the drawn bbox may not be in the image'): + UserWarning, match="Warning: The bbox is out of bounds, the drawn bbox may not be in the image" + ): visualizer.draw_bboxes(torch.tensor([1, 1, 20, 2])) # test incorrect bbox format @@ -170,15 +142,12 @@ def test_draw_bboxes(self): visualizer.draw_bboxes([1, 1, 2, 2]) def test_close(self): - visualizer = Visualizer( - image=self.image, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") - for name in ['mock1', 'mock2']: + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._close is False visualizer.close() - for name in ['mock1', 'mock2']: + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._close is True def test_draw_points(self): @@ -189,94 +158,72 @@ def test_draw_points(self): with pytest.raises(AssertionError): visualizer.draw_points(positions=np.array([1, 2, 3], dtype=object)) # test color + visualizer.draw_points(positions=torch.tensor([[1, 1], [3, 3]]), colors=["g", (255, 255, 0)]) visualizer.draw_points( - positions=torch.tensor([[1, 1], [3, 3]]), - colors=['g', (255, 255, 0)]) - visualizer.draw_points( - positions=torch.tensor([[1, 1], [3, 3]]), - colors=['g', (255, 255, 0)], - marker='.', - sizes=[1, 5]) + positions=torch.tensor([[1, 1], [3, 3]]), colors=["g", (255, 255, 0)], marker=".", sizes=[1, 5] + ) def test_draw_texts(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy + visualizer.draw_texts("text1", positions=torch.tensor([5, 5]), colors=(0, 255, 0)) visualizer.draw_texts( - 'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0)) - visualizer.draw_texts(['text1', 'text2'], - positions=torch.tensor([[5, 5], [3, 3]]), - colors=[(255, 0, 0), (255, 0, 0)]) - visualizer.draw_texts('text1', positions=np.array([5, 5])) - visualizer.draw_texts(['text1', 'text2'], - positions=np.array([[5, 5], [3, 3]])) - visualizer.draw_texts( - 'text1', - positions=torch.tensor([5, 5]), - bboxes=dict(facecolor='r', alpha=0.6)) + ["text1", "text2"], positions=torch.tensor([[5, 5], [3, 3]]), colors=[(255, 0, 0), (255, 0, 0)] + ) + visualizer.draw_texts("text1", positions=np.array([5, 5])) + visualizer.draw_texts(["text1", "text2"], positions=np.array([[5, 5], [3, 3]])) + visualizer.draw_texts("text1", positions=torch.tensor([5, 5]), bboxes=dict(facecolor="r", alpha=0.6)) # test out of bounds with pytest.warns( - UserWarning, - match='Warning: The text is out of bounds,' - ' the drawn text may not be in the image'): - visualizer.draw_texts('text1', positions=torch.tensor([15, 5])) + UserWarning, match="Warning: The text is out of bounds, the drawn text may not be in the image" + ): + visualizer.draw_texts("text1", positions=torch.tensor([15, 5])) # test incorrect format with pytest.raises(TypeError): - visualizer.draw_texts('text', positions=[5, 5]) + visualizer.draw_texts("text", positions=[5, 5]) # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_texts(['text1', 'text2'], - positions=torch.tensor([5, 5])) + visualizer.draw_texts(["text1", "text2"], positions=torch.tensor([5, 5])) with pytest.raises(AssertionError): - visualizer.draw_texts( - 'text1', positions=torch.tensor([[5, 5], [3, 3]])) + visualizer.draw_texts("text1", positions=torch.tensor([[5, 5], [3, 3]])) with pytest.raises(AssertionError): - visualizer.draw_texts(['text1', 'test2'], - positions=torch.tensor([[5, 5], [3, 3]]), - colors=['r']) + visualizer.draw_texts(["text1", "test2"], positions=torch.tensor([[5, 5], [3, 3]]), colors=["r"]) with pytest.raises(AssertionError): - visualizer.draw_texts(['text1', 'test2'], - positions=torch.tensor([[5, 5], [3, 3]]), - vertical_alignments=['top']) + visualizer.draw_texts( + ["text1", "test2"], positions=torch.tensor([[5, 5], [3, 3]]), vertical_alignments=["top"] + ) with pytest.raises(AssertionError): - visualizer.draw_texts(['text1', 'test2'], - positions=torch.tensor([[5, 5], [3, 3]]), - horizontal_alignments=['left']) + visualizer.draw_texts( + ["text1", "test2"], positions=torch.tensor([[5, 5], [3, 3]]), horizontal_alignments=["left"] + ) with pytest.raises(AssertionError): - visualizer.draw_texts(['text1', 'test2'], - positions=torch.tensor([[5, 5], [3, 3]]), - font_sizes=[1]) + visualizer.draw_texts(["text1", "test2"], positions=torch.tensor([[5, 5], [3, 3]]), font_sizes=[1]) # test type valid with pytest.raises(TypeError): - visualizer.draw_texts(['text1', 'test2'], - positions=torch.tensor([[5, 5], [3, 3]]), - font_sizes='b') + visualizer.draw_texts(["text1", "test2"], positions=torch.tensor([[5, 5], [3, 3]]), font_sizes="b") def test_draw_lines(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_lines( - x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6])) - visualizer.draw_lines( - x_datas=np.array([[1, 5], [2, 4]]), - y_datas=np.array([[2, 6], [4, 7]])) + visualizer.draw_lines(x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6])) + visualizer.draw_lines(x_datas=np.array([[1, 5], [2, 4]]), y_datas=np.array([[2, 6], [4, 7]])) visualizer.draw_lines( x_datas=np.array([[1, 5], [2, 4]]), y_datas=np.array([[2, 6], [4, 7]]), - colors='r', - line_styles=['-', '-.'], - line_widths=[1, 2]) + colors="r", + line_styles=["-", "-."], + line_widths=[1, 2], + ) # test out of bounds with pytest.warns( - UserWarning, - match='Warning: The line is out of bounds,' - ' the drawn line may not be in the image'): - visualizer.draw_lines( - x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6])) + UserWarning, match="Warning: The line is out of bounds, the drawn line may not be in the image" + ): + visualizer.draw_lines(x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6])) # test incorrect format with pytest.raises(TypeError): @@ -286,9 +233,7 @@ def test_draw_lines(self): # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_lines( - x_datas=torch.tensor([1, 5]), - y_datas=torch.tensor([[2, 6], [4, 7]])) + visualizer.draw_lines(x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([[2, 6], [4, 7]])) def test_draw_circles(self): visualizer = Visualizer(image=self.image) @@ -296,33 +241,31 @@ def test_draw_circles(self): # only support tensor and numpy visualizer.draw_circles(torch.tensor([1, 5]), torch.tensor([1])) visualizer.draw_circles(np.array([1, 5]), np.array([1])) - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2])) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2])) # test face_colors visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]), face_colors=(255, 0, 0), - edge_colors=(255, 0, 0)) + edge_colors=(255, 0, 0), + ) # test config visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]), - edge_colors=['g', 'r'], - line_styles=['-', '-.'], - line_widths=[1, 2]) + edge_colors=["g", "r"], + line_styles=["-", "-."], + line_widths=[1, 2], + ) # test out of bounds with pytest.warns( - UserWarning, - match='Warning: The circle is out of bounds,' - ' the drawn circle may not be in the image'): - visualizer.draw_circles( - torch.tensor([12, 5]), radius=torch.tensor([1])) - visualizer.draw_circles( - torch.tensor([1, 5]), radius=torch.tensor([10])) + UserWarning, match="Warning: The circle is out of bounds, the drawn circle may not be in the image" + ): + visualizer.draw_circles(torch.tensor([12, 5]), radius=torch.tensor([1])) + visualizer.draw_circles(torch.tensor([1, 5]), radius=torch.tensor([10])) # test incorrect format with pytest.raises(TypeError): @@ -332,39 +275,30 @@ def test_draw_circles(self): # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_circles( - torch.tensor([[1, 5]]), radius=torch.tensor([1, 2])) + visualizer.draw_circles(torch.tensor([[1, 5]]), radius=torch.tensor([1, 2])) def test_draw_polygons(self): visualizer = Visualizer(image=self.image) # shape Nx2 or list[Nx2] visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])) visualizer.draw_polygons(np.array([[1, 1], [2, 2], [3, 4]])) - visualizer.draw_polygons([ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ]) + visualizer.draw_polygons([np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]])]) visualizer.draw_polygons( - polygons=[ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ], + polygons=[np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]])], face_colors=(255, 0, 0), - edge_colors=(255, 0, 0)) + edge_colors=(255, 0, 0), + ) visualizer.draw_polygons( - polygons=[ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ], - edge_colors=['r', 'g'], - line_styles='-', - line_widths=[2, 1]) + polygons=[np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]])], + edge_colors=["r", "g"], + line_styles="-", + line_widths=[2, 1], + ) # test out of bounds with pytest.warns( - UserWarning, - match='Warning: The polygon is out of bounds,' - ' the drawn polygon may not be in the image'): + UserWarning, match="Warning: The polygon is out of bounds, the drawn polygon may not be in the image" + ): visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [16, 4]])) def test_draw_binary_masks(self): @@ -375,7 +309,7 @@ def test_draw_binary_masks(self): # multi binary binary_mask = np.random.randint(0, 2, size=(2, 10, 10)).astype(bool) visualizer = Visualizer(image=self.image) - visualizer.draw_binary_masks(binary_mask, colors=['r', (0, 255, 0)]) + visualizer.draw_binary_masks(binary_mask, colors=["r", (0, 255, 0)]) # test the error that the size of mask and image are different. with pytest.raises(AssertionError): binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(bool) @@ -388,26 +322,21 @@ def test_draw_binary_masks(self): # test color dim with pytest.raises(AssertionError): - visualizer.draw_binary_masks( - binary_mask, colors=np.array([1, 22, 4, 45])) + visualizer.draw_binary_masks(binary_mask, colors=np.array([1, 22, 4, 45])) binary_mask = np.random.randint(0, 2, size=(10, 10)) with pytest.raises(AssertionError): visualizer.draw_binary_masks(binary_mask) def test_draw_featmap(self): visualizer = Visualizer() - image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8') + image = np.random.randint(0, 256, size=(3, 3, 3), dtype="uint8") # must be Tensor - with pytest.raises( - AssertionError, - match='`featmap` should be torch.Tensor, but got ' - ""): + with pytest.raises(AssertionError, match="`featmap` should be torch.Tensor, but got "): visualizer.draw_featmap(np.ones((3, 3, 3))) # test tensor format - with pytest.raises( - AssertionError, match='Input dimension must be 3, but got 4'): + with pytest.raises(AssertionError, match="Input dimension must be 3, but got 4"): visualizer.draw_featmap(torch.randn(1, 1, 3, 3)) # test overlaid_image shape @@ -415,124 +344,89 @@ def test_draw_featmap(self): visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image) # test resize_shape - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), resize_shape=(6, 7)) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), resize_shape=(6, 7)) assert featmap.shape[:2] == (6, 7) - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7)) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7)) assert featmap.shape[:2] == (6, 7) # test channel_reduction parameter # mode only supports 'squeeze_mean' and 'select_max' with pytest.raises(AssertionError): - visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='xx') + visualizer.draw_featmap(torch.randn(2, 3, 3), channel_reduction="xx") - featmap = visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='squeeze_mean') + featmap = visualizer.draw_featmap(torch.randn(2, 3, 3), channel_reduction="squeeze_mean") assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='select_max') + featmap = visualizer.draw_featmap(torch.randn(2, 3, 3), channel_reduction="select_max") assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(2, 4, 3), - overlaid_image=image, - channel_reduction='select_max') + featmap = visualizer.draw_featmap(torch.randn(2, 4, 3), overlaid_image=image, channel_reduction="select_max") assert featmap.shape[:2] == (3, 3) # test topk parameter with pytest.raises( - AssertionError, - match='The input tensor channel dimension must be 1 or 3 ' - 'when topk is less than 1, but the channel ' - 'dimension you input is 6, you can use the ' - 'channel_reduction parameter or set topk ' - 'greater than 0 to solve the error'): - visualizer.draw_featmap( - torch.randn(6, 3, 3), channel_reduction=None, topk=0) - - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), channel_reduction='select_max', topk=10) + AssertionError, + match="The input tensor channel dimension must be 1 or 3 " + "when topk is less than 1, but the channel " + "dimension you input is 6, you can use the " + "channel_reduction parameter or set topk " + "greater than 0 to solve the error", + ): + visualizer.draw_featmap(torch.randn(6, 3, 3), channel_reduction=None, topk=0) + + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), channel_reduction="select_max", topk=10) assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), channel_reduction=None, topk=-1) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), channel_reduction=None, topk=-1) assert featmap.shape[:2] == (4, 3) - featmap = visualizer.draw_featmap( - torch.randn(3, 4, 3), - overlaid_image=image, - channel_reduction=None, - topk=-1) + featmap = visualizer.draw_featmap(torch.randn(3, 4, 3), overlaid_image=image, channel_reduction=None, topk=-1) assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(2, 2)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), channel_reduction=None, topk=4, arrangement=(2, 2)) assert featmap.shape[:2] == (6, 6) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(1, 4)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), channel_reduction=None, topk=4, arrangement=(1, 4)) assert featmap.shape[:2] == (3, 12) with pytest.raises( - AssertionError, - match='The product of row and col in the `arrangement` ' - 'is less than topk, please set ' - 'the `arrangement` correctly'): - visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(1, 2)) + AssertionError, + match="The product of row and col in the `arrangement` " + "is less than topk, please set " + "the `arrangement` correctly", + ): + visualizer.draw_featmap(torch.randn(6, 3, 3), channel_reduction=None, topk=4, arrangement=(1, 2)) # test gray featmap = visualizer.draw_featmap( torch.randn(6, 3, 3), - overlaid_image=np.random.randint( - 0, 256, size=(3, 3), dtype='uint8'), + overlaid_image=np.random.randint(0, 256, size=(3, 3), dtype="uint8"), channel_reduction=None, topk=4, - arrangement=(2, 2)) + arrangement=(2, 2), + ) assert featmap.shape[:2] == (6, 6) def test_chain_call(self): visualizer = Visualizer(image=self.image) binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(bool) - visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2])). \ - draw_texts('test', torch.tensor([5, 5])). \ - draw_lines(x_datas=torch.tensor([1, 5]), - y_datas=torch.tensor([2, 6])). \ - draw_circles(torch.tensor([1, 5]), radius=torch.tensor([2])). \ - draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \ - draw_binary_masks(binary_mask) + visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2])).draw_texts("test", torch.tensor([5, 5])).draw_lines( + x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6]) + ).draw_circles(torch.tensor([1, 5]), radius=torch.tensor([2])).draw_polygons( + torch.tensor([[1, 1], [2, 2], [3, 4]]) + ).draw_binary_masks(binary_mask) def test_get_backend(self): - visualizer = Visualizer( - image=self.image, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') - for name in ['mock1', 'mock2']: + visualizer = Visualizer(image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") + for name in ["mock1", "mock2"]: assert isinstance(visualizer.get_backend(name), MockVisBackend) def test_add_config(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) visualizer.add_config(cfg) - for name in ['mock1', 'mock2']: + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._add_config is True def test_add_graph(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") class Model(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 2, 1) @@ -541,107 +435,73 @@ def forward(self, x, y=None): return self.conv(x) visualizer.add_graph(Model(), np.zeros([1, 1, 3, 3])) - for name in ['mock1', 'mock2']: + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._add_graph is True def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") - visualizer.add_image('img', image) - for name in ['mock1', 'mock2']: + visualizer.add_image("img", image) + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._add_image is True def test_add_scalar(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') - visualizer.add_scalar('map', 0.9, step=0) - for name in ['mock1', 'mock2']: + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") + visualizer.add_scalar("map", 0.9, step=0) + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._add_scalar is True def test_add_scalars(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} + visualizer = Visualizer(vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir="temp_dir") + input_dict = {"map": 0.7, "acc": 0.9} visualizer.add_scalars(input_dict) - for name in ['mock1', 'mock2']: + for name in ["mock1", "mock2"]: assert visualizer.get_backend(name)._add_scalars is True def test_get_instance(self): - class DetLocalVisualizer(Visualizer): - def __init__(self, name): super().__init__(name) - visualizer1 = DetLocalVisualizer.get_instance('name1') + visualizer1 = DetLocalVisualizer.get_instance("name1") visualizer2 = Visualizer.get_current_instance() visualizer3 = DetLocalVisualizer.get_current_instance() assert id(visualizer1) == id(visualizer2) == id(visualizer3) def test_data_info(self): visualizer = Visualizer() - visualizer.dataset_meta = {'class': 'cat'} - assert visualizer.dataset_meta['class'] == 'cat' + visualizer.dataset_meta = {"class": "cat"} + assert visualizer.dataset_meta["class"] == "cat" def test_show(self): cv2 = MagicMock() wait_continue = MagicMock() - visualizer = Visualizer('test_show') + visualizer = Visualizer("test_show") img = np.ones([1, 1, 1]) - with patch('mmengine.visualization.visualizer.cv2', cv2), \ - patch('mmengine.visualization.visualizer.wait_continue', - wait_continue): + with ( + patch("mmengine.visualization.visualizer.cv2", cv2), + patch("mmengine.visualization.visualizer.wait_continue", wait_continue), + ): # test default backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='matplotlib') - assert hasattr(visualizer, 'manager') - calls = [ - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' ') - ] + visualizer.show(drawn_img=img, win_name="test_show", wait_time=0, backend="matplotlib") + assert hasattr(visualizer, "manager") + calls = [call(visualizer.manager.canvas.figure, timeout=0, continue_key=" ")] wait_continue.assert_has_calls(calls) # matplotlib backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='matplotlib') - assert hasattr(visualizer, 'manager') + visualizer.show(drawn_img=img, win_name="test_show", wait_time=0, backend="matplotlib") + assert hasattr(visualizer, "manager") calls = [ - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' '), - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' ') + call(visualizer.manager.canvas.figure, timeout=0, continue_key=" "), + call(visualizer.manager.canvas.figure, timeout=0, continue_key=" "), ] wait_continue.assert_has_calls(calls) # cv2 backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='cv2') + visualizer.show(drawn_img=img, win_name="test_show", wait_time=0, backend="cv2") cv2.imshow.assert_called_once_with(str(id(visualizer)), img) # unknown backend with pytest.raises(ValueError): - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='unknown') + visualizer.show(drawn_img=img, win_name="test_show", wait_time=0, backend="unknown")