diff --git a/commitizen/defaults.py b/commitizen/defaults.py index e4363f4ab..5c70d3119 100644 --- a/commitizen/defaults.py +++ b/commitizen/defaults.py @@ -55,6 +55,7 @@ class Settings(TypedDict, total=False): always_signoff: bool template: str | None extras: dict[str, Any] + branch_prerelease_map: dict[str, str] name: str = "cz_conventional_commits" @@ -101,6 +102,7 @@ class Settings(TypedDict, total=False): "always_signoff": False, "template": None, # default provided by plugin "extras": {}, + "branch_prerelease_map": {} } MAJOR = "MAJOR" diff --git a/commitizen/git.py b/commitizen/git.py index 1f758889e..dcfaf345f 100644 --- a/commitizen/git.py +++ b/commitizen/git.py @@ -276,6 +276,13 @@ def smart_open(*args, **kargs): return open(*args, newline=get_eol_style().get_eol_for_open(), **kargs) +def get_current_branch() -> str: + c = cmd.run("git rev-parse --abbrev-ref HEAD") + if c.return_code != 0: + raise GitCommandError(c.err) + return c.out.strip("\n") + + def _get_log_as_str_list(start: str | None, end: str, args: str) -> list[str]: """Get string representation of each log entry""" delimiter = "----------commit-delimiter----------" diff --git a/commitizen/providers/scm_provider.py b/commitizen/providers/scm_provider.py index 00df3e415..2e2f10c22 100644 --- a/commitizen/providers/scm_provider.py +++ b/commitizen/providers/scm_provider.py @@ -3,7 +3,7 @@ import re from typing import Callable -from commitizen.git import get_tags +from commitizen.git import get_tags, get_current_branch from commitizen.providers.base_provider import VersionProvider from commitizen.version_schemes import ( InvalidVersion, @@ -79,6 +79,24 @@ def get_version(self) -> str: ) if not matches: return "0.0.0" + + branch_prerelease_map = self.config.settings.get("branch_prerelease_map") + current_branch = get_current_branch() + + if branch_prerelease_map and current_branch in branch_prerelease_map: + release_type = branch_prerelease_map[current_branch] + + if release_type: + prerelease_matches = [ + v for v in matches if v.is_prerelease and release_type in v.prerelease + ] + if prerelease_matches: + return str(prerelease_matches[-1]) + else: + stable_matches = [v for v in matches if not v.is_prerelease] + if stable_matches: + return str(stable_matches[-1]) + return str(matches[-1]) def set_version(self, version: str): diff --git a/tests/providers/test_scm_provider.py b/tests/providers/test_scm_provider.py index 01e7ab994..8a2b73527 100644 --- a/tests/providers/test_scm_provider.py +++ b/tests/providers/test_scm_provider.py @@ -3,6 +3,7 @@ import pytest from commitizen.config.base_config import BaseConfig +from commitizen.git import get_tags from commitizen.providers import get_provider from commitizen.providers.scm_provider import ScmProvider from tests.utils import ( @@ -113,3 +114,65 @@ def test_scm_provider_default_with_commits_and_tags(config: BaseConfig): merge_branch("master") assert provider.get_version() == "1.1.0rc0" + + +@pytest.mark.usefixtures("tmp_git_project") +def test_scm_provider_highest_tag_across_branches(config: BaseConfig): + from collections import Counter + config.settings["version_provider"] = "scm" + config.settings["tag_format"] = "$version" + + # Providing branch_prerelease_map + config.settings["branch_prerelease_map"] = { + "develop": "b", + "staging": "rc", + "master": "", + } + provider = ScmProvider(config) + + assert isinstance(provider, ScmProvider) + + create_file_and_commit("Initial state") + + # Add feature to develop + create_branch("develop") + switch_branch("develop") + create_file_and_commit("Initial state") + create_tag("0.1.0b0") + + # Create staging branch and promote develop to staging + create_branch("staging") + switch_branch("staging") + merge_branch("develop") + create_tag("0.1.0rc0") + + # Add another feature to develop + switch_branch("develop") + create_file_and_commit("develop: feature 2") + create_tag("0.2.0b0") + + # Promote staging to master + switch_branch("master") + merge_branch("staging") + create_tag("0.1.0") + + # Promote develop to staging + switch_branch("staging") + merge_branch("develop") + create_tag("0.2.0rc0") + + # Check the version and tags on each branch + switch_branch("master") + master_tags = [x.name for x in get_tags(reachable_only=True)] + assert Counter(master_tags) == Counter(['0.1.0', '0.1.0b0', '0.1.0rc0']) + assert provider.get_version() == "0.1.0" + + switch_branch("staging") + staging_tags = [x.name for x in get_tags(reachable_only=True)] + assert Counter(staging_tags) == Counter(['0.1.0', '0.1.0b0', '0.1.0rc0', '0.2.0b0', '0.2.0rc0']) + assert provider.get_version() == "0.2.0rc0" + + switch_branch("develop") + develop_tags = [x.name for x in get_tags(reachable_only=True)] + assert Counter(develop_tags) == Counter(['0.1.0', '0.1.0b0', '0.1.0rc0', '0.2.0b0', '0.2.0rc0']) + assert provider.get_version() == "0.2.0b0" diff --git a/tests/utils.py b/tests/utils.py index 971ff9182..917049410 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,7 +60,7 @@ def get_current_branch() -> str: c = cmd.run("git rev-parse --abbrev-ref HEAD") if c.return_code != 0: raise exceptions.GitCommandError(c.err) - return c.out + return c.out.strip("\n") def create_tag(tag: str, message: str | None = None) -> None: