From 4f62c98d93d1b0899218f200d74246c2b98945ad Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Sun, 21 Jul 2024 17:44:47 +0800 Subject: [PATCH 01/18] Fix torch FutureWarning FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..dcdb552943 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -17,7 +18,9 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + # from torch.cuda.amp import GradScaler + from torch.amp import GradScaler as amp_GradScaler + GradScaler = partial(amp_GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() From b6b42241e77c19745fc0bb0d20afceff9debcc54 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Sun, 21 Jul 2024 17:45:28 +0800 Subject: [PATCH 02/18] Fix torch FutureWarning FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index dcdb552943..60200924b5 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -18,7 +18,6 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - # from torch.cuda.amp import GradScaler from torch.amp import GradScaler as amp_GradScaler GradScaler = partial(amp_GradScaler, device='cuda') From 4c7a5d499ff232eaf97f2d4f15fc37088c407bfa Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Fri, 26 Jul 2024 10:08:17 +0800 Subject: [PATCH 03/18] Optimize the prompt for compile --- mmengine/_strategy/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index 5df3a79c92..05d070f6f2 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -322,7 +322,8 @@ def compile_model( Returns: nn.Module: Compiled model. """ - if isinstance(compile, bool) and not compile: + if isinstance(compile, bool) and not compile or \ + isinstance(compile, dict) and not compile.get('disable', False): return model assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( From 28d47f849d3ee2fa6041500baf058df1325e047e Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Wed, 21 Aug 2024 14:15:53 +0800 Subject: [PATCH 04/18] Fix Incorrect Optim Param Resume Method FSDP.optim_state_dict_to_load requires the following parameters: model: Module, optim: Optimizer, optim_state_dict: Dict[str, Any] --- mmengine/_strategy/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..124dfd7c57 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + self.model, self.optim_wrapper.optimizer, state_dict) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: From 91d945f29205c16c3dd8e9655f80daa73bd6ab63 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Wed, 28 Aug 2024 23:59:38 +0800 Subject: [PATCH 05/18] Update runner.py to support pure-python style model wrapper configurations The current runner implementation has not yet supported for pure-python style configurations on model wrapper class. I follow the mainstream implementation to support this feature. --- mmengine/runner/runner.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 68716ab253..7160ac84d7 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import logging import os import os.path as osp @@ -902,8 +903,18 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore + model_wrapper_type = model_wrapper_cfg.get('type') + if isinstance(model_wrapper_type, str): + model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + elif inspect.isclass(model_wrapper_type): + pass + else: + raise KeyError( + f'{model_wrapper_type} is not in the ' + 'registry. Please check whether the value of ' + f'`{model_wrapper_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 + ) default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore From 7103c3e629a189336cac3308add2b080319025d6 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Mon, 23 Sep 2024 03:00:04 +0000 Subject: [PATCH 06/18] reconstruct --- mmengine/_strategy/fsdp.py | 2 +- mmengine/model/wrappers/distributed.py | 3 +++ mmengine/optim/optimizer/builder.py | 1 + mmengine/runner/loops.py | 4 +++- mmengine/runner/runner.py | 4 +++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 124dfd7c57..0788fafdab 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - self.model, self.optim_wrapper.optimizer, state_dict) + state_dict, self.model, self.optim_wrapper.optimizer) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4113aebf9e..b88bc7c2b0 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -95,6 +95,7 @@ def __init__(self, def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + return self.module.train_step(data, optim_wrapper) """Interface for model forward, backward and parameters updating during training process. @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.test_step(data) """Gets the predictions of module during testing process. Args: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 8557f4d34c..b57ebc315a 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -207,5 +207,6 @@ def build_optim_wrapper(model: nn.Module, type=constructor_type, optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..25ff690f0b 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -12,6 +12,7 @@ from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of +from mmengine.dataset.sampler import InfiniteSampler from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -274,13 +275,14 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0: + if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler): print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', logger='current', level=logging.WARNING) for _ in range(self._iter): + break next(self.dataloader_iterator) while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7160ac84d7..435bd55ac0 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -903,9 +903,11 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') + model_wrapper_type = model_wrapper_cfg.get('type') if isinstance(model_wrapper_type, str): - model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_type) # type: ignore elif inspect.isclass(model_wrapper_type): pass else: From eecaa92179bb275f650931f7597b0ead0420f6b5 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sun, 3 Nov 2024 05:18:48 +0000 Subject: [PATCH 07/18] PyTorch Profiler within IterBasedTrainLoop --- mmengine/runner/loops.py | 19 +++++++++++++++++-- mmengine/runner/runner.py | 2 +- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 25ff690f0b..7be8995781 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -282,8 +282,21 @@ def run(self) -> None: logger='current', level=logging.WARNING) for _ in range(self._iter): - break + break # NOTE MGAM: override all preprocessing steps during resume. next(self.dataloader_iterator) + + # with torch.profiler.profile( + # activities=[torch.profiler.ProfilerActivity.CPU, + # torch.profiler.ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3), + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'), + # record_shapes=True, + # profile_memory=True, + # with_stack=False, + # with_flops=True, + # with_modules=True, + # ) as p: + while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() @@ -294,8 +307,10 @@ def run(self) -> None: if (self.runner.val_loop is not None and self._iter >= self.val_begin and (self._iter % self.val_interval == 0 - or self._iter == self._max_iters)): + or self._iter == self._max_iters)): self.runner.val_loop.run() + + # p.step() self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 435bd55ac0..6b8dd60e2b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1851,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from None + raise TypeError(f'{e} in {hook}') from e def register_hook( self, From 698ad5ebaed47965fb3c999da2ee82228a4b0600 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sun, 3 Nov 2024 13:53:29 +0800 Subject: [PATCH 08/18] enable hook error exception traceback --- mmengine/runner/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 435bd55ac0..6b8dd60e2b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1851,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from None + raise TypeError(f'{e} in {hook}') from e def register_hook( self, From 1e4c2ed17e6bb01af74ccd45923e844a2764bc32 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 15 Nov 2024 01:18:32 +0000 Subject: [PATCH 09/18] improve codes --- mmengine/runner/checkpoint.py | 2 +- mmengine/runner/loops.py | 4 ++-- mmengine/visualization/vis_backend.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..fa0a1eb520 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -344,7 +344,7 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load(filename, map_location=map_location, weights_only=False) return checkpoint diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 7be8995781..f511c14e68 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -288,11 +288,11 @@ def run(self) -> None: # with torch.profiler.profile( # activities=[torch.profiler.ProfilerActivity.CPU, # torch.profiler.ProfilerActivity.CUDA], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3), + # schedule=torch.profiler.schedule(wait=1, warmup=2, active=3), # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'), # record_shapes=True, # profile_memory=True, - # with_stack=False, + # with_stack=True, # with_flops=True, # with_modules=True, # ) as p: diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index b752ec85a7..a5bf7d88e7 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -604,7 +604,8 @@ def add_scalar(self, (int, float, torch.Tensor, np.ndarray, np.number)): self._tensorboard.add_scalar(name, value, step) else: - warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, ' + warnings.warn(f'Got type {type(value)} with name {name}, ' + 'but numpy array, torch tensor, ' f'int or float are expected. skip it!') @force_init_env From 29e3a0882cace17b6d5391224a8c4abcf35419c7 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 3 Jan 2025 08:56:10 +0000 Subject: [PATCH 10/18] KeyError: 'Adafactor is already registered in optimizer at torch.optim'. This may be due to the version confliction. Newer PyTorch may have introduced this optimizer. --- mmengine/optim/optimizer/builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index b57ebc315a..e778a3d5bc 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -166,7 +166,8 @@ def register_transformers_optimizers(): except ImportError: pass else: - OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + # KeyError: 'Adafactor is already registered in optimizer at torch.optim' + # OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) transformer_optimizers.append('Adafactor') return transformer_optimizers From be86710ecb96d5682b7d7e7f1e9e72bccc1bd6a2 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sat, 11 Jan 2025 17:57:03 +0800 Subject: [PATCH 11/18] Update support for deep speed and multiple improvements. --- mmengine/_strategy/deepspeed.py | 10 +- mmengine/config/config.py | 218 ++++++++++++++++--------------- mmengine/logging/message_hub.py | 6 +- mmengine/model/averaged_model.py | 1 + 4 files changed, 123 insertions(+), 112 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3f89ff760d..1fff461bf3 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -63,9 +63,11 @@ def register_deepspeed_optimizers() -> List[str]: @OPTIM_WRAPPERS.register_module() class DeepSpeedOptimWrapper(BaseOptimWrapper): - def __init__(self, optimizer): + def __init__(self, optimizer, accumulative_counts): super().__init__(optimizer) self._model = None + self._inner_count = 0 + self._accumulative_counts = accumulative_counts @property def model(self): @@ -80,11 +82,13 @@ def model(self, value): def update_params(self, loss) -> None: # type: ignore """Update parameters in :attr:`optimizer`.""" self.backward(loss) - self.step() + if self.should_update(): + self.step() def backward(self, loss: torch.Tensor, **kwargs) -> None: """"Perform gradient back propagation.""" self.model.backward(loss) + self._inner_count += 1 def zero_grad(self, **kwargs) -> None: raise NotImplementedError( @@ -107,6 +111,8 @@ def load_state_dict(self, state_dict: dict) -> None: if base_param_settings is not None: self.base_param_settings = base_param_settings + def should_update(self) -> bool: + return (self._inner_count % self._accumulative_counts == 0) @MODEL_WRAPPERS.register_module() class MMDeepSpeedEngineWrapper: diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 36f92f0b3a..5ca06954ed 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -1375,120 +1375,122 @@ def env_variables(self) -> dict: @property def pretty_text(self) -> str: """Get formatted python config text.""" + try: + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = repr(v) + else: + v_str = str(v) - indent = 4 - - def _indent(s_, num_spaces): - s = s_.split('\n') - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s - return s - - def _format_basic_types(k, v, use_mapping=False): - if isinstance(v, str): - v_str = repr(v) - else: - v_str = str(v) - - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) - return attr_str + return attr_str - def _format_list_tuple(k, v, use_mapping=False): - if isinstance(v, list): - left = '[' - right = ']' - else: - left = '(' - right = ')' - - v_str = f'{left}\n' - # check if all items in the list are dict - for item in v: - if isinstance(item, dict): - v_str += f'dict({_indent(_format_dict(item), indent)}),\n' - elif isinstance(item, tuple): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, list): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, str): - v_str += f'{_indent(repr(item), indent)},\n' + def _format_list_tuple(k, v, use_mapping=False): + if isinstance(v, list): + left = '[' + right = ']' else: - v_str += str(item) + ',\n' - if k is None: - return _indent(v_str, indent) + right - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + right - return attr_str - - def _contain_invalid_identifier(dict_str): - contain_invalid_identifier = False - for key_name in dict_str: - contain_invalid_identifier |= \ - (not str(key_name).isidentifier()) - return contain_invalid_identifier - - def _format_dict(input_dict, outest_level=False): - r = '' - s = [] - - use_mapping = _contain_invalid_identifier(input_dict) - if use_mapping: - r += '{' - for idx, (k, v) in enumerate( - sorted(input_dict.items(), key=lambda x: str(x[0]))): - is_last = idx >= len(input_dict) - 1 - end = '' if outest_level or is_last else ',' - if isinstance(v, dict): - v_str = '\n' + _format_dict(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: dict({v_str}' + left = '(' + right = ')' + + v_str = f'{left}\n' + # check if all items in the list are dict + for item in v: + if isinstance(item, dict): + v_str += f'dict({_indent(_format_dict(item), indent)}),\n' + elif isinstance(item, tuple): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, list): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, str): + v_str += f'{_indent(repr(item), indent)},\n' else: - attr_str = f'{str(k)}=dict({v_str}' - attr_str = _indent(attr_str, indent) + ')' + end - elif isinstance(v, (list, tuple)): - attr_str = _format_list_tuple(k, v, use_mapping) + end + v_str += str(item) + ',\n' + if k is None: + return _indent(v_str, indent) + right + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' else: - attr_str = _format_basic_types(k, v, use_mapping) + end - - s.append(attr_str) - r += '\n'.join(s) - if use_mapping: - r += '}' - return r - - cfg_dict = self.to_dict() - text = _format_dict(cfg_dict, outest_level=True) - if self._format_python_code: - # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) - try: - if digit_version(yapf.__version__) >= digit_version('0.40.2'): - text, _ = FormatCode(text, style_config=yapf_style) - else: - text, _ = FormatCode( - text, style_config=yapf_style, verify=True) - except: # noqa: E722 - raise SyntaxError('Failed to format the config file, please ' - f'check the syntax of: \n{text}') - return text + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + right + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate( + sorted(input_dict.items(), key=lambda x: str(x[0]))): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, (list, tuple)): + attr_str = _format_list_tuple(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + if self._format_python_code: + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + try: + if digit_version(yapf.__version__) >= digit_version('0.40.2'): + text, _ = FormatCode(text, style_config=yapf_style) + else: + text, _ = FormatCode( + text, style_config=yapf_style, verify=True) + except: # noqa: E722 + raise SyntaxError('Failed to format the config file, please ' + f'check the syntax of: \n{text}') + return text + except Exception as e: + return f'Error occurs when formatting config: {e}' def __repr__(self): return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 82565d8832..e4edc3466e 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -342,8 +342,10 @@ def _get_valid_value( else: # check whether value is torch.Tensor but don't want # to import torch in this file - assert hasattr(value, 'numel') and value.numel() == 1 - value = value.item() + if hasattr(value, 'numel') and value.numel() == 1: + value = value.item() + else: + print_log(f"MessageHub got unexpceted log: {value}", level=logging.WARN) return value # type: ignore def state_dict(self) -> dict: diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 58457c2a6e..cc83a5976d 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -103,6 +103,7 @@ def update_parameters(self, model: nn.Module) -> None: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) elif self.steps % self.interval == 0: + print(self.avg_parameters) for k, p_avg in self.avg_parameters.items(): if p_avg.dtype.is_floating_point: device = p_avg.device From 861fc1b91b111ceecf9c05895de4156f3725a6f3 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sun, 12 Jan 2025 14:40:11 +0800 Subject: [PATCH 12/18] improve multiple mmengine undeveloped issues. --- mmengine/model/base_module.py | 1 - mmengine/optim/optimizer/builder.py | 18 ++++++++++-------- mmengine/runner/loops.py | 1 - 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 3cfe0b14a8..276e6fe218 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -65,7 +65,6 @@ def is_init(self, value): def init_weights(self): """Initialize the weights.""" - is_top_level_module = False # check if it is top-level module if not hasattr(self, '_params_init_info'): diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 09467a192f..ebba603dbf 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -9,6 +9,7 @@ from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS +from .default_constructor import DefaultOptimWrapperConstructor from .optimizer_wrapper import OptimWrapper @@ -197,8 +198,9 @@ def build_optim_wrapper(model: nn.Module, OptimWrapper: The built optimizer wrapper. """ optim_wrapper_cfg = copy.deepcopy(cfg) - constructor_type = optim_wrapper_cfg.pop('constructor', - 'DefaultOptimWrapperConstructor') + constructor_cfg = optim_wrapper_cfg.pop('constructor', None) + if constructor_cfg is None: + constructor_cfg = dict(type=DefaultOptimWrapperConstructor) paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports @@ -206,12 +208,12 @@ def build_optim_wrapper(model: nn.Module, # to make the training normal if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' - - optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - dict( - type=constructor_type, - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg)) + constructor_cfg.update(dict( + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg + )) + + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(constructor_cfg) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index f511c14e68..ba9ec9d9dd 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -394,7 +394,6 @@ def run(self) -> dict: self.val_loss.clear() for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) - # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) From 8f37dd2d16f8ee4ea46cb8a9603f0f42eb96280d Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 17 Jan 2025 09:56:42 +0000 Subject: [PATCH 13/18] Multiple improvements --- mmengine/model/averaged_model.py | 1 - mmengine/runner/loops.py | 4 ++-- mmengine/visualization/vis_backend.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index cc83a5976d..58457c2a6e 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -103,7 +103,6 @@ def update_parameters(self, model: nn.Module) -> None: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) elif self.steps % self.interval == 0: - print(self.avg_parameters) for k, p_avg in self.avg_parameters.items(): if p_avg.dtype.is_floating_point: device = p_avg.device diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index ba9ec9d9dd..4411edb412 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -418,9 +418,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]): # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) - + outputs, self.val_loss = _update_losses(outputs, self.val_loss) - + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index a5bf7d88e7..fcb7d23b05 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -578,7 +578,7 @@ def add_image(self, step: int = 0, **kwargs) -> None: """Record the image to tensorboard. - + Args: name (str): The image identifier. image (np.ndarray): The image to be saved. The format From 46cfdbb96caa88090cadbd6c0732b0d11710cf2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E8=B4=BB=E9=92=A6?= <15337184+MGAM@user.noreply.gitee.com> Date: Fri, 17 Jan 2025 23:11:04 +0800 Subject: [PATCH 14/18] Revert "Multiple improvements" This reverts commit 8f37dd2d16f8ee4ea46cb8a9603f0f42eb96280d. --- mmengine/model/averaged_model.py | 1 + mmengine/runner/loops.py | 4 ++-- mmengine/visualization/vis_backend.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 58457c2a6e..cc83a5976d 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -103,6 +103,7 @@ def update_parameters(self, model: nn.Module) -> None: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) elif self.steps % self.interval == 0: + print(self.avg_parameters) for k, p_avg in self.avg_parameters.items(): if p_avg.dtype.is_floating_point: device = p_avg.device diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 4411edb412..ba9ec9d9dd 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -418,9 +418,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]): # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) - + outputs, self.val_loss = _update_losses(outputs, self.val_loss) - + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index fcb7d23b05..a5bf7d88e7 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -578,7 +578,7 @@ def add_image(self, step: int = 0, **kwargs) -> None: """Record the image to tensorboard. - + Args: name (str): The image identifier. image (np.ndarray): The image to be saved. The format From 5376661f5715786ca1fae138a349e97bb311e1fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E8=B4=BB=E9=92=A6?= <15337184+MGAM@user.noreply.gitee.com> Date: Fri, 17 Jan 2025 23:13:47 +0800 Subject: [PATCH 15/18] Revert "Update support for deep speed and multiple improvements." This reverts commit be86710ecb96d5682b7d7e7f1e9e72bccc1bd6a2. --- mmengine/_strategy/deepspeed.py | 10 +- mmengine/config/config.py | 218 +++++++++++++++---------------- mmengine/logging/message_hub.py | 6 +- mmengine/model/averaged_model.py | 1 - 4 files changed, 112 insertions(+), 123 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 1fff461bf3..3f89ff760d 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -63,11 +63,9 @@ def register_deepspeed_optimizers() -> List[str]: @OPTIM_WRAPPERS.register_module() class DeepSpeedOptimWrapper(BaseOptimWrapper): - def __init__(self, optimizer, accumulative_counts): + def __init__(self, optimizer): super().__init__(optimizer) self._model = None - self._inner_count = 0 - self._accumulative_counts = accumulative_counts @property def model(self): @@ -82,13 +80,11 @@ def model(self, value): def update_params(self, loss) -> None: # type: ignore """Update parameters in :attr:`optimizer`.""" self.backward(loss) - if self.should_update(): - self.step() + self.step() def backward(self, loss: torch.Tensor, **kwargs) -> None: """"Perform gradient back propagation.""" self.model.backward(loss) - self._inner_count += 1 def zero_grad(self, **kwargs) -> None: raise NotImplementedError( @@ -111,8 +107,6 @@ def load_state_dict(self, state_dict: dict) -> None: if base_param_settings is not None: self.base_param_settings = base_param_settings - def should_update(self) -> bool: - return (self._inner_count % self._accumulative_counts == 0) @MODEL_WRAPPERS.register_module() class MMDeepSpeedEngineWrapper: diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 5ca06954ed..36f92f0b3a 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -1375,122 +1375,120 @@ def env_variables(self) -> dict: @property def pretty_text(self) -> str: """Get formatted python config text.""" - try: - indent = 4 - - def _indent(s_, num_spaces): - s = s_.split('\n') - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s - return s - - def _format_basic_types(k, v, use_mapping=False): - if isinstance(v, str): - v_str = repr(v) - else: - v_str = str(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = repr(v) + else: + v_str = str(v) - return attr_str + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) - def _format_list_tuple(k, v, use_mapping=False): - if isinstance(v, list): - left = '[' - right = ']' + return attr_str + + def _format_list_tuple(k, v, use_mapping=False): + if isinstance(v, list): + left = '[' + right = ']' + else: + left = '(' + right = ')' + + v_str = f'{left}\n' + # check if all items in the list are dict + for item in v: + if isinstance(item, dict): + v_str += f'dict({_indent(_format_dict(item), indent)}),\n' + elif isinstance(item, tuple): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, list): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, str): + v_str += f'{_indent(repr(item), indent)},\n' else: - left = '(' - right = ')' - - v_str = f'{left}\n' - # check if all items in the list are dict - for item in v: - if isinstance(item, dict): - v_str += f'dict({_indent(_format_dict(item), indent)}),\n' - elif isinstance(item, tuple): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, list): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, str): - v_str += f'{_indent(repr(item), indent)},\n' + v_str += str(item) + ',\n' + if k is None: + return _indent(v_str, indent) + right + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + right + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate( + sorted(input_dict.items(), key=lambda x: str(x[0]))): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' else: - v_str += str(item) + ',\n' - if k is None: - return _indent(v_str, indent) + right - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, (list, tuple)): + attr_str = _format_list_tuple(k, v, use_mapping) + end else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + right - return attr_str - - def _contain_invalid_identifier(dict_str): - contain_invalid_identifier = False - for key_name in dict_str: - contain_invalid_identifier |= \ - (not str(key_name).isidentifier()) - return contain_invalid_identifier - - def _format_dict(input_dict, outest_level=False): - r = '' - s = [] - - use_mapping = _contain_invalid_identifier(input_dict) - if use_mapping: - r += '{' - for idx, (k, v) in enumerate( - sorted(input_dict.items(), key=lambda x: str(x[0]))): - is_last = idx >= len(input_dict) - 1 - end = '' if outest_level or is_last else ',' - if isinstance(v, dict): - v_str = '\n' + _format_dict(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: dict({v_str}' - else: - attr_str = f'{str(k)}=dict({v_str}' - attr_str = _indent(attr_str, indent) + ')' + end - elif isinstance(v, (list, tuple)): - attr_str = _format_list_tuple(k, v, use_mapping) + end - else: - attr_str = _format_basic_types(k, v, use_mapping) + end - - s.append(attr_str) - r += '\n'.join(s) - if use_mapping: - r += '}' - return r - - cfg_dict = self.to_dict() - text = _format_dict(cfg_dict, outest_level=True) - if self._format_python_code: - # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) - try: - if digit_version(yapf.__version__) >= digit_version('0.40.2'): - text, _ = FormatCode(text, style_config=yapf_style) - else: - text, _ = FormatCode( - text, style_config=yapf_style, verify=True) - except: # noqa: E722 - raise SyntaxError('Failed to format the config file, please ' - f'check the syntax of: \n{text}') - return text - except Exception as e: - return f'Error occurs when formatting config: {e}' + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + if self._format_python_code: + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + try: + if digit_version(yapf.__version__) >= digit_version('0.40.2'): + text, _ = FormatCode(text, style_config=yapf_style) + else: + text, _ = FormatCode( + text, style_config=yapf_style, verify=True) + except: # noqa: E722 + raise SyntaxError('Failed to format the config file, please ' + f'check the syntax of: \n{text}') + return text def __repr__(self): return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index e4edc3466e..82565d8832 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -342,10 +342,8 @@ def _get_valid_value( else: # check whether value is torch.Tensor but don't want # to import torch in this file - if hasattr(value, 'numel') and value.numel() == 1: - value = value.item() - else: - print_log(f"MessageHub got unexpceted log: {value}", level=logging.WARN) + assert hasattr(value, 'numel') and value.numel() == 1 + value = value.item() return value # type: ignore def state_dict(self) -> dict: diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index cc83a5976d..58457c2a6e 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -103,7 +103,6 @@ def update_parameters(self, model: nn.Module) -> None: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) elif self.steps % self.interval == 0: - print(self.avg_parameters) for k, p_avg in self.avg_parameters.items(): if p_avg.dtype.is_floating_point: device = p_avg.device From f038e5ef888d6b03405622700322836195698149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E8=B4=BB=E9=92=A6?= <15337184+MGAM@user.noreply.gitee.com> Date: Fri, 17 Jan 2025 23:18:43 +0800 Subject: [PATCH 16/18] Revert "reconstruct" This reverts commit 7103c3e629a189336cac3308add2b080319025d6. --- mmengine/_strategy/fsdp.py | 2 +- mmengine/model/wrappers/distributed.py | 3 --- mmengine/runner/loops.py | 4 +--- mmengine/runner/runner.py | 4 +--- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..124dfd7c57 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + self.model, self.optim_wrapper.optimizer, state_dict) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index b88bc7c2b0..4113aebf9e 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -95,7 +95,6 @@ def __init__(self, def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: - return self.module.train_step(data, optim_wrapper) """Interface for model forward, backward and parameters updating during training process. @@ -127,7 +126,6 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: - return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -139,7 +137,6 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: - return self.module.test_step(data) """Gets the predictions of module during testing process. Args: diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index ba9ec9d9dd..4503f85b5c 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -12,7 +12,6 @@ from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of -from mmengine.dataset.sampler import InfiniteSampler from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -275,14 +274,13 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler): + if self._iter > 0: print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', logger='current', level=logging.WARNING) for _ in range(self._iter): - break # NOTE MGAM: override all preprocessing steps during resume. next(self.dataloader_iterator) # with torch.profiler.profile( diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index f89fb260a1..9821a85077 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -903,11 +903,9 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = model_wrapper_cfg.get('type') if isinstance(model_wrapper_type, str): - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_type) # type: ignore + model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore elif inspect.isclass(model_wrapper_type): pass else: From e36e2f1c58a11e3bb534fef16870c665e43913dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E8=B4=BB=E9=92=A6?= <15337184+MGAM@user.noreply.gitee.com> Date: Fri, 17 Jan 2025 23:19:58 +0800 Subject: [PATCH 17/18] Revert "PyTorch Profiler within IterBasedTrainLoop" This reverts commit eecaa92179bb275f650931f7597b0ead0420f6b5. --- mmengine/runner/loops.py | 17 +---------------- mmengine/runner/runner.py | 2 +- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 4503f85b5c..065bf9243c 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -282,19 +282,6 @@ def run(self) -> None: level=logging.WARNING) for _ in range(self._iter): next(self.dataloader_iterator) - - # with torch.profiler.profile( - # activities=[torch.profiler.ProfilerActivity.CPU, - # torch.profiler.ProfilerActivity.CUDA], - # schedule=torch.profiler.schedule(wait=1, warmup=2, active=3), - # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'), - # record_shapes=True, - # profile_memory=True, - # with_stack=True, - # with_flops=True, - # with_modules=True, - # ) as p: - while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() @@ -305,10 +292,8 @@ def run(self) -> None: if (self.runner.val_loop is not None and self._iter >= self.val_begin and (self._iter % self.val_interval == 0 - or self._iter == self._max_iters)): + or self._iter == self._max_iters)): self.runner.val_loop.run() - - # p.step() self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 9821a85077..9bbbcaedce 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1849,7 +1849,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from e + raise TypeError(f'{e} in {hook}') from None def register_hook( self, From 834cf9a77839f75a210723ee2189f39bc71790ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E8=B4=BB=E9=92=A6?= <15337184+MGAM@user.noreply.gitee.com> Date: Fri, 17 Jan 2025 23:26:12 +0800 Subject: [PATCH 18/18] fix --- mmengine/optim/optimizer/builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index ebba603dbf..af98043b7f 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -8,6 +8,7 @@ from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision +from mmengine.logging.logger import print_log from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS from .default_constructor import DefaultOptimWrapperConstructor from .optimizer_wrapper import OptimWrapper @@ -171,8 +172,10 @@ def register_transformers_optimizers(): except ImportError: pass else: - # KeyError: 'Adafactor is already registered in optimizer at torch.optim' - # OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + try: + OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + except KeyError as e: + pass transformer_optimizers.append('Adafactor') return transformer_optimizers