Skip to content

Commit bf2ee82

Browse files
committed
Improve get_version method
1 parent 553d552 commit bf2ee82

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

src/scyjava/_versions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def get_version(java_class_or_python_package) -> str:
1515
"""
1616
Return the version of a Java class or Python package.
1717
18-
For Python package, uses importlib.metadata.version if available
19-
(Python 3.8+), with pkg_resources.get_distribution as a fallback.
18+
For Python packages, invokes importlib.metadata.version on the given
19+
object's base __module__ or __package__ (before the first dot symbol).
2020
2121
For Java classes, requires org.scijava:scijava-common on the classpath.
2222
@@ -32,8 +32,16 @@ def get_version(java_class_or_python_package) -> str:
3232
VersionUtils = jimport("org.scijava.util.VersionUtils")
3333
return str(VersionUtils.getVersion(java_class_or_python_package))
3434

35-
# Assume we were given a Python package name.
36-
return version(java_class_or_python_package)
35+
# Assume we were given a Python package name or module.
36+
package_name = None
37+
if hasattr(java_class_or_python_package, "__module__"):
38+
package_name = java_class_or_python_package.__module__
39+
elif hasattr(java_class_or_python_package, "__package__"):
40+
package_name = java_class_or_python_package.__package__
41+
else:
42+
package_name = str(java_class_or_python_package)
43+
44+
return version(package_name.split(".")[0])
3745

3846

3947
def is_version_at_least(actual_version: str, minimum_version: str) -> bool:

tests/test_versions.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tests for functions in _versions submodule.
33
"""
44

5+
from importlib.metadata import version
56
from pathlib import Path
67

78
import toml
@@ -18,8 +19,18 @@ def _expected_version():
1819

1920

2021
def test_version():
21-
# First, ensure that the version is correct
22-
assert _expected_version() == scyjava.__version__
22+
sjver = _expected_version()
2323

24-
# Then, ensure that we get the correct version via get_version
25-
assert _expected_version() == scyjava.get_version("scyjava")
24+
# First, ensure that the version is correct.
25+
assert sjver == scyjava.__version__
26+
27+
# Then, ensure that we get the correct version via get_version.
28+
assert sjver == scyjava.get_version("scyjava")
29+
assert sjver == scyjava.get_version(scyjava)
30+
assert sjver == scyjava.get_version("scyjava.config")
31+
assert sjver == scyjava.get_version(scyjava.config)
32+
assert sjver == scyjava.get_version(scyjava.config.mode)
33+
assert sjver == scyjava.get_version(scyjava.config.Mode)
34+
35+
# And that we get the correct version of other things, too.
36+
assert version("toml") == scyjava.get_version(toml)

0 commit comments

Comments
 (0)