diff --git a/scripts/create_release_branch.py b/scripts/create_release_branch.py new file mode 100644 index 00000000..0d3b6af9 --- /dev/null +++ b/scripts/create_release_branch.py @@ -0,0 +1,640 @@ +#!/usr/bin/env python3 +# Copyright Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +""" +ROCm TheRock – Release Branch Automation Tool +-------------------------------------------- + +Creates release branches for TheRock and every tracked ROCm submodule at a +caller-provided commit. The script maintains a cached clone (configurable via +`--cache-dir`, default: `/tmp/rock-branching-cache`), fetches the latest refs, +and hard-resets to the specified commit before creating release branches. + +Authentication is done via SSH. Remotes are configured via +`git remote set-url` (with `add` fallback). + +High-level workflow: +1. Reuse (or populate) the cached TheRock clone; reclone only when the cache + is missing or corrupt (`--force-clone` deletes and reclones when the cache + directory exists but is not a valid git repo). Otherwise fetch/prune to + pick up new commits. +2. Hard-reset to the requested commit and populate submodules via + `fetch_sources.py` when available (fallback to `git submodule update`). +3. Build an execution plan from `.gitmodules` + `git submodule status`, + capturing repo URL, commit SHA, and working tree path for each component + plus TheRock itself. Repos listed in `--exclude-list` and repos outside + the ROCm GitHub org are filtered out. +4. For each component: + a. Set up the SSH `rocm-github` remote. + b. Check if the release branch already exists on the remote; if so, skip + the repo entirely (recorded as skipped, not a failure). + c. Create (or reset) the branch at the recorded commit. + d. Push to `rocm-github` (skipped in dry-run mode). +5. Log a summary of successful and failed repos. + +Dry-run mode (the default) logs every action without touching remotes; +`--no-dry-run` enables actual pushes. + +Usage: + python create_release_branch.py \\ + --branch-name \\ + --commitid \\ + [--no-dry-run] + +Options: + --branch-name Name of the release branch to create (required) + --commitid Commit SHA of TheRock to branch from (required) + --dry-run/--no-dry-run + Log actions without pushing to remotes (default: enabled) + --exclude-list Submodule repo names to skip (space-separated) + --force-clone Delete and reclone if cache dir is not a valid git repo + --cache-dir Directory to cache the TheRock clone + (default: /tmp/rock-branching-cache) +""" +import argparse +import logging +import re +import shlex +import shutil +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from pprint import pformat + + +@dataclass +class RepoInfo: + """Information about a repository to branch.""" + + url: str + commit: str + path: Path + + +class RockBranchingAutomation: + """Automates creation of release branches for TheRock and its ROCm submodules.""" + + def __init__(self, cli_args: argparse.Namespace) -> None: + self.release_branch: str = cli_args.branch_name + self.dry_run: bool = cli_args.dry_run + self.commitid: str = cli_args.commitid + + if not re.fullmatch(r"[0-9a-f]{40}", self.commitid): + raise SystemExit( + f"ERROR: --commitid must be a full 40-character lowercase hex " + f"SHA-1 hash, got: {self.commitid!r}" + ) + + self.exclude_list: set[str] = set(cli_args.exclude_list or []) + self.force_clone: bool = cli_args.force_clone + self.cache_dir: Path | None = ( + Path(cli_args.cache_dir) if cli_args.cache_dir else None + ) + self.rock_url: str = "https://github.com/ROCm/TheRock.git" + self.cache_root: Path | None = None + + self._logger = logging.getLogger("rock_branching") + + self.log("Authentication Mode: SSH") + self.log(f"Dry run mode = {self.dry_run}") + if self.exclude_list: + self.log(f"Exclude list: {self.exclude_list}") + + def log(self, msg: str) -> None: + """Log an info-level message.""" + self._logger.info(msg) + + def run_command( + self, + args: list[str | Path], + cwd: Path, + *, + input_data: bytes | None = None, + stream: bool = False, + timeout: int | None = None, + ) -> None: + """Execute a subprocess command, raising CalledProcessError on failure. + + Args: + args: Command and arguments to execute. + cwd: Working directory for the command. + input_data: Optional bytes piped to stdin. + stream: If True, print stdout/stderr line-by-line as it arrives + (useful for long-running operations like clone/fetch). + If False, buffer output and log after completion. + timeout: Maximum seconds to wait before raising TimeoutExpired. + """ + cmd = args if isinstance(args, list) else [args] + self.log(f"++ Exec [{cwd}]$ {shlex.join(map(str, cmd))}") + sys.stdout.flush() + + if stream: + process = subprocess.Popen( + cmd, + cwd=str(cwd), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + for line in process.stdout: + self.log(line.rstrip()) + + try: + ret = process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + raise + if ret != 0: + raise subprocess.CalledProcessError(ret, cmd) + + return + + try: + result = subprocess.run( + cmd, + cwd=str(cwd), + shell=False, + input=input_data, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + stdin=None if input_data else subprocess.DEVNULL, + text=False, + timeout=timeout, + ) + + if result.stdout: + self.log( + result.stdout + if isinstance(result.stdout, str) + else result.stdout.decode(errors="ignore") + ) + if result.stderr: + self.log( + result.stderr + if isinstance(result.stderr, str) + else result.stderr.decode(errors="ignore") + ) + + except subprocess.CalledProcessError as exc: + self.log( + (exc.stdout or b"").decode(errors="ignore") + if isinstance(exc.stdout, bytes) + else (exc.stdout or "") + ) + self.log( + (exc.stderr or b"").decode(errors="ignore") + if isinstance(exc.stderr, bytes) + else (exc.stderr or "") + ) + raise + + def run_command_output( + self, args: list[str | Path], cwd: Path, timeout: int | None = None + ) -> str: + """Run a command and return its stripped stdout as a string. + + Raises CalledProcessError on non-zero exit. + Raises subprocess.TimeoutExpired when *timeout* seconds elapse. + """ + cmd = args if isinstance(args, list) else [args] + self.log(f"++ Exec [{cwd}]$ {shlex.join(map(str, cmd))}") + + result = subprocess.run( + cmd, + cwd=str(cwd), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + stdin=subprocess.DEVNULL, + timeout=timeout, + ) + return result.stdout.strip() + + def _setup_remote(self, url: str, repo_dir: Path) -> None: + """Add or update the rocm-github remote for a repo.""" + remote_url = self.convert_to_ssh(url) + try: + self.run_command( + ["git", "remote", "set-url", "rocm-github", remote_url], + cwd=repo_dir, + ) + except subprocess.CalledProcessError: + self.run_command( + ["git", "remote", "add", "rocm-github", remote_url], + cwd=repo_dir, + ) + + def _remote_branch_exists(self, repo_dir: Path) -> bool: + """Return True if the release branch already exists on rocm-github. + + Raises CalledProcessError if the remote check itself fails. + Raises subprocess.TimeoutExpired if the network call hangs (60 s). + """ + output = self.run_command_output( + ["git", "ls-remote", "--heads", "rocm-github", self.release_branch], + cwd=repo_dir, + timeout=60, + ) + return bool(output) + + def _create_branch(self, commit: str, repo_dir: Path) -> None: + """Create (or reset) the release branch at the given commit.""" + self.run_command( + ["git", "checkout", "-B", self.release_branch, commit], + cwd=repo_dir, + ) + + def _push_branch(self, repo_name: str, repo_dir: Path) -> None: + """Push the release branch to rocm-github, respecting dry-run mode.""" + if self.dry_run: + self.log( + f"[DRY RUN] Skipping push of {self.release_branch} " + f"for {repo_name}" + ) + else: + self.run_command( + ["git", "push", "rocm-github", self.release_branch], + cwd=repo_dir, + timeout=120, + ) + + def execute_plan(self, plan: dict[str, RepoInfo]) -> None: + """Execute the branching plan for every repo in *plan*. + + For each repo: + 1. Set up the ``rocm-github`` remote with the SSH URL. + 2. Guard against a pre-existing remote branch (recorded as skipped, not failed, if found). + 3. Create (or reset) the release branch at the recorded commit SHA. + 4. Push to ``rocm-github`` (skipped in dry-run mode). + """ + successful_repos: dict[str, RepoInfo] = {} + skipped_repos: dict[str, str] = {} + failed_repos: dict[str, str] = {} + + for repo_name, info in plan.items(): + self.log(f"Processing {repo_name} at {info.path}") + + if not info.path.exists(): + failed_repos[repo_name] = ( + f"Repo path does not exist: {info.path}" + ) + continue + + try: + self._setup_remote(info.url, info.path) + except subprocess.CalledProcessError as exc: + failed_repos[repo_name] = f"Remote setup failed: {exc}" + continue + + try: + branch_exists = self._remote_branch_exists(info.path) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: + failed_repos[repo_name] = ( + f"Remote branch check failed: {exc}" + ) + continue + + if branch_exists: + msg = ( + f"Remote branch {self.release_branch} already exists " + "on rocm-github" + ) + self.log(msg) + skipped_repos[repo_name] = msg + continue + + try: + self._create_branch(info.commit, info.path) + except subprocess.CalledProcessError as exc: + failed_repos[repo_name] = ( + f"Branch creation failed at {info.commit}: {exc}" + ) + continue + + try: + self._push_branch(repo_name, info.path) + successful_repos[repo_name] = info + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: + failed_repos[repo_name] = f"Branch push failed: {exc}" + + self.log( + f"Summary: {len(successful_repos)} succeeded, " + f"{len(skipped_repos)} skipped, " + f"{len(failed_repos)} failed out of {len(plan)} repos" + ) + if successful_repos: + self.log(f"Successful repos: {pformat(successful_repos)}") + if skipped_repos: + self.log(f"Skipped repos (branch already exists): {pformat(skipped_repos)}") + if failed_repos: + self.log(f"Failed repos: {pformat(failed_repos)}") + + def convert_to_ssh(self, url: str) -> str: + """Convert https://github.com/X/Y.git to git@github.com:X/Y.git.""" + if url.startswith("https://github.com/"): + path = url.replace("https://github.com/", "") + return f"git@github.com:{path}" + return url + + def get_submodule_url_map(self, repo_dir: Path) -> dict[str, str]: + """Return mapping of submodule working-tree paths to remote URLs.""" + gitmodules_path = repo_dir / ".gitmodules" + if not gitmodules_path.exists(): + return {} + + try: + path_entries = self.run_command_output( + [ + "git", + "config", + "--file", + str(gitmodules_path), + "--get-regexp", + r"submodule\..*\.path", + ], + cwd=repo_dir, + ) + except subprocess.CalledProcessError: + return {} + + url_map: dict[str, str] = {} + for line in path_entries.splitlines(): + # Each line looks like: + # "submodule.external/hipcc.path external/hipcc" + parts = line.strip().split(None, 1) + if len(parts) != 2: + continue + key, path_value = parts + section = key.rsplit(".", 1)[0] + try: + url = self.run_command_output( + [ + "git", + "config", + "--file", + str(gitmodules_path), + "--get", + f"{section}.url", + ], + cwd=repo_dir, + ) + except subprocess.CalledProcessError: + self.log(f"No URL entry for {section}; skipping") + continue + url_map[path_value.strip()] = url + + return url_map + + def build_plan(self) -> dict[str, RepoInfo]: + """Build the branching execution plan. + + 1. Clone (or reuse cached clone of) TheRock. + 2. Check out and hard-reset to ``self.commitid``. + 3. Populate submodules via ``fetch_sources.py`` (or ``git submodule update``). + 4. Read ``git submodule status`` and ``.gitmodules`` to collect each + submodule's commit SHA, remote URL, and local path. + 5. Return a dict keyed by repo name, including TheRock itself. + """ + cache_root = ( + self.cache_dir + or Path(tempfile.gettempdir()) / "rock-branching-cache" + ) + cache_root.mkdir(parents=True, exist_ok=True) + clone_dir = cache_root / "TheRock" + self.cache_root = cache_root + + needs_clone = not clone_dir.exists() + if not needs_clone and not (clone_dir / ".git").exists(): + if not self.force_clone: + raise RuntimeError( + f"Cache directory {clone_dir} exists but is not a git " + "repo. Use --force-clone to delete it and reclone." + ) + self.log( + f"Cache directory {clone_dir} is not a git repo; " + "removing before reclone (--force-clone)" + ) + shutil.rmtree(clone_dir) + needs_clone = True + + if needs_clone: + self.log( + f"Cloning TheRock repo from {self.rock_url} into {clone_dir}" + ) + self.run_command( + ["git", "clone", str(self.rock_url), str(clone_dir)], + cwd=cache_root, + stream=True, + timeout=600, + ) + else: + self.log(f"Reusing existing TheRock repo at {clone_dir}") + + try: + remote_url = self.run_command_output( + ["git", "remote", "get-url", "origin"], + cwd=clone_dir, + ) + if "TheRock" not in remote_url: + raise RuntimeError( + f"Existing repo at {clone_dir} does not look like " + f"TheRock (origin={remote_url})" + ) + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"Failed to inspect existing repo at {clone_dir}: {exc}" + ) from exc + + self.log("Fetching latest changes for existing TheRock clone...") + self.run_command( + [ + "git", + "fetch", + "origin", + "--prune", + "--recurse-submodules=on-demand", + ], + cwd=clone_dir, + stream=True, + timeout=600, + ) + + fetch_script = clone_dir / "build_tools" / "fetch_sources.py" + rock_commit = self.commitid + + self.log(f"Checking out TheRock at commit {rock_commit}") + self.run_command(["git", "checkout", rock_commit], cwd=clone_dir) + self.run_command( + ["git", "reset", "--hard", rock_commit], cwd=clone_dir + ) + + if fetch_script.exists(): + self.log( + "Updating submodules via fetch_sources.py " + "(jobs=10, no patches)..." + ) + self.run_command( + [ + "python3", + str(fetch_script), + "--jobs", + "10", + "--no-apply-patches", + ], + cwd=clone_dir, + stream=True, + ) + else: + self.log( + "fetch_sources.py not found; " + "falling back to git submodule update" + ) + self.run_command( + ["git", "submodule", "update", "--init", "--recursive"], + cwd=clone_dir, + stream=True, + ) + + self.log("Reading submodule status...") + try: + status_output = self.run_command_output( + ["git", "submodule", "status"], + cwd=clone_dir, + ) + lines = status_output.split("\n") if status_output else [] + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"Failed to read submodule status: {exc}" + ) from exc + + url_map = self.get_submodule_url_map(clone_dir) + + plan: dict[str, RepoInfo] = {} + + # Each line from `git submodule status` looks like: + # " ()" or "- " (not initialized) + for line in lines: + if not line: + continue + + parts = line.split() + if len(parts) < 2: + continue + sha = parts[0].lstrip("-+") + path = parts[1] + + repo_name = Path(path).name + repo_url = url_map.get(path) + + if not repo_url: + self.log( + f"No URL found for submodule {path} in .gitmodules" + ) + continue + + if repo_name in self.exclude_list: + self.log(f"Skipping {repo_name} (in exclude list)") + continue + + url_lower = repo_url.lower() + if ( + "github.com/rocm/" not in url_lower + and "github.com:rocm/" not in url_lower + ): + self.log( + f"Skipping {repo_name} " + f"(not a ROCm org repo: {repo_url})" + ) + continue + + plan[repo_name] = RepoInfo( + url=repo_url, + commit=sha, + path=clone_dir / path, + ) + + plan["TheRock"] = RepoInfo( + url=self.rock_url, + commit=rock_commit, + path=clone_dir, + ) + + return plan + + def run(self) -> None: + """Build the execution plan and execute it.""" + plan = self.build_plan() + self.log(f"Execution plan:\n{pformat(plan)}") + self.execute_plan(plan) + + +def main(argv: list[str]) -> int: + """Parse arguments and run the branching automation.""" + parser = argparse.ArgumentParser( + description="Rock Branching Automation Tool", + ) + parser.add_argument( + "-B", + "--branch-name", + required=True, + help="Name of the release branch to create", + ) + parser.add_argument( + "-C", + "--commitid", + required=True, + help="Commit ID of TheRock to branch from", + ) + parser.add_argument( + "--dry-run", + action=argparse.BooleanOptionalAction, + default=True, + help="Log actions without pushing to remotes (default: enabled)", + ) + parser.add_argument( + "--exclude-list", + nargs="*", + default=[], + help="List of submodule repo names to exclude from branching", + ) + parser.add_argument( + "--force-clone", + action="store_true", + default=False, + help="Delete and reclone if cache directory exists but is not a " + "valid git repo", + ) + parser.add_argument( + "--cache-dir", + default=None, + help="Directory to cache the TheRock clone " + "(default: /tmp/rock-branching-cache)", + ) + args = parser.parse_args(argv) + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" + ) + + try: + RockBranchingAutomation(args).run() + return 0 + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: + logging.error("Command failed: %s", exc) + return 1 + except RuntimeError as exc: + logging.error("%s", exc) + return 1 + except Exception as exc: + logging.error("Unexpected error: %s", exc) + return 1 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/scripts/tests/test_create_release_branch.py b/scripts/tests/test_create_release_branch.py new file mode 100644 index 00000000..186e3db8 --- /dev/null +++ b/scripts/tests/test_create_release_branch.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +# Copyright Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +""" +Tests for create_release_branch.py. + +Covers: +- convert_to_ssh URL conversion +- ROCm org filtering logic +- get_submodule_url_map parsing +- execute_plan behaviour with mocked subprocess calls +""" +import subprocess +import textwrap +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, call, patch + +import pytest + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) +from scripts.create_release_branch import RepoInfo, RockBranchingAutomation + + +_FAKE_COMMIT = "a" * 40 + + +def make_automation(**kwargs) -> RockBranchingAutomation: + defaults = dict( + branch_name="release/6.4", + commitid=_FAKE_COMMIT, + dry_run=True, + exclude_list=[], + force_clone=False, + cache_dir=None, + ) + defaults.update(kwargs) + return RockBranchingAutomation(SimpleNamespace(**defaults)) + + +# --------------------------------------------------------------------------- +# convert_to_ssh +# --------------------------------------------------------------------------- + +class TestConvertToSsh: + def test_https_converted(self): + auto = make_automation() + assert auto.convert_to_ssh("https://github.com/ROCm/hip.git") == \ + "git@github.com:ROCm/hip.git" + + def test_https_without_dot_git(self): + auto = make_automation() + assert auto.convert_to_ssh("https://github.com/ROCm/clr") == \ + "git@github.com:ROCm/clr" + + def test_ssh_url_passthrough(self): + auto = make_automation() + url = "git@github.com:ROCm/hip.git" + assert auto.convert_to_ssh(url) == url + + def test_non_github_url_passthrough(self): + auto = make_automation() + url = "https://gitlab.com/someorg/repo.git" + assert auto.convert_to_ssh(url) == url + + +# --------------------------------------------------------------------------- +# ROCm org filter logic (mirrors the logic in build_plan) +# --------------------------------------------------------------------------- + +class TestRocmOrgFilter: + @pytest.mark.parametrize("url,is_rocm", [ + ("https://github.com/ROCm/hip.git", True), + ("https://github.com/rocm/hip.git", True), # case-insensitive + ("git@github.com:ROCm/clr.git", True), + ("git@github.com:rocm/clr.git", True), + ("https://github.com/llvm/llvm-project.git", False), + ("https://github.com/other/repo.git", False), + ("https://gitlab.com/ROCm/hip.git", False), # wrong host + ]) + def test_rocm_org_detection(self, url, is_rocm): + url_lower = url.lower() + result = ( + "github.com/rocm/" in url_lower + or "github.com:rocm/" in url_lower + ) + assert result == is_rocm + + +# --------------------------------------------------------------------------- +# get_submodule_url_map +# --------------------------------------------------------------------------- + +class TestGetSubmoduleUrlMap: + def test_no_gitmodules_returns_empty(self, tmp_path): + auto = make_automation() + assert auto.get_submodule_url_map(tmp_path) == {} + + def test_parses_paths_and_urls(self, tmp_path): + gitmodules = tmp_path / ".gitmodules" + gitmodules.write_text(textwrap.dedent("""\ + [submodule "external/hip"] + path = external/hip + url = https://github.com/ROCm/hip.git + [submodule "external/clr"] + path = external/clr + url = https://github.com/ROCm/clr.git + """)) + + auto = make_automation() + url_map = auto.get_submodule_url_map(tmp_path) + + assert url_map["external/hip"] == "https://github.com/ROCm/hip.git" + assert url_map["external/clr"] == "https://github.com/ROCm/clr.git" + + def test_missing_url_entry_skipped(self, tmp_path): + # A path entry with no corresponding URL entry should be silently skipped. + gitmodules = tmp_path / ".gitmodules" + gitmodules.write_text(textwrap.dedent("""\ + [submodule "external/hip"] + path = external/hip + """)) + + auto = make_automation() + url_map = auto.get_submodule_url_map(tmp_path) + assert "external/hip" not in url_map + + +# --------------------------------------------------------------------------- +# execute_plan — integration tests with mocked subprocess +# --------------------------------------------------------------------------- + +def _make_plan(tmp_path: Path) -> dict[str, RepoInfo]: + repo_dir = tmp_path / "hip" + repo_dir.mkdir() + return { + "hip": RepoInfo( + url="https://github.com/ROCm/hip.git", + commit="b" * 40, + path=repo_dir, + ) + } + + +class TestExecutePlan: + def test_dry_run_does_not_push(self, tmp_path): + auto = make_automation(dry_run=True) + plan = _make_plan(tmp_path) + + with patch.object(auto, "_setup_remote"), \ + patch.object(auto, "_remote_branch_exists", return_value=False), \ + patch.object(auto, "_create_branch"), \ + patch.object(auto, "run_command") as mock_run: + auto.execute_plan(plan) + + for c in mock_run.call_args_list: + assert "push" not in c.args[0], \ + f"Unexpected push call in dry-run: {c}" + + def test_existing_remote_branch_goes_to_skipped_not_failed(self, tmp_path): + auto = make_automation(dry_run=True) + plan = _make_plan(tmp_path) + + with patch.object(auto, "_setup_remote"), \ + patch.object(auto, "_remote_branch_exists", return_value=True), \ + patch.object(auto, "_create_branch") as mock_create, \ + patch.object(auto, "_push_branch") as mock_push: + auto.execute_plan(plan) + + mock_create.assert_not_called() + mock_push.assert_not_called() + + def test_missing_repo_path_recorded_as_failure(self, tmp_path): + auto = make_automation(dry_run=True) + plan = { + "missing": RepoInfo( + url="https://github.com/ROCm/missing.git", + commit="c" * 40, + path=tmp_path / "nonexistent", + ) + } + # Must not raise; logs the failure and moves on. + auto.execute_plan(plan) + + def test_setup_remote_failure_recorded_not_raised(self, tmp_path): + auto = make_automation(dry_run=True) + plan = _make_plan(tmp_path) + + with patch.object( + auto, "_setup_remote", + side_effect=subprocess.CalledProcessError(1, "git remote"), + ): + auto.execute_plan(plan) # must not raise + + def test_create_branch_failure_recorded_not_raised(self, tmp_path): + auto = make_automation(dry_run=True) + plan = _make_plan(tmp_path) + + with patch.object(auto, "_setup_remote"), \ + patch.object(auto, "_remote_branch_exists", return_value=False), \ + patch.object( + auto, "_create_branch", + side_effect=subprocess.CalledProcessError(1, "git checkout"), + ): + auto.execute_plan(plan) # must not raise + + def test_successful_dry_run_calls_create_branch(self, tmp_path): + auto = make_automation(dry_run=True) + plan = _make_plan(tmp_path) + + with patch.object(auto, "_setup_remote"), \ + patch.object(auto, "_remote_branch_exists", return_value=False), \ + patch.object(auto, "_create_branch") as mock_create, \ + patch.object(auto, "_push_branch") as mock_push: + auto.execute_plan(plan) + + mock_create.assert_called_once_with("b" * 40, plan["hip"].path) + mock_push.assert_called_once_with("hip", plan["hip"].path) + + def test_no_dry_run_calls_push(self, tmp_path): + auto = make_automation(dry_run=False) + plan = _make_plan(tmp_path) + + with patch.object(auto, "_setup_remote"), \ + patch.object(auto, "_remote_branch_exists", return_value=False), \ + patch.object(auto, "_create_branch"), \ + patch.object(auto, "run_command") as mock_run: + auto.execute_plan(plan) + + push_calls = [ + c for c in mock_run.call_args_list + if "push" in c.args[0] + ] + assert len(push_calls) == 1 + + +# --------------------------------------------------------------------------- +# commitid validation +# --------------------------------------------------------------------------- + +class TestCommitidValidation: + def test_valid_sha_accepted(self): + make_automation(commitid="a" * 40) # should not raise + + def test_short_sha_rejected(self): + with pytest.raises(SystemExit): + make_automation(commitid="abc123") + + def test_uppercase_sha_rejected(self): + with pytest.raises(SystemExit): + make_automation(commitid="A" * 40) + + def test_non_hex_rejected(self): + with pytest.raises(SystemExit): + make_automation(commitid="z" * 40)