diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 452bbaddaa..606d3686c3 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess +from importlib.metadata import PackageNotFoundError, distribution +from typing import Any def is_installed(package: str) -> bool: @@ -9,28 +11,19 @@ def is_installed(package: str) -> bool: Args: package (str): Name of package to be checked. """ - # When executing `import mmengine.runner`, - # pkg_resources will be imported and it takes too much time. - # Therefore, import it in function scope to save time. import importlib.util - import pkg_resources # type: ignore - from pkg_resources import get_distribution + # First check if it's an importable module + spec = importlib.util.find_spec(package) + if spec is not None and spec.origin is not None: + return True - # refresh the pkg_resources - # more datails at https://github.com/pypa/setuptools/issues/373 - importlib.reload(pkg_resources) + # If not found as module, check if it's a distribution package try: - get_distribution(package) + distribution(package) return True - except pkg_resources.DistributionNotFound: - spec = importlib.util.find_spec(package) - if spec is None: - return False - elif spec.origin is not None: - return True - else: - return False + except PackageNotFoundError: + return False def get_installed_path(package: str) -> str: @@ -45,17 +38,21 @@ def get_installed_path(package: str) -> str: """ import importlib.util - from pkg_resources import DistributionNotFound, get_distribution - # if the package name is not the same as module name, module name should be # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name + # Try to get location from distribution package metadata + location = None try: - pkg = get_distribution(package) - except DistributionNotFound as e: - # if the package is not installed, package path set in PYTHONPATH - # can be detected by `find_spec` + dist = distribution(package) + locate_result: Any = dist.locate_file('') + location = str(locate_result.parent) + except PackageNotFoundError: + pass + + # If distribution package not found, try to find via importlib + if location is None: spec = importlib.util.find_spec(package) if spec is not None: if spec.origin is not None: @@ -67,28 +64,33 @@ def get_installed_path(package: str) -> str: f'{package} is a namespace package, which is invalid ' 'for `get_install_path`') else: - raise e + raise PackageNotFoundError(f'Package {package} is not installed') - possible_path = osp.join(pkg.location, package) # type: ignore + # Check if package directory exists in the location + possible_path = osp.join(location, package) if osp.exists(possible_path): return possible_path else: - return osp.join(pkg.location, package2module(package)) # type: ignore + return osp.join(location, package2module(package)) -def package2module(package: str): +def package2module(package: str) -> str: """Infer module name from package. Args: package (str): Package to infer module name. """ - from pkg_resources import get_distribution - pkg = get_distribution(package) - if pkg.has_metadata('top_level.txt'): - module_name = pkg.get_metadata('top_level.txt').split('\n')[0] - return module_name - else: - raise ValueError(f'can not infer the module name of {package}') + dist = distribution(package) + + # In importlib.metadata, + # top-level modules are in dist.read_text('top_level.txt') + top_level_text = dist.read_text('top_level.txt') + if top_level_text is not None: + lines = top_level_text.strip().split('\n') + if lines: + module_name = lines[0].strip() + return module_name + raise ValueError(f'can not infer the module name of {package}') def call_command(cmd: list) -> None: diff --git a/tests/data/config/lazy_module_config/test_ast_transform.py b/tests/data/config/lazy_module_config/test_ast_transform.py index a8803dde24..141ec0304e 100644 --- a/tests/data/config/lazy_module_config/test_ast_transform.py +++ b/tests/data/config/lazy_module_config/test_ast_transform.py @@ -12,4 +12,5 @@ from ._base_.default_runtime import default_scope as scope from ._base_.scheduler import val_cfg from rich.progress import Progress + start = Progress.start diff --git a/tests/data/config/lazy_module_config/test_mix_builtin.py b/tests/data/config/lazy_module_config/test_mix_builtin.py index e36da58a3b..1698f1ea4b 100644 --- a/tests/data/config/lazy_module_config/test_mix_builtin.py +++ b/tests/data/config/lazy_module_config/test_mix_builtin.py @@ -13,4 +13,3 @@ chained = list(chain([1, 2], [3, 4])) existed = ex(__file__) cfgname = partial(basename, __file__)() - diff --git a/tests/data/config/lazy_module_config/toy_model.py b/tests/data/config/lazy_module_config/toy_model.py index a9d2a3f64a..99755b4525 100644 --- a/tests/data/config/lazy_module_config/toy_model.py +++ b/tests/data/config/lazy_module_config/toy_model.py @@ -13,7 +13,6 @@ param_scheduler.milestones = [2, 4] - train_dataloader = dict( dataset=dict(type=ToyDataset), sampler=dict(type=DefaultSampler, shuffle=True), diff --git a/tests/data/config/py_config/test_custom_class.py b/tests/data/config/py_config/test_custom_class.py index ad706b087e..ae6af19e25 100644 --- a/tests/data/config/py_config/test_custom_class.py +++ b/tests/data/config/py_config/test_custom_class.py @@ -2,4 +2,5 @@ class A: ... + item_a = dict(a=A) diff --git a/tests/data/config/py_config/test_dump_pickle_support.py b/tests/data/config/py_config/test_dump_pickle_support.py index 6050ce10b1..2f8dafa3aa 100644 --- a/tests/data/config/py_config/test_dump_pickle_support.py +++ b/tests/data/config/py_config/test_dump_pickle_support.py @@ -24,5 +24,5 @@ def func(): dict_item5 = {'x/x': {'a.0': 233}} dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]} # Test windows path and escape. -str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in +str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in str_item_8 = func() diff --git a/tests/data/config/py_config/test_get_external_cfg3.py b/tests/data/config/py_config/test_get_external_cfg3.py index 5ae261350a..2dded0da76 100644 --- a/tests/data/config/py_config/test_get_external_cfg3.py +++ b/tests/data/config/py_config/test_get_external_cfg3.py @@ -3,16 +3,11 @@ 'mmdet::_base_/models/faster-rcnn_r50_fpn.py', 'mmdet::_base_/datasets/coco_detection.py', 'mmdet::_base_/schedules/schedule_1x.py', - 'mmdet::_base_/default_runtime.py', - './test_get_external_cfg_base.py' + 'mmdet::_base_/default_runtime.py', './test_get_external_cfg_base.py' ] custom_hooks = [dict(type='mmdet.DetVisualizationHook')] model = dict( roi_head=dict( - bbox_head=dict( - loss_cls=dict(_delete_=True, type='test.ToyLoss') - ) - ) -) + bbox_head=dict(loss_cls=dict(_delete_=True, type='test.ToyLoss')))) diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index 11ce294c29..276d514dfc 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import sys +from importlib.metadata import PackageNotFoundError -import pkg_resources # type: ignore import pytest from mmengine.utils import get_installed_path, is_installed @@ -20,6 +20,12 @@ def test_is_installed(): assert is_installed('optim') sys.path.pop() + assert is_installed('nonexistentpackage12345') is False + assert is_installed('os') is True # 'os' is a module name + assert is_installed('setuptools') is True + # Should work on both distribution and module name + assert is_installed('pillow') is True and is_installed('PIL') is True + def test_get_install_path(): # TODO: Windows CI may failed in unknown reason. Skip check the value @@ -33,5 +39,5 @@ def test_get_install_path(): assert get_installed_path('optim') == osp.join(PYTHONPATH, 'optim') sys.path.pop() - with pytest.raises(pkg_resources.DistributionNotFound): + with pytest.raises(PackageNotFoundError): get_installed_path('unknown')