diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index a713da9a70..b555df9e94 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -322,7 +322,8 @@ def compile_model( Returns: nn.Module: Compiled model. """ - if isinstance(compile, bool) and not compile: + if isinstance(compile, bool) and not compile or \ + isinstance(compile, dict) and not compile.get('disable', False): return model assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..124dfd7c57 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + self.model, self.optim_wrapper.optimizer, state_dict) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 3cfe0b14a8..276e6fe218 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -65,7 +65,6 @@ def is_init(self, value): def init_weights(self): """Initialize the weights.""" - is_top_level_module = False # check if it is top-level module if not hasattr(self, '_params_init_info'): diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..60200924b5 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -17,7 +18,8 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + from torch.amp import GradScaler as amp_GradScaler + GradScaler = partial(amp_GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 7b4090ba7a..af98043b7f 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -8,7 +8,9 @@ from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision +from mmengine.logging.logger import print_log from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS +from .default_constructor import DefaultOptimWrapperConstructor from .optimizer_wrapper import OptimWrapper @@ -170,7 +172,10 @@ def register_transformers_optimizers(): except ImportError: pass else: - OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + try: + OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + except KeyError as e: + pass transformer_optimizers.append('Adafactor') return transformer_optimizers @@ -196,8 +201,9 @@ def build_optim_wrapper(model: nn.Module, OptimWrapper: The built optimizer wrapper. """ optim_wrapper_cfg = copy.deepcopy(cfg) - constructor_type = optim_wrapper_cfg.pop('constructor', - 'DefaultOptimWrapperConstructor') + constructor_cfg = optim_wrapper_cfg.pop('constructor', None) + if constructor_cfg is None: + constructor_cfg = dict(type=DefaultOptimWrapperConstructor) paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports @@ -205,11 +211,12 @@ def build_optim_wrapper(model: nn.Module, # to make the training normal if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' + + constructor_cfg.update(dict( + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg + )) - optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - dict( - type=constructor_type, - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg)) + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(constructor_cfg) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..fa0a1eb520 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -344,7 +344,7 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load(filename, map_location=map_location, weights_only=False) return checkpoint diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..065bf9243c 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -377,7 +377,6 @@ def run(self) -> dict: self.val_loss.clear() for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) - # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7d1f655aad..9bbbcaedce 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import logging import os import os.path as osp @@ -902,8 +903,18 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore + model_wrapper_type = model_wrapper_cfg.get('type') + if isinstance(model_wrapper_type, str): + model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + elif inspect.isclass(model_wrapper_type): + pass + else: + raise KeyError( + f'{model_wrapper_type} is not in the ' + 'registry. Please check whether the value of ' + f'`{model_wrapper_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 + ) default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index b752ec85a7..a5bf7d88e7 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -604,7 +604,8 @@ def add_scalar(self, (int, float, torch.Tensor, np.ndarray, np.number)): self._tensorboard.add_scalar(name, value, step) else: - warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, ' + warnings.warn(f'Got type {type(value)} with name {name}, ' + 'but numpy array, torch tensor, ' f'int or float are expected. skip it!') @force_init_env