Skip to content

Commit cd7cf6a

Browse files
committed
refactor(SemVerIncrement): use IntEnum
1 parent d820bf7 commit cd7cf6a

File tree

2 files changed

+19
-131
lines changed

2 files changed

+19
-131
lines changed

commitizen/bump_rule.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
11
from __future__ import annotations
22

33
import re
4-
from collections.abc import Iterable
5-
from enum import Enum, auto
4+
from collections.abc import Iterable, Mapping
5+
from enum import IntEnum, auto
66
from functools import cached_property
7-
from typing import Any, Callable, Protocol
7+
from typing import Callable, Protocol
88

99
from commitizen.exceptions import NoPatternMapError
1010

1111

12-
class SemVerIncrement(Enum):
12+
class SemVerIncrement(IntEnum):
1313
"""An enumeration representing semantic versioning increments.
1414
1515
This class defines the three types of version increments according to semantic versioning:
16-
- MAJOR: For incompatible API changes
17-
- MINOR: For backwards-compatible functionality additions
1816
- PATCH: For backwards-compatible bug fixes
17+
- MINOR: For backwards-compatible functionality additions
18+
- MAJOR: For incompatible API changes
1919
"""
2020

21-
MAJOR = auto()
22-
MINOR = auto()
2321
PATCH = auto()
22+
MINOR = auto()
23+
MAJOR = auto()
2424

2525
def __str__(self) -> str:
2626
return self.name
2727

2828
@classmethod
29-
def safe_cast(cls, value: Any) -> SemVerIncrement | None:
30-
if value is None:
29+
def safe_cast(cls, value: object) -> SemVerIncrement | None:
30+
if not isinstance(value, str):
3131
return None
3232
try:
3333
return cls[value]
3434
except KeyError:
3535
return None
3636

3737
@classmethod
38-
def safe_cast_dict(cls, d: dict[str, Any]) -> dict[str, SemVerIncrement]:
38+
def safe_cast_dict(cls, d: Mapping[str, object]) -> dict[str, SemVerIncrement]:
3939
return {
4040
k: v
4141
for k, v in ((k, SemVerIncrement.safe_cast(v)) for k, v in d.items())
@@ -68,25 +68,17 @@ def get_highest_by_messages(
6868
>>> SemVerIncrement.get_highest_by_messages(commit_messages, lambda x: rule.get_increment(x, False))
6969
'MINOR'
7070
"""
71-
return _find_highest_increment(
71+
return SemVerIncrement.get_highest(
7272
get_increment(line)
7373
for message in commit_messages
7474
for line in message.split("\n")
7575
)
7676

77-
78-
_VERSION_ORDERING = dict(
79-
zip(
80-
(None, SemVerIncrement.PATCH, SemVerIncrement.MINOR, SemVerIncrement.MAJOR),
81-
range(4),
82-
)
83-
)
84-
85-
86-
def _find_highest_increment(
87-
increments: Iterable[SemVerIncrement | None],
88-
) -> SemVerIncrement | None:
89-
return max(increments, key=lambda x: _VERSION_ORDERING[x], default=None)
77+
@staticmethod
78+
def get_highest(
79+
increments: Iterable[SemVerIncrement | None],
80+
) -> SemVerIncrement | None:
81+
return max(filter(None, increments), default=None)
9082

9183

9284
class BumpRule(Protocol):
@@ -193,8 +185,8 @@ def get_increment(
193185
)
194186

195187
try:
196-
if ret := _find_highest_increment(
197-
(increment for name, increment in bump_map.items() if m.group(name))
188+
if ret := SemVerIncrement.get_highest(
189+
(increment for name, increment in bump_map.items() if m.group(name)),
198190
):
199191
return ret
200192
except IndexError:

tests/test_bump_rule.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
ConventionalCommitBumpRule,
55
CustomBumpRule,
66
SemVerIncrement,
7-
_find_highest_increment,
87
)
98
from commitizen.defaults import (
109
BUMP_MAP,
@@ -587,106 +586,3 @@ def test_with_find_increment_by_callable(self, custom_bump_rule):
587586
)
588587
== SemVerIncrement.MAJOR
589588
)
590-
591-
592-
def test_find_highest_increment():
593-
"""Test the _find_highest_increment function."""
594-
# Test with single increment
595-
assert _find_highest_increment([SemVerIncrement.MAJOR]) == SemVerIncrement.MAJOR
596-
assert _find_highest_increment([SemVerIncrement.MINOR]) == SemVerIncrement.MINOR
597-
assert _find_highest_increment([SemVerIncrement.PATCH]) == SemVerIncrement.PATCH
598-
599-
# Test with multiple increments
600-
assert (
601-
_find_highest_increment(
602-
[SemVerIncrement.PATCH, SemVerIncrement.MINOR, SemVerIncrement.MAJOR]
603-
)
604-
== SemVerIncrement.MAJOR
605-
)
606-
assert (
607-
_find_highest_increment([SemVerIncrement.PATCH, SemVerIncrement.MINOR])
608-
== SemVerIncrement.MINOR
609-
)
610-
assert (
611-
_find_highest_increment([SemVerIncrement.PATCH, SemVerIncrement.PATCH])
612-
== SemVerIncrement.PATCH
613-
)
614-
615-
# Test with None values
616-
assert (
617-
_find_highest_increment([None, SemVerIncrement.PATCH]) == SemVerIncrement.PATCH
618-
)
619-
assert _find_highest_increment([None, None]) is None
620-
assert _find_highest_increment([]) is None
621-
622-
# Test with mixed values
623-
assert (
624-
_find_highest_increment(
625-
[None, SemVerIncrement.PATCH, SemVerIncrement.MINOR, SemVerIncrement.MAJOR]
626-
)
627-
== SemVerIncrement.MAJOR
628-
)
629-
assert (
630-
_find_highest_increment([None, SemVerIncrement.PATCH, SemVerIncrement.MINOR])
631-
== SemVerIncrement.MINOR
632-
)
633-
assert (
634-
_find_highest_increment([None, SemVerIncrement.PATCH]) == SemVerIncrement.PATCH
635-
)
636-
637-
# Test with empty iterator
638-
assert _find_highest_increment(iter([])) is None
639-
640-
# Test with generator expression
641-
assert (
642-
_find_highest_increment(
643-
x
644-
for x in [
645-
SemVerIncrement.PATCH,
646-
SemVerIncrement.MINOR,
647-
SemVerIncrement.MAJOR,
648-
]
649-
)
650-
== SemVerIncrement.MAJOR
651-
)
652-
assert (
653-
_find_highest_increment(
654-
x for x in [None, SemVerIncrement.PATCH, SemVerIncrement.MINOR]
655-
)
656-
== SemVerIncrement.MINOR
657-
)
658-
659-
660-
class TestSemVerIncrementSafeCast:
661-
def test_safe_cast_valid_values(self):
662-
"""Test safe_cast with valid enum values."""
663-
assert SemVerIncrement.safe_cast("MAJOR") == SemVerIncrement.MAJOR
664-
assert SemVerIncrement.safe_cast("MINOR") == SemVerIncrement.MINOR
665-
assert SemVerIncrement.safe_cast("PATCH") == SemVerIncrement.PATCH
666-
667-
def test_safe_cast_invalid_values(self):
668-
"""Test safe_cast with invalid values."""
669-
assert SemVerIncrement.safe_cast("INVALID") is None
670-
assert SemVerIncrement.safe_cast("") is None
671-
assert SemVerIncrement.safe_cast(123) is None
672-
assert SemVerIncrement.safe_cast(None) is None
673-
674-
def test_safe_cast_dict(self):
675-
"""Test safe_cast_dict method."""
676-
test_dict = {
677-
"MAJOR": "MAJOR",
678-
"MINOR": "MINOR",
679-
"PATCH": "PATCH",
680-
"INVALID": "INVALID",
681-
"empty": "",
682-
"number": 123,
683-
"none": None,
684-
}
685-
686-
expected_dict = {
687-
"MAJOR": SemVerIncrement.MAJOR,
688-
"MINOR": SemVerIncrement.MINOR,
689-
"PATCH": SemVerIncrement.PATCH,
690-
}
691-
692-
assert SemVerIncrement.safe_cast_dict(test_dict) == expected_dict

0 commit comments

Comments
 (0)