diff --git a/scripts/claim_inventory.py b/scripts/claim_inventory.py index 768e05f2..bb0e6810 100644 --- a/scripts/claim_inventory.py +++ b/scripts/claim_inventory.py @@ -17,10 +17,14 @@ from scripts.api_host_args import public_api_host from scripts.bounty_refs import BOUNTY_REF_RE +from scripts.public_json_fetch import PublicJsonError, load_public_json DEFAULT_API_HOST = "https://api.mrwk.online" GH_TIMEOUT_SECONDS = 30 GH_LIMIT = 200 +GH_PUBLIC_API_SAFETY_CAP = 201 +GH_ISSUE_SAFETY_CAP = 201 +GH_PR_SAFETY_CAP = 201 GITHUB_URL_RE = re.compile( r"https://github\.com/[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+/" r"(?:issues|pull)/\d+(?:#[A-Za-z0-9_.-]+)?" @@ -571,24 +575,37 @@ def _load_pr_review_comments(repo: str, pr_number: int) -> list[dict[str, Any]]: def _get_json(url: str) -> Any: - request = urllib.request.Request(url, headers={"accept": "application/json"}) try: - with urllib.request.urlopen(request, timeout=GH_TIMEOUT_SECONDS) as response: - return json.loads(response.read().decode("utf-8")) - except (urllib.error.URLError, TimeoutError) as exc: + return load_public_json( + url, + description="public API request", + timeout=GH_TIMEOUT_SECONDS, + ) + except PublicJsonError as exc: raise RuntimeError(f"public API request failed: {url}") from exc def load_public_api_state(api_host: str) -> dict[str, Any]: host = api_host.rstrip("/") - bounties = _get_json(f"{host}/api/v1/bounties?limit={GH_LIMIT}") - activity = _get_json(f"{host}/api/v1/activity?limit={GH_LIMIT}") + bounties = _get_json(f"{host}/api/v1/bounties?limit={GH_PUBLIC_API_SAFETY_CAP}") + if isinstance(bounties, list) and len(bounties) >= GH_PUBLIC_API_SAFETY_CAP: + raise RuntimeError( + f"public bounties list reached the {GH_PUBLIC_API_SAFETY_CAP} item safety cap; " + "use an API-paginated collector before trusting this live report" + ) + activity = _get_json(f"{host}/api/v1/activity?limit={GH_PUBLIC_API_SAFETY_CAP}") data: dict[str, Any] = {} if isinstance(bounties, list): data["bounties"] = bounties if isinstance(activity, dict): contributors = activity.get("contributors") if isinstance(contributors, list): + if len(contributors) >= GH_PUBLIC_API_SAFETY_CAP: + raise RuntimeError( + f"public activity contributors list reached the " + f"{GH_PUBLIC_API_SAFETY_CAP} item safety cap; " + "use an API-paginated collector before trusting this live report" + ) data["contributors"] = contributors recent = activity.get("recent") if isinstance(recent, list): @@ -607,11 +624,16 @@ def load_live_inventory(repo: str, api_host: str) -> dict[str, Any]: "--state", "open", "--limit", - str(GH_LIMIT), + str(GH_ISSUE_SAFETY_CAP), "--json", "number,title,url,labels,author", ] ) + if len(issue_list) >= GH_ISSUE_SAFETY_CAP: + raise RuntimeError( + f"gh issue list reached the {GH_ISSUE_SAFETY_CAP} item safety cap; " + "use an API-paginated collector before trusting this live report" + ) issues: list[dict[str, Any]] = [] for issue in issue_list: if ( @@ -643,11 +665,16 @@ def load_live_inventory(repo: str, api_host: str) -> dict[str, Any]: "--state", "open", "--limit", - str(GH_LIMIT), + str(GH_PR_SAFETY_CAP), "--json", "number,title,url,body,author,labels", ] ) + if len(prs) >= GH_PR_SAFETY_CAP: + raise RuntimeError( + f"gh pr list reached the {GH_PR_SAFETY_CAP} item safety cap; " + "use an API-paginated collector before trusting this live report" + ) pull_requests: list[dict[str, Any]] = [] for pr in prs: if not isinstance(pr, dict) or not isinstance(pr.get("number"), int): @@ -680,6 +707,15 @@ def _load_input(path: str) -> dict[str, Any]: return data +def _require_non_empty_arg(parser: argparse.ArgumentParser, option_name: str, value: str) -> str: + stripped = value.strip() + if not stripped: + parser.error(f"{option_name} must be a non-empty value") + if stripped != value: + parser.error(f"{option_name} must not include leading or trailing whitespace") + return value + + def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser( description="Inventory public MergeWork claim surfaces and payout status." @@ -691,7 +727,13 @@ def main(argv: list[str] | None = None) -> int: parser.add_argument("--format", choices=["json", "markdown"], default="markdown") args = parser.parse_args(argv) - data = _load_input(args.input) if args.input else load_live_inventory(args.repo, args.api_host) + if args.input is not None: + data = _load_input(_require_non_empty_arg(parser, "--input", args.input)) + else: + data = load_live_inventory( + _require_non_empty_arg(parser, "--repo", args.repo), + args.api_host, + ) report = analyze_inventory(data, api_host=args.api_host) if args.format == "json": print(json.dumps(report, indent=2, sort_keys=True)) diff --git a/scripts/public_json_fetch.py b/scripts/public_json_fetch.py new file mode 100644 index 00000000..95f48c4f --- /dev/null +++ b/scripts/public_json_fetch.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import urllib.error +import urllib.request +from typing import Any + + +class PublicJsonError(RuntimeError): + """Raised when a public JSON endpoint cannot be fetched or decoded.""" + + +def load_public_json( + url: str, + *, + description: str | None = None, + timeout: float = 30, + user_agent: str = "mergework-maintenance-script", + accept: str = "application/json", +) -> Any: + label = description or "public JSON" + request = urllib.request.Request( + url, + headers={ + "Accept": accept, + "User-Agent": user_agent, + }, + ) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + payload = response.read().decode("utf-8") + except urllib.error.HTTPError as exc: + raise PublicJsonError(f"{label} unavailable: HTTP {exc.code}") from exc + except (urllib.error.URLError, TimeoutError) as exc: + raise PublicJsonError(f"{label} unavailable: {exc}") from exc + try: + return json.loads(payload) + except json.JSONDecodeError as exc: + raise PublicJsonError(f"{label} unavailable: invalid JSON") from exc diff --git a/tests/test_public_json_fetch.py b/tests/test_public_json_fetch.py new file mode 100644 index 00000000..a23ea54e --- /dev/null +++ b/tests/test_public_json_fetch.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import io +import json +from unittest.mock import patch + +import pytest +import urllib.error + +from scripts.public_json_fetch import PublicJsonError, load_public_json + + +def test_load_public_json_returns_decoded_payload() -> None: + payload = {"status": "ok"} + response = io.BytesIO(json.dumps(payload).encode("utf-8")) + response.status = 200 # type: ignore[attr-defined] + + with patch("urllib.request.urlopen", return_value=response): + assert load_public_json("https://example.test/api/v1/status") == payload + + +def test_load_public_json_sets_default_headers() -> None: + captured: dict[str, str] = {} + + class FakeResponse(io.BytesIO): + def __init__(self) -> None: + super().__init__(b"[]") + + def fake_urlopen(request, timeout=30): # noqa: ANN001 + captured["accept"] = request.get_header("Accept") + captured["user_agent"] = request.get_header("User-agent") + captured["timeout"] = str(timeout) + return FakeResponse() + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + load_public_json("https://example.test/api/v1/bounties") + + assert captured["accept"] == "application/json" + assert captured["user_agent"] == "mergework-maintenance-script" + assert captured["timeout"] == "30" + + +def test_load_public_json_wraps_http_error() -> None: + error = urllib.error.HTTPError( + url="https://example.test/api/v1/bounties", + code=503, + msg="service unavailable", + hdrs=None, + fp=None, + ) + with patch("urllib.request.urlopen", side_effect=error): + with pytest.raises(PublicJsonError, match="MergeWork API bounty data unavailable: HTTP 503"): + load_public_json( + "https://example.test/api/v1/bounties", + description="MergeWork API bounty data", + ) + + +def test_load_public_json_wraps_timeout() -> None: + with patch("urllib.request.urlopen", side_effect=TimeoutError("timed out")): + with pytest.raises(PublicJsonError, match="public JSON unavailable:"): + load_public_json("https://example.test/api/v1/status") + + +def test_load_public_json_wraps_invalid_json() -> None: + response = io.BytesIO(b"not-json") + + with patch("urllib.request.urlopen", return_value=response): + with pytest.raises(PublicJsonError, match="public JSON unavailable: invalid JSON"): + load_public_json("https://example.test/api/v1/status") + + +def test_load_public_json_honors_custom_timeout() -> None: + captured: dict[str, str] = {} + + class FakeResponse(io.BytesIO): + def __init__(self) -> None: + super().__init__(b"{}") + + def fake_urlopen(request, timeout=30): # noqa: ANN001 + captured["timeout"] = str(timeout) + return FakeResponse() + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + load_public_json("https://example.test/api/v1/status", timeout=12) + + assert captured["timeout"] == "12"