Skip to content

Commit 166f768

Browse files
committed
refactor(bump_rule): use enum on increment
1 parent 42d965e commit 166f768

8 files changed

+583
-351
lines changed

commitizen/bump_rule.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,51 @@
22

33
import re
44
from collections.abc import Iterable
5+
from enum import Enum, auto
56
from functools import cached_property
6-
from typing import Callable, Protocol
7+
from typing import Any, Callable, Protocol
78

89
from commitizen.exceptions import NoPatternMapError
9-
from commitizen.version_schemes import Increment
1010

11-
_VERSION_ORDERING = dict(zip((None, "PATCH", "MINOR", "MAJOR"), range(4)))
11+
12+
class SemVerIncrement(Enum):
13+
MAJOR = auto()
14+
MINOR = auto()
15+
PATCH = auto()
16+
17+
def __str__(self) -> str:
18+
return self.name
19+
20+
@classmethod
21+
def safe_cast(cls, value: str | None) -> SemVerIncrement | None:
22+
if value is None:
23+
return None
24+
try:
25+
return cls[value]
26+
except ValueError:
27+
return None
28+
29+
@classmethod
30+
def safe_cast_dict(cls, d: dict[str, Any]) -> dict[str, SemVerIncrement]:
31+
return {
32+
k: v
33+
for k, v in ((k, SemVerIncrement.safe_cast(v)) for k, v in d.items())
34+
if v is not None
35+
}
36+
37+
38+
_VERSION_ORDERING = dict(
39+
zip(
40+
(None, SemVerIncrement.PATCH, SemVerIncrement.MINOR, SemVerIncrement.MAJOR),
41+
range(4),
42+
)
43+
)
1244

1345

1446
def find_increment_by_callable(
15-
commit_messages: Iterable[str], get_increment: Callable[[str], Increment | None]
16-
) -> Increment | None:
47+
commit_messages: Iterable[str],
48+
get_increment: Callable[[str], SemVerIncrement | None],
49+
) -> SemVerIncrement | None:
1750
"""Find the highest version increment from a list of messages.
1851
1952
This function processes a list of messages and determines the highest version
@@ -23,7 +56,7 @@ def find_increment_by_callable(
2356
Args:
2457
commit_messages: A list of messages to analyze.
2558
get_increment: A callable that takes a commit message string and returns an
26-
Increment value (MAJOR, MINOR, PATCH) or None if no increment is needed.
59+
SemVerIncrement value (MAJOR, MINOR, PATCH) or None if no increment is needed.
2760
2861
Returns:
2962
The highest version increment needed (MAJOR, MINOR, PATCH) or None if no
@@ -40,14 +73,16 @@ def find_increment_by_callable(
4073
return _find_highest_increment(increments)
4174

4275

43-
def _find_highest_increment(increments: Iterable[Increment | None]) -> Increment | None:
76+
def _find_highest_increment(
77+
increments: Iterable[SemVerIncrement | None],
78+
) -> SemVerIncrement | None:
4479
return max(increments, key=lambda x: _VERSION_ORDERING[x], default=None)
4580

4681

4782
class BumpRule(Protocol):
4883
def get_increment(
4984
self, commit_message: str, major_version_zero: bool
50-
) -> Increment | None:
85+
) -> SemVerIncrement | None:
5186
"""Determine the version increment based on a commit message.
5287
5388
This method analyzes a commit message to determine what kind of version increment
@@ -60,7 +95,7 @@ def get_increment(
6095
instead of MAJOR. This is useful for projects in 0.x.x versions.
6196
6297
Returns:
63-
Increment | None: The type of version increment needed:
98+
SemVerIncrement | None: The type of version increment needed:
6499
- "MAJOR": For breaking changes when major_version_zero is False
65100
- "MINOR": For breaking changes when major_version_zero is True, or for new features
66101
- "PATCH": For bug fixes, performance improvements, or refactors
@@ -76,19 +111,21 @@ class ConventionalCommitBumpRule(BumpRule):
76111

77112
def get_increment(
78113
self, commit_message: str, major_version_zero: bool
79-
) -> Increment | None:
114+
) -> SemVerIncrement | None:
80115
if not (m := self._head_pattern.match(commit_message)):
81116
return None
82117

83118
change_type = m.group("change_type")
84119
if m.group("bang") or self._RE_BREAKING_CHANGE.match(change_type):
85-
return "MINOR" if major_version_zero else "MAJOR"
120+
return (
121+
SemVerIncrement.MINOR if major_version_zero else SemVerIncrement.MAJOR
122+
)
86123

87124
if change_type == "feat":
88-
return "MINOR"
125+
return SemVerIncrement.MINOR
89126

90127
if change_type in self._PATCH_CHANGE_TYPES:
91-
return "PATCH"
128+
return SemVerIncrement.PATCH
92129

93130
return None
94131

@@ -118,8 +155,8 @@ class OldSchoolBumpRule(BumpRule):
118155
def __init__(
119156
self,
120157
bump_pattern: str,
121-
bump_map: dict[str, Increment],
122-
bump_map_major_version_zero: dict[str, Increment],
158+
bump_map: dict[str, SemVerIncrement],
159+
bump_map_major_version_zero: dict[str, SemVerIncrement],
123160
):
124161
if not bump_map or not bump_pattern or not bump_map_major_version_zero:
125162
raise NoPatternMapError(
@@ -132,7 +169,7 @@ def __init__(
132169

133170
def get_increment(
134171
self, commit_message: str, major_version_zero: bool
135-
) -> Increment | None:
172+
) -> SemVerIncrement | None:
136173
if not (m := self.bump_pattern.search(commit_message)):
137174
return None
138175

commitizen/commands/bump.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import questionary
88

99
from commitizen import bump, factory, git, hooks, out
10-
from commitizen.bump_rule import OldSchoolBumpRule, find_increment_by_callable
10+
from commitizen.bump_rule import (
11+
OldSchoolBumpRule,
12+
SemVerIncrement,
13+
find_increment_by_callable,
14+
)
1115
from commitizen.changelog_formats import get_changelog_format
1216
from commitizen.commands.changelog import Changelog
1317
from commitizen.config import BaseConfig
@@ -29,7 +33,6 @@
2933
from commitizen.providers import get_provider
3034
from commitizen.tags import TagRules
3135
from commitizen.version_schemes import (
32-
Increment,
3336
InvalidVersion,
3437
Prerelease,
3538
get_version_scheme,
@@ -120,22 +123,23 @@ def is_initial_tag(
120123
is_initial = questionary.confirm("Is this the first tag created?").ask()
121124
return is_initial
122125

123-
def find_increment(self, commits: list[git.GitCommit]) -> Increment | None:
126+
def find_increment(self, commits: list[git.GitCommit]) -> SemVerIncrement | None:
124127
# Update the bump map to ensure major version doesn't increment.
125128
is_major_version_zero: bool = self.bump_settings["major_version_zero"]
126129

127130
# Fallback to old school bump rule if no bump rule is provided
128131
rule = self.cz.bump_rule or OldSchoolBumpRule(
129132
*self._get_validated_cz_bump(),
130133
)
134+
131135
return find_increment_by_callable(
132136
(commit.message for commit in commits),
133137
lambda x: rule.get_increment(x, is_major_version_zero),
134138
)
135139

136140
def _get_validated_cz_bump(
137141
self,
138-
) -> tuple[str, dict[str, Increment], dict[str, Increment]]:
142+
) -> tuple[str, dict[str, SemVerIncrement], dict[str, SemVerIncrement]]:
139143
"""For fixing the type errors"""
140144
bump_pattern = self.cz.bump_pattern
141145
bump_map = self.cz.bump_map
@@ -145,9 +149,10 @@ def _get_validated_cz_bump(
145149
f"'{self.config.settings['name']}' rule does not support bump"
146150
)
147151

148-
return cast(
149-
tuple[str, dict[str, Increment], dict[str, Increment]],
150-
(bump_pattern, bump_map, bump_map_major_version_zero),
152+
return (
153+
bump_pattern,
154+
SemVerIncrement.safe_cast_dict(bump_map),
155+
SemVerIncrement.safe_cast_dict(bump_map_major_version_zero),
151156
)
152157

153158
def __call__(self) -> None: # noqa: C901
@@ -166,7 +171,9 @@ def __call__(self) -> None: # noqa: C901
166171

167172
dry_run: bool = self.arguments["dry_run"]
168173
is_yes: bool = self.arguments["yes"]
169-
increment: Increment | None = self.arguments["increment"]
174+
increment: SemVerIncrement | None = SemVerIncrement.safe_cast(
175+
self.arguments["increment"]
176+
)
170177
prerelease: Prerelease | None = self.arguments["prerelease"]
171178
devrelease: int | None = self.arguments["devrelease"]
172179
is_files_only: bool | None = self.arguments["files_only"]
@@ -283,7 +290,7 @@ def __call__(self) -> None: # noqa: C901
283290

284291
# we create an empty PATCH increment for empty tag
285292
if increment is None and allow_no_commit:
286-
increment = "PATCH"
293+
increment = SemVerIncrement.PATCH
287294

288295
new_version = current_version.bump(
289296
increment,

commitizen/defaults.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from collections.abc import Iterable, MutableMapping, Sequence
66
from typing import Any, TypedDict
77

8+
from commitizen.bump_rule import SemVerIncrement
9+
810
# Type
911
Questions = Iterable[MutableMapping[str, Any]]
1012

@@ -108,31 +110,27 @@ class Settings(TypedDict, total=False):
108110
"extras": {},
109111
}
110112

111-
MAJOR = "MAJOR"
112-
MINOR = "MINOR"
113-
PATCH = "PATCH"
114-
115113
CHANGELOG_FORMAT = "markdown"
116114

117115
BUMP_PATTERN = r"^((BREAKING[\-\ ]CHANGE|\w+)(\(.+\))?!?):"
118-
BUMP_MAP = OrderedDict(
116+
BUMP_MAP = dict(
119117
(
120-
(r"^.+!$", MAJOR),
121-
(r"^BREAKING[\-\ ]CHANGE", MAJOR),
122-
(r"^feat", MINOR),
123-
(r"^fix", PATCH),
124-
(r"^refactor", PATCH),
125-
(r"^perf", PATCH),
118+
(r"^.+!$", str(SemVerIncrement.MAJOR)),
119+
(r"^BREAKING[\-\ ]CHANGE", str(SemVerIncrement.MAJOR)),
120+
(r"^feat", str(SemVerIncrement.MINOR)),
121+
(r"^fix", str(SemVerIncrement.PATCH)),
122+
(r"^refactor", str(SemVerIncrement.PATCH)),
123+
(r"^perf", str(SemVerIncrement.PATCH)),
126124
)
127125
)
128-
BUMP_MAP_MAJOR_VERSION_ZERO = OrderedDict(
126+
BUMP_MAP_MAJOR_VERSION_ZERO = dict(
129127
(
130-
(r"^.+!$", MINOR),
131-
(r"^BREAKING[\-\ ]CHANGE", MINOR),
132-
(r"^feat", MINOR),
133-
(r"^fix", PATCH),
134-
(r"^refactor", PATCH),
135-
(r"^perf", PATCH),
128+
(r"^.+!$", str(SemVerIncrement.MINOR)),
129+
(r"^BREAKING[\-\ ]CHANGE", str(SemVerIncrement.MINOR)),
130+
(r"^feat", str(SemVerIncrement.MINOR)),
131+
(r"^fix", str(SemVerIncrement.PATCH)),
132+
(r"^refactor", str(SemVerIncrement.PATCH)),
133+
(r"^perf", str(SemVerIncrement.PATCH)),
136134
)
137135
)
138136
change_type_order = ["BREAKING CHANGE", "Feat", "Fix", "Refactor", "Perf"]

commitizen/version_schemes.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
runtime_checkable,
1515
)
1616

17+
from commitizen.bump_rule import SemVerIncrement
18+
1719
if sys.version_info >= (3, 10):
1820
from importlib import metadata
1921
else:
@@ -22,7 +24,7 @@
2224
from packaging.version import InvalidVersion # noqa: F401: expose the common exception
2325
from packaging.version import Version as _BaseVersion
2426

25-
from commitizen.defaults import MAJOR, MINOR, PATCH, Settings
27+
from commitizen.defaults import Settings
2628
from commitizen.exceptions import VersionSchemeUnknown
2729

2830
if TYPE_CHECKING:
@@ -39,7 +41,6 @@
3941
from typing import Self
4042

4143

42-
Increment: TypeAlias = Literal["MAJOR", "MINOR", "PATCH"]
4344
Prerelease: TypeAlias = Literal["alpha", "beta", "rc"]
4445
DEFAULT_VERSION_PARSER = r"v?(?P<version>([0-9]+)\.([0-9]+)(?:\.([0-9]+))?(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+[0-9A-Za-z.]+)?(\w+)?)"
4546

@@ -126,7 +127,7 @@ def __ne__(self, other: object) -> bool:
126127

127128
def bump(
128129
self,
129-
increment: Increment | None,
130+
increment: SemVerIncrement | None,
130131
prerelease: Prerelease | None = None,
131132
prerelease_offset: int = 0,
132133
devrelease: int | None = None,
@@ -223,26 +224,30 @@ def generate_build_metadata(self, build_metadata: str | None) -> str:
223224

224225
return f"+{build_metadata}"
225226

226-
def increment_base(self, increment: Increment | None = None) -> str:
227+
def increment_base(self, increment: SemVerIncrement | None = None) -> str:
227228
prev_release = list(self.release)
228-
increments = [MAJOR, MINOR, PATCH]
229+
increments = [
230+
SemVerIncrement.MAJOR,
231+
SemVerIncrement.MINOR,
232+
SemVerIncrement.PATCH,
233+
]
229234
base = dict(zip_longest(increments, prev_release, fillvalue=0))
230235

231-
if increment == MAJOR:
232-
base[MAJOR] += 1
233-
base[MINOR] = 0
234-
base[PATCH] = 0
235-
elif increment == MINOR:
236-
base[MINOR] += 1
237-
base[PATCH] = 0
238-
elif increment == PATCH:
239-
base[PATCH] += 1
236+
if increment == SemVerIncrement.MAJOR:
237+
base[SemVerIncrement.MAJOR] += 1
238+
base[SemVerIncrement.MINOR] = 0
239+
base[SemVerIncrement.PATCH] = 0
240+
elif increment == SemVerIncrement.MINOR:
241+
base[SemVerIncrement.MINOR] += 1
242+
base[SemVerIncrement.PATCH] = 0
243+
elif increment == SemVerIncrement.PATCH:
244+
base[SemVerIncrement.PATCH] += 1
240245

241-
return f"{base[MAJOR]}.{base[MINOR]}.{base[PATCH]}"
246+
return f"{base[SemVerIncrement.MAJOR]}.{base[SemVerIncrement.MINOR]}.{base[SemVerIncrement.PATCH]}"
242247

243248
def bump(
244249
self,
245-
increment: Increment | None,
250+
increment: SemVerIncrement | None,
246251
prerelease: Prerelease | None = None,
247252
prerelease_offset: int = 0,
248253
devrelease: int | None = None,
@@ -272,12 +277,12 @@ def bump(
272277
base = self.increment_base(increment)
273278
else:
274279
base = f"{self.major}.{self.minor}.{self.micro}"
275-
if increment == PATCH:
280+
if increment == SemVerIncrement.PATCH:
276281
pass
277-
elif increment == MINOR:
282+
elif increment == SemVerIncrement.MINOR:
278283
if self.micro != 0:
279284
base = self.increment_base(increment)
280-
elif increment == MAJOR:
285+
elif increment == SemVerIncrement.MAJOR:
281286
if self.minor != 0 or self.micro != 0:
282287
base = self.increment_base(increment)
283288
dev_version = self.generate_devrelease(devrelease)

0 commit comments

Comments
 (0)