Skip to content

feat(scm_provider): add branch-based version retrieval #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions commitizen/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Settings(TypedDict, total=False):
always_signoff: bool
template: str | None
extras: dict[str, Any]
branch_prerelease_map: dict[str, str]


name: str = "cz_conventional_commits"
Expand Down Expand Up @@ -101,6 +102,7 @@ class Settings(TypedDict, total=False):
"always_signoff": False,
"template": None, # default provided by plugin
"extras": {},
"branch_prerelease_map": {}
}

MAJOR = "MAJOR"
Expand Down
7 changes: 7 additions & 0 deletions commitizen/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ def smart_open(*args, **kargs):
return open(*args, newline=get_eol_style().get_eol_for_open(), **kargs)


def get_current_branch() -> str:
c = cmd.run("git rev-parse --abbrev-ref HEAD")
if c.return_code != 0:
raise GitCommandError(c.err)
return c.out.strip("\n")


def _get_log_as_str_list(start: str | None, end: str, args: str) -> list[str]:
"""Get string representation of each log entry"""
delimiter = "----------commit-delimiter----------"
Expand Down
20 changes: 19 additions & 1 deletion commitizen/providers/scm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Callable

from commitizen.git import get_tags
from commitizen.git import get_tags, get_current_branch
from commitizen.providers.base_provider import VersionProvider
from commitizen.version_schemes import (
InvalidVersion,
Expand Down Expand Up @@ -79,6 +79,24 @@ def get_version(self) -> str:
)
if not matches:
return "0.0.0"

branch_prerelease_map = self.config.settings.get("branch_prerelease_map")
current_branch = get_current_branch()

if branch_prerelease_map and current_branch in branch_prerelease_map:
release_type = branch_prerelease_map[current_branch]

if release_type:
prerelease_matches = [
v for v in matches if v.is_prerelease and release_type in v.prerelease
]
if prerelease_matches:
return str(prerelease_matches[-1])
Copy link

@chadrik chadrik Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to only keep tags of the same prerelease type? isn't it possible that I could be on staging/rc branch and the latest version is a beta version because no rc versions have been created from that beta version?

would it work to say that we want to filter prereleases that are "higher" than the prerelease that corresponds to our branch?

  • alpha would filter beta and rc.
  • beta would filter rc.
  • rc would not filter any prereleases

or to put this in the reverse:

  • rc would keep rc, beta, and alpha versions
  • beta would keep beta, and alpha versions
  • alpha would keep only alpha

else:
stable_matches = [v for v in matches if not v.is_prerelease]
if stable_matches:
return str(stable_matches[-1])

return str(matches[-1])

def set_version(self, version: str):
Expand Down
63 changes: 63 additions & 0 deletions tests/providers/test_scm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from commitizen.config.base_config import BaseConfig
from commitizen.git import get_tags
from commitizen.providers import get_provider
from commitizen.providers.scm_provider import ScmProvider
from tests.utils import (
Expand Down Expand Up @@ -113,3 +114,65 @@ def test_scm_provider_default_with_commits_and_tags(config: BaseConfig):
merge_branch("master")

assert provider.get_version() == "1.1.0rc0"


@pytest.mark.usefixtures("tmp_git_project")
def test_scm_provider_highest_tag_across_branches(config: BaseConfig):
from collections import Counter
config.settings["version_provider"] = "scm"
config.settings["tag_format"] = "$version"

# Providing branch_prerelease_map
config.settings["branch_prerelease_map"] = {
"develop": "b",
"staging": "rc",
"master": "",
}
provider = ScmProvider(config)

assert isinstance(provider, ScmProvider)

create_file_and_commit("Initial state")

# Add feature to develop
create_branch("develop")
switch_branch("develop")
create_file_and_commit("Initial state")
create_tag("0.1.0b0")

# Create staging branch and promote develop to staging
create_branch("staging")
switch_branch("staging")
merge_branch("develop")
create_tag("0.1.0rc0")

# Add another feature to develop
switch_branch("develop")
create_file_and_commit("develop: feature 2")
create_tag("0.2.0b0")

# Promote staging to master
switch_branch("master")
merge_branch("staging")
create_tag("0.1.0")

# Promote develop to staging
switch_branch("staging")
merge_branch("develop")
create_tag("0.2.0rc0")

# Check the version and tags on each branch
switch_branch("master")
master_tags = [x.name for x in get_tags(reachable_only=True)]
assert Counter(master_tags) == Counter(['0.1.0', '0.1.0b0', '0.1.0rc0'])
assert provider.get_version() == "0.1.0"

switch_branch("staging")
staging_tags = [x.name for x in get_tags(reachable_only=True)]
assert Counter(staging_tags) == Counter(['0.1.0', '0.1.0b0', '0.1.0rc0', '0.2.0b0', '0.2.0rc0'])
assert provider.get_version() == "0.2.0rc0"

switch_branch("develop")
develop_tags = [x.name for x in get_tags(reachable_only=True)]
assert Counter(develop_tags) == Counter(['0.1.0', '0.1.0b0', '0.1.0rc0', '0.2.0b0', '0.2.0rc0'])
assert provider.get_version() == "0.2.0b0"
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_current_branch() -> str:
c = cmd.run("git rev-parse --abbrev-ref HEAD")
if c.return_code != 0:
raise exceptions.GitCommandError(c.err)
return c.out
return c.out.strip("\n")


def create_tag(tag: str, message: str | None = None) -> None:
Expand Down