Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci(*:skip) Improve update_changelog.py to include all contributors and make it faster #4912

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 123 additions & 80 deletions dev/update_changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@

import pathlib
import re

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import date
from sys import argv
from typing import Optional

import git
from git import Commit
from github import Github
from github.PullRequest import PullRequest
from github.Repository import Repository
from github.Tag import Tag

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib


REPO_NAME = "adap/flower"
CHANGELOG_FILE = "framework/docs/source/ref-changelog.md"
Expand All @@ -39,8 +43,8 @@
# Load the TOML configuration
with (pathlib.Path(__file__).parent.resolve() / "changelog_config.toml").open(
"rb"
) as file:
CONFIG = tomllib.load(file)
) as toml_f:
CONFIG = tomllib.load(toml_f)

# Extract types, project, and scope from the config
TYPES = "|".join(CONFIG["type"])
Expand All @@ -52,12 +56,29 @@
PATTERN_TEMPLATE = CONFIG["pattern_template"]
PATTERN = PATTERN_TEMPLATE.format(types=TYPES, projects=PROJECTS, scope=SCOPE)

# Local git repository
LOCAL_REPO = git.Repo(search_parent_directories=True)

# Map PR types to sections in the changelog
PR_TYPE_TO_SECTION = {
"feat": "### New features",
"docs": "### Documentation improvements",
"break": "### Incompatible changes",
"ci": "### Other changes",
"fix": "### Other changes",
"refactor": "### Other changes",
"unknown": "### Unknown changes",
}

# Maximum number of workers in the thread pool
MAX_WORKERS = argv[2] if len(argv) > 2 else 10


def _get_latest_tag(gh_api: Github) -> tuple[Repository, Optional[Tag]]:
def _get_latest_tag(gh_api: Github) -> tuple[Repository, str]:
"""Retrieve the latest tag from the GitHub repository."""
repo = gh_api.get_repo(REPO_NAME)
tags = repo.get_tags()
return repo, tags[0] if tags.totalCount > 0 else None
tags = sorted(LOCAL_REPO.tags, key=lambda t: t.commit.committed_datetime)
return repo, tags[-1].name


def _add_shortlog(new_version: str, shortlog: str) -> None:
Expand Down Expand Up @@ -89,24 +110,64 @@ def _add_shortlog(new_version: str, shortlog: str) -> None:
file.write(line)


def _get_pull_requests_since_tag(
repo: Repository, tag: Tag
) -> tuple[str, set[PullRequest]]:
"""Get a list of pull requests merged into the main branch since a given tag."""
commit_shas = set()
contributors = set()
prs = set()
def _git_commits_since_tag(tag: str) -> list[Commit]:
"""Get a set of commits since a given tag."""
return list(LOCAL_REPO.iter_commits(f"{tag}..origin/main"))

for commit in repo.compare(tag.commit.sha, "main").commits:
commit_shas.add(commit.sha)

def _get_contributors_from_commits(api: Github, commits: list[Commit]) -> set[str]:
"""Get a set of contributors from a set of commits."""
# Get authors and co-authors from the commits
contributors: set[str] = set()
coauthor_names_emails: set[tuple[str, str]] = set()
coauthor_pattern = r"Co-authored-by:\s*(.+?)\s*<(.+?)>"

for commit in commits:
if commit.author.name is None:
continue
if "[bot]" in commit.author.name:
continue
# Find co-authors in the commit message
matches: list[str] = re.findall(coauthor_pattern, commit.message)

contributors.add(commit.author.name)
if matches:
coauthor_names_emails.update(matches)

# Get full names of the GitHub usernames
def _get_user(username: str, email: str) -> Optional[str]:
try:
if user := api.get_user(username):
if user.email == email:
return user.name
except Exception: # pylint: disable=broad-exception-caught
pass
print(f"FAILED to get user: {username} <{email}>")
return None

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
for name in executor.map(lambda x: _get_user(*x), coauthor_names_emails):
if name:
contributors.add(name)
return contributors


def _get_pull_requests_since_tag(
api: Github, repo: Repository, tag: str
) -> tuple[str, set[PullRequest]]:
"""Get a list of pull requests merged into the main branch since a given tag."""
prs = set()

print(f"Retrieving commits since tag '{tag}'...")
commits = _git_commits_since_tag(tag)

print("Retrieving contributors...")
contributors = _get_contributors_from_commits(api, commits)

print("Retrieving pull requests...")
commit_shas = {commit.hexsha for commit in commits}
for pr_info in repo.get_pulls(
state="closed", sort="created", direction="desc", base="main"
state="closed", sort="updated", direction="desc", base="main"
):
if pr_info.merge_commit_sha in commit_shas:
prs.add(pr_info)
Expand Down Expand Up @@ -167,83 +228,47 @@ def _extract_changelog_entry(
}


def _update_changelog(prs: set[PullRequest]) -> bool:
def _update_changelog(prs: set[PullRequest], tag: str) -> bool:
"""Update the changelog file with entries from provided pull requests."""
breaking_changes = False
unknown_changes = False

with open(CHANGELOG_FILE, "r+", encoding="utf-8") as file:
content = file.read()
unreleased_index = content.find("## Unreleased")

# Find the end of the Unreleased section
end_index = content.find(f"## {tag}", unreleased_index + 1)

for section in PR_TYPE_TO_SECTION.values():
if content.find(section, unreleased_index, end_index) == -1:
content = content[:end_index] + f"\n{section}\n\n" + content[end_index:]
end_index = content.find(f"## {tag}", end_index)

if unreleased_index == -1:
print("Unreleased header not found in the changelog.")
return False

# Find the end of the Unreleased section
next_header_index = content.find("## ", unreleased_index + 1)
next_header_index = (
next_header_index if next_header_index != -1 else len(content)
)

for pr_info in prs:
parsed_title = _extract_changelog_entry(pr_info)

# Skip if PR should be skipped or already in changelog
if (
parsed_title.get("scope", "unknown") == "skip"
or f"#{pr_info.number}]" in content
):
# Skip if the PR is already in changelog
if f"#{pr_info.number}]" in content:
continue

# Find section to insert
pr_type = parsed_title.get("type", "unknown")
if pr_type == "feat":
insert_content_index = content.find("### What", unreleased_index + 1)
elif pr_type == "docs":
insert_content_index = content.find(
"### Documentation improvements", unreleased_index + 1
)
elif pr_type == "break":
breaking_changes = True
insert_content_index = content.find(
"### Incompatible changes", unreleased_index + 1
)
elif pr_type in {"ci", "fix", "refactor"}:
insert_content_index = content.find(
"### Other changes", unreleased_index + 1
)
else:
unknown_changes = True
insert_content_index = unreleased_index
section = PR_TYPE_TO_SECTION.get(pr_type, "### Unknown changes")
insert_index = content.find(section, unreleased_index, end_index)

pr_reference = _format_pr_reference(
pr_info.title, pr_info.number, pr_info.html_url
)

content = _insert_entry_no_desc(
content,
pr_reference,
insert_content_index,
insert_index,
)

next_header_index = content.find("## ", unreleased_index + 1)
next_header_index = (
next_header_index if next_header_index != -1 else len(content)
)

if unknown_changes:
content = _insert_entry_no_desc(
content,
"### Unknown changes",
unreleased_index,
)

if not breaking_changes:
content = _insert_entry_no_desc(
content,
"None",
content.find("### Incompatible changes", unreleased_index + 1),
)
# Find the end of the Unreleased section
end_index = content.find(f"## {tag}", end_index)

# Finalize content update
file.seek(0)
Expand All @@ -263,34 +288,52 @@ def _insert_entry_no_desc(
return content


def _bump_minor_version(tag: Tag) -> Optional[str]:
def _bump_minor_version(tag: str) -> Optional[str]:
"""Bump the minor version of the tag."""
match = re.match(r"v(\d+)\.(\d+)\.(\d+)", tag.name)
match = re.match(r"v(\d+)\.(\d+)\.(\d+)", tag)
if match is None:
return None
major, minor, _ = [int(x) for x in match.groups()]
major, minor, _ = (int(x) for x in match.groups())
# Increment the minor version and reset patch version
new_version = f"v{major}.{minor + 1}.0"
return new_version


def _fetch_origin() -> None:
"""Fetch the latest changes from the origin."""
LOCAL_REPO.remote("origin").fetch()


def main() -> None:
"""Update changelog using the descriptions of PRs since the latest tag."""
start = time.time()

# Initialize GitHub Client with provided token (as argument)
gh_api = Github(argv[1])

# Fetch the latest changes from the origin
print("Fetching the latest changes from the origin...")
_fetch_origin()

# Get the repository and the latest tag
print("Retrieving the latest tag...")
repo, latest_tag = _get_latest_tag(gh_api)
if not latest_tag:
print("No tags found in the repository.")
return

shortlog, prs = _get_pull_requests_since_tag(repo, latest_tag)
if _update_changelog(prs):
# Get the shortlog and the pull requests since the latest tag
shortlog, prs = _get_pull_requests_since_tag(gh_api, repo, latest_tag)

# Update the changelog
print("Updating the changelog...")
if _update_changelog(prs, latest_tag):
new_version = _bump_minor_version(latest_tag)
if not new_version:
print("Wrong tag format.")
panh99 marked this conversation as resolved.
Show resolved Hide resolved
return
_add_shortlog(new_version, shortlog)
print("Changelog updated succesfully.")
print(f"Changelog updated successfully in {time.time() - start:.2f}s.")


if __name__ == "__main__":
Expand Down