From 03da4d11a1de2073984521a40dfe9c20d9c11e0a Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Fri, 29 May 2026 09:28:47 -0600 Subject: [PATCH 1/4] Use src-py-lib 0.1.6 GraphQL pagination Amp-Thread-ID: https://ampcode.com/threads/T-019e6fd7-ad23-7585-9a4b-d528a5dfc633 Co-authored-by: Amp --- pyproject.toml | 2 +- .../permissions/sourcegraph.py | 12 +------ tests/unit/test_snapshot.py | 35 +++++++------------ uv.lock | 8 ++--- 4 files changed, 19 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd63e95..28c43d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ dependencies = [ "json5>=0.14.0", "pyyaml>=6.0.3", - "src-py-lib==0.1.5", + "src-py-lib==0.1.6", ] keywords = [ "Sourcegraph" diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index c64c4b4..ec48ba8 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -171,11 +171,10 @@ def list_users_explicit_repo_ids( repository_ids_by_user_id: dict[str, list[str]] = {user_id: [] for user_id in user_ids} pending_pages: list[tuple[str, str | None]] = [(user_id, None) for user_id in user_ids] - graphql_client = _graphql_client_without_auto_pagination(client) while pending_pages: batch = pending_pages[:batch_size] del pending_pages[:batch_size] - data = graphql_client.execute( + data = client.graphql( _user_explicit_repos_batch_query(len(batch)), _user_explicit_repos_batch_variables(batch), follow_pages=False, @@ -237,15 +236,6 @@ def list_repositories_by_ids( return repositories -def _graphql_client_without_auto_pagination(client: src.SourcegraphClient) -> src.GraphQLClient: - return src.GraphQLClient( - url=f"{client.endpoint}/.api/graphql", - headers={"Authorization": f"token {client.token}"}, - label="Sourcegraph", - http=client.http, - ) - - def _batches(values: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]: for start_index in range(0, len(values), batch_size): yield values[start_index : start_index + batch_size] diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 3525fc9..bcb72f9 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -177,22 +177,14 @@ def test_list_users_explicit_repos_batches_aliases_and_follows_pages(self) -> No ), ] - class FakeGraphQLClient: - def __init__(self, **_kwargs: object) -> None: - pass - - def execute( - self, - query: str, - variables: src.JSONDict, - *, - follow_pages: bool = True, - ) -> src.JSONDict: - calls.append((query, dict(variables), follow_pages)) - return responses.pop(0) - - def graphql(query: str, variables: object = None) -> src.JSONDict: - return FakeGraphQLClient().execute(query, cast(src.JSONDict, variables or {})) + def graphql( + query: str, + variables: object = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + calls.append((query, dict(cast(src.JSONDict, variables or {})), follow_pages)) + return responses.pop(0) client = cast( src.SourcegraphClient, @@ -203,12 +195,11 @@ def graphql(query: str, variables: object = None) -> src.JSONDict: graphql=graphql, ), ) - with patch.object(permissions_sourcegraph.src, "GraphQLClient", FakeGraphQLClient): - repos_by_user_id = permissions_sourcegraph.list_users_explicit_repos( - client, - ["user-1", "user-2"], - batch_size=2, - ) + repos_by_user_id = permissions_sourcegraph.list_users_explicit_repos( + client, + ["user-1", "user-2"], + batch_size=2, + ) self.assertEqual( { diff --git a/uv.lock b/uv.lock index e759784..12eb8bd 100644 --- a/uv.lock +++ b/uv.lock @@ -337,7 +337,7 @@ dev = [ requires-dist = [ { name = "json5", specifier = ">=0.14.0" }, { name = "pyyaml", specifier = ">=6.0.3" }, - { name = "src-py-lib", specifier = "==0.1.5" }, + { name = "src-py-lib", specifier = "==0.1.6" }, ] [package.metadata.requires-dev] @@ -349,16 +349,16 @@ dev = [ [[package]] name = "src-py-lib" -version = "0.1.5" +version = "0.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, { name = "pydantic" }, { name = "python-dotenv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/68/39/dc534f18686d255141982cae8d7935c3cb807a6b98356d8936ed9c2d3b3d/src_py_lib-0.1.5.tar.gz", hash = "sha256:695f0fc0a2c539bd7ffc6c537822dca604fe8718de343c34f973765b31201d69", size = 71613, upload-time = "2026-05-29T08:58:25.545Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/e7/bc59bf44fc2130df83aeef64dc2666f4617f236b61b879dc8d5629609361/src_py_lib-0.1.6.tar.gz", hash = "sha256:e2c5b015e2bb077e6116ad7457654cc81d17d13bc9f05768fa6720719d350f93", size = 71768, upload-time = "2026-05-29T15:19:45.891Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ee/550cbda36b6853584f60f4acbbae781f2e0a38e12811fdcd8731532ed077/src_py_lib-0.1.5-py3-none-any.whl", hash = "sha256:1bafff027ccb68478d5712a5522e7e21dd4ef5fe51b14723fff95dbd6496db30", size = 44873, upload-time = "2026-05-29T08:58:24.475Z" }, + { url = "https://files.pythonhosted.org/packages/7e/eb/cddb0c92806cbb3ebfed4baa32ee6fa8550a2f5ac56c55de92e65f0066ba/src_py_lib-0.1.6-py3-none-any.whl", hash = "sha256:781c82fa42f48268a3b8b1ac7406fa69418dfd3d0ba3bc795b549093d004647a", size = 44956, upload-time = "2026-05-29T15:19:44.559Z" }, ] [[package]] From 865d7b25c459eac7f981579f1c4061a8d278ca6c Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Sat, 30 May 2026 01:39:35 -0600 Subject: [PATCH 2/4] Prepare for memory efficiency analysis --- dev/TODO.md | 5 - dev/analyze-memory.py | 618 +++++++++++++ dev/monitor-sourcegraph-load.sh | 348 +++++++ dev/run-memory-model-sweep.py | 822 +++++++++++++++++ ...ourcegraph-explicit-permissions-tracing.md | 71 +- dev/test-end-to-end.py | 860 ++++++++++++++++-- dev/test-plan.md | 65 ++ maps-example.yaml | 50 +- src/src_auth_perms_sync/cli.py | 10 + src/src_auth_perms_sync/permissions/apply.py | 4 +- .../permissions/command.py | 43 +- .../permissions/full_set.py | 16 +- .../permissions/mapping.py | 161 +++- .../permissions/queries.py | 37 +- .../permissions/snapshot.py | 2 +- .../permissions/sourcegraph.py | 36 +- src/src_auth_perms_sync/permissions/types.py | 5 +- src/src_auth_perms_sync/shared/queries.py | 37 +- src/src_auth_perms_sync/shared/sourcegraph.py | 12 +- src/src_auth_perms_sync/shared/types.py | 6 + tests/unit/test_cli_config.py | 37 + tests/unit/test_maps.py | 301 +++++- 22 files changed, 3383 insertions(+), 163 deletions(-) create mode 100755 dev/analyze-memory.py create mode 100755 dev/monitor-sourcegraph-load.sh create mode 100755 dev/run-memory-model-sweep.py diff --git a/dev/TODO.md b/dev/TODO.md index 7fe3c6a..a6ec344 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -1,10 +1,5 @@ # TODO -## High priority: Bump src-py-lib after Node ID helper release - -- After releasing `src-py-lib` with Sourcegraph Node ID helpers, update - `pyproject.toml` and `uv.lock` to depend on that new version. - ## Medium priority: Lightweight incremental updates - When a new user's account is created, or a new repo is synced from a code host, diff --git a/dev/analyze-memory.py b/dev/analyze-memory.py new file mode 100755 index 0000000..f3f062e --- /dev/null +++ b/dev/analyze-memory.py @@ -0,0 +1,618 @@ +#!/usr/bin/env python3 +"""Fit a Sourcegraph permissions memory model from e2e result JSON. + +The model is intentionally small and dependency-free: + + peak RSS MiB = intercept + users*b1 + repos*b2 + grants*b3 + +Use one command mode per fit. Mixing backup, no-backup, get, set, and restore +runs makes the per-grant coefficient much less useful. +""" + +from __future__ import annotations + +import argparse +import json +import math +import re +import statistics +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +FEATURE_NAMES = ("users", "repos", "grants") +COEFFICIENT_SCALE = { + "users": "bytes/user", + "repos": "bytes/repo", + "grants": "bytes/grant", +} + + +@dataclass(frozen=True) +class WorkloadDimensions: + """Canonical workload dimensions used by the memory model.""" + + users: float | None + repos: float | None + grants: float | None + + +@dataclass(frozen=True) +class MemoryObservation: + """One e2e command result with peak memory and workload dimensions.""" + + source_path: str + variant: str + case_name: str + command: str + iteration: int + peak_resident_megabytes: float + dimensions: WorkloadDimensions + + +@dataclass(frozen=True) +class MemoryModel: + """Fitted linear memory model.""" + + feature_names: tuple[str, ...] + coefficients_megabytes: dict[str, float] + observation_count: int + r_squared: float | None + mean_absolute_error_megabytes: float + p95_absolute_error_megabytes: float + max_absolute_error_megabytes: float + + +@dataclass(frozen=True) +class MemoryEstimate: + """Predicted memory for a proposed users x repos workload.""" + + dimensions: WorkloadDimensions + peak_resident_megabytes: float + peak_resident_megabytes_with_headroom: float + headroom_percent: float + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Fit a fixed + users + repos + grants memory model from e2e JSON.", + ) + parser.add_argument( + "results_json", + nargs="+", + type=Path, + help="One or more JSON files written by dev/test-end-to-end.py --results-json.", + ) + parser.add_argument( + "--variant", + help="Only include one variant, e.g. candidate or baseline.", + ) + parser.add_argument( + "--command", + help="Only include one structured command, e.g. set_full or get.", + ) + parser.add_argument( + "--case-regex", + help="Only include cases whose e2e case name matches this regular expression.", + ) + parser.add_argument( + "--features", + default="users,repos,grants", + help="Comma-separated model features from users,repos,grants (default: all).", + ) + parser.add_argument( + "--min-grants", + type=float, + default=1.0, + help="Drop observations below this grant count (default: 1).", + ) + parser.add_argument( + "--estimate-users", + type=float, + help="Estimate memory for this many users.", + ) + parser.add_argument( + "--estimate-repos", + type=float, + help="Estimate memory for this many repos.", + ) + parser.add_argument( + "--estimate-grants", + type=float, + help="Estimate memory for this many grants; defaults to users * repos.", + ) + parser.add_argument( + "--headroom-percent", + type=float, + default=30.0, + help="Headroom to add to estimates (default: 30).", + ) + parser.add_argument( + "--json", + action="store_true", + help="Write machine-readable JSON instead of a text report.", + ) + arguments = parser.parse_args() + + feature_names = parse_feature_names(arguments.features) + observations = load_observations(arguments.results_json) + filtered_observations = filter_observations( + observations, + variant=arguments.variant, + command=arguments.command, + case_regex=arguments.case_regex, + min_grants=arguments.min_grants, + ) + model_observations = observations_with_features(filtered_observations, feature_names) + minimum_observations = len(feature_names) + 1 + if len(model_observations) < minimum_observations: + print( + "Need at least " + f"{minimum_observations} observations with {', '.join(feature_names)} " + f"to fit this model; found {len(model_observations)}.", + file=sys.stderr, + ) + return 2 + + try: + model = fit_memory_model(model_observations, feature_names) + except ValueError as error: + print(f"Could not fit memory model: {error}", file=sys.stderr) + print( + "Try filtering to one command mode, adding varied users x repos shapes, " + "or using fewer --features.", + file=sys.stderr, + ) + return 2 + + estimate = build_estimate( + model, + feature_names, + estimate_users=arguments.estimate_users, + estimate_repos=arguments.estimate_repos, + estimate_grants=arguments.estimate_grants, + headroom_percent=arguments.headroom_percent, + ) + if arguments.json: + write_json_report(model, model_observations, estimate) + else: + write_text_report(model, model_observations, estimate) + return 0 + + +def parse_feature_names(raw_features: str) -> tuple[str, ...]: + names = tuple(name.strip() for name in raw_features.split(",") if name.strip()) + invalid = sorted(set(names) - set(FEATURE_NAMES)) + if invalid: + raise SystemExit(f"Unknown feature(s): {', '.join(invalid)}") + duplicates = sorted({name for name in names if names.count(name) > 1}) + if duplicates: + raise SystemExit(f"Duplicate feature(s): {', '.join(duplicates)}") + if not names: + raise SystemExit("At least one feature is required.") + return names + + +def load_observations(paths: list[Path]) -> list[MemoryObservation]: + observations: list[MemoryObservation] = [] + for path in paths: + with path.open(encoding="utf-8") as input_file: + payload: object = json.load(input_file) + for result in result_mappings(payload): + observation = observation_from_result(path, result) + if observation is not None: + observations.append(observation) + return observations + + +def result_mappings(payload: object) -> list[dict[str, Any]]: + if isinstance(payload, dict): + mapping = cast(dict[str, Any], payload) + results = mapping.get("results") + if isinstance(results, list): + return mapping_items(cast(list[object], results)) + if "memory" in mapping and "workload" in mapping: + return [mapping] + if isinstance(payload, list): + return mapping_items(cast(list[object], payload)) + return [] + + +def mapping_items(values: list[object]) -> list[dict[str, Any]]: + """Return only dict-like JSON objects from a JSON list.""" + return [cast(dict[str, Any], value) for value in values if isinstance(value, dict)] + + +def observation_from_result(path: Path, result: dict[str, Any]) -> MemoryObservation | None: + memory = object_mapping(result.get("memory")) + workload = object_mapping(result.get("workload")) + if memory is None or workload is None: + return None + peak_resident_megabytes = first_number(memory, ("peak_rss_mb", "external_peak_rss_mb")) + if peak_resident_megabytes is None: + return None + return MemoryObservation( + source_path=str(path), + variant=string_value(result.get("variant")), + case_name=string_value(result.get("case")), + command=string_value(result.get("command")), + iteration=integer_value(result.get("iteration")), + peak_resident_megabytes=peak_resident_megabytes, + dimensions=WorkloadDimensions( + users=first_number( + workload, + ( + "memory_model_user_count", + "selected_user_count", + "captured_user_count", + "snapshot_user_count_max", + "user_count", + "total_users_scanned", + "sourcegraph_user_count", + "total_users", + ), + ), + repos=first_number( + workload, + ( + "memory_model_repo_count", + "planned_repo_count", + "restore_snapshot_repo_count", + "snapshot_repos_with_explicit_grants_max", + "repos_with_explicit_grants", + "loaded_repo_count", + "repo_count", + ), + ), + grants=first_number( + workload, + ( + "memory_model_grant_count", + "planned_total_grants", + "restore_snapshot_total_grants", + "selected_total_grants", + "snapshot_total_grants_max", + "total_grants", + "apply_payload_grant_count", + ), + ), + ), + ) + + +def filter_observations( + observations: list[MemoryObservation], + *, + variant: str | None, + command: str | None, + case_regex: str | None, + min_grants: float, +) -> list[MemoryObservation]: + pattern = re.compile(case_regex) if case_regex else None + filtered: list[MemoryObservation] = [] + for observation in observations: + if variant is not None and observation.variant != variant: + continue + if command is not None and observation.command != command: + continue + if pattern is not None and pattern.search(observation.case_name) is None: + continue + if observation.dimensions.grants is None or observation.dimensions.grants < min_grants: + continue + filtered.append(observation) + return filtered + + +def observations_with_features( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> list[MemoryObservation]: + return [ + observation + for observation in observations + if all(feature_value(observation.dimensions, name) is not None for name in feature_names) + ] + + +def fit_memory_model( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> MemoryModel: + feature_scales = feature_scale_by_name(observations, feature_names) + matrix = [ + [1.0] + + [ + required_feature_value(observation.dimensions, feature_name) + / feature_scales[feature_name] + for feature_name in feature_names + ] + for observation in observations + ] + targets = [observation.peak_resident_megabytes for observation in observations] + scaled_coefficients = solve_normal_equations(matrix, targets) + coefficients = {"intercept": scaled_coefficients[0]} + for feature_index, feature_name in enumerate(feature_names, start=1): + coefficients[feature_name] = ( + scaled_coefficients[feature_index] / feature_scales[feature_name] + ) + predictions = [ + predict_megabytes(coefficients, observation.dimensions) for observation in observations + ] + residuals = [ + target - prediction for target, prediction in zip(targets, predictions, strict=True) + ] + absolute_residuals = [abs(residual) for residual in residuals] + target_mean = statistics.fmean(targets) + residual_sum_squares = sum(residual * residual for residual in residuals) + total_sum_squares = sum((target - target_mean) ** 2 for target in targets) + return MemoryModel( + feature_names=feature_names, + coefficients_megabytes=coefficients, + observation_count=len(observations), + r_squared=( + None if total_sum_squares == 0 else 1.0 - residual_sum_squares / total_sum_squares + ), + mean_absolute_error_megabytes=statistics.fmean(absolute_residuals), + p95_absolute_error_megabytes=percentile(absolute_residuals, 95.0), + max_absolute_error_megabytes=max(absolute_residuals), + ) + + +def feature_scale_by_name( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> dict[str, float]: + scales: dict[str, float] = {} + for feature_name in feature_names: + maximum = max( + abs(required_feature_value(observation.dimensions, feature_name)) + for observation in observations + ) + scales[feature_name] = maximum if maximum > 0 else 1.0 + return scales + + +def solve_normal_equations(matrix: list[list[float]], targets: list[float]) -> list[float]: + column_count = len(matrix[0]) + normal_matrix = [[0.0 for _ in range(column_count)] for _ in range(column_count)] + normal_targets = [0.0 for _ in range(column_count)] + for row, target in zip(matrix, targets, strict=True): + for row_index in range(column_count): + normal_targets[row_index] += row[row_index] * target + for column_index in range(column_count): + normal_matrix[row_index][column_index] += row[row_index] * row[column_index] + return solve_linear_system(normal_matrix, normal_targets) + + +def solve_linear_system(matrix: list[list[float]], values: list[float]) -> list[float]: + size = len(values) + augmented = [matrix[row_index][:] + [values[row_index]] for row_index in range(size)] + for pivot_index in range(size): + pivot_row = max( + range(pivot_index, size), + key=lambda row_index: abs(augmented[row_index][pivot_index]), + ) + pivot_value = augmented[pivot_row][pivot_index] + if abs(pivot_value) < 1e-12: + raise ValueError("features are collinear or the sample is too small") + augmented[pivot_index], augmented[pivot_row] = augmented[pivot_row], augmented[pivot_index] + for column_index in range(pivot_index, size + 1): + augmented[pivot_index][column_index] /= pivot_value + for row_index in range(size): + if row_index == pivot_index: + continue + factor = augmented[row_index][pivot_index] + for column_index in range(pivot_index, size + 1): + augmented[row_index][column_index] -= factor * augmented[pivot_index][column_index] + return [augmented[row_index][size] for row_index in range(size)] + + +def build_estimate( + model: MemoryModel, + feature_names: tuple[str, ...], + *, + estimate_users: float | None, + estimate_repos: float | None, + estimate_grants: float | None, + headroom_percent: float, +) -> MemoryEstimate | None: + if estimate_users is None and estimate_repos is None and estimate_grants is None: + return None + if "users" in feature_names and estimate_users is None: + raise SystemExit("--estimate-users is required because users is in --features.") + if "repos" in feature_names and estimate_repos is None: + raise SystemExit("--estimate-repos is required because repos is in --features.") + if "grants" in feature_names and estimate_grants is None: + if estimate_users is None or estimate_repos is None: + raise SystemExit( + "--estimate-grants is required unless --estimate-users and --estimate-repos " + "are both set." + ) + estimate_grants = estimate_users * estimate_repos + dimensions = WorkloadDimensions( + users=estimate_users, + repos=estimate_repos, + grants=estimate_grants, + ) + peak_resident_megabytes = predict_megabytes(model.coefficients_megabytes, dimensions) + return MemoryEstimate( + dimensions=dimensions, + peak_resident_megabytes=peak_resident_megabytes, + peak_resident_megabytes_with_headroom=peak_resident_megabytes + * (1.0 + headroom_percent / 100.0), + headroom_percent=headroom_percent, + ) + + +def predict_megabytes( + coefficients_megabytes: dict[str, float], dimensions: WorkloadDimensions +) -> float: + prediction = coefficients_megabytes["intercept"] + for feature_name in FEATURE_NAMES: + coefficient = coefficients_megabytes.get(feature_name) + value = feature_value(dimensions, feature_name) + if coefficient is not None and value is not None: + prediction += coefficient * value + return prediction + + +def write_text_report( + model: MemoryModel, observations: list[MemoryObservation], estimate: MemoryEstimate | None +) -> None: + print(f"Observations used: {model.observation_count}") + print(f"Features: {', '.join(model.feature_names)}") + print("\nCoefficients:") + print(f" intercept: {model.coefficients_megabytes['intercept']:.3f} MiB") + for feature_name in model.feature_names: + coefficient_megabytes = model.coefficients_megabytes[feature_name] + coefficient_bytes = coefficient_megabytes * 1024.0 * 1024.0 + print( + f" {feature_name}: {coefficient_megabytes:.9f} MiB/unit " + f"({coefficient_bytes:.1f} {COEFFICIENT_SCALE[feature_name]})" + ) + r_squared = "n/a" if model.r_squared is None else f"{model.r_squared:.4f}" + print("\nFit quality:") + print(f" R²: {r_squared}") + print(f" mean absolute error: {model.mean_absolute_error_megabytes:.2f} MiB") + print(f" p95 absolute error: {model.p95_absolute_error_megabytes:.2f} MiB") + print(f" max absolute error: {model.max_absolute_error_megabytes:.2f} MiB") + print("\nObserved range:") + print_dimension_range(observations, "users") + print_dimension_range(observations, "repos") + print_dimension_range(observations, "grants") + if estimate is not None: + print("\nEstimate:") + print(f" users: {format_optional_number(estimate.dimensions.users)}") + print(f" repos: {format_optional_number(estimate.dimensions.repos)}") + print(f" grants: {format_optional_number(estimate.dimensions.grants)}") + print(f" peak RSS: {estimate.peak_resident_megabytes:.1f} MiB") + print( + f" with {estimate.headroom_percent:g}% headroom: " + f"{estimate.peak_resident_megabytes_with_headroom:.1f} MiB" + ) + + +def write_json_report( + model: MemoryModel, observations: list[MemoryObservation], estimate: MemoryEstimate | None +) -> None: + report: dict[str, Any] = { + "observation_count": model.observation_count, + "features": list(model.feature_names), + "coefficients_mib": model.coefficients_megabytes, + "coefficients_bytes": { + feature_name: model.coefficients_megabytes[feature_name] * 1024.0 * 1024.0 + for feature_name in model.feature_names + }, + "fit": { + "r_squared": model.r_squared, + "mean_absolute_error_mib": model.mean_absolute_error_megabytes, + "p95_absolute_error_mib": model.p95_absolute_error_megabytes, + "max_absolute_error_mib": model.max_absolute_error_megabytes, + }, + "observed_range": observed_range_to_json(observations), + "estimate": estimate_to_json(estimate), + } + json.dump(report, sys.stdout, indent=2, sort_keys=True) + sys.stdout.write("\n") + + +def print_dimension_range(observations: list[MemoryObservation], feature_name: str) -> None: + values = [ + value + for observation in observations + if (value := feature_value(observation.dimensions, feature_name)) is not None + ] + if not values: + print(f" {feature_name}: n/a") + return + print(f" {feature_name}: {format_number(min(values))} .. {format_number(max(values))}") + + +def observed_range_to_json(observations: list[MemoryObservation]) -> dict[str, dict[str, float]]: + ranges: dict[str, dict[str, float]] = {} + for feature_name in FEATURE_NAMES: + values = [ + value + for observation in observations + if (value := feature_value(observation.dimensions, feature_name)) is not None + ] + if values: + ranges[feature_name] = {"min": min(values), "max": max(values)} + return ranges + + +def estimate_to_json(estimate: MemoryEstimate | None) -> dict[str, Any] | None: + if estimate is None: + return None + return { + "users": estimate.dimensions.users, + "repos": estimate.dimensions.repos, + "grants": estimate.dimensions.grants, + "peak_rss_mib": estimate.peak_resident_megabytes, + "headroom_percent": estimate.headroom_percent, + "peak_rss_mib_with_headroom": estimate.peak_resident_megabytes_with_headroom, + } + + +def object_mapping(value: object) -> dict[str, Any] | None: + return cast(dict[str, Any], value) if isinstance(value, dict) else None + + +def first_number(mapping: dict[str, Any], names: tuple[str, ...]) -> float | None: + for name in names: + value = mapping.get(name) + if isinstance(value, bool): + continue + if isinstance(value, int | float): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + continue + return None + + +def string_value(value: object) -> str: + return value if isinstance(value, str) else "" + + +def integer_value(value: object) -> int: + if isinstance(value, bool): + return 0 + return value if isinstance(value, int) else 0 + + +def feature_value(dimensions: WorkloadDimensions, feature_name: str) -> float | None: + if feature_name == "users": + return dimensions.users + if feature_name == "repos": + return dimensions.repos + if feature_name == "grants": + return dimensions.grants + raise ValueError(f"Unknown feature: {feature_name}") + + +def required_feature_value(dimensions: WorkloadDimensions, feature_name: str) -> float: + value = feature_value(dimensions, feature_name) + if value is None: + raise ValueError(f"Observation is missing feature: {feature_name}") + return value + + +def percentile(values: list[float], percentile_value: float) -> float: + if not values: + return math.nan + sorted_values = sorted(values) + index = math.ceil((percentile_value / 100.0) * len(sorted_values)) - 1 + return sorted_values[min(max(index, 0), len(sorted_values) - 1)] + + +def format_optional_number(value: float | None) -> str: + return "n/a" if value is None else format_number(value) + + +def format_number(value: float) -> str: + return f"{value:.0f}" if value.is_integer() else f"{value:.3f}" + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dev/monitor-sourcegraph-load.sh b/dev/monitor-sourcegraph-load.sh new file mode 100755 index 0000000..584529c --- /dev/null +++ b/dev/monitor-sourcegraph-load.sh @@ -0,0 +1,348 @@ +#!/usr/bin/env bash +set -euo pipefail + +namespace="${SRC_AUTH_PERMS_SYNC_MONITOR_NAMESPACE:-m}" +interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_INTERVAL_SECONDS:-5}" +postgres_interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_POSTGRES_INTERVAL_SECONDS:-10}" +statements_interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_STATEMENTS_INTERVAL_SECONDS:-30}" +duration_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_DURATION_SECONDS:-}" +output_dir="${SRC_AUTH_PERMS_SYNC_MONITOR_OUTPUT_DIR:-}" +frontend_target="${SRC_AUTH_PERMS_SYNC_MONITOR_FRONTEND_TARGET:-deployment/sourcegraph-frontend}" +postgres_target="${SRC_AUTH_PERMS_SYNC_MONITOR_POSTGRES_TARGET:-pod/pgsql-0}" +kubectl_bin="${KUBECTL:-kubectl}" +psql_command="${SRC_AUTH_PERMS_SYNC_MONITOR_PSQL_COMMAND:-psql -X -U sg -d sg}" +stream_logs=true + +usage() { + cat <<'EOF' +Usage: dev/monitor-sourcegraph-load.sh [options] + +Collect timestamped Sourcegraph pod load evidence while the e2e script runs. +Press Ctrl-C to stop, or pass --duration-seconds. + +Options: + --namespace NAME Kubernetes namespace (default: m) + --interval-seconds N Pod/process/cgroup sample interval (default: 5) + --postgres-interval-seconds N pg_stat_activity sample interval (default: 10) + --statements-interval-seconds N pg_stat_statements sample interval (default: 30) + --duration-seconds N Stop automatically after N seconds + --output-dir PATH Output directory (default: /tmp/src-auth-perms-sync-sourcegraph-load-) + --frontend-target TARGET kubectl target for frontend (default: deployment/sourcegraph-frontend) + --postgres-target TARGET kubectl target for Postgres (default: pod/pgsql-0) + --psql-command COMMAND Command to run inside Postgres pod (default: psql -X -U sg -d sg) + --no-logs Do not stream frontend logs + -h, --help Show this help + +Examples: + dev/monitor-sourcegraph-load.sh + + dev/monitor-sourcegraph-load.sh \ + --duration-seconds 1800 \ + --output-dir /tmp/src-auth-perms-sync-load-$(date -u +%Y%m%d-%H%M%S) + +In another terminal, run: + uv run python dev/test-end-to-end.py --trace --sample-interval 0 --external-sample-interval 0 +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --namespace) + namespace="$2" + shift 2 + ;; + --interval-seconds) + interval_seconds="$2" + shift 2 + ;; + --postgres-interval-seconds) + postgres_interval_seconds="$2" + shift 2 + ;; + --statements-interval-seconds) + statements_interval_seconds="$2" + shift 2 + ;; + --duration-seconds) + duration_seconds="$2" + shift 2 + ;; + --output-dir) + output_dir="$2" + shift 2 + ;; + --frontend-target) + frontend_target="$2" + shift 2 + ;; + --postgres-target) + postgres_target="$2" + shift 2 + ;; + --psql-command) + psql_command="$2" + shift 2 + ;; + --no-logs) + stream_logs=false + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 2 + ;; + esac +done + +if [[ -z "${output_dir}" ]]; then + output_dir="/tmp/src-auth-perms-sync-sourcegraph-load-$(date -u +%Y%m%d-%H%M%S)" +fi +mkdir -p "${output_dir}" + +end_epoch="" +if [[ -n "${duration_seconds}" ]]; then + end_epoch="$(( $(date +%s) + duration_seconds ))" +fi + +pids=() + +timestamp() { + date -u +%Y-%m-%dT%H:%M:%SZ +} + +should_continue() { + [[ -z "${end_epoch}" || "$(date +%s)" -lt "${end_epoch}" ]] +} + +append_header() { + local file="$1" + local title="$2" + { + printf '\n===== %s %s =====\n' "$(timestamp)" "${title}" + } >>"${file}" +} + +run_sample_loop() { + local name="$1" + local sleep_seconds="$2" + local pid + shift 2 + ( + while should_continue; do + "$@" || true + sleep "${sleep_seconds}" + done + ) & + pid="$!" + pids+=("${pid}") + echo "Started ${name} sampler: pid=${pid} interval=${sleep_seconds}s" +} + +run_stream() { + local name="$1" + local pid + shift + ( + "$@" || true + ) & + pid="$!" + pids+=("${pid}") + echo "Started ${name} stream: pid=${pid}" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + if [[ ${#pids[@]} -gt 0 ]]; then + kill "${pids[@]}" 2>/dev/null || true + wait "${pids[@]}" 2>/dev/null || true + fi + echo "Stopped Sourcegraph load monitor. Output: ${output_dir}" + exit "${status}" +} + +trap cleanup EXIT INT TERM + +kubectl_exec() { + local target="$1" + shift + "${kubectl_bin}" exec -n "${namespace}" "${target}" -- "$@" +} + +kubectl_exec_stdin() { + local target="$1" + shift + "${kubectl_bin}" exec -i -n "${namespace}" "${target}" -- "$@" +} + +prepare_pg_stat_statements() { + local file="${output_dir}/postgres-statements-setup.log" + append_header "${file}" "create pg_stat_statements extension and reset stats" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select current_database(), current_user; +show shared_preload_libraries; +show track_io_timing; +create extension if not exists pg_stat_statements; +select pg_stat_statements_reset(); +SQL +} + +sample_kubectl_top() { + local file="${output_dir}/kubectl-top-pods-containers.log" + append_header "${file}" "kubectl top pods --containers" + "${kubectl_bin}" top pods -n "${namespace}" --containers >>"${file}" 2>&1 || true +} + +sample_frontend_processes() { + local file="${output_dir}/frontend-processes.log" + append_header "${file}" "${frontend_target} process CPU/RSS" + kubectl_exec "${frontend_target}" sh -lc ' + echo "--- top CPU ---" + ps auxww | sort -nrk3 | head -30 + echo "--- top RSS ---" + ps auxww | sort -nrk4 | head -30 + ' >>"${file}" 2>&1 || true +} + +sample_postgres_processes() { + local file="${output_dir}/postgres-processes.log" + append_header "${file}" "${postgres_target} process CPU/RSS" + kubectl_exec "${postgres_target}" sh -lc ' + echo "--- top CPU ---" + ps auxww | sort -nrk3 | head -30 + echo "--- top RSS ---" + ps auxww | sort -nrk4 | head -30 + ' >>"${file}" 2>&1 || true +} + +sample_cgroups() { + local file="${output_dir}/cgroups.log" + append_header "${file}" "cgroup CPU/memory" + for target in "${frontend_target}" "${postgres_target}"; do + { + echo "--- ${target} ---" + kubectl_exec "${target}" sh -lc ' + echo "cpu.stat" + cat /sys/fs/cgroup/cpu.stat 2>/dev/null || cat /sys/fs/cgroup/cpu/cpu.stat 2>/dev/null || true + echo "memory.current" + cat /sys/fs/cgroup/memory.current 2>/dev/null || cat /sys/fs/cgroup/memory/memory.usage_in_bytes 2>/dev/null || true + echo "memory.events" + cat /sys/fs/cgroup/memory.events 2>/dev/null || true + echo "memory.max" + cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null || true + ' + } >>"${file}" 2>&1 || true + done +} + +sample_postgres_activity() { + local file="${output_dir}/postgres-activity.log" + append_header "${file}" "pg_stat_activity, waits, locks" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select + pid, + now() - query_start as age, + state, + wait_event_type, + wait_event, + left(query, 220) as query +from pg_stat_activity +where state <> 'idle' +order by age desc +limit 30; + +select + wait_event_type, + wait_event, + state, + count(*) +from pg_stat_activity +group by 1,2,3 +order by count(*) desc; + +select + locktype, + mode, + granted, + count(*) +from pg_locks +group by 1,2,3 +order by count(*) desc; +SQL +} + +sample_pg_stat_statements() { + local file="${output_dir}/postgres-statements.log" + append_header "${file}" "pg_stat_statements top total_exec_time" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select + calls, + round(total_exec_time::numeric, 1) as total_ms, + round(mean_exec_time::numeric, 1) as mean_ms, + rows, + left(query, 260) as query +from pg_stat_statements +order by total_exec_time desc +limit 25; +SQL +} + +snapshot_pod_descriptions() { + local file="${output_dir}/pod-descriptions.log" + append_header "${file}" "kubectl describe selected targets" + "${kubectl_bin}" describe -n "${namespace}" "${frontend_target}" >>"${file}" 2>&1 || true + "${kubectl_bin}" describe -n "${namespace}" "${postgres_target}" >>"${file}" 2>&1 || true +} + +stream_frontend_logs() { + "${kubectl_bin}" logs -n "${namespace}" "${frontend_target}" --since=1m --timestamps -f \ + >"${output_dir}/frontend.log" 2>"${output_dir}/frontend-log-errors.log" +} + +stream_frontend_error_logs() { + "${kubectl_bin}" logs -n "${namespace}" "${frontend_target}" --since=1m --timestamps -f 2>/dev/null \ + | grep -Ei 'timeout|deadline|database|postgres|graphql|error|slow|cancel' \ + >"${output_dir}/frontend-errors-filtered.log" || true +} + +cat >"${output_dir}/metadata.txt" < int: + return self.users * self.repos + + @property + def name(self) -> str: + return f"u{self.users:05d}-r{self.repos:05d}-g{self.grants:010d}" + + +@dataclass(frozen=True) +class ExternalServiceChoice: + """Code host connection selected for repo sampling.""" + + graphql_id: str + database_id: int + display_name: str + kind: str + url: str + repo_count: int + + +@dataclass(frozen=True) +class GeneratedMap: + """One generated maps.yaml file and its workload dimensions.""" + + case: SweepCase + path: Path + + +@dataclass(frozen=True) +class CommandRunResult: + """One CLI execution result written in analyze-memory.py-compatible shape.""" + + generated_map: GeneratedMap + return_code: int + elapsed_seconds: float + output_path: Path + log_path: Path | None + run_record: dict[str, Any] | None + + +def main() -> int: + parser = build_parser() + arguments = parser.parse_args() + mode = cast(RunMode, arguments.mode) + if mode == "apply-no-backup" and not arguments.allow_apply: + parser.error("--mode apply-no-backup requires --allow-apply") + + config = sourcegraph_config(arguments) + output_dir = arguments.output_dir or default_output_dir(config.src_endpoint) + maps_dir = output_dir / "maps" + output_dir.mkdir(parents=True, exist_ok=True) + maps_dir.mkdir(parents=True, exist_ok=True) + + requested_cases = parse_cases(arguments.cases) + + client = src.SourcegraphClient( + endpoint=config.src_endpoint, + token=config.src_access_token, + http=src.HTTPClient( + timeout=arguments.http_timeout_seconds, + max_connections=max(4, arguments.parallelism), + ), + ) + try: + external_services = list_external_services(client) + inventory_repo_count = sum(service.repo_count for service in external_services) + service = choose_external_service(external_services, arguments.external_service_id) + total_user_count = count_users(client) + cases = requested_cases or default_cases_for_inventory( + total_user_count, + service.repo_count, + ) + max_users = max(sweep_case.users for sweep_case in cases) + max_repos = max(sweep_case.repos for sweep_case in cases) + usernames = list_usernames(client, max_users, arguments.page_size) + repo_names = list_repo_names(client, service, max_repos, arguments.page_size) + finally: + client.http.close() + + generated_maps = write_maps(maps_dir, cases, usernames, repo_names, service) + write_manifest(output_dir, generated_maps, service, config.src_endpoint, inventory_repo_count) + print(f"Generated {len(generated_maps)} maps.yaml file(s) under {maps_dir}") + print( + f"Selected code host: {service.display_name} id={service.database_id} " + f"repos={service.repo_count}; instance repoCount sum={inventory_repo_count}" + ) + + if not arguments.run: + print("Generation only. Re-run with --run to execute the sweep.") + return 0 + + run_results = run_sweep( + generated_maps, + endpoint=config.src_endpoint, + access_token=config.src_access_token, + output_dir=output_dir, + command=arguments.command, + mode=mode, + parallelism=arguments.parallelism, + explicit_permissions_batch_size=arguments.explicit_permissions_batch_size, + http_timeout_seconds=arguments.http_timeout_seconds, + sample_interval=arguments.sample_interval, + trace=arguments.trace, + sourcegraph_user_count=total_user_count, + sourcegraph_inventory_repo_count=inventory_repo_count, + ) + write_results(output_dir, run_results, inventory_repo_count, total_user_count) + return 0 if all(result.return_code == 0 for result in run_results) else 1 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Generate and optionally run maps.yaml memory-model sweep cases.", + ) + parser.add_argument( + "--env-file", + type=Path, + default=Path(".env"), + help="Environment file with SRC_ENDPOINT and SRC_ACCESS_TOKEN (default: .env).", + ) + parser.add_argument("--src-endpoint", help="Override SRC_ENDPOINT for discovery and runs.") + parser.add_argument("--src-access-token", help="Override SRC_ACCESS_TOKEN.") + parser.add_argument( + "--output-dir", + type=Path, + help=( + "Directory for generated maps and result files. " + "Defaults under src-auth-perms-sync-runs/." + ), + ) + parser.add_argument( + "--cases", + default=DEFAULT_CASES, + help=( + "Comma-separated users x repos cases, e.g. '100x10,1000x25', " + "or 'auto' for a gentle inventory-aware sweep. Default: auto." + ), + ) + parser.add_argument( + "--external-service-id", + type=int, + help="Decoded external service DB id to sample repos from. Defaults to largest repoCount.", + ) + parser.add_argument( + "--page-size", + type=int, + default=1000, + help="GraphQL page size for discovery queries (default: 1000).", + ) + parser.add_argument( + "--run", + action="store_true", + help="Run src-auth-perms-sync for each generated maps.yaml file.", + ) + parser.add_argument( + "--mode", + choices=("dry-run", "apply-no-backup"), + default="dry-run", + help="Run mode when --run is set. Default is dry-run.", + ) + parser.add_argument( + "--allow-apply", + action="store_true", + help="Required safety acknowledgement for --mode apply-no-backup.", + ) + parser.add_argument( + "--command", + default=DEFAULT_COMMAND, + help=f"Command used to invoke the CLI (default: {DEFAULT_COMMAND!r}).", + ) + parser.add_argument( + "--parallelism", + type=int, + default=1, + help="CLI --parallelism for sweep runs. Default 1 is gentle on pgsql.", + ) + parser.add_argument( + "--explicit-permissions-batch-size", + type=int, + default=25, + help="CLI --explicit-permissions-batch-size for sweep runs (default: 25).", + ) + parser.add_argument( + "--http-timeout-seconds", + type=float, + default=120.0, + help="HTTP timeout for discovery and CLI runs (default: 120).", + ) + parser.add_argument( + "--sample-interval", + type=float, + default=1.0, + help="CLI --sample-interval for resource samples (default: 1).", + ) + parser.add_argument( + "--trace", + action="store_true", + help="Pass --trace to src-auth-perms-sync sweep runs.", + ) + return parser + + +def sourcegraph_config(arguments: argparse.Namespace) -> SweepSourcegraphConfig: + overrides: dict[str, object] = {} + if arguments.src_endpoint: + overrides["src_endpoint"] = arguments.src_endpoint + if arguments.src_access_token: + overrides["src_access_token"] = arguments.src_access_token + return load_config( + SweepSourcegraphConfig, + env_file=arguments.env_file, + cli_overrides=overrides, + base_dir=Path.cwd(), + resolve_op_refs=True, + require=("src_access_token",), + ) + + +def parse_cases(raw_cases: str) -> list[SweepCase] | None: + if raw_cases.strip().lower() == "auto": + return None + cases: list[SweepCase] = [] + for raw_case in raw_cases.split(","): + case = raw_case.strip().lower() + if not case: + continue + users_text, separator, repos_text = case.partition("x") + if not separator: + raise SystemExit(f"Invalid case {raw_case!r}; expected USERSxREPOS") + try: + users = int(users_text) + repos = int(repos_text) + except ValueError as error: + raise SystemExit(f"Invalid case {raw_case!r}; counts must be integers") from error + if users < 1 or repos < 1: + raise SystemExit(f"Invalid case {raw_case!r}; counts must be >= 1") + cases.append(SweepCase(users=users, repos=repos)) + if not cases: + raise SystemExit("At least one --cases entry is required") + return cases + + +def default_cases_for_inventory(user_count: int, repo_count: int) -> list[SweepCase]: + """Return a safe default sweep that covers user, repo, and grant axes.""" + if user_count < 1: + raise SystemExit("Need at least one Sourcegraph user for an auto sweep") + if repo_count < 1: + raise SystemExit("Need at least one Sourcegraph repo for an auto sweep") + + user_points = bounded_points(user_count, DEFAULT_USER_POINTS) + repo_points = bounded_points(repo_count, DEFAULT_REPO_POINTS) + cases: list[SweepCase] = [SweepCase(users=users, repos=1) for users in user_points] + cases.extend(SweepCase(users=1, repos=repos) for repos in repo_points if repos != 1) + + for users, repos in ( + (1000, 10), + (10000, 10), + (1000, 100), + (100, 1000), + ): + if users <= user_count and repos <= repo_count: + cases.append(SweepCase(users=users, repos=repos)) + + return unique_cases(cases) + + +def bounded_points(available_count: int, candidate_points: Sequence[int]) -> list[int]: + """Return candidate points that fit, plus the exact inventory cap if useful.""" + points = [point for point in candidate_points if point <= available_count] + if available_count not in points and available_count < candidate_points[-1]: + points.append(available_count) + return sorted(set(points)) + + +def unique_cases(cases: Sequence[SweepCase]) -> list[SweepCase]: + """Preserve case order while removing duplicates.""" + seen: set[tuple[int, int]] = set() + unique: list[SweepCase] = [] + for sweep_case in cases: + key = (sweep_case.users, sweep_case.repos) + if key in seen: + continue + seen.add(key) + unique.append(sweep_case) + return unique + + +def list_external_services(client: src.SourcegraphClient) -> list[ExternalServiceChoice]: + services: list[ExternalServiceChoice] = [] + for node in client.stream_connection_nodes( + QUERY_EXTERNAL_SERVICES, + variables={"first": 100, "after": None}, + connection_path=("externalServices",), + page_size=100, + ): + service = cast(dict[str, Any], node) + graphql_id = str(service["id"]) + services.append( + ExternalServiceChoice( + graphql_id=graphql_id, + database_id=src.decode_external_service_id(graphql_id), + display_name=str(service.get("displayName") or ""), + kind=str(service.get("kind") or ""), + url=str(service.get("url") or ""), + repo_count=int(service.get("repoCount") or 0), + ) + ) + if not services: + raise SystemExit("No external services found on the Sourcegraph instance") + return services + + +def choose_external_service( + services: list[ExternalServiceChoice], requested_id: int | None +) -> ExternalServiceChoice: + if requested_id is not None: + for service in services: + if service.database_id == requested_id: + return service + raise SystemExit(f"External service id {requested_id} was not found") + return max(services, key=lambda service: service.repo_count) + + +def list_usernames(client: src.SourcegraphClient, count: int, page_size: int) -> list[str]: + usernames: list[str] = [] + for node in client.stream_connection_nodes( + QUERY_USERNAMES, + connection_path=("users",), + page_size=page_size, + ): + username = node.get("username") + if isinstance(username, str) and username: + usernames.append(username) + if len(usernames) >= count: + break + if len(usernames) < count: + raise SystemExit(f"Need {count} users but discovered only {len(usernames)}") + return usernames + + +def count_users(client: src.SourcegraphClient) -> int: + """Return total users on the Sourcegraph instance.""" + data = client.graphql(QUERY_USER_COUNT) + users = cast(dict[str, Any], data.get("users") or {}) + total_count = users.get("totalCount") + if not isinstance(total_count, int): + raise SystemExit("CountUsers response did not include users.totalCount") + return total_count + + +def list_repo_names( + client: src.SourcegraphClient, + service: ExternalServiceChoice, + count: int, + page_size: int, +) -> list[str]: + repo_names: list[str] = [] + for node in client.stream_connection_nodes( + QUERY_REPOS_BY_EXTERNAL_SERVICE, + variables={"externalService": service.graphql_id}, + connection_path=("repositories",), + page_size=page_size, + ): + name = node.get("name") + if isinstance(name, str) and name: + repo_names.append(name) + if len(repo_names) >= count: + break + if len(repo_names) < count: + raise SystemExit( + f"Need {count} repos from external service id={service.database_id} " + f"but discovered only {len(repo_names)}" + ) + return repo_names + + +def write_maps( + maps_dir: Path, + cases: Sequence[SweepCase], + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> list[GeneratedMap]: + generated: list[GeneratedMap] = [] + for sweep_case in cases: + map_path = maps_dir / f"maps-{sweep_case.name}.yaml" + payload = { + "maps": [ + { + "name": ( + "memory model " + f"users={sweep_case.users} repos={sweep_case.repos} " + f"grants={sweep_case.grants}" + ), + "users": {"usernames": list(usernames[: sweep_case.users])}, + "repos": { + "codeHostConnection": {"id": service.database_id}, + "names": list(repo_names[: sweep_case.repos]), + }, + } + ] + } + with map_path.open("w", encoding="utf-8") as output_file: + output_file.write( + "# Generated by dev/run-memory-model-sweep.py; safe to delete/regenerate.\n" + ) + output_file.write( + f"# users={sweep_case.users} repos={sweep_case.repos} " + f"planned_grants={sweep_case.grants}\n" + ) + yaml.safe_dump(payload, output_file, sort_keys=False, allow_unicode=True) + generated.append(GeneratedMap(case=sweep_case, path=map_path)) + return generated + + +def write_manifest( + output_dir: Path, + generated_maps: Sequence[GeneratedMap], + service: ExternalServiceChoice, + endpoint: str, + inventory_repo_count: int, +) -> None: + manifest = { + "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "endpoint": endpoint, + "external_service": service_to_json(service), + "sourcegraph_inventory_repo_count": inventory_repo_count, + "maps": [ + { + "case": generated_map.case.name, + "users": generated_map.case.users, + "repos": generated_map.case.repos, + "grants": generated_map.case.grants, + "path": str(generated_map.path), + } + for generated_map in generated_maps + ], + } + write_json(output_dir / "manifest.json", manifest) + + +def run_sweep( + generated_maps: Sequence[GeneratedMap], + *, + endpoint: str, + access_token: str, + output_dir: Path, + command: str, + mode: RunMode, + parallelism: int, + explicit_permissions_batch_size: int, + http_timeout_seconds: float, + sample_interval: float, + trace: bool, + sourcegraph_user_count: int, + sourcegraph_inventory_repo_count: int, +) -> list[CommandRunResult]: + results: list[CommandRunResult] = [] + for generated_map in generated_maps: + print(f"Running {generated_map.case.name} ...", flush=True) + started = time.monotonic() + process_output_path = output_dir / f"{generated_map.case.name}.out" + arguments = command_arguments( + command, + generated_map.path, + mode=mode, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + http_timeout_seconds=http_timeout_seconds, + sample_interval=sample_interval, + trace=trace, + ) + environment = command_environment(endpoint, access_token) + process = subprocess.run( + arguments, + cwd=Path.cwd(), + env=environment, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=False, + ) + elapsed_seconds = time.monotonic() - started + process_output_path.write_text(process.stdout, encoding="utf-8") + log_path = log_path_from_output(process.stdout) + run_record = read_run_record(log_path) + result = CommandRunResult( + generated_map=generated_map, + return_code=process.returncode, + elapsed_seconds=elapsed_seconds, + output_path=process_output_path, + log_path=log_path, + run_record=run_record, + ) + results.append(result) + write_results( + output_dir, + results, + inventory_repo_count=sourcegraph_inventory_repo_count, + sourcegraph_user_count=sourcegraph_user_count, + ) + print( + f" return_code={process.returncode} " + f"peak_rss_mb={memory_peak(result.run_record)} " + f"output={process_output_path}", + flush=True, + ) + if process.returncode != 0: + print("Stopping after first failed case.", file=sys.stderr) + break + return results + + +def command_arguments( + command: str, + map_path: Path, + *, + mode: RunMode, + parallelism: int, + explicit_permissions_batch_size: int, + http_timeout_seconds: float, + sample_interval: float, + trace: bool, +) -> list[str]: + arguments = [ + *shlex.split(command), + "--set", + str(map_path.resolve()), + "--full", + "--parallelism", + str(parallelism), + "--explicit-permissions-batch-size", + str(explicit_permissions_batch_size), + "--http-timeout-seconds", + f"{http_timeout_seconds:g}", + "--sample-interval", + f"{sample_interval:g}", + ] + if mode == "apply-no-backup": + arguments.extend(("--apply", "--no-backup")) + if trace: + arguments.append("--trace") + return arguments + + +def command_environment(endpoint: str, access_token: str) -> dict[str, str]: + environment = dict(os.environ) + environment["SRC_ENDPOINT"] = endpoint + environment["SRC_ACCESS_TOKEN"] = access_token + return environment + + +def log_path_from_output(output: str) -> Path | None: + match = LOG_PATH_PATTERN.search(output) + return Path(match.group(1)) if match else None + + +def read_run_record(log_path: Path | None) -> dict[str, Any] | None: + if log_path is None or not log_path.exists(): + return None + run_record: dict[str, Any] | None = None + with log_path.open(encoding="utf-8") as input_file: + for line in input_file: + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + if not isinstance(record, dict): + continue + record_mapping = cast(dict[str, object], record) + if record_mapping.get("event") == "run" and record_mapping.get("phase") == "end": + run_record = cast(dict[str, Any], record_mapping) + return run_record + + +def write_results( + output_dir: Path, + results: Sequence[CommandRunResult], + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + result_payload = { + "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "results": [ + result_to_json(result, inventory_repo_count, sourcegraph_user_count) + for result in results + ], + "comparisons": [], + } + write_json(output_dir / "results.json", result_payload) + write_results_csv( + output_dir / "results.csv", + results, + inventory_repo_count, + sourcegraph_user_count, + ) + + +def result_to_json( + result: CommandRunResult, inventory_repo_count: int, sourcegraph_user_count: int +) -> dict[str, Any]: + run_record = result.run_record or {} + peak_rss_mb = memory_peak(result.run_record) + case = result.generated_map.case + return { + "variant": "candidate", + "iteration": 1, + "case": case.name, + "arguments": ["--set", str(result.generated_map.path), "--full"], + "return_code": result.return_code, + "elapsed_seconds": round(result.elapsed_seconds, 3), + "log_path": str(result.log_path) if result.log_path else None, + "run_directory": str(result.log_path.parent) if result.log_path else None, + "command": run_record.get("command") or "set_full", + "status": run_record.get("status"), + "jaeger_traces": [], + "memory": { + "peak_rss_mb": peak_rss_mb, + "sampled_peak_rss_mb": None, + "external_peak_rss_mb": None, + "resource_sample_count": 0, + "external_sample_count": 0, + "max_num_fds": run_record.get("num_fds"), + "max_num_threads": run_record.get("num_threads"), + "max_process_cpu_percent": None, + }, + "phase_memory": [], + "artifact_sizes": {}, + "workload": workload_json(case, inventory_repo_count, sourcegraph_user_count), + } + + +def workload_json( + sweep_case: SweepCase, inventory_repo_count: int, sourcegraph_user_count: int +) -> dict[str, int]: + return { + "selected_user_count": sweep_case.users, + "selected_repo_count": sweep_case.repos, + "selected_total_grants": sweep_case.grants, + "memory_model_user_count": sweep_case.users, + "memory_model_repo_count": sweep_case.repos, + "memory_model_grant_count": sweep_case.grants, + "sourcegraph_user_count": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + } + + +def write_results_csv( + path: Path, + results: Sequence[CommandRunResult], + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + fieldnames = [ + "case", + "users", + "repos", + "grants", + "sourcegraph_users_discovered", + "sourcegraph_inventory_repo_count", + "return_code", + "elapsed_seconds", + "peak_rss_mb", + "log_path", + "map_path", + "output_path", + ] + with path.open("w", encoding="utf-8", newline="") as output_file: + writer = csv.DictWriter(output_file, fieldnames=fieldnames) + writer.writeheader() + for result in results: + case = result.generated_map.case + writer.writerow( + { + "case": case.name, + "users": case.users, + "repos": case.repos, + "grants": case.grants, + "sourcegraph_users_discovered": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + "return_code": result.return_code, + "elapsed_seconds": f"{result.elapsed_seconds:.3f}", + "peak_rss_mb": memory_peak(result.run_record) or "", + "log_path": str(result.log_path) if result.log_path else "", + "map_path": str(result.generated_map.path), + "output_path": str(result.output_path), + } + ) + + +def memory_peak(run_record: Mapping[str, Any] | None) -> float | None: + if run_record is None: + return None + value = run_record.get("peak_rss_mb") + return float(value) if isinstance(value, int | float) else None + + +def write_json(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as output_file: + json.dump(payload, output_file, indent=2, sort_keys=True) + output_file.write("\n") + + +def service_to_json(service: ExternalServiceChoice) -> dict[str, object]: + return { + "graphql_id": service.graphql_id, + "database_id": service.database_id, + "display_name": service.display_name, + "kind": service.kind, + "url": service.url, + "repo_count": service.repo_count, + } + + +def default_output_dir(endpoint: str) -> Path: + host = urlsplit(endpoint).hostname or "sourcegraph" + safe_host = re.sub(r"[^A-Za-z0-9_.-]+", "-", host) + timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d-%H-%M-%S") + return Path("src-auth-perms-sync-runs") / safe_host / "memory-model-sweep" / timestamp + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dev/sourcegraph-explicit-permissions-tracing.md b/dev/sourcegraph-explicit-permissions-tracing.md index fe8de9a..c61a399 100644 --- a/dev/sourcegraph-explicit-permissions-tracing.md +++ b/dev/sourcegraph-explicit-permissions-tracing.md @@ -95,8 +95,9 @@ To trace the full integration matrix, run the end-to-end script with its own tails each child run log and fetches all traced GraphQL Jaeger traces in the background while that child command is still running. The runner uses `src-py-lib` Config parsing, logging, Sourcegraph endpoint normalization, -`SourcegraphClient.fetch_jaeger_trace_summary()`, and a shared HTTP pool, so -trace summary and retry behavior match the CLI's Sourcegraph client: +`SourcegraphClient.fetch_jaeger_trace()`, `summarize_jaeger_trace()`, and a +shared HTTP pool, so trace fetch and summary behavior match the CLI's +Sourcegraph client: ```bash uv run python dev/test-end-to-end.py \ @@ -116,11 +117,35 @@ The runner writes trace summaries incrementally as JSON Lines. By default, it uses a sibling of `--results-json` or `--results-csv`, named `*-jaeger-traces.jsonl`. Override this with `--jaeger-trace-jsonl PATH`. +The runner also writes complete raw Jaeger trace payloads for in-depth +follow-up. By default, it uses a sibling directory named `*-jaeger-traces`. +Override this with `--jaeger-trace-dir PATH`. Each file is stored by variant, +iteration, case, and trace ID: + +```text +/ + candidate/ + iteration-0001/ + set-full-no-backup-apply/ + .json +``` + +Each raw trace file includes: + +- `trace_request`: CLI-side correlation metadata from the HTTP request and the + surrounding `graphql_query` event, including query name, page number, page + size, cursor presence, query byte count, variable names, response fields, + status, and timing. If `src-py-lib` later logs sanitized GraphQL variable + values, the same field will include them as `variables`, `input_variables`, + or `variable_values`. +- `jaeger_summary`: compact hot-operation and GraphQL-operation summary. +- `jaeger_trace`: the complete Jaeger trace JSON returned by Sourcegraph. + The shared `src-py-lib` `stream_jaeger_trace_summaries()` helper now fetches in parallel for in-process Sourcegraph clients. The end-to-end script still uses a bounded global worker pool because the traced requests happen in child processes and are discovered by tailing their JSON logs. Tune this with -`--jaeger-trace-parallelism N` (default 16). The runner drains outstanding +`--jaeger-trace-parallelism N` (default 8). The runner drains outstanding background collectors once at the end, before it writes JSON/CSV results, so Jaeger collection does not add a blocking phase between child cases. @@ -136,6 +161,46 @@ For each tested batch size and parallelism, record: `sql.conn.query`, and `database.PermsStore.LoadUserPermissions` - retries/timeouts from the CLI log +## Monitor Sourcegraph pod load during e2e runs + +Prefer running the end-to-end script as the single orchestrator. It can start +the Sourcegraph pod/Postgres monitor, collect Jaeger traces in parallel with +each child CLI command, and write all artifact paths into the result JSON: + +```bash +uv run python dev/test-end-to-end.py \ + --trace \ + --monitor-sourcegraph-load \ + --sample-interval 0 \ + --external-sample-interval 0 \ + --results-json /tmp/src-auth-perms-sync-end-to-end-trace.json \ + --results-csv /tmp/src-auth-perms-sync-end-to-end-trace.csv +``` + +By default, monitor output is written beside `--results-json` or +`--results-csv` as `*-sourcegraph-load`, and the monitor's own stdout/stderr is +written to `*-sourcegraph-load.log`. Override the location with +`--monitor-output-dir PATH`. Tune Kubernetes targets and sample intervals with +the `--monitor-*` flags if the test namespace or pod names differ. + +The lower-level helper remains available for focused profiling outside a full +e2e run: + +```bash +dev/monitor-sourcegraph-load.sh \ + --namespace m \ + --output-dir /tmp/src-auth-perms-sync-sourcegraph-load-$(date -u +%Y%m%d-%H%M%S) +``` + +Stop the helper with Ctrl-C after the e2e run finishes, or add +`--duration-seconds N`. The script samples Kubernetes CPU/memory, frontend and +Postgres processes, cgroup CPU/memory pressure, Postgres active queries/waits/locks, +`pg_stat_statements` when enabled, and frontend logs. Outputs are timestamped +files in the selected directory. On startup, it runs `CREATE EXTENSION IF NOT +EXISTS pg_stat_statements` and `pg_stat_statements_reset()` through +`kubectl exec` against `pod/pgsql-0`, so the statement summary starts clean for +the monitored run. + In a traced sgdev end-to-end run after the matrix was trimmed to avoid overlapping code paths, all 36 cases passed. Child command time summed to about 1,126 seconds. The JSONL trace summary file contained 3,256 GraphQL trace diff --git a/dev/test-end-to-end.py b/dev/test-end-to-end.py index 90e1c73..97f5bb3 100755 --- a/dev/test-end-to-end.py +++ b/dev/test-end-to-end.py @@ -13,19 +13,22 @@ from __future__ import annotations +import contextlib import csv import datetime +import heapq import json import os import re import shlex +import signal import statistics import subprocess import sys import threading import time from collections.abc import Iterable, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future from concurrent.futures import wait as wait_for_futures from dataclasses import dataclass from pathlib import Path @@ -35,37 +38,44 @@ import src_py_lib as src from src_py_lib.clients.sourcegraph import ( DEFAULT_SOURCEGRAPH_ENDPOINT, - JAEGER_TRACE_RETRY_DELAYS_SECONDS, sourcegraph_trace_from_headers, + summarize_jaeger_trace, ) LOG_PATH_PATTERN = re.compile(r"Writing log events to (.+?/log\.json)\.") +SAFE_PATH_PART_PATTERN = re.compile(r"[^A-Za-z0-9_.-]+") DEFAULT_FUTURE_DATE = "2099-01-01" REMOVED_SRC_AUTH_PERMS_SYNC_ENVIRONMENT_PREFIX = "SRC_AUTH_PERMS_SYNC_" DEFAULT_SAMPLE_INTERVAL_SECONDS = 1.0 DEFAULT_REPEAT_COUNT = 1 DEFAULT_JAEGER_TRACE_LIMIT: int | None = None -DEFAULT_JAEGER_TRACE_PARALLELISM = 16 +DEFAULT_JAEGER_TRACE_PARALLELISM = 8 +DEFAULT_JAEGER_INITIAL_DELAY_SECONDS = 35.0 +DEFAULT_JAEGER_RETRY_DELAYS_SECONDS = ( + 2.0, + 5.0, + 10.0, + 20.0, + 30.0, + 60.0, + 60.0, + 60.0, + 60.0, + 60.0, + 60.0, +) DEFAULT_PARALLELISM = 4 DEFAULT_FULL_RESTORE_PARALLELISM = 1 +DEFAULT_INCLUDE_REDUNDANT_SCALE_CASES = False DEFAULT_MEMORY_SUMMARY_LIMIT = 20 DEFAULT_SRC_AUTH_PERMS_SYNC_COMMAND = "uv run src-auth-perms-sync" -WORKLOAD_FIELDS = ( - "user_count", - "total_users", - "total_users_scanned", - "repo_count", - "repos_with_explicit_grants", - "total_grants", - "mapping_count", - "plan_size", - "payload_count", - "target_organizations", - "desired_memberships", - "mutations_succeeded", - "mutations_failed", - "mutations_canceled", -) +DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE = "m" +DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS = 5 +DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS = 10 +DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS = 30 +DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET = "deployment/sourcegraph-frontend" +DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET = "pod/pgsql-0" +DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND = "psql -X -U sg -d sg" def format_jaeger_retry_delays(delays: Sequence[float]) -> str: @@ -160,6 +170,16 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): f"(default: {DEFAULT_FULL_RESTORE_PARALLELISM})" ), ) + include_redundant_scale_cases: bool = src.config_field( + default=DEFAULT_INCLUDE_REDUNDANT_SCALE_CASES, + env_var="SRC_AUTH_PERMS_SYNC_E2E_INCLUDE_REDUNDANT_SCALE_CASES", + cli_flag="--include-redundant-scale-cases", + cli_action="store_true", + help=( + "Also run older overlapping full-scale cases. Default keeps one heavy full " + "snapshot path and uses smaller cases for overlapping coverage." + ), + ) allow_non_test_endpoint: bool = src.config_field( default=False, env_var="SRC_AUTH_PERMS_SYNC_E2E_ALLOW_NON_TEST_ENDPOINT", @@ -203,6 +223,17 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): f"(default: {DEFAULT_JAEGER_TRACE_PARALLELISM})" ), ) + jaeger_initial_delay_seconds: float = src.config_field( + default=DEFAULT_JAEGER_INITIAL_DELAY_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_INITIAL_DELAY_SECONDS", + cli_flag="--jaeger-initial-delay-seconds", + metavar="SECONDS", + ge=0, + help=( + "Seconds to wait before first fetching each Jaeger trace, to allow OTel tail " + f"sampling to decide (default: {DEFAULT_JAEGER_INITIAL_DELAY_SECONDS:g})" + ), + ) jaeger_trace_jsonl: Path | None = src.config_field( default=None, env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_TRACE_JSONL", @@ -213,14 +244,26 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): "of --results-json or --results-csv when --trace is set." ), ) + jaeger_trace_directory: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_TRACE_DIR", + cli_flag="--jaeger-trace-dir", + metavar="PATH", + help=( + "Directory where complete raw Jaeger trace JSON files are written. Defaults " + "to a sibling directory of --results-json or --results-csv when --trace is set." + ), + ) jaeger_retry_delays: tuple[float, ...] = src.config_field( - default=JAEGER_TRACE_RETRY_DELAYS_SECONDS, + default=DEFAULT_JAEGER_RETRY_DELAYS_SECONDS, env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_RETRY_DELAYS", cli_flag="--jaeger-retry-delays", metavar="SECONDS[,SECONDS...]", help=( - "Comma-separated retry delays for Jaeger trace lookup lag " - f"(default: {format_jaeger_retry_delays(JAEGER_TRACE_RETRY_DELAYS_SECONDS)})" + "Comma-separated delays between queued Jaeger trace fetch retries. " + "Each value schedules one retry after the initial fetch; add more values " + "to try for longer " + f"(default: {format_jaeger_retry_delays(DEFAULT_JAEGER_RETRY_DELAYS_SECONDS)})" ), ) sample_interval: float = src.config_field( @@ -271,6 +314,103 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): "beside it as *-phases.csv" ), ) + monitor_sourcegraph_load: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_SOURCEGRAPH_LOAD", + cli_flag="--monitor-sourcegraph-load", + cli_action="store_true", + help=( + "Start the Sourcegraph pod/Postgres load monitor for this e2e run and write " + "its output beside the result artifacts." + ), + ) + sourcegraph_monitor_namespace: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_NAMESPACE", + cli_flag="--monitor-namespace", + metavar="NAME", + help=( + "Kubernetes namespace for Sourcegraph load monitoring " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE})" + ), + ) + sourcegraph_monitor_output_dir: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_OUTPUT_DIR", + cli_flag="--monitor-output-dir", + metavar="PATH", + help="Directory for Sourcegraph load monitor output; defaults beside result artifacts.", + ) + sourcegraph_monitor_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_INTERVAL_SECONDS", + cli_flag="--monitor-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "Pod/process/cgroup monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_postgres_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_POSTGRES_INTERVAL_SECONDS", + cli_flag="--monitor-postgres-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "Postgres activity monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_statements_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_STATEMENTS_INTERVAL_SECONDS", + cli_flag="--monitor-statements-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "pg_stat_statements monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_frontend_target: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_FRONTEND_TARGET", + cli_flag="--monitor-frontend-target", + metavar="TARGET", + help=( + "kubectl target for Sourcegraph frontend " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET})" + ), + ) + sourcegraph_monitor_postgres_target: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_POSTGRES_TARGET", + cli_flag="--monitor-postgres-target", + metavar="TARGET", + help=( + "kubectl target for Sourcegraph Postgres " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET})" + ), + ) + sourcegraph_monitor_psql_command: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_PSQL_COMMAND", + cli_flag="--monitor-psql-command", + metavar="COMMAND", + help=( + "psql command to run inside the Postgres pod " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND})" + ), + ) + sourcegraph_monitor_no_logs: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_NO_LOGS", + cli_flag="--monitor-no-logs", + cli_action="store_true", + help="Do not stream frontend logs while Sourcegraph load monitoring is enabled.", + ) fail_on_memory_regression_percent: float | None = src.config_field( default=None, env_var="SRC_AUTH_PERMS_SYNC_E2E_FAIL_ON_MEMORY_REGRESSION_PERCENT", @@ -431,22 +571,136 @@ def sample_once(self) -> None: self.peak_rss_mb = max_optional_float(self.peak_rss_mb, rss_mb) +class SourcegraphLoadMonitor: + """Run the Sourcegraph pod/Postgres monitor for the duration of the e2e suite.""" + + def __init__(self, config: EndToEndConfig, output_dir: Path) -> None: + self.config = config + self.output_dir = output_dir + self.log_path = output_dir.with_name(f"{output_dir.name}.log") + self._log_file: TextIO | None = None + self._process: subprocess.Popen[str] | None = None + + def start(self) -> None: + script_path = sourcegraph_monitor_script_path() + if not script_path.exists(): + raise RuntimeError(f"Sourcegraph load monitor script not found: {script_path}") + self.output_dir.parent.mkdir(parents=True, exist_ok=True) + self.log_path.parent.mkdir(parents=True, exist_ok=True) + command = [ + str(script_path), + "--namespace", + self.config.sourcegraph_monitor_namespace, + "--output-dir", + str(self.output_dir), + "--interval-seconds", + str(self.config.sourcegraph_monitor_interval_seconds), + "--postgres-interval-seconds", + str(self.config.sourcegraph_monitor_postgres_interval_seconds), + "--statements-interval-seconds", + str(self.config.sourcegraph_monitor_statements_interval_seconds), + "--frontend-target", + self.config.sourcegraph_monitor_frontend_target, + "--postgres-target", + self.config.sourcegraph_monitor_postgres_target, + "--psql-command", + self.config.sourcegraph_monitor_psql_command, + ] + if self.config.sourcegraph_monitor_no_logs: + command.append("--no-logs") + print(f"Starting Sourcegraph load monitor: {self.output_dir}") + self._log_file = self.log_path.open("w", encoding="utf-8") + self._process = subprocess.Popen( # noqa: S603 - command is trusted test config. + command, + cwd=Path.cwd(), + stdout=self._log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + ) + self._wait_until_started() + + def stop(self) -> None: + process = self._process + if process is None: + self._close_log_file() + return + if process.poll() is None: + with contextlib.suppress(ProcessLookupError): + os.killpg(process.pid, signal.SIGTERM) + try: + process.wait(timeout=15) + except subprocess.TimeoutExpired: + with contextlib.suppress(ProcessLookupError): + os.killpg(process.pid, signal.SIGKILL) + process.wait(timeout=15) + return_code = process.returncode + self._close_log_file() + if return_code not in {0, -15, 143}: + print( + f"Sourcegraph load monitor exited with status {return_code}; see {self.log_path}", + file=sys.stderr, + ) + else: + print(f"Stopped Sourcegraph load monitor. Output: {self.output_dir}") + + def _wait_until_started(self) -> None: + process = self._process + if process is None: + return + deadline = time.monotonic() + 60 + while time.monotonic() < deadline: + if process.poll() is not None: + raise RuntimeError( + f"Sourcegraph load monitor exited before startup completed; see {self.log_path}" + ) + if self.log_path.exists() and "Started kubectl-top" in self.log_path.read_text( + encoding="utf-8", errors="ignore" + ): + return + time.sleep(0.2) + raise RuntimeError( + f"Timed out waiting for Sourcegraph load monitor startup; see {self.log_path}" + ) + + def _close_log_file(self) -> None: + if self._log_file is not None: + self._log_file.close() + self._log_file = None + + +@dataclass +class JaegerTraceFetchTask: + """One trace fetch request that can be retried across the whole e2e run.""" + + trace_request: dict[str, Any] + future: Future[dict[str, Any]] + fetch_attempts: int = 0 + first_fetch_at: str | None = None + last_fetch_at: str | None = None + + class JaegerTraceFetchPool: - """Fetch Sourcegraph Jaeger trace summaries through one bounded HTTP pool.""" + """Fetch Sourcegraph Jaeger traces through one bounded retry queue.""" def __init__( self, config: EndToEndConfig, *, parallelism: int, + initial_delay_seconds: float, retry_delays_seconds: Sequence[float], jsonl_path: Path | None, + trace_directory: Path | None, ) -> None: + self.initial_delay_seconds = initial_delay_seconds self.retry_delays_seconds = tuple(retry_delays_seconds) - self._executor = ThreadPoolExecutor( - max_workers=parallelism, - thread_name_prefix="JaegerTraceFetch", - ) + self.max_fetch_attempts = len(self.retry_delays_seconds) + 1 + self._trace_directory = trace_directory + self._tasks: list[tuple[float, int, JaegerTraceFetchTask]] = [] + self._condition = threading.Condition() + self._sequence = 0 + self._closed = False self._jsonl_file: TextIO | None = None self._lock = threading.Lock() http = src.HTTPClient( @@ -459,36 +713,166 @@ def __init__( jsonl_path.parent.mkdir(parents=True, exist_ok=True) self._jsonl_file = jsonl_path.open("w", encoding="utf-8") print(f"Writing Jaeger trace summaries incrementally to {jsonl_path}") + if self._trace_directory is not None: + self._trace_directory.mkdir(parents=True, exist_ok=True) + print(f"Writing complete Jaeger traces to {self._trace_directory}") + self._workers = [ + threading.Thread( + target=self._worker, + name=f"JaegerTraceFetch-{worker_number}", + daemon=True, + ) + for worker_number in range(1, parallelism + 1) + ] + for worker in self._workers: + worker.start() def submit( self, trace_request: dict[str, Any], collector: JaegerTraceCollector, ) -> Future[dict[str, Any]]: - future = src.submit_with_log_context(self._executor, self._fetch_summary, trace_request) + future: Future[dict[str, Any]] = Future() future.add_done_callback(lambda completed: self._record_summary(collector, completed)) + task = JaegerTraceFetchTask( + trace_request=trace_request, + future=future, + ) + self._schedule(task, self.initial_delay_seconds) return future def close(self) -> None: - self._executor.shutdown(wait=True) + with self._condition: + self._closed = True + self._condition.notify_all() + for worker in self._workers: + worker.join() self._client.http.close() if self._jsonl_file is not None: self._jsonl_file.close() - def _fetch_summary(self, trace_request: dict[str, Any]) -> dict[str, Any]: + def _schedule(self, task: JaegerTraceFetchTask, delay_seconds: float) -> None: + with self._condition: + self._sequence += 1 + heapq.heappush( + self._tasks, + (time.monotonic() + delay_seconds, self._sequence, task), + ) + self._condition.notify() + + def _worker(self) -> None: + while True: + task = self._next_ready_task() + if task is None: + return + self._process(task) + + def _next_ready_task(self) -> JaegerTraceFetchTask | None: + with self._condition: + while True: + if self._closed and not self._tasks: + return None + if not self._tasks: + self._condition.wait() + continue + ready_at, _sequence, task = self._tasks[0] + delay_seconds = ready_at - time.monotonic() + if delay_seconds > 0: + self._condition.wait(delay_seconds) + continue + heapq.heappop(self._tasks) + return task + + def _process(self, task: JaegerTraceFetchTask) -> None: + if task.future.done(): + return + summary = self._fetch_summary(task) + if summary.get("jaeger_found") is True or not self._should_retry(task, summary): + task.future.set_result(summary) + return + self._schedule(task, self._retry_delay_seconds(task.fetch_attempts)) + + def _fetch_summary(self, task: JaegerTraceFetchTask) -> dict[str, Any]: + task.fetch_attempts += 1 + now = datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds") + if task.first_fetch_at is None: + task.first_fetch_at = now + task.last_fetch_at = now try: - trace = sourcegraph_trace_from_request(trace_request) - summary = self._client.fetch_jaeger_trace_summary( - trace, - retry_delays_seconds=self.retry_delays_seconds, - ).to_json() - return {**trace_request, **summary} + trace = sourcegraph_trace_from_request(task.trace_request) + jaeger_trace = self._client.fetch_jaeger_trace( + trace.trace_id, + retry_delays_seconds=(0.0,), + ) + summary = summarize_jaeger_trace(trace, jaeger_trace).to_json() + try: + trace_path = self._write_complete_trace(task, jaeger_trace, summary) + if trace_path is not None: + summary["jaeger_trace_path"] = str(trace_path) + except OSError as write_error: + summary["jaeger_trace_write_error"] = f"{type(write_error).__name__}: {write_error}" + return self._with_fetch_fields(task, summary) except Exception as exception: # noqa: BLE001 - keep long-running evidence collection alive. - return { - **trace_request, - "jaeger_found": False, - "error": f"{type(exception).__name__}: {exception}", - } + return self._with_fetch_fields( + task, + { + **task.trace_request, + "jaeger_found": False, + "error": f"{type(exception).__name__}: {exception}", + }, + ) + + def _with_fetch_fields( + self, task: JaegerTraceFetchTask, summary: dict[str, Any] + ) -> dict[str, Any]: + return { + **task.trace_request, + **summary, + "fetch_attempts": task.fetch_attempts, + "first_fetch_at": task.first_fetch_at, + "last_fetch_at": task.last_fetch_at, + "max_fetch_attempts": self.max_fetch_attempts, + } + + def _write_complete_trace( + self, + task: JaegerTraceFetchTask, + jaeger_trace: dict[str, Any], + summary: dict[str, Any], + ) -> Path | None: + if self._trace_directory is None: + return None + path = complete_jaeger_trace_path(self._trace_directory, task.trace_request) + payload = { + "collected_at": task.last_fetch_at, + "fetch_attempts": task.fetch_attempts, + "max_fetch_attempts": self.max_fetch_attempts, + "trace_request": task.trace_request, + "jaeger_summary": summary, + "jaeger_trace": jaeger_trace, + } + path.parent.mkdir(parents=True, exist_ok=True) + temporary_path = path.with_name( + f".{path.name}.tmp-{threading.get_ident()}-{time.monotonic_ns()}" + ) + temporary_path.write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + temporary_path.replace(path) + return path + + def _should_retry(self, task: JaegerTraceFetchTask, summary: dict[str, Any]) -> bool: + if self._closed or task.fetch_attempts >= self.max_fetch_attempts: + return False + error = str(summary.get("error") or "") + return error.startswith(("HTTP 404", "HTTP 502", "HTTP 503", "HTTP 504")) + + def _retry_delay_seconds(self, fetch_attempts: int) -> float: + if not self.retry_delays_seconds: + return 0.0 + delay_index = min(fetch_attempts - 1, len(self.retry_delays_seconds) - 1) + return self.retry_delays_seconds[delay_index] def _record_summary( self, @@ -527,6 +911,8 @@ def __init__( self.iteration = iteration self.case_name = case_name self.summaries: list[dict[str, Any]] = [] + self._graphql_queries_by_span: dict[tuple[str, str], dict[str, Any]] = {} + self._trace_requests_by_graphql_span: dict[tuple[str, str], dict[str, Any]] = {} self._requests_by_trace_id: dict[str, dict[str, Any]] = {} self._queued_trace_ids: set[str] = set() self._futures: list[Future[dict[str, Any]]] = [] @@ -599,15 +985,22 @@ def _record_line(self, line: str) -> None: return if not isinstance(record, dict): return + self._record_graphql_query_metadata(cast(dict[str, Any], record)) trace_request = graphql_trace_request_from_record(cast(dict[str, Any], record)) if trace_request is None: return trace_request.update( {"variant": self.variant, "iteration": self.iteration, "case": self.case_name} ) + graphql_span_key = self._graphql_span_key_for_http_record(cast(dict[str, Any], record)) trace_id = trace_request["trace_id"] submit_request: dict[str, Any] | None = None with self._lock: + if graphql_span_key is not None: + graphql_query = self._graphql_queries_by_span.get(graphql_span_key) + if graphql_query is not None: + trace_request["graphql_query"] = dict(graphql_query) + self._trace_requests_by_graphql_span[graphql_span_key] = trace_request existing_request = self._requests_by_trace_id.get(trace_id) if existing_request is None or trace_summary_duration_ms( trace_request @@ -621,6 +1014,29 @@ def _record_line(self, line: str) -> None: with self._lock: self._futures.append(future) + def _record_graphql_query_metadata(self, record: dict[str, Any]) -> None: + metadata = graphql_query_metadata_from_record(record) + if metadata is None: + return + span_key = graphql_query_span_key(record) + if span_key is None: + return + with self._lock: + existing_metadata = self._graphql_queries_by_span.get(span_key, {}) + merged_metadata = existing_metadata | metadata + self._graphql_queries_by_span[span_key] = merged_metadata + trace_request = self._trace_requests_by_graphql_span.get(span_key) + if trace_request is not None: + trace_request["graphql_query"] = dict(merged_metadata) + + @staticmethod + def _graphql_span_key_for_http_record(record: dict[str, Any]) -> tuple[str, str] | None: + trace_id = optional_string(record.get("trace")) + parent_span_id = optional_string(record.get("parent_span")) + if trace_id is None or parent_span_id is None: + return None + return trace_id, parent_span_id + def _submit_limited_requests(self) -> None: if self.limit is None: return @@ -879,14 +1295,22 @@ def run_end_to_end(config: EndToEndConfig) -> None: all_failures: list[str] = [] all_jaeger_collectors: list[JaegerTraceCollector] = [] jaeger_trace_fetch_pool = create_jaeger_trace_fetch_pool(config) + sourcegraph_load_monitor = create_sourcegraph_load_monitor(config) latest_baseline_repositories: set[str] = set() try: + if sourcegraph_load_monitor is not None: + sourcegraph_load_monitor.start() with src.event( "end_to_end_matrix", repeat=config.repeat, variant_count=len(variants), trace=config.trace, + sourcegraph_load_monitor=sourcegraph_load_monitor is not None, ) as matrix_summary: + if sourcegraph_load_monitor is not None: + matrix_summary["sourcegraph_load_monitor_dir"] = str( + sourcegraph_load_monitor.output_dir + ) for iteration in range(1, config.repeat + 1): for variant in variants: with src.stage("matrix_variant", variant=variant.name, iteration=iteration): @@ -915,6 +1339,8 @@ def run_end_to_end(config: EndToEndConfig) -> None: wait_for_jaeger_trace_collectors(all_jaeger_collectors) if jaeger_trace_fetch_pool is not None: jaeger_trace_fetch_pool.close() + if sourcegraph_load_monitor is not None: + sourcegraph_load_monitor.stop() if all_failures: print("\nFailures:", file=sys.stderr) for failure in all_failures: @@ -928,7 +1354,7 @@ def run_end_to_end(config: EndToEndConfig) -> None: print_phase_memory_summary(all_results, config.memory_summary_limit) comparisons = compare_variants(all_results) print_comparison_summary(comparisons) - write_results_files(all_results, comparisons, config) + write_results_files(all_results, comparisons, config, sourcegraph_load_monitor) raise_for_memory_regressions(comparisons, config) @@ -955,8 +1381,10 @@ def create_jaeger_trace_fetch_pool( return JaegerTraceFetchPool( config, parallelism=config.jaeger_trace_parallelism, + initial_delay_seconds=config.jaeger_initial_delay_seconds, retry_delays_seconds=config.jaeger_retry_delays, jsonl_path=jaeger_trace_jsonl_path(config), + trace_directory=jaeger_trace_directory(config), ) @@ -971,6 +1399,56 @@ def jaeger_trace_jsonl_path(config: EndToEndConfig) -> Path | None: return Path("/tmp") / f"src-auth-perms-sync-end-to-end-jaeger-traces-{stamp}.jsonl" +def jaeger_trace_directory(config: EndToEndConfig) -> Path: + """Return the directory where complete raw Jaeger traces should be stored.""" + if config.jaeger_trace_directory is not None: + return config.jaeger_trace_directory + anchor = config.results_json or config.results_csv + if anchor is not None: + return anchor.with_name(f"{anchor.stem}-jaeger-traces") + stamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") + return Path("/tmp") / f"src-auth-perms-sync-end-to-end-jaeger-traces-{stamp}" + + +def create_sourcegraph_load_monitor(config: EndToEndConfig) -> SourcegraphLoadMonitor | None: + """Return the Sourcegraph load monitor for this run, if enabled.""" + if not config.monitor_sourcegraph_load: + return None + return SourcegraphLoadMonitor(config, sourcegraph_monitor_output_dir(config)) + + +def sourcegraph_monitor_output_dir(config: EndToEndConfig) -> Path: + """Return where Sourcegraph pod/Postgres monitor artifacts should be stored.""" + if config.sourcegraph_monitor_output_dir is not None: + return config.sourcegraph_monitor_output_dir + anchor = config.results_json or config.results_csv + if anchor is not None: + return anchor.with_name(f"{anchor.stem}-sourcegraph-load") + stamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") + return Path("/tmp") / f"src-auth-perms-sync-end-to-end-sourcegraph-load-{stamp}" + + +def sourcegraph_monitor_script_path() -> Path: + """Return the lower-level monitor script used by the e2e orchestrator.""" + return Path(__file__).resolve().with_name("monitor-sourcegraph-load.sh") + + +def complete_jaeger_trace_path(trace_directory: Path, trace_request: dict[str, Any]) -> Path: + """Return the stable per-trace path for a complete Jaeger trace payload.""" + variant = safe_path_part(trace_request.get("variant"), default="variant") + iteration = int_field(trace_request, "iteration") or 0 + case_name = safe_path_part(trace_request.get("case"), default="case") + trace_id = safe_path_part(trace_request.get("trace_id"), default="trace") + return trace_directory / variant / f"iteration-{iteration:04d}" / case_name / f"{trace_id}.json" + + +def safe_path_part(value: object, *, default: str) -> str: + """Return a filesystem-safe path segment for generated trace artifacts.""" + text = str(value) if value is not None else "" + safe_text = SAFE_PATH_PART_PATTERN.sub("-", text).strip("-.") + return safe_text[:120] or default + + def command_environment(config: EndToEndConfig) -> dict[str, str]: """Return a deterministic child environment for CLI config parsing.""" environment = dict(os.environ) @@ -1187,7 +1665,7 @@ def read_only_cases(config: EndToEndConfig) -> list[CommandCase]: ), CommandCase( name="get-sync-saml-orgs-dry-run", - arguments=("--get", "--sync-saml-orgs"), + arguments=("--get", "--user", config.user, "--sync-saml-orgs"), expected_log_command="get_sync_saml_orgs", must_contain=("Wrote before-snapshot", "Dry run complete"), ), @@ -1325,29 +1803,30 @@ def run_full_apply_cases(config: EndToEndConfig, runner: CommandPermutationRunne ) baseline_snapshot = snapshot_path(dry_run_result) - try: - runner.run( - CommandCase( - name="set-full-apply", - arguments=( - "--set", - "--apply", - "--parallelism", - str(config.parallelism), - ), - expected_log_command="set_full", - must_contain=("VALIDATION OK",), + if config.include_redundant_scale_cases: + try: + runner.run( + CommandCase( + name="set-full-apply", + arguments=( + "--set", + "--apply", + "--parallelism", + str(config.parallelism), + ), + expected_log_command="set_full", + must_contain=("VALIDATION OK",), + ) ) - ) - finally: - runner.run( - restore_full_apply_case( - "restore-full-apply-cleanup", - baseline_snapshot, - config, - no_backup=False, + finally: + runner.run( + restore_full_apply_case( + "restore-full-apply-cleanup", + baseline_snapshot, + config, + no_backup=False, + ) ) - ) try: runner.run( @@ -1374,14 +1853,14 @@ def run_full_apply_cases(config: EndToEndConfig, runner: CommandPermutationRunne ) ) - # Covers the combined set+SAML dispatch and SAML dry-run path without - # repeating the full set apply and full restore cleanup paths, which are - # already covered above. + # Covers combined set+SAML dispatch and SAML dry-run with a user-scoped + # set path, so the default suite keeps only one expensive full-snapshot + # case. Pass --include-redundant-scale-cases to restore older overlap. runner.run( CommandCase( - name="set-full-sync-saml-orgs-dry-run", - arguments=("--set", "--sync-saml-orgs"), - expected_log_command="set_full_sync_saml_orgs", + name="set-user-sync-saml-orgs-dry-run", + arguments=("--set", "maps.yaml", "--user", config.user, "--sync-saml-orgs"), + expected_log_command="set_user_sync_saml_orgs", must_contain=("Dry run complete",), ) ) @@ -1602,20 +2081,178 @@ def parse_log_timestamp(value: object) -> datetime.datetime | None: def workload_from_records(records: list[dict[str, Any]]) -> dict[str, int | float | str]: - """Collect stable workload-size fields so memory can be normalized.""" + """Collect named workload dimensions from structured log records. + + Earlier e2e summaries used raw field names from unrelated events, which made + values like `total_users` and `repo_count` ambiguous. Keep this summary + event-aware so each key says what it counts. + """ workload: dict[str, int | float | str] = {} for record in records: - for field_name in WORKLOAD_FIELDS: - value = record.get(field_name) - if isinstance(value, int | float): - old_value = workload.get(field_name) - if not isinstance(old_value, int | float) or value > old_value: - workload[field_name] = value - elif isinstance(value, str) and field_name not in workload: - workload[field_name] = value + event_name = optional_string(record.get("event")) + phase = optional_string(record.get("phase")) + if event_name == "capture_explicit_grants": + record_workload_max(workload, "sourcegraph_user_count", record.get("total_users")) + if phase == "end": + record_workload_max(workload, "captured_user_count", record.get("user_count")) + elif event_name in {"build_snapshot", "build_user_scoped_snapshot"} and phase == "end": + record_workload_max(workload, "snapshot_user_count_max", record.get("user_count")) + record_workload_max( + workload, + "snapshot_repos_with_explicit_grants_max", + record.get("repos_with_explicit_grants"), + ) + record_workload_max(workload, "snapshot_total_grants_max", record.get("total_grants")) + record_workload_max(workload, "captured_user_count", record.get("user_count")) + elif event_name == "user_explicit_repos_batch_fetch" and phase == "end": + record_workload_max(workload, "batch_user_count_max", record.get("user_count")) + record_workload_max( + workload, + "batch_fetched_grant_count_max", + record.get("fetched_grant_count") + if "fetched_grant_count" in record + else record.get("repo_count"), + ) + elif event_name == "load_repos_by_external_service" and phase == "end": + record_workload_max(workload, "loaded_repo_count", record.get("repo_count")) + record_workload_max( + workload, + "expected_repo_count", + record.get("expected_repo_count"), + ) + elif event_name == "apply_username_overwrites": + record_workload_max(workload, "apply_payload_count", record.get("payload_count")) + record_workload_max( + workload, + "apply_payload_grant_count", + record.get("payload_grant_count") + if "payload_grant_count" in record + else record.get("total_users"), + ) + record_workload_max(workload, "parallelism", record.get("parallelism")) + if phase == "end": + record_workload_max( + workload, + "apply_mutations_succeeded", + record.get("succeeded"), + ) + record_workload_max(workload, "apply_mutations_failed", record.get("failed")) + record_workload_max(workload, "apply_mutations_canceled", record.get("canceled")) + elif ( + event_name + in { + "cmd_get", + "cmd_restore", + "cmd_restore_user_scoped", + "cmd_set", + "cmd_set_additive_user", + "cmd_set_additive_users_without_explicit_perms", + } + and phase == "end" + ): + record_command_workload(workload, record) + elif event_name in {"sync_saml_orgs", "cmd_sync_saml_orgs"} and phase == "end": + record_workload_max( + workload, + "target_organizations", + record.get("target_organizations"), + ) + record_workload_max(workload, "desired_memberships", record.get("desired_memberships")) + + record_workload_model_dimensions(workload) return workload +def record_command_workload(workload: dict[str, int | float | str], record: dict[str, Any]) -> None: + """Copy command-level counts using names that preserve their meaning.""" + event_name = optional_string(record.get("event")) + repo_count = record.get("repo_count") + total_grants = record.get("total_grants") + if event_name == "cmd_set": + record_workload_max(workload, "planned_repo_count", repo_count) + record_workload_max(workload, "planned_total_grants", total_grants) + elif event_name == "cmd_get": + record_workload_max(workload, "selected_user_count", record.get("user_count")) + record_workload_max(workload, "selected_total_grants", total_grants) + elif event_name == "cmd_restore": + record_workload_max(workload, "restore_snapshot_repo_count", record.get("snapshot_repos")) + record_workload_max( + workload, + "restore_snapshot_total_grants", + record.get("snapshot_grants"), + ) + elif event_name == "cmd_set_additive_user": + record_workload_max(workload, "selected_user_count", record.get("user_count")) + record_workload_max(workload, "planned_repo_count", repo_count) + record_workload_max(workload, "planned_total_grants", total_grants) + + record_workload_max(workload, "mapping_count", record.get("mapping_count")) + record_workload_max(workload, "mutations_succeeded", record.get("mutations_succeeded")) + record_workload_max(workload, "mutations_failed", record.get("mutations_failed")) + record_workload_max(workload, "mutations_canceled", record.get("mutations_canceled")) + + +def record_workload_model_dimensions(workload: dict[str, int | float | str]) -> None: + """Add the canonical dimensions used by memory modeling.""" + user_count = max_workload_number( + workload, + ( + "selected_user_count", + "captured_user_count", + "snapshot_user_count_max", + "sourcegraph_user_count", + ), + ) + repo_count = max_workload_number( + workload, + ( + "planned_repo_count", + "restore_snapshot_repo_count", + "snapshot_repos_with_explicit_grants_max", + "loaded_repo_count", + ), + ) + grant_count = max_workload_number( + workload, + ( + "planned_total_grants", + "restore_snapshot_total_grants", + "selected_total_grants", + "snapshot_total_grants_max", + "apply_payload_grant_count", + ), + ) + if user_count is not None: + workload["memory_model_user_count"] = user_count + if repo_count is not None: + workload["memory_model_repo_count"] = repo_count + if grant_count is not None: + workload["memory_model_grant_count"] = grant_count + + +def max_workload_number( + workload: dict[str, int | float | str], field_names: Sequence[str] +) -> int | float | None: + """Return the largest numeric value found for the supplied workload fields.""" + values = [ + value + for field_name in field_names + if isinstance((value := workload.get(field_name)), int | float) + ] + return max(values) if values else None + + +def record_workload_max( + workload: dict[str, int | float | str], field_name: str, value: object +) -> None: + """Record the maximum numeric value for a named workload dimension.""" + if isinstance(value, bool) or not isinstance(value, int | float): + return + old_value = workload.get(field_name) + if not isinstance(old_value, int | float) or value > old_value: + workload[field_name] = value + + def artifact_sizes_for_run(log_path: Path) -> dict[str, int]: """Return sizes of JSON artifacts in the same run directory as the log.""" run_directory = log_path.parent @@ -1636,6 +2273,54 @@ def wait_for_jaeger_trace_collectors(collectors: list[JaegerTraceCollector]) -> collector.wait() +def graphql_query_metadata_from_record(record: dict[str, Any]) -> dict[str, Any] | None: + """Return correlation metadata from a structured `graphql_query` log record.""" + if record.get("event") != "graphql_query": + return None + metadata: dict[str, Any] = { + "span_id": record.get("span"), + "parent_span_id": record.get("parent_span"), + "trace_id": record.get("trace"), + } + phase = record.get("phase") + if phase == "start": + metadata["started_at"] = record.get("ts") + elif phase == "end": + metadata["ended_at"] = record.get("ts") + for field_name in ( + "cursor_present", + "duration_ms", + "error_type", + "graphql_client", + "page_number", + "page_size", + "query_bytes", + "query_name", + "response_fields", + "status", + "url", + "variable_names", + # Current src-py-lib logs variable names only. Keep these optional fields + # so raw trace artifacts automatically include values if the GraphQL log + # event grows an opt-in sanitized-variable field later. + "input_variables", + "variable_values", + "variables", + ): + if field_name in record: + metadata[field_name] = record[field_name] + return {key: value for key, value in metadata.items() if value is not None} + + +def graphql_query_span_key(record: dict[str, Any]) -> tuple[str, str] | None: + """Return the `(trace_id, span_id)` key for a GraphQL query log span.""" + trace_id = optional_string(record.get("trace")) + span_id = optional_string(record.get("span")) + if trace_id is None or span_id is None: + return None + return trace_id, span_id + + def graphql_trace_request_from_record(record: dict[str, Any]) -> dict[str, Any] | None: if record.get("event") != "http_request" or record.get("phase") != "end": return None @@ -1997,9 +2682,10 @@ def write_results_files( results: list[CommandResult], comparisons: list[CaseComparison], config: EndToEndConfig, + sourcegraph_load_monitor: SourcegraphLoadMonitor | None, ) -> None: if config.results_json is not None: - write_results_json(config.results_json, results, comparisons) + write_results_json(config.results_json, results, comparisons, sourcegraph_load_monitor) if config.results_csv is not None: write_results_csv(config.results_csv, results) phase_csv = phase_results_csv_path(config.results_csv) @@ -2010,12 +2696,20 @@ def write_results_json( path: Path, results: list[CommandResult], comparisons: list[CaseComparison], + sourcegraph_load_monitor: SourcegraphLoadMonitor | None, ) -> None: path.parent.mkdir(parents=True, exist_ok=True) + sourcegraph_monitor: dict[str, Any] | None = None + if sourcegraph_load_monitor is not None: + sourcegraph_monitor = { + "output_dir": str(sourcegraph_load_monitor.output_dir), + "log_path": str(sourcegraph_load_monitor.log_path), + } with path.open("w", encoding="utf-8") as output_file: json.dump( { "generated_at": datetime.datetime.now(datetime.UTC).isoformat(), + "sourcegraph_load_monitor": sourcegraph_monitor, "results": [result_to_json(result) for result in results], "comparisons": [comparison_to_json(comparison) for comparison in comparisons], }, @@ -2240,7 +2934,11 @@ def normalized_memory(result: CommandResult) -> dict[str, float]: if peak_rss_mb is None: return {} normalized: dict[str, float] = {} - for field_name in ("user_count", "total_users", "repo_count", "total_grants"): + for field_name in ( + "memory_model_user_count", + "memory_model_repo_count", + "memory_model_grant_count", + ): value = result.workload.get(field_name) if isinstance(value, int | float) and value > 0: normalized[f"peak_rss_mb_per_{field_name}"] = peak_rss_mb / float(value) diff --git a/dev/test-plan.md b/dev/test-plan.md index 521d165..d71a735 100644 --- a/dev/test-plan.md +++ b/dev/test-plan.md @@ -118,6 +118,71 @@ These numbers don't have published baselines yet; this run *creates* them. The deliverable is "we now know `--set --apply` with `--parallelism 16` hits N MB RSS and W seconds of snapshot wall-clock at G grants." +### Memory-per-grant model + +Generate exact users × repos maps and, when ready, run them through the CLI: + +```bash +uv run python dev/run-memory-model-sweep.py + +uv run python dev/run-memory-model-sweep.py \ + --run \ + --parallelism 1 +``` + +The runner writes generated maps and `results.json` under +`src-auth-perms-sync-runs//memory-model-sweep//`. +It uses an inventory-aware `--cases auto` sweep and dry-run mode by default. +On an instance with 1K+ visible repos, `auto` includes repo-axis points up to +1K repos and mixed cases up to 100K planned grants. Use explicit cases for +larger stress points, and use `--mode apply-no-backup --allow-apply` only on a +scratch instance: + +```bash +uv run python dev/run-memory-model-sweep.py \ + --cases '1x1,10000x1,1x1000,100x1000,1000x1000,10000x1000' \ + --run \ + --parallelism 1 +``` + +Fit memory from repeated e2e JSON results instead of dividing one run's +`peak_rss_mb` by one run's grants: + +```bash +uv run python dev/analyze-memory.py results/*.json \ + --command set_full \ + --case-regex 'set-full' \ + --features users,repos,grants \ + --estimate-users 10000 \ + --estimate-repos 100 +``` + +The analyzer fits: + +```text +peak RSS MiB = intercept + users*b1 + repos*b2 + grants*b3 +``` + +Use one command mode per fit (`set_full` with backup, `set_full --no-backup`, +`restore`, etc.). Mixing modes smears fixed snapshot / apply costs into the +per-grant coefficient. + +On the sgdev test instance with 10,001 users and 1,023 visible repos, a +dry-run `10000x1000` case planned 10M grants and measured about 651 MiB peak +RSS. The grants-only fit across 14 dry-run observations was roughly 69 MiB +fixed plus 61 bytes per planned grant. Re-measure after meaningful mapping or +snapshot changes; these numbers describe dry-run planning memory, not apply +mutation throughput. + +The e2e `workload` object now uses event-aware names. In older result JSON, +`total_users: 40004` came from `apply_username_overwrites` and meant "username +entries in mutation payloads" (`4 mutated repos × 10001 users`), not total +Sourcegraph users. Likewise `repo_count: 575` came from a batch fetch and meant +"grant rows fetched for 25 users" (`25 × 23`), not distinct repos. New results +expose those as `apply_payload_grant_count` and +`batch_fetched_grant_count_max`, plus canonical `memory_model_user_count`, +`memory_model_repo_count`, and `memory_model_grant_count` fields for modeling. + --- ## Failure injection (scenario e) diff --git a/maps-example.yaml b/maps-example.yaml index c854791..e5dfb7e 100644 --- a/maps-example.yaml +++ b/maps-example.yaml @@ -1,26 +1,37 @@ # Auth provider → code host connection mapping rules # Maintain this file using auth-providers.yaml and code-hosts.yaml as references. # Those files are generated under src-auth-perms-sync-runs//. +# +# These examples cover every supported filter field: +# - users.authProvider: clientID, configID, displayName, samlGroup, serviceID, type +# - users.emails (verified email addresses) +# - users.usernames +# - repos.codeHostConnection: config, displayName, id, kind, url +# - repos.names +# - repos.regexes maps: -- name: All users from Line of Business 1 - User Group 1 get access to all repos synced from service account 1 +- name: SAML group users get all repos synced from one service account users: authProvider: + configID: okta samlGroup: LOB1-GROUP1 + type: saml repos: codeHostConnection: config: username: LOB1-SA1 -- name: All users from Line of Business 1 - User Group 2 get access to all repos synced from service account 2 +- name: Users from one exact auth provider get repos from one exact code host connection users: authProvider: - samlGroup: LOB1-GROUP2 + clientID: sourcegraph + displayName: Okta SAML + serviceID: https://idp.example.com/saml repos: codeHostConnection: - config: - username: LOB1-SA2 + id: 12 - name: All Okta SAML users get access to all Bitbucket repos users: @@ -31,9 +42,36 @@ maps: codeHostConnection: kind: BITBUCKETSERVER +- name: All builtin users get repos from the GitHub Cloud connection + users: + authProvider: + type: builtin + repos: + codeHostConnection: + displayName: GitHub Cloud + +- name: All builtin users get repos from the GitHub URL connection + users: + authProvider: + type: builtin + repos: + codeHostConnection: + url: https://github.com/ + +- name: Exact user gets named repos + users: + emails: + - alice@example.com + - bob@example.com + repos: + names: + - github.com/example/private-repo + - name: All builtin users get access to all repos under the github.com/example org, from any code host connection users: authProvider: type: builtin repos: - regex: https://github.com/example/.* + regexes: + - ^github\.com/example/.* + - ^gitlab\.com/example/.* diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 8761abb..7e25143 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -192,6 +192,14 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi ge=1, help="Max attempts per HTTP request before giving up (default: 5)", ) + http_timeout_seconds: float = src.config_field( + default=60.0, + env_var="SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS", + cli_flag="--http-timeout-seconds", + metavar="SECONDS", + gt=0, + help="HTTP read timeout per request in seconds (default: 60)", + ) sample_interval: float = src.config_field( default=10.0, env_var="SRC_AUTH_PERMS_SYNC_SAMPLE_INTERVAL", @@ -387,6 +395,7 @@ def run_fields( "explicit_permissions_batch_size": config.explicit_permissions_batch_size, "trace": config.trace, "max_attempts": config.max_attempts, + "http_timeout_seconds": config.http_timeout_seconds, "no_backup": config.no_backup, "sample_interval": config.sample_interval, "user_created_after": config.created_after, @@ -404,6 +413,7 @@ def run_with_client( ) -> None: """Create a client, run the selected command, and always close HTTP resources.""" http = src.HTTPClient( + timeout=config.http_timeout_seconds, user_agent="src-auth-perms-sync/0.1 (+python)", max_attempts=config.max_attempts, max_connections=config.parallelism, diff --git a/src/src_auth_perms_sync/permissions/apply.py b/src/src_auth_perms_sync/permissions/apply.py index 7849855..14969ba 100644 --- a/src/src_auth_perms_sync/permissions/apply.py +++ b/src/src_auth_perms_sync/permissions/apply.py @@ -301,12 +301,12 @@ def _apply_repo_overwrite_plans( ) -> shared_types.MutationCounts: """Dispatch per-repo overwrite mutations with bounded in-flight work.""" max_pending_futures = max(1, parallelism * 2) - total_users = sum(len(overwrite.usernames) for overwrite in overwrites) + payload_grant_count = sum(len(overwrite.usernames) for overwrite in overwrites) with src.event( "apply_username_overwrites", payload_count=len(overwrites), parallelism=parallelism, - total_users=total_users, + payload_grant_count=payload_grant_count, max_pending_futures=max_pending_futures, ) as batch_event: succeeded = 0 diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 95d4041..fbdc00c 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -386,7 +386,14 @@ def cmd_set_additive_user( context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) if context is None: return run_context.CommandData() - user = _resolve_user_identifier(client, user_identifier) + include_user_emails = permissions_mapping.mapping_rules_need_user_emails( + context.mapping_rules + ) + user = _resolve_user_identifier( + client, + user_identifier, + include_emails=include_user_emails, + ) if user_created_after is not None: candidate_user_ids = user_ids_created_on_or_after(client, user_created_after) if user["id"] not in candidate_user_ids: @@ -446,6 +453,9 @@ def cmd_set_additive_users_without_explicit_perms( context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) if context is None: return run_context.CommandData() + include_user_emails = permissions_mapping.mapping_rules_need_user_emails( + context.mapping_rules + ) resolved_mappings = resolve_additive_mappings(context) candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) log.info("Received %d non-deleted user candidate(s).", len(candidates)) @@ -455,7 +465,11 @@ def cmd_set_additive_users_without_explicit_perms( for candidate in candidates: if permissions_sourcegraph.user_has_explicit_repos(client, candidate["id"]): continue - user = permissions_sourcegraph.get_user_by_id(client, candidate["id"]) + user = permissions_sourcegraph.get_user_by_id( + client, + candidate["id"], + include_emails=include_user_emails, + ) if user is None: log.warning( "Skipping user candidate %s: user no longer exists.", @@ -492,18 +506,33 @@ def cmd_set_additive_users_without_explicit_perms( def _resolve_user_identifier( - client: src.SourcegraphClient, user_identifier: str + client: src.SourcegraphClient, + user_identifier: str, + *, + include_emails: bool = False, ) -> shared_types.User: """Resolve username/email input to one Sourcegraph user.""" user: shared_types.User | None if "@" in user_identifier: user = permissions_sourcegraph.get_user_by_email( - client, user_identifier - ) or permissions_sourcegraph.get_user_by_username(client, user_identifier) + client, + user_identifier, + include_emails=include_emails, + ) or permissions_sourcegraph.get_user_by_username( + client, + user_identifier, + include_emails=include_emails, + ) else: user = permissions_sourcegraph.get_user_by_username( - client, user_identifier - ) or permissions_sourcegraph.get_user_by_email(client, user_identifier) + client, + user_identifier, + include_emails=include_emails, + ) or permissions_sourcegraph.get_user_by_email( + client, + user_identifier, + include_emails=include_emails, + ) if user is None: raise SystemExit(f"No Sourcegraph user found for {user_identifier!r}.") if user["username"] != user_identifier: diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 1ef3fda..6439fdc 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -98,6 +98,7 @@ def _capture_full_set_snapshot_state( explicit_permissions_batch_size: int, bind_id_mode: str, worker_pool: ThreadPoolExecutor | None = None, + include_user_emails: bool = False, ) -> _FullSetUserState: """Load users while capturing the before-snapshot.""" total_users = shared_sourcegraph.count_users(client) @@ -110,7 +111,11 @@ def _capture_full_set_snapshot_state( before_timestamp = backups.backup_timestamp() before_snapshot = permission_snapshot.build_snapshot( client, - shared_sourcegraph.list_users_streaming(client, collect_into=users), + shared_sourcegraph.list_users_streaming( + client, + collect_into=users, + include_emails=include_user_emails, + ), parallelism, bind_id_mode, input_path, @@ -140,6 +145,7 @@ def _load_full_set_snapshot_state( bind_id_mode: str, capture_before: bool, worker_pool: ThreadPoolExecutor | None = None, + include_user_emails: bool = False, ) -> _FullSetUserState: """Load all users, optionally with a before-snapshot.""" if capture_before: @@ -150,10 +156,14 @@ def _load_full_set_snapshot_state( explicit_permissions_batch_size, bind_id_mode, worker_pool, + include_user_emails=include_user_emails, ) log.info("Loading users from %s ...", client.endpoint) - users = shared_sourcegraph.list_users_with_accounts(client) + users = shared_sourcegraph.list_users_with_accounts( + client, + include_emails=include_user_emails, + ) log.info("Received %d total users.", len(users)) return _FullSetUserState(users=users) @@ -656,6 +666,7 @@ def _load_full_set_plan( retain_saml_group_users: bool, worker_pool: ThreadPoolExecutor | None = None, ) -> _FullSetLoadedPlan: + include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) user_state = _load_full_set_snapshot_state( client, input_path, @@ -664,6 +675,7 @@ def _load_full_set_plan( bind_id_mode, capture_before=capture_before, worker_pool=worker_pool, + include_user_emails=include_user_emails, ) before_path: Path | None = None if capture_before: diff --git a/src/src_auth_perms_sync/permissions/mapping.py b/src/src_auth_perms_sync/permissions/mapping.py index 904a63e..33a9fb0 100644 --- a/src/src_auth_perms_sync/permissions/mapping.py +++ b/src/src_auth_perms_sync/permissions/mapping.py @@ -1,11 +1,11 @@ """Permission mapping resolution: validate rules and match users/repos. Each mapping rule has a `users:` section and a `repos:` section, each -containing one or more matchers (today: `authProvider`, -`codeHostConnection`, and `regex`). Within a matcher, the supplied -keys AND together against the discovered auth-provider / external- -service entries. Across mapping rules, `cmd_set` unions the per-repo -user sets at apply time — see `src/src_auth_perms_sync/permissions/types.py` for the rationale. +containing one or more matchers. Within a matcher, the supplied keys +AND together against the discovered auth-provider / external-service +entries. Across sibling matchers, results intersect. Across mapping +rules, `cmd_set` unions the per-repo user sets at apply time — see +`src/src_auth_perms_sync/permissions/types.py` for the rationale. Adding a new matcher type: @@ -103,7 +103,13 @@ def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: ) -_KNOWN_USER_MATCHERS: set[str] = {"authProvider"} +_KNOWN_USER_MATCHERS: set[str] = {"authProvider", "emails", "usernames"} +_KNOWN_REPO_MATCHERS: set[str] = {"codeHostConnection", "names", "regexes"} + + +def mapping_rules_need_user_emails(mapping_rules: list[permission_types.MappingRule]) -> bool: + """Return whether any mapping rule filters users by verified email.""" + return any("emails" in mapping["users"] for mapping in mapping_rules) def _validate_users_section(section: dict[str, object], prefix: str) -> list[str]: @@ -123,6 +129,10 @@ def _validate_users_section(section: dict[str, object], prefix: str) -> list[str ) if "samlGroup" in auth_provider: errors.extend(_validate_saml_group(auth_provider, prefix)) + if "emails" in section: + errors.extend(_validate_string_list(section["emails"], prefix, "users.emails")) + if "usernames" in section: + errors.extend(_validate_string_list(section["usernames"], prefix, "users.usernames")) return errors @@ -157,7 +167,7 @@ def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str """Reject unknown matcher keys and validate `codeHostConnection:` shape.""" errors: list[str] = [] for key in section: - if key not in {"codeHostConnection", "regex"}: + if key not in _KNOWN_REPO_MATCHERS: errors.append(f"{prefix}: unknown repos matcher {key!r}") code_host_section = cast(dict[str, object] | None, section.get("codeHostConnection")) if code_host_section is not None: @@ -190,17 +200,46 @@ def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str f"key/value pairs to deep-subset-match against the service's " f"parsed config (got {type(code_host_section['config']).__name__})" ) - regex = section.get("regex") - if regex is not None: - if not isinstance(regex, str): - errors.append(f"{prefix}: repos.regex must be a string (got {type(regex).__name__})") - elif not regex: - errors.append(f"{prefix}: repos.regex is an empty string") - else: - try: - re.compile(regex) - except re.error as exception: - errors.append(f"{prefix}: repos.regex is not a valid Python regex: {exception}") + if "names" in section: + errors.extend(_validate_string_list(section["names"], prefix, "repos.names")) + regexes = section.get("regexes") + if regexes is not None: + errors.extend(_validate_regexes(regexes, prefix)) + return errors + + +def _validate_regexes(value: object, prefix: str) -> list[str]: + """Validate list-based regex filters.""" + errors = _validate_string_list(value, prefix, "repos.regexes") + if errors: + return errors + + for index, pattern in enumerate(cast(list[str], value)): + try: + re.compile(pattern) + except re.error as exception: + errors.append( + f"{prefix}: repos.regexes[{index}] is not a valid Python regex: {exception}" + ) + return errors + + +def _validate_string_list(value: object, prefix: str, path: str) -> list[str]: + """Validate list-based exact-match filters.""" + if not isinstance(value, list): + return [f"{prefix}: {path} must be a list of strings (got {type(value).__name__})"] + + items = cast(list[object], value) + errors: list[str] = [] + if not items: + errors.append(f"{prefix}: {path} is empty (matches nothing)") + for index, item in enumerate(items): + if not isinstance(item, str): + errors.append( + f"{prefix}: {path}[{index}] must be a string (got {type(item).__name__} {item!r})" + ) + elif not item: + errors.append(f"{prefix}: {path}[{index}] is an empty string") return errors @@ -243,6 +282,15 @@ def resolve_users( saml_groups_attribute_names, ) } + elif key == "emails": + current_ids = { + user["id"] for user in _users_matching_emails(cast(list[str], matcher), all_users) + } + elif key == "usernames": + current_ids = { + user["id"] + for user in _users_matching_usernames(cast(list[str], matcher), all_users) + } else: # validate_mapping_rules catches this earlier with a clearer # message; this only fires for programmatic callers. @@ -273,6 +321,12 @@ def user_matches_users_section( saml_groups_attribute_names, ): return False + elif key == "emails": + if not _user_matches_emails(user, cast(list[str], matcher)): + return False + elif key == "usernames": + if user["username"] not in cast(list[str], matcher): + return False else: # validate_mapping_rules catches this earlier with a clearer # message; this only fires for programmatic callers. @@ -280,6 +334,38 @@ def user_matches_users_section( return True +def _users_matching_emails( + emails: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users with at least one verified email in `emails`.""" + matched = [user for user in all_users if _user_matches_emails(user, emails)] + log.info(" emails → %d user(s) matched %d email(s)", len(matched), len(set(emails))) + return matched + + +def _user_matches_emails(user: shared_types.User, emails: list[str]) -> bool: + """Match only verified emails, mirroring Sourcegraph's `user(email:)` lookup.""" + email_set = set(emails) + return any( + user_email["verified"] and user_email["email"] in email_set + for user_email in user.get("emails", []) + ) + + +def _users_matching_usernames( + usernames: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users whose Sourcegraph username is listed exactly.""" + username_set = set(usernames) + matched = [user for user in all_users if user["username"] in username_set] + log.info( + " usernames → %d user(s) matched %d username(s)", + len(matched), + len(username_set), + ) + return matched + + def _users_matching_auth_provider( matcher: permission_types.AuthProviderMatcher, all_users: list[shared_types.User], @@ -449,7 +535,7 @@ def resolve_repos( matched_ids: set[str] | None = None repo_index: dict[str, permission_types.Repository] = {} - ordered_keys = [key for key in ("codeHostConnection", "regex") if key in section] + ordered_keys = [key for key in ("codeHostConnection", "names", "regexes") if key in section] for key in ordered_keys: matcher = section[key] if key == "codeHostConnection": @@ -458,13 +544,20 @@ def resolve_repos( services_by_id, repos_by_external_service_id, ) - elif key == "regex": + elif key == "names": + candidate_repos = ( + [repo_index[repo_id] for repo_id in matched_ids] + if matched_ids is not None + else list(all_repos_by_id.values()) + ) + repos = _repos_matching_names(cast(list[str], matcher), candidate_repos) + elif key == "regexes": candidate_repos = ( [repo_index[repo_id] for repo_id in matched_ids] if matched_ids is not None else list(all_repos_by_id.values()) ) - repos = _repos_matching_regex(cast(str, matcher), candidate_repos) + repos = _repos_matching_regexes(cast(list[str], matcher), candidate_repos) else: # validate_mapping_rules catches this earlier with a clearer # message; this only fires for programmatic callers. @@ -479,6 +572,16 @@ def resolve_repos( return [repo_index[repo_id] for repo_id in matched_ids] +def _repos_matching_names( + names: list[str], repos: list[permission_types.Repository] +) -> list[permission_types.Repository]: + """Return repos whose Sourcegraph name is listed exactly.""" + name_set = set(names) + matched = [repo for repo in repos if repo["name"] in name_set] + log.info(" names → %d repo(s) matched %d name(s)", len(matched), len(name_set)) + return matched + + def _repos_matching_code_host_connection( matcher: permission_types.CodeHostConnectionMatcher, services_by_id: dict[int, permission_types.ExternalService], @@ -505,22 +608,26 @@ def _repos_matching_code_host_connection( return list(matched_repos.values()) -def _repos_matching_regex( - pattern: str, repos: list[permission_types.Repository] +def _repos_matching_regexes( + patterns: list[str], repos: list[permission_types.Repository] ) -> list[permission_types.Repository]: - """Return repos whose name matches `pattern` using Python `re`. + """Return repos whose name matches any pattern using Python `re`. Sourcegraph repo names usually omit the URL scheme (for example `github.com/example/repo`). To keep URL-looking operator patterns useful, also test `https://`. """ - compiled = re.compile(pattern) + compiled_patterns = [re.compile(pattern) for pattern in patterns] matched = [ repo for repo in repos - if compiled.search(repo["name"]) or compiled.search(f"https://{repo['name']}") + if any( + compiled_pattern.search(repo["name"]) + or compiled_pattern.search(f"https://{repo['name']}") + for compiled_pattern in compiled_patterns + ) ] - log.info(" regex → %d repo(s) matched %r", len(matched), pattern) + log.info(" regexes → %d repo(s) matched %d pattern(s)", len(matched), len(patterns)) return matched diff --git a/src/src_auth_perms_sync/permissions/queries.py b/src/src_auth_perms_sync/permissions/queries.py index afa83b5..7a7e2ec 100644 --- a/src/src_auth_perms_sync/permissions/queries.py +++ b/src/src_auth_perms_sync/permissions/queries.py @@ -89,32 +89,57 @@ } """ -QUERY_USER_BY_USERNAME = f""" +USER_EMAIL_FIELDS = """ +emails { + email + verified +} +""" + + +def user_fields(*, include_emails: bool = False) -> str: + """Return user fields, adding emails only when downstream matching needs them.""" + if include_emails: + return f"{USER_FIELDS}\n{USER_EMAIL_FIELDS}" + return USER_FIELDS + + +def query_user_by_username(*, include_emails: bool = False) -> str: + return f""" query UserByUsername($username: String!) {{ user(username: $username) {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} """ -QUERY_USER_BY_EMAIL = f""" + +def query_user_by_email(*, include_emails: bool = False) -> str: + return f""" query UserByEmail($email: String!) {{ user(email: $email) {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} """ -QUERY_USER_BY_ID = f""" + +def query_user_by_id(*, include_emails: bool = False) -> str: + return f""" query UserByID($id: ID!) {{ node(id: $id) {{ ... on User {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} }} """ + +QUERY_USER_BY_USERNAME = query_user_by_username() +QUERY_USER_BY_EMAIL = query_user_by_email() +QUERY_USER_BY_ID = query_user_by_id() + QUERY_SITE_USERS = """ query SiteUsers($limit: Int!, $offset: Int!, $createdAt: SiteUsersDateRangeInput) { site { diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index ae3928c..3073205 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -230,7 +230,7 @@ def _fetch( user["username"]: repository_ids_by_user_id.get(user["id"], []) for user in batch_users } - fetch_event["repo_count"] = sum( + fetch_event["fetched_grant_count"] = sum( len(repository_ids) for repository_ids in repository_ids_by_username.values() ) fetch_event["per_user_failures"] = failures diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index ec48ba8..ed4e32b 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -41,29 +41,53 @@ def list_repos_for_external_service( ] -def get_user_by_username(client: src.SourcegraphClient, username: str) -> shared_types.User | None: +def get_user_by_username( + client: src.SourcegraphClient, + username: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Return the exact Sourcegraph user for `username`, if it exists.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_USERNAME, cast(src.JSONDict, {"username": username})), + client.graphql( + queries.query_user_by_username(include_emails=include_emails), + cast(src.JSONDict, {"username": username}), + ), ) return cast(shared_types.User | None, data.get("user")) -def get_user_by_email(client: src.SourcegraphClient, email: str) -> shared_types.User | None: +def get_user_by_email( + client: src.SourcegraphClient, + email: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Return the user owning the verified email address, if it exists.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_EMAIL, cast(src.JSONDict, {"email": email})), + client.graphql( + queries.query_user_by_email(include_emails=include_emails), + cast(src.JSONDict, {"email": email}), + ), ) return cast(shared_types.User | None, data.get("user")) -def get_user_by_id(client: src.SourcegraphClient, user_id: str) -> shared_types.User | None: +def get_user_by_id( + client: src.SourcegraphClient, + user_id: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Hydrate a User node by GraphQL ID.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_ID, cast(src.JSONDict, {"id": user_id})), + client.graphql( + queries.query_user_by_id(include_emails=include_emails), + cast(src.JSONDict, {"id": user_id}), + ), ) return cast(shared_types.User | None, data.get("node")) diff --git a/src/src_auth_perms_sync/permissions/types.py b/src/src_auth_perms_sync/permissions/types.py index f57ffab..a320892 100644 --- a/src/src_auth_perms_sync/permissions/types.py +++ b/src/src_auth_perms_sync/permissions/types.py @@ -85,11 +85,14 @@ class CodeHostConnectionMatcher(TypedDict, total=False): class UsersFilter(TypedDict, total=False): authProvider: AuthProviderMatcher + emails: list[str] + usernames: list[str] class ReposFilter(TypedDict, total=False): codeHostConnection: CodeHostConnectionMatcher - regex: str + names: list[str] + regexes: list[str] class MappingRule(TypedDict): diff --git a/src/src_auth_perms_sync/shared/queries.py b/src/src_auth_perms_sync/shared/queries.py index 4ffb1e4..c833e42 100644 --- a/src/src_auth_perms_sync/shared/queries.py +++ b/src/src_auth_perms_sync/shared/queries.py @@ -38,15 +38,25 @@ } """ -QUERY_USERS = """ -query ListUsers($first: Int!, $after: String) { - users(first: $first, after: $after) { - nodes { +USER_EMAIL_FIELDS = """ emails { + email + verified + } +""" + + +def query_users(*, include_emails: bool = False) -> str: + """Return the users page query, adding email fields only when requested.""" + email_fields = USER_EMAIL_FIELDS if include_emails else "" + return f""" +query ListUsers($first: Int!, $after: String) {{ + users(first: $first, after: $after) {{ + nodes {{ id username builtinAuth - externalAccounts(first: 50) { - nodes { +{email_fields} externalAccounts(first: 50) {{ + nodes {{ serviceType serviceID clientID @@ -56,10 +66,13 @@ # Admin. Returns null for serviceType where the resolver does # not expose data (e.g. plain GitHub OAuth without SSO). accountData - } - } - } - pageInfo { hasNextPage endCursor } - } -} + }} + }} + }} + pageInfo {{ hasNextPage endCursor }} + }} +}} """ + + +QUERY_USERS = query_users() diff --git a/src/src_auth_perms_sync/shared/sourcegraph.py b/src/src_auth_perms_sync/shared/sourcegraph.py index 3b39d94..f138e2d 100644 --- a/src/src_auth_perms_sync/shared/sourcegraph.py +++ b/src/src_auth_perms_sync/shared/sourcegraph.py @@ -32,11 +32,15 @@ def count_users(client: src.SourcegraphClient) -> int: return cast(int, data["users"]["totalCount"]) -def list_users_with_accounts(client: src.SourcegraphClient) -> list[shared_types.User]: +def list_users_with_accounts( + client: src.SourcegraphClient, + *, + include_emails: bool = False, +) -> list[shared_types.User]: return [ cast(shared_types.User, node) for node in client.stream_connection_nodes( - queries.QUERY_USERS, + queries.query_users(include_emails=include_emails), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ) @@ -46,6 +50,8 @@ def list_users_with_accounts(client: src.SourcegraphClient) -> list[shared_types def list_users_streaming( client: src.SourcegraphClient, collect_into: list[shared_types.User] | None = None, + *, + include_emails: bool = False, ) -> Iterator[shared_types.User]: """Stream ListUsers pages one at a time, yielding each User as it arrives. @@ -59,7 +65,7 @@ def list_users_streaming( streaming benefit in one pass — no double-pagination. """ for node in client.stream_connection_nodes( - queries.QUERY_USERS, + queries.query_users(include_emails=include_emails), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ): diff --git a/src/src_auth_perms_sync/shared/types.py b/src/src_auth_perms_sync/shared/types.py index 9429a41..7ac9096 100644 --- a/src/src_auth_perms_sync/shared/types.py +++ b/src/src_auth_perms_sync/shared/types.py @@ -30,11 +30,17 @@ class ExternalAccountConnection(TypedDict): nodes: list[ExternalAccount] +class UserEmail(TypedDict): + email: str + verified: bool + + class User(TypedDict): id: str username: str builtinAuth: bool externalAccounts: ExternalAccountConnection + emails: NotRequired[list[UserEmail]] @dataclass(frozen=True, slots=True) diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index fe6ef17..7c47f77 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -180,6 +180,15 @@ def test_explicit_permissions_batch_size_rejects_values_below_one(self) -> None: with self.assertRaisesRegex(shared_config.ConfigError, "greater than or equal to 1"): load_config_from_env(SRC_AUTH_PERMS_SYNC_EXPLICIT_PERMISSIONS_BATCH_SIZE="0") + def test_http_timeout_config_is_loaded_from_env(self) -> None: + config = load_config_from_env(SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS="90") + + self.assertEqual(90, config.http_timeout_seconds) + + def test_http_timeout_rejects_values_at_or_below_zero(self) -> None: + with self.assertRaisesRegex(shared_config.ConfigError, "greater than 0"): + load_config_from_env(SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS="0") + def test_trace_config_is_loaded_from_env(self) -> None: config = load_config_from_env(SRC_AUTH_PERMS_SYNC_TRACE="true") @@ -212,6 +221,33 @@ def capture_client( self.assertEqual(1, len(captured_clients)) self.assertTrue(captured_clients[0].trace) + def test_run_with_client_uses_configured_http_timeout(self) -> None: + configuration = make_config(http_timeout_seconds=75.0) + command = cli.resolve_command(configuration) + captured_clients: list[src.SourcegraphClient] = [] + + def capture_client( + _config: cli.SrcAuthPermissionsSyncConfig, + _command: cli.ResolvedCommand, + client: src.SourcegraphClient, + _worker_pool: ThreadPoolExecutor, + ) -> None: + captured_clients.append(client) + + with ( + ThreadPoolExecutor(max_workers=1) as worker_pool, + mock.patch.object(cli, "run_command", side_effect=capture_client), + ): + cli.run_with_client( + configuration, + command, + "https://sourcegraph.example.com", + worker_pool, + ) + + self.assertEqual(1, len(captured_clients)) + self.assertEqual(75.0, captured_clients[0].http.timeout) + def test_validate_config_rejects_multiple_set_modes(self) -> None: self.assert_config_error( make_config(set_path=Path("maps.yaml"), full=True, user="alice"), @@ -249,6 +285,7 @@ def test_run_fields_include_concrete_command(self) -> None: self.assertEqual(True, fields["apply_flag"]) self.assertEqual(25, fields["explicit_permissions_batch_size"]) self.assertEqual(False, fields["trace"]) + self.assertEqual(60.0, fields["http_timeout_seconds"]) def test_run_command_passes_primary_data_to_combined_sync(self) -> None: configuration = make_config(get=True, sync_saml_organizations=True) diff --git a/tests/unit/test_maps.py b/tests/unit/test_maps.py index 34bb483..96303f9 100644 --- a/tests/unit/test_maps.py +++ b/tests/unit/test_maps.py @@ -1,12 +1,18 @@ from __future__ import annotations +import base64 +import itertools import tempfile import unittest from pathlib import Path +from typing import cast import yaml -from src_auth_perms_sync.permissions import maps +from src_auth_perms_sync.permissions import mapping, maps +from src_auth_perms_sync.permissions import queries as permission_queries +from src_auth_perms_sync.permissions import types as permission_types +from src_auth_perms_sync.shared import queries as shared_queries from src_auth_perms_sync.shared import types as shared_types @@ -73,3 +79,296 @@ def test_count_users_per_provider_counts_each_user_once_per_provider(self) -> No self.assertEqual(1, counts[maps.BUILTIN_PROVIDER_KEY]) self.assertEqual(1, counts[("saml", "https://idp.example.com", "sourcegraph")]) self.assertEqual(1, counts[("github", "https://github.com/", "github-client")]) + + +class MappingTests(unittest.TestCase): + def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: + rules_without_email_filters = cast( + list[permission_types.MappingRule], + [ + { + "users": {"usernames": ["alice"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + rules_with_email_filters = cast( + list[permission_types.MappingRule], + [ + { + "users": {"emails": ["alice@example.com"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + + self.assertFalse(mapping.mapping_rules_need_user_emails(rules_without_email_filters)) + self.assertTrue(mapping.mapping_rules_need_user_emails(rules_with_email_filters)) + + def test_user_filter_matchers_intersect_without_expanding_selection(self) -> None: + providers: list[shared_types.AuthProvider] = [ + { + "serviceType": "builtin", + "serviceID": "", + "clientID": "", + "displayName": "Builtin", + "isBuiltin": True, + "configID": "", + } + ] + users = [ + self.make_user("user-1", "alice", True, "alice@example.com", True), + self.make_user("user-2", "bob", True, "bob@example.com", True), + self.make_user("user-3", "carol", True, "carol@example.com", False), + self.make_user("user-4", "dana", False, "dana@example.com", True), + ] + user_filters: dict[str, object] = { + "authProvider": {"type": "builtin"}, + "emails": ["alice@example.com", "carol@example.com", "dana@example.com"], + "usernames": ["alice", "bob", "carol"], + } + single_filter_usernames = { + name: self.usernames_for( + mapping.resolve_users({name: matcher}, users, providers), + ) + for name, matcher in user_filters.items() + } + + for filter_count in range(2, len(user_filters) + 1): + for filter_names in itertools.combinations(user_filters, filter_count): + matched_usernames = self.usernames_for( + mapping.resolve_users( + {name: user_filters[name] for name in filter_names}, + users, + providers, + ) + ) + expected_usernames = self.intersection_for(filter_names, single_filter_usernames) + + self.assertEqual(expected_usernames, matched_usernames) + for name in filter_names: + self.assertLessEqual(matched_usernames, single_filter_usernames[name]) + + self.assertEqual( + {"alice"}, + self.usernames_for(mapping.resolve_users(user_filters, users, providers)), + ) + + def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> None: + sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") + example_private_repo = self.make_repo("repo-2", "github.com/example/private-repo") + gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") + example_public_repo = self.make_repo("repo-4", "github.com/example/public-repo") + all_repos = { + sourcegraph_repo["id"]: sourcegraph_repo, + example_private_repo["id"]: example_private_repo, + gitlab_repo["id"]: gitlab_repo, + example_public_repo["id"]: example_public_repo, + } + services_by_id = { + 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise"), + 2: self.make_external_service(2, "GITHUB", "GitHub Cloud"), + } + repos_by_external_service_id = { + 1: [sourcegraph_repo, example_private_repo, gitlab_repo], + 2: [example_public_repo], + } + repo_filters: dict[str, object] = { + "codeHostConnection": {"id": 1}, + "names": [ + "github.com/example/private-repo", + "gitlab.com/example/private-repo", + ], + "regexes": [ + r"^github\.com/example/", + r"^gitlab\.com/example/", + ], + } + single_filter_repo_names = { + name: self.repo_names_for( + mapping.resolve_repos( + {name: matcher}, + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ) + for name, matcher in repo_filters.items() + } + + for filter_count in range(2, len(repo_filters) + 1): + for filter_names in itertools.combinations(repo_filters, filter_count): + matched_repo_names = self.repo_names_for( + mapping.resolve_repos( + {name: repo_filters[name] for name in filter_names}, + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ) + expected_repo_names = self.intersection_for(filter_names, single_filter_repo_names) + + self.assertEqual(expected_repo_names, matched_repo_names) + for name in filter_names: + self.assertLessEqual(matched_repo_names, single_filter_repo_names[name]) + + self.assertEqual( + {"github.com/example/private-repo", "gitlab.com/example/private-repo"}, + self.repo_names_for( + mapping.resolve_repos( + repo_filters, + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ), + ) + + def test_validate_mapping_rules_accepts_string_list_filters(self) -> None: + mapping.validate_mapping_rules( + cast( + list[permission_types.MappingRule], + [ + { + "users": { + "emails": ["alice@example.com"], + "usernames": ["alice"], + }, + "repos": { + "names": ["github.com/example/private-repo"], + "regexes": [r"^github\.com/example/"], + }, + } + ], + ) + ) + + def test_repos_regexes_match_any_pattern(self) -> None: + sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") + github_repo = self.make_repo("repo-2", "github.com/example/private-repo") + gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") + all_repos = { + sourcegraph_repo["id"]: sourcegraph_repo, + github_repo["id"]: github_repo, + gitlab_repo["id"]: gitlab_repo, + } + + matched_repos = mapping.resolve_repos( + { + "regexes": [ + r"^github\.com/example/", + r"^gitlab\.com/example/", + ], + }, + {}, + {}, + all_repos, + ) + + self.assertEqual( + {"github.com/example/private-repo", "gitlab.com/example/private-repo"}, + self.repo_names_for(matched_repos), + ) + + def test_validate_mapping_rules_rejects_non_string_list_filters(self) -> None: + with self.assertRaises(SystemExit) as raised: + mapping.validate_mapping_rules( + cast( + list[permission_types.MappingRule], + [ + { + "users": { + "emails": "alice@example.com", + "usernames": [""], + }, + "repos": { + "names": [123], + "regexes": ["["], + }, + }, + { + "users": {"usernames": ["alice"]}, + "repos": {"regex": r"^github\.com/example/"}, + }, + ], + ) + ) + + message = str(raised.exception) + self.assertIn("users.emails must be a list of strings", message) + self.assertIn("users.usernames[0] is an empty string", message) + self.assertIn("repos.names[0] must be a string", message) + self.assertIn("repos.regexes[0] is not a valid Python regex", message) + self.assertIn("unknown repos matcher 'regex'", message) + + def make_user( + self, + user_id: str, + username: str, + builtin_auth: bool, + email: str, + verified: bool, + ) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": builtin_auth, + "emails": [{"email": email, "verified": verified}], + "externalAccounts": {"nodes": []}, + } + + def make_repo(self, repo_id: str, name: str) -> permission_types.Repository: + return {"id": repo_id, "name": name} + + def make_external_service( + self, + external_service_id: int, + kind: str, + display_name: str, + ) -> permission_types.ExternalService: + graphql_id = base64.b64encode(f"ExternalService:{external_service_id}".encode()).decode() + return { + "id": graphql_id, + "kind": kind, + "displayName": display_name, + "url": f"https://code-host-{external_service_id}.example.com", + "repoCount": 0, + "createdAt": "2026-05-30T00:00:00Z", + "updatedAt": "2026-05-30T00:00:00Z", + "lastSyncAt": None, + "nextSyncAt": None, + "lastSyncError": None, + "warning": None, + "unrestricted": False, + "suspended": False, + "hasConnectionCheck": False, + "supportsRepoExclusion": False, + "creator": None, + "lastUpdater": None, + "config": "{}", + } + + def usernames_for(self, users: list[shared_types.User]) -> set[str]: + return {user["username"] for user in users} + + def repo_names_for(self, repos: list[permission_types.Repository]) -> set[str]: + return {repo["name"] for repo in repos} + + def intersection_for( + self, names: tuple[str, ...], sets_by_name: dict[str, set[str]] + ) -> set[str]: + matched = set(sets_by_name[names[0]]) + for name in names[1:]: + matched &= sets_by_name[name] + return matched + + +class QueryTests(unittest.TestCase): + def test_user_email_fields_are_opt_in(self) -> None: + self.assertNotIn("emails {", shared_queries.QUERY_USERS) + self.assertNotIn("emails {", shared_queries.query_users()) + self.assertIn("emails {", shared_queries.query_users(include_emails=True)) + + self.assertNotIn("emails {", permission_queries.QUERY_USER_BY_ID) + self.assertNotIn("emails {", permission_queries.query_user_by_id()) + self.assertIn("emails {", permission_queries.query_user_by_id(include_emails=True)) From 9b7136ae3e5aaf18442853811ca167a614bede1b Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Sat, 30 May 2026 01:51:57 -0600 Subject: [PATCH 3/4] Optimize full-set permission planning memory Amp-Thread-ID: https://ampcode.com/threads/T-019e7603-f8fe-755c-bddf-53e486dbf952 Co-authored-by: Amp --- dev/TODO.md | 7 + dev/mapping-efficiency.md | 140 ++++++++++++++++++ dev/test-plan.md | 10 +- .../permissions/full_set.py | 34 +++-- tests/unit/test_maps.py | 85 ++++++++++- 5 files changed, 260 insertions(+), 16 deletions(-) create mode 100644 dev/mapping-efficiency.md diff --git a/dev/TODO.md b/dev/TODO.md index a6ec344..afe2acd 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -64,6 +64,13 @@ If/when we revisit: 3. Add a CLI flag (e.g. `--cross-check-capture`) gated behind a clear "this doubles capture cost" warning. +## Low priority: Grouped full-set plan if memory is still too high + +Phase 1 now avoids per-repo username sets for non-overlapping full-set maps. +If memory remains too high after re-measuring, implement the Phase 2 grouped +plan in [mapping-efficiency.md](./mapping-efficiency.md): combine map-entry +overlays into final groups of repos that share the same desired username tuple. + ## Low priority: Expand group-membership filters beyond SAML `allowGroups`-style enforcement exists on more than just SAML, but only diff --git a/dev/mapping-efficiency.md b/dev/mapping-efficiency.md new file mode 100644 index 0000000..f8ccb51 --- /dev/null +++ b/dev/mapping-efficiency.md @@ -0,0 +1,140 @@ +# Mapping efficiency + +## Rectangular maps example + +Input maps + +```yaml +maps: + - name: engineers get generated repos + users: + usernames: + - alice + - bob + - carol + repos: + names: + - repo-1 + - repo-2 + - repo-3 +``` + +Current: Repo-centric plan + +repo-1 -> (alice, bob, carol) +repo-2 -> (alice, bob, carol) +repo-3 -> (alice, bob, carol) + +Grouped plan + +(alice, bob, carol) -> (repo-1, repo-2, repo-3) + +## Current semantics + +Each `maps:` entry is naturally a grouped rule: + +```text +selected users × selected repos +``` + +The full-set command must combine all entries before mutating Sourcegraph, +because `setRepositoryPermissionsForUsers` overwrites a repo's whole explicit +permission list. The required final state is: + +```text +desired_users(repo) = union(users_i for each map_i where repo is in repos_i) +``` + +Only after this union is known can the command safely apply per-repo overwrite +mutations. + +## Phase 1: lazy per-repo union sets + +The old full-set planner immediately expanded every map entry into: + +```text +repo_id -> set(username) +``` + +That is expensive for rectangular maps such as `10000 users × 1000 repos`: +the username strings are shared, but each repo owns a large Python set with one +hash-table entry per planned grant. + +Phase 1 keeps the existing downstream plan shape: + +```text +repo_id -> tuple(username) +``` + +but builds it more carefully: + +1. For a non-overlapping map entry, create one sorted username tuple and reuse + that same tuple for every matched repo. +2. If a later map entry touches a repo that already has users, promote only + that repo to a temporary set and union the usernames. +3. Convert only promoted repos back to sorted tuples after all map entries are + processed. + +This preserves the hard invariant while avoiding the large per-repo sets in +the common non-overlapping rectangular case. + +Measured on the sgdev test instance, the dry-run `10000x1000` case planned 10M +grants. Before Phase 1 it peaked at about 651 MiB RSS; after Phase 1 it peaked +at about 68 MiB RSS. + +## Phase 2: final grouped plan, if needed + +If Phase 1 is not enough, store the combined final plan as groups of repos that +share the same final user set: + +```text +tuple(username) -> tuple(repo_id) +``` + +This is not just one group per `maps:` entry. Map entries are input overlays; +final groups are the compressed result after every map entry has been unioned +onto the repo space. + +Example: + +```text +map A: alice,bob -> repo-1,repo-2 +map B: bob,chris -> repo-2,repo-3 + +final: +alice,bob -> repo-1 +alice,bob,chris -> repo-2 +bob,chris -> repo-3 +``` + +One practical data model would be: + +```python +@dataclass(frozen=True) +class RepositoryPermissionGroup: + usernames: tuple[str, ...] + repository_ids: tuple[str, ...] + + +@dataclass(frozen=True) +class FullSetPlan: + groups: tuple[RepositoryPermissionGroup, ...] + repo_names: dict[str, str] + repo_to_group_index: dict[str, int] + + def usernames_for_repo(self, repo_id: str) -> tuple[str, ...]: + return self.groups[self.repo_to_group_index[repo_id]].usernames +``` + +Apply still happens per repo: + +```text +for group in groups: + for repo_id in group.repository_ids: + setRepositoryPermissionsForUsers(repo_id, group.usernames) +``` + +Phase 2 touches more code than Phase 1: projected snapshots, diffs, +short-circuit filtering, apply iteration, and validation all currently expect +direct `repo_id -> usernames` lookups. Do it only if Phase 1 measurements still +show unacceptable memory use. diff --git a/dev/test-plan.md b/dev/test-plan.md index d71a735..8312541 100644 --- a/dev/test-plan.md +++ b/dev/test-plan.md @@ -168,11 +168,11 @@ Use one command mode per fit (`set_full` with backup, `set_full --no-backup`, per-grant coefficient. On the sgdev test instance with 10,001 users and 1,023 visible repos, a -dry-run `10000x1000` case planned 10M grants and measured about 651 MiB peak -RSS. The grants-only fit across 14 dry-run observations was roughly 69 MiB -fixed plus 61 bytes per planned grant. Re-measure after meaningful mapping or -snapshot changes; these numbers describe dry-run planning memory, not apply -mutation throughput. +dry-run `10000x1000` case planned 10M grants. Before the lazy-union planner, +it measured about 651 MiB peak RSS; after Phase 1 in +[mapping-efficiency.md](./mapping-efficiency.md), the same case measured about +68 MiB. Re-measure after meaningful mapping or snapshot changes; these numbers +describe dry-run planning memory, not apply mutation throughput. The e2e `workload` object now uses event-aware names. In older result JSON, `total_users: 40004` came from `apply_username_overwrites` and meant "username diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 6439fdc..6b79458 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -266,12 +266,13 @@ def _write_noop_full_set_snapshots( return before_path, after_path, diff_path, maps_backup_path -def _plan_full_set_permissions( +def plan_full_set_permissions( context: permission_types.MappingContext, users: list[shared_types.User], ) -> _FullSetPlan: """Resolve mapping rules into one repo-to-users overwrite plan.""" - repo_usernames: dict[str, set[str]] = {} + expected_users: dict[str, tuple[str, ...]] = {} + union_usernames_by_repo_id: dict[str, set[str]] = {} repo_names: dict[str, str] = {} for mapping_index, mapping in enumerate(context.mapping_rules, start=1): @@ -303,15 +304,28 @@ def _plan_full_set_permissions( log.warning(" No repos matched — skipping rule.") continue - matched_usernames = tuple(user["username"] for user in matched_users) + matched_usernames = tuple(sorted({user["username"] for user in matched_users})) for repo in matched_repos: - bucket = repo_usernames.setdefault(repo["id"], set()) - repo_names[repo["id"]] = repo["name"] - bucket.update(matched_usernames) + repo_id = repo["id"] + repo_names[repo_id] = repo["name"] + union_usernames = union_usernames_by_repo_id.get(repo_id) + if union_usernames is not None: + union_usernames.update(matched_usernames) + continue + + existing_usernames = expected_users.get(repo_id) + if existing_usernames is not None: + union_usernames = set(existing_usernames) + union_usernames.update(matched_usernames) + union_usernames_by_repo_id[repo_id] = union_usernames + del expected_users[repo_id] + continue + + expected_users[repo_id] = matched_usernames + + for repo_id, usernames in union_usernames_by_repo_id.items(): + expected_users[repo_id] = tuple(sorted(usernames)) - expected_users = { - repo_id: tuple(sorted(usernames)) for repo_id, usernames in repo_usernames.items() - } total_grants = sum(len(usernames) for usernames in expected_users.values()) if expected_users: log.info( @@ -699,7 +713,7 @@ def _load_full_set_plan( user_state.users, user_created_after, ) - plan = _plan_full_set_permissions(context, users) + plan = plan_full_set_permissions(context, users) snapshot_state = _compact_full_set_snapshot_state(user_state, users) saml_group_users = ( saml_groups.compact_saml_group_users( diff --git a/tests/unit/test_maps.py b/tests/unit/test_maps.py index 96303f9..3dcbe77 100644 --- a/tests/unit/test_maps.py +++ b/tests/unit/test_maps.py @@ -9,7 +9,7 @@ import yaml -from src_auth_perms_sync.permissions import mapping, maps +from src_auth_perms_sync.permissions import full_set, mapping, maps from src_auth_perms_sync.permissions import queries as permission_queries from src_auth_perms_sync.permissions import types as permission_types from src_auth_perms_sync.shared import queries as shared_queries @@ -363,6 +363,89 @@ def intersection_for( return matched +class FullSetPlanningTests(unittest.TestCase): + def test_full_set_plan_reuses_user_tuple_for_non_overlapping_repos(self) -> None: + users = [self.make_user("user-1", "bob"), self.make_user("user-2", "alice")] + repositories = [ + self.make_repo("repo-1", "github.com/example/one"), + self.make_repo("repo-2", "github.com/example/two"), + ] + context = self.make_context( + [ + { + "users": {"usernames": ["alice", "bob"]}, + "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, + } + ], + repositories, + ) + + plan = full_set.plan_full_set_permissions(context, users) + + self.assertEqual(("alice", "bob"), plan.expected_users["repo-1"]) + self.assertEqual(("alice", "bob"), plan.expected_users["repo-2"]) + self.assertIs(plan.expected_users["repo-1"], plan.expected_users["repo-2"]) + self.assertEqual(4, plan.total_grants) + + def test_full_set_plan_unions_only_overlapping_repos(self) -> None: + users = [ + self.make_user("user-1", "alice"), + self.make_user("user-2", "bob"), + self.make_user("user-3", "chris"), + ] + repositories = [ + self.make_repo("repo-1", "github.com/example/one"), + self.make_repo("repo-2", "github.com/example/two"), + self.make_repo("repo-3", "github.com/example/three"), + ] + context = self.make_context( + [ + { + "users": {"usernames": ["alice", "bob"]}, + "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, + }, + { + "users": {"usernames": ["bob", "chris"]}, + "repos": {"names": ["github.com/example/two", "github.com/example/three"]}, + }, + ], + repositories, + ) + + plan = full_set.plan_full_set_permissions(context, users) + + self.assertEqual(("alice", "bob"), plan.expected_users["repo-1"]) + self.assertEqual(("alice", "bob", "chris"), plan.expected_users["repo-2"]) + self.assertEqual(("bob", "chris"), plan.expected_users["repo-3"]) + self.assertEqual(7, plan.total_grants) + + def make_context( + self, + mapping_rules: list[permission_types.MappingRule], + repositories: list[permission_types.Repository], + ) -> permission_types.MappingContext: + return permission_types.MappingContext( + mapping_rules=mapping_rules, + providers=[], + saml_groups_attribute_names={}, + services_by_id={}, + repos_by_external_service_id={}, + all_repos_by_id={repository["id"]: repository for repository in repositories}, + ) + + def make_user(self, user_id: str, username: str) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": True, + "emails": [], + "externalAccounts": {"nodes": []}, + } + + def make_repo(self, repo_id: str, name: str) -> permission_types.Repository: + return {"id": repo_id, "name": name} + + class QueryTests(unittest.TestCase): def test_user_email_fields_are_opt_in(self) -> None: self.assertNotIn("emails {", shared_queries.QUERY_USERS) From 72eba384bd33dfa48dedc8adccdbd8e717401f89 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:35:12 -0600 Subject: [PATCH 4/4] Refactor maps schema, expose Python API, refactor subcommands --- README.md | 77 +- dev/TODO.md | 30 + dev/mapping-efficiency.md | 43 +- dev/run-memory-model-sweep.py | 352 ++++++++- ...ourcegraph-explicit-permissions-tracing.md | 55 ++ maps-example.yaml | 137 ++-- src/src_auth_perms_sync/__init__.py | 12 +- src/src_auth_perms_sync/cli.py | 257 ++++--- src/src_auth_perms_sync/permissions/apply.py | 47 ++ .../permissions/command.py | 19 +- .../permissions/full_set.py | 18 +- .../permissions/mapping.py | 673 ++++++++++-------- src/src_auth_perms_sync/permissions/maps.py | 42 +- .../permissions/restore.py | 10 +- src/src_auth_perms_sync/permissions/types.py | 25 +- .../permissions/workflow.py | 24 +- src/src_auth_perms_sync/shared/types.py | 1 + tests/integration/test_cli_entrypoint.py | 6 +- tests/unit/test_apply.py | 102 +++ tests/unit/test_cli_config.py | 240 +++++-- tests/unit/test_maps.py | 149 +++- 21 files changed, 1661 insertions(+), 658 deletions(-) create mode 100644 tests/unit/test_apply.py diff --git a/README.md b/README.md index bc31c32..f57fca0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # src-auth-perms-sync - + src-auth-perms-sync automates Sourcegraph's Explicit Permissions GraphQL API, setting user-to-repo permissions based on mapping rules, for example: @@ -92,34 +92,48 @@ Feel free to open issues or PRs, but responses are best effort. - Requires Python 3.11 - Recommended: Use a Python virtual environment -### Install from a GitHub Release - -Use this when the VM can reach GitHub and PyPI during install: +### Install from PyPI ```bash -python3.11 -m venv .venv -. .venv/bin/activate -pip install \ - "https://github.com/sourcegraph/src-auth-perms-sync/releases/download/v0.1.0/src_auth_perms_sync-0.1.0-py3-none-any.whl" -``` +pip install src-auth-perms-sync -### Restricted/offline install from a GitHub Release +# Run the CLI +src-auth-perms-sync --help +``` -Use this when the VM cannot reach package indexes during install +### Restricted / offline install from a GitHub Release -Download the .tar.gz file from the GitHub release +Download the .tar.gz file from [a GitHub release](https://github.com/sourcegraph/src-auth-perms-sync/releases) ```bash tar -xzf src-auth-perms-sync-linux-x64.tar.gz -python3.11 -m venv .venv -. .venv/bin/activate + pip install --no-index --find-links ./wheelhouse src-auth-perms-sync + +# Run the CLI +src-auth-perms-sync --help ``` -After either install method, run the CLI from the activated virtual environment: +### Import into your own Python script -```bash -src-auth-perms-sync --help +```python +from pathlib import Path + +import src_auth_perms_sync as src + +config = src.Config( + src_endpoint="https://sourcegraph.example.com", + src_access_token="sgp_...", + maps_path=Path("/absolute/path/to/maps.yaml"), + apply=False, # Dry run (default), set to True to make changes +) + +succeeded = src.Set(config) + +# Other command wrappers: +# succeeded = src.Get(config) +# succeeded = src.Restore(config) +# succeeded = src.SyncSamlOrgs(config) ``` ## Inputs @@ -136,21 +150,20 @@ src-auth-perms-sync --help - A list of filters for users - A list of filters for repos - See [maps-example.yaml](./maps-example.yaml) - - An empty maps.yaml file is created for you on the first --get run + - An empty maps.yaml file is created for you on the first `get` run ## Usage: Permission sync 1. **Get auth providers and code hosts** ```bash - uv run src-auth-perms-sync [--get] + uv run src-auth-perms-sync get ``` - Queries the Sourcegraph instance for auth providers and code host connections - Writes generated reference files `auth-providers.yaml` and `code-hosts.yaml` under `src-auth-perms-sync-runs//` - Creates an empty `maps.yaml` if it doesn't exist - - Runs by default when no command is selected 2. **Configure mapping rules** @@ -161,20 +174,20 @@ src-auth-perms-sync --help 3. **Set: Dry run** ```bash - uv run src-auth-perms-sync --set maps.yaml --full + uv run src-auth-perms-sync set --maps-path maps.yaml --full ``` 4. **Set: Apply** ```bash - uv run src-auth-perms-sync --set maps.yaml --full --apply + uv run src-auth-perms-sync set --maps-path maps.yaml --full --apply ``` 5. **Restore: Dry run** ```bash - uv run src-auth-perms-sync \ - --restore backups/maps.yaml/2026-04-27-08-24-25-set-apply/before.json + uv run src-auth-perms-sync restore \ + --restore-path backups/maps.yaml/2026-04-27-08-24-25-set-apply/before.json ``` - Roll back the explicit-permissions state on the @@ -183,8 +196,8 @@ src-auth-perms-sync --help 6. **Restore: Apply** ```bash - uv run src-auth-perms-sync \ - --restore backups/maps.yaml/2026-04-27-08-24-25-set-apply/before.json \ + uv run src-auth-perms-sync restore \ + --restore-path backups/maps.yaml/2026-04-27-08-24-25-set-apply/before.json \ --apply ``` @@ -193,7 +206,7 @@ src-auth-perms-sync --help 1. **Get user and org metadata** ```bash - uv run src-auth-perms-sync --sync-saml-orgs + uv run src-auth-perms-sync sync-saml-orgs ``` - Queries the Sourcegraph instance for auth providers, users, users' SAML groups, and orgs @@ -202,11 +215,11 @@ src-auth-perms-sync --help 2. **Apply org sync** ```bash - uv run src-auth-perms-sync --sync-saml-orgs --apply + uv run src-auth-perms-sync sync-saml-orgs --apply ``` - Creates the orgs if they don't exist, and sync the members from the SAML groups to the orgs - - `--sync-saml-orgs` can also be added to a `--set` run, to run both at the same time + - `--sync-saml-orgs` can also be added to a `get` or `set` run, to run both at the same time ## Options @@ -231,9 +244,9 @@ src-auth-perms-sync-runs/endpoint/ - The `src-auth-perms-sync-runs` dir is created under your current working directory - The `endpoint` dir is created with the hostname from `SRC_ENDPOINT` - If `maps.yaml` doesn't exist already, it'll be created for you -- `auth-providers.yaml` and `code-hosts.yaml` are created / replaced by the `--get` command, +- `auth-providers.yaml` and `code-hosts.yaml` are created / replaced by the `get` command, for you to copy values from, to use in your `maps.yaml` -- Only one `maps.yaml` file can be used at a time per Sourcegraph instance, as each `--set --apply` +- Only one `maps.yaml` file can be used at a time per Sourcegraph instance, as each `set --apply` command resets the state on the Sourcegraph instance to the `maps.yaml` file which was used - Each run of the script creates a new `timestamp-command` dir under the `runs` dir, with: - A log file @@ -243,4 +256,4 @@ src-auth-perms-sync-runs/endpoint/ - An `after.json` file, capturing the new state - A `diff.json` file, a shorter, reviewable file containing the diffs between before and after - + diff --git a/dev/TODO.md b/dev/TODO.md index afe2acd..7f92cdb 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -1,5 +1,35 @@ # TODO +## High priority: Customer feedback + +- Allow use as either CLI or importable module +- Take maps.yaml as a constructor object? + +## High priority: Instrument with OpenTelemetry — in progress + +- [ ] Add OTel-native traces, metrics, and wide log events in `src-py-lib`. +- [ ] Add shared OTel bootstrap config/helpers with `--otel` and standard + `OTEL_*` env-var-backed CLI args. +- [ ] Replace custom trace-context propagation with OTel W3C propagation. +- [ ] Instrument shared HTTP and GraphQL clients manually, preserving safe + sanitized attributes and Sourcegraph-specific metadata. +- [ ] Rename Sourcegraph debug tracing from `--trace` to `--fetch-sg-traces`. +- [ ] Wire `src-auth-perms-sync` to the shared OTel bootstrap without doing + import-time logger/provider setup. +- [ ] Verify pyright, tests, and CLI help in both repos. + +## High priority: Reduce worst-case full-permission sync load + +- Use the stress-run evidence in + [sourcegraph-explicit-permissions-tracing.md](./sourcegraph-explicit-permissions-tracing.md) + to request Sourcegraph bulk explicit-permission read and write APIs. +- Add an explicit destructive/performance-test mode to the e2e runner so giant + stress runs can skip or defer full restore cleanup when the goal is finding + the server-side breaking point. +- Revisit full snapshot capture once Sourcegraph exposes a bulk read path; + replace aliased `User.permissionsInfo.repositories(source: API)` calls before + raising concurrency further. + ## Medium priority: Lightweight incremental updates - When a new user's account is created, or a new repo is synced from a code host, diff --git a/dev/mapping-efficiency.md b/dev/mapping-efficiency.md index f8ccb51..65f8870 100644 --- a/dev/mapping-efficiency.md +++ b/dev/mapping-efficiency.md @@ -19,15 +19,45 @@ maps: - repo-3 ``` -Current: Repo-centric plan +### Original + +- Repo-centric plan, but each and every single repo gets a full copy of the list of users, + so the memory storage size is truly users x repos +- If you your list of users is 1,000 users, and 10 MB RAM, and you have 1,000 repos, + then this is 1,000,000 repo+user pairs, which is 1,000 x 10 MB RAM = 10 GB RAM +- This is a "full square" repo-1 -> (alice, bob, carol) repo-2 -> (alice, bob, carol) -repo-3 -> (alice, bob, carol) +repo-3 -> (alice, bob, dan) +repo-4 -> (alice, bob, dan) + +### Current: Groups of users + +- We anticipate that many users will be grouped up into a small number of sets, + and that most repos' perms will be one of the sets +- This example cuts in ~half the amount memory consumed by lists of users as the Current example + +user-group-a = (alice, bob, carol) +user-group-b = (alice, bob, dan) + +repo-1 -> user-group-a +repo-2 -> user-group-a +repo-3 -> user-group-b +repo-4 -> user-group-b -Grouped plan +### Phase 2: Groups of users x Groups of repos -(alice, bob, carol) -> (repo-1, repo-2, repo-3) +- Realistically, we anticipate that many repos will also be grouped up into a small number of sets + +user-group-a = (alice, bob, carol) +user-group-b = (alice, bob, dan) + +repo-group-1 = (repo-1, repo-2) +repo-group-2 = (repo-3, repo-4) + +user-group-a -> repo-group-1 +user-group-b -> repo-group-2 ## Current semantics @@ -37,6 +67,11 @@ Each `maps:` entry is naturally a grouped rule: selected users × selected repos ``` +The maps schema keeps this restrictive: `users:` and `repos:` are selector +maps, top-level selectors inside each map are ANDed together, and values inside +one selector list are ORed together. To OR across selectors, write more +top-level `maps:` entries. + The full-set command must combine all entries before mutating Sourcegraph, because `setRepositoryPermissionsForUsers` overwrites a repo's whole explicit permission list. The required final state is: diff --git a/dev/run-memory-model-sweep.py b/dev/run-memory-model-sweep.py index 816ff44..6896941 100755 --- a/dev/run-memory-model-sweep.py +++ b/dev/run-memory-model-sweep.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 """Generate and optionally run maps.yaml files for memory-model sweeps. -The generated maps use exact `users.usernames` and `repos.names` filters so -each case has a known planned grant count: `users * repos`. +The generated maps use exact `users.usernames` and `repos.names` filters. +Different workload shapes stress different parts of mapping resolution and +full-set planning, while preserving known selected user/repo/grant counts. By default this script only generates the maps. Pass `--run` to execute the CLI in dry-run mode. Pass `--mode apply-no-backup --allow-apply` only on a @@ -24,7 +25,7 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, cast +from typing import Any, Literal, TypeAlias, cast from urllib.parse import urlsplit import src_py_lib as src @@ -82,6 +83,27 @@ DEFAULT_COMMAND = "uv run src-auth-perms-sync" LOG_PATH_PATTERN = re.compile(r"Writing log events to (.+?/log\.json)\.") RunMode = Literal["dry-run", "apply-no-backup"] +SweepSuite = Literal["gentle", "breaking"] +StressShape: TypeAlias = Literal[ + "rectangle", + "user-shards", + "repo-shards", + "diagonal-shards", + "duplicate-rules", +] +STRESS_SHAPES: tuple[StressShape, ...] = ( + "rectangle", + "user-shards", + "repo-shards", + "diagonal-shards", + "duplicate-rules", +) +BREAKING_SHAPES: tuple[StressShape, ...] = ( + "rectangle", + "user-shards", + "repo-shards", + "duplicate-rules", +) class SweepSourcegraphConfig(src.SourcegraphClientConfig): @@ -90,18 +112,33 @@ class SweepSourcegraphConfig(src.SourcegraphClientConfig): @dataclass(frozen=True) class SweepCase: - """One users x repos planned-permissions case.""" + """One generated workload case.""" users: int repos: int + shape: StressShape = "rectangle" + rule_count: int = 1 @property def grants(self) -> int: - return self.users * self.repos + """Final unique planned grants after map-entry unioning.""" + return unique_grant_count(self) + + @property + def raw_rule_grants(self) -> int: + """Total per-rule rectangle grants before cross-rule unioning.""" + return raw_rule_grant_count(self) + + @property + def map_rule_count(self) -> int: + return map_rule_count(self) @property def name(self) -> str: - return f"u{self.users:05d}-r{self.repos:05d}-g{self.grants:010d}" + return ( + f"{self.shape}-m{self.map_rule_count:03d}-" + f"u{self.users:05d}-r{self.repos:05d}-g{self.grants:012d}" + ) @dataclass(frozen=True) @@ -140,8 +177,11 @@ def main() -> int: parser = build_parser() arguments = parser.parse_args() mode = cast(RunMode, arguments.mode) + suite = cast(SweepSuite, arguments.suite) if mode == "apply-no-backup" and not arguments.allow_apply: parser.error("--mode apply-no-backup requires --allow-apply") + if arguments.rule_count < 1: + parser.error("--rule-count must be >= 1") config = sourcegraph_config(arguments) output_dir = arguments.output_dir or default_output_dir(config.src_endpoint) @@ -164,10 +204,13 @@ def main() -> int: inventory_repo_count = sum(service.repo_count for service in external_services) service = choose_external_service(external_services, arguments.external_service_id) total_user_count = count_users(client) - cases = requested_cases or default_cases_for_inventory( + base_cases = requested_cases or default_cases_for_inventory( total_user_count, service.repo_count, + suite=suite, ) + shapes = parse_shapes(arguments.shapes, suite) + cases = expand_cases(base_cases, shapes, arguments.rule_count) max_users = max(sweep_case.users for sweep_case in cases) max_repos = max(sweep_case.repos for sweep_case in cases) usernames = list_usernames(client, max_users, arguments.page_size) @@ -176,7 +219,14 @@ def main() -> int: client.http.close() generated_maps = write_maps(maps_dir, cases, usernames, repo_names, service) - write_manifest(output_dir, generated_maps, service, config.src_endpoint, inventory_repo_count) + write_manifest( + output_dir, + generated_maps, + service, + config.src_endpoint, + inventory_repo_count, + total_user_count, + ) print(f"Generated {len(generated_maps)} maps.yaml file(s) under {maps_dir}") print( f"Selected code host: {service.display_name} id={service.database_id} " @@ -226,6 +276,15 @@ def build_parser() -> argparse.ArgumentParser: "Defaults under src-auth-perms-sync-runs/." ), ) + parser.add_argument( + "--suite", + choices=("gentle", "breaking"), + default="gentle", + help=( + "Case-size preset. gentle keeps the previous low-risk auto sweep; " + "breaking adds larger dimensions intended to find the failure point." + ), + ) parser.add_argument( "--cases", default=DEFAULT_CASES, @@ -234,6 +293,22 @@ def build_parser() -> argparse.ArgumentParser: "or 'auto' for a gentle inventory-aware sweep. Default: auto." ), ) + parser.add_argument( + "--shapes", + default="auto", + help=( + "Comma-separated workload shapes: " + f"{', '.join(STRESS_SHAPES)}. " + "Default auto means rectangle for --suite gentle and a mixed set " + "for --suite breaking." + ), + ) + parser.add_argument( + "--rule-count", + type=int, + default=10, + help="Target map-rule/selector count for multi-rule shapes (default: 10).", + ) parser.add_argument( "--external-service-id", type=int, @@ -338,12 +413,37 @@ def parse_cases(raw_cases: str) -> list[SweepCase] | None: return cases -def default_cases_for_inventory(user_count: int, repo_count: int) -> list[SweepCase]: - """Return a safe default sweep that covers user, repo, and grant axes.""" +def parse_shapes(raw_shapes: str, suite: SweepSuite) -> tuple[StressShape, ...]: + """Return workload shapes requested by the operator.""" + if raw_shapes.strip().lower() == "auto": + return BREAKING_SHAPES if suite == "breaking" else ("rectangle",) + + valid_shapes = set(STRESS_SHAPES) + shapes: list[StressShape] = [] + for raw_shape in raw_shapes.split(","): + shape = raw_shape.strip().lower() + if not shape: + continue + if shape not in valid_shapes: + raise SystemExit( + f"Invalid shape {raw_shape!r}; expected one of {', '.join(STRESS_SHAPES)}" + ) + shapes.append(shape) + if not shapes: + raise SystemExit("At least one --shapes entry is required") + return tuple(dict.fromkeys(shapes)) + + +def default_cases_for_inventory( + user_count: int, repo_count: int, *, suite: SweepSuite +) -> list[SweepCase]: + """Return an inventory-aware sweep that covers user, repo, and grant axes.""" if user_count < 1: raise SystemExit("Need at least one Sourcegraph user for an auto sweep") if repo_count < 1: raise SystemExit("Need at least one Sourcegraph repo for an auto sweep") + if suite == "breaking": + return breaking_cases_for_inventory(user_count, repo_count) user_points = bounded_points(user_count, DEFAULT_USER_POINTS) repo_points = bounded_points(repo_count, DEFAULT_REPO_POINTS) @@ -362,6 +462,48 @@ def default_cases_for_inventory(user_count: int, repo_count: int) -> list[SweepC return unique_cases(cases) +def breaking_cases_for_inventory(user_count: int, repo_count: int) -> list[SweepCase]: + """Return larger cases ordered from likely-safe to likely-breaking.""" + capped_users = min(user_count, 10000) + capped_repos = min(repo_count, 50000) + candidate_dimensions = ( + (1, capped_repos), + (100, capped_repos), + (1000, min(capped_repos, 1000)), + (capped_users, 1), + (capped_users, 100), + (capped_users, 1000), + (capped_users, 5000), + (capped_users, 10000), + (capped_users, 25000), + (capped_users, capped_repos), + ) + cases = [ + SweepCase(users=users, repos=repos) + for users, repos in candidate_dimensions + if users <= user_count and repos <= repo_count + ] + return unique_cases(cases) + + +def expand_cases( + base_cases: Sequence[SweepCase], shapes: Sequence[StressShape], rule_count: int +) -> list[SweepCase]: + """Expand dimensions into the requested workload shapes.""" + cases: list[SweepCase] = [] + for base_case in base_cases: + for shape in shapes: + cases.append( + SweepCase( + users=base_case.users, + repos=base_case.repos, + shape=shape, + rule_count=rule_count if shape != "rectangle" else 1, + ) + ) + return unique_cases(cases) + + def bounded_points(available_count: int, candidate_points: Sequence[int]) -> list[int]: """Return candidate points that fit, plus the exact inventory cap if useful.""" points = [point for point in candidate_points if point <= available_count] @@ -372,10 +514,10 @@ def bounded_points(available_count: int, candidate_points: Sequence[int]) -> lis def unique_cases(cases: Sequence[SweepCase]) -> list[SweepCase]: """Preserve case order while removing duplicates.""" - seen: set[tuple[int, int]] = set() + seen: set[tuple[int, int, StressShape, int]] = set() unique: list[SweepCase] = [] for sweep_case in cases: - key = (sweep_case.users, sweep_case.repos) + key = (sweep_case.users, sweep_case.repos, sweep_case.shape, sweep_case.rule_count) if key in seen: continue seen.add(key) @@ -472,6 +614,71 @@ def list_repo_names( return repo_names +def map_rule_count(sweep_case: SweepCase) -> int: + """Return the actual map-rule count for this case.""" + if sweep_case.shape == "rectangle": + return 1 + if sweep_case.shape == "duplicate-rules": + return sweep_case.rule_count + if sweep_case.shape == "user-shards": + return min(sweep_case.users, sweep_case.rule_count) + if sweep_case.shape == "repo-shards": + return min(sweep_case.repos, sweep_case.rule_count) + return min(sweep_case.users, sweep_case.repos, sweep_case.rule_count) + + +def user_selector_count(sweep_case: SweepCase) -> int: + """Return the number of user selectors emitted across all map rules.""" + return map_rule_count(sweep_case) + + +def repository_selector_count(sweep_case: SweepCase) -> int: + """Return the number of repository selectors emitted across all map rules.""" + return map_rule_count(sweep_case) + + +def unique_grant_count(sweep_case: SweepCase) -> int: + """Return final unique grants after unioning all map entries.""" + if sweep_case.shape == "diagonal-shards": + return sum( + user_count * repo_count + for user_count, repo_count in zip( + chunk_lengths(sweep_case.users, map_rule_count(sweep_case)), + chunk_lengths(sweep_case.repos, map_rule_count(sweep_case)), + strict=True, + ) + ) + return sweep_case.users * sweep_case.repos + + +def raw_rule_grant_count(sweep_case: SweepCase) -> int: + """Return total per-rule grants before cross-rule unioning.""" + if sweep_case.shape == "duplicate-rules": + return sweep_case.users * sweep_case.repos * sweep_case.rule_count + return unique_grant_count(sweep_case) + + +def chunk_lengths(total: int, chunk_count: int) -> list[int]: + """Return near-even chunk lengths for `total` items.""" + if chunk_count < 1: + raise ValueError("chunk_count must be >= 1") + base, extra = divmod(total, chunk_count) + return [base + (1 if index < extra else 0) for index in range(chunk_count)] + + +def chunked_values(values: Sequence[str], chunk_count: int) -> list[list[str]]: + """Split values into near-even non-empty chunks.""" + lengths = chunk_lengths(len(values), chunk_count) + chunks: list[list[str]] = [] + offset = 0 + for length in lengths: + if length < 1: + continue + chunks.append(list(values[offset : offset + length])) + offset += length + return chunks + + def write_maps( maps_dir: Path, cases: Sequence[SweepCase], @@ -482,21 +689,14 @@ def write_maps( generated: list[GeneratedMap] = [] for sweep_case in cases: map_path = maps_dir / f"maps-{sweep_case.name}.yaml" + rules = map_rules_for_case( + sweep_case, + usernames[: sweep_case.users], + repo_names[: sweep_case.repos], + service, + ) payload = { - "maps": [ - { - "name": ( - "memory model " - f"users={sweep_case.users} repos={sweep_case.repos} " - f"grants={sweep_case.grants}" - ), - "users": {"usernames": list(usernames[: sweep_case.users])}, - "repos": { - "codeHostConnection": {"id": service.database_id}, - "names": list(repo_names[: sweep_case.repos]), - }, - } - ] + "maps": rules, } with map_path.open("w", encoding="utf-8") as output_file: output_file.write( @@ -504,28 +704,115 @@ def write_maps( ) output_file.write( f"# users={sweep_case.users} repos={sweep_case.repos} " - f"planned_grants={sweep_case.grants}\n" + f"planned_grants={sweep_case.grants} " + f"raw_rule_grants={sweep_case.raw_rule_grants} " + f"shape={sweep_case.shape} map_rules={sweep_case.map_rule_count}\n" ) yaml.safe_dump(payload, output_file, sort_keys=False, allow_unicode=True) generated.append(GeneratedMap(case=sweep_case, path=map_path)) return generated +def map_rules_for_case( + sweep_case: SweepCase, + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> list[dict[str, object]]: + """Build map rules for one workload shape.""" + if sweep_case.shape == "rectangle": + return [map_rule(sweep_case, 1, usernames, repo_names, service)] + if sweep_case.shape == "user-shards": + return [ + map_rule(sweep_case, index, user_chunk, repo_names, service) + for index, user_chunk in enumerate( + chunked_values(usernames, sweep_case.map_rule_count), start=1 + ) + ] + if sweep_case.shape == "repo-shards": + return [ + map_rule(sweep_case, index, usernames, repo_chunk, service) + for index, repo_chunk in enumerate( + chunked_values(repo_names, sweep_case.map_rule_count), start=1 + ) + ] + if sweep_case.shape == "diagonal-shards": + return [ + map_rule(sweep_case, index, user_chunk, repo_chunk, service) + for index, (user_chunk, repo_chunk) in enumerate( + zip( + chunked_values(usernames, sweep_case.map_rule_count), + chunked_values(repo_names, sweep_case.map_rule_count), + strict=True, + ), + start=1, + ) + ] + if sweep_case.shape == "duplicate-rules": + return [ + map_rule(sweep_case, index, usernames, repo_names, service) + for index in range(1, sweep_case.map_rule_count + 1) + ] + raise AssertionError(f"Unhandled shape {sweep_case.shape!r}") + + +def map_rule( + sweep_case: SweepCase, + index: int, + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> dict[str, object]: + """Build one rectangular map rule.""" + return { + "name": f"memory model {sweep_case.shape} rule {index}/{sweep_case.map_rule_count}", + "users": username_selector(usernames), + "repos": repository_selector(repo_names, service), + } + + +def username_selector(usernames: Sequence[str]) -> dict[str, object]: + return {"usernames": list(usernames)} + + +def repository_selector( + repo_names: Sequence[str], service: ExternalServiceChoice +) -> dict[str, object]: + return { + "codeHostConnection": { + "kind": service.kind, + "displayName": service.display_name, + "url": service.url, + }, + "names": list(repo_names), + } + + def write_manifest( output_dir: Path, generated_maps: Sequence[GeneratedMap], service: ExternalServiceChoice, endpoint: str, inventory_repo_count: int, + sourcegraph_user_count: int, ) -> None: manifest = { "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), "endpoint": endpoint, "external_service": service_to_json(service), + "sourcegraph_user_count": sourcegraph_user_count, "sourcegraph_inventory_repo_count": inventory_repo_count, "maps": [ { "case": generated_map.case.name, + "shape": generated_map.case.shape, + "map_rule_count": generated_map.case.map_rule_count, + "user_selector_count": user_selector_count(generated_map.case), + "repository_selector_count": repository_selector_count(generated_map.case), + "selected_user_count": generated_map.case.users, + "selected_repo_count": generated_map.case.repos, + "selected_total_grants": generated_map.case.grants, + "raw_rule_grant_count": generated_map.case.raw_rule_grants, "users": generated_map.case.users, "repos": generated_map.case.repos, "grants": generated_map.case.grants, @@ -704,6 +991,7 @@ def result_to_json( "variant": "candidate", "iteration": 1, "case": case.name, + "shape": case.shape, "arguments": ["--set", str(result.generated_map.path), "--full"], "return_code": result.return_code, "elapsed_seconds": round(result.elapsed_seconds, 3), @@ -735,6 +1023,10 @@ def workload_json( "selected_user_count": sweep_case.users, "selected_repo_count": sweep_case.repos, "selected_total_grants": sweep_case.grants, + "raw_rule_grant_count": sweep_case.raw_rule_grants, + "map_rule_count": sweep_case.map_rule_count, + "user_selector_count": user_selector_count(sweep_case), + "repository_selector_count": repository_selector_count(sweep_case), "memory_model_user_count": sweep_case.users, "memory_model_repo_count": sweep_case.repos, "memory_model_grant_count": sweep_case.grants, @@ -751,6 +1043,9 @@ def write_results_csv( ) -> None: fieldnames = [ "case", + "shape", + "map_rule_count", + "raw_rule_grants", "users", "repos", "grants", @@ -771,6 +1066,9 @@ def write_results_csv( writer.writerow( { "case": case.name, + "shape": case.shape, + "map_rule_count": case.map_rule_count, + "raw_rule_grants": case.raw_rule_grants, "users": case.users, "repos": case.repos, "grants": case.grants, diff --git a/dev/sourcegraph-explicit-permissions-tracing.md b/dev/sourcegraph-explicit-permissions-tracing.md index c61a399..c9ed120 100644 --- a/dev/sourcegraph-explicit-permissions-tracing.md +++ b/dev/sourcegraph-explicit-permissions-tracing.md @@ -220,6 +220,49 @@ Fetch Jaeger traces immediately for long runs. In that same full matrix, older trace IDs were no longer available by the time the run finished. Focused reruns with immediate fetches gave stable Jaeger data. +## Stress-run findings + +The deliberately hard stress map used about 10,001 users and 1,000 repos, +planning roughly 10.23 million explicit grants. The sgdev instance exposes +1,000 `test-repo-*` repos today, not the intended 1 million, so the current +stress profile hammers a 10k-user × 1k-repo permission matrix. + +The stress run showed a clear server-side cliff: + +| Case | Elapsed | Notes | +| --- | ---: | --- | +| `set-full-dry-run` | 214s | planning/snapshot only, peak CLI RSS about 669 MiB | +| `set-full-no-backup-apply` | 4,067s | ~1,023 overwrites; hit `repo not found` | +| `restore-full-no-backup-cleanup` | 5,679s | re-snapshot and huge overwrite cleanup | + +Postgres was busy but not obviously resource-starved: + +| Component | Max CPU | Max memory | Notes | +| --- | ---: | ---: | --- | +| `pgsql-0/pgsql` | ~1235m | ~3231 MiB | below the 4G memory limit; no observed OOM/throttling | +| `sourcegraph-frontend` | ~954m | ~561 MiB | frontend was not memory-bound | +| `src-serve-git/src-cli` | ~1610m | ~388 MiB | incidental load during the same window | + +`pg_stat_statements` made the dominant work visible. The stress run spent most +of its database time in explicit-permissions read and write helpers: + +| Sourcegraph operation | Calls | Total time | Mean time | +| --- | ---: | ---: | ---: | +| `permsStore.ListUserPermissions` | 19,974 | 30,862.6s | 1,545ms | +| `permsStore.upsertUserRepoPermissions-range1` | 472 | 1,178.8s | 2,497ms | + +Compared with focused traces at normal scale, `ListUserPermissions` became much +slower under the large explicit-perms state. This reinforces that the bottleneck +is Sourcegraph server/database work, not local Python CPU. The CLI can avoid +some redundant work, but Sourcegraph still needs a bulk read path and probably a +more efficient bulk overwrite path for very large explicit permission sets. + +One live-instance behavior is now expected: if Sourcegraph returns a GraphQL +application error showing that a repo/user disappeared between planning and the +mutation, `src-auth-perms-sync` logs a skipped mutation and continues. The next +scheduled run will re-plan against the then-current users/repos. Other GraphQL +application errors still fail normally. + For current `src-auth-perms-sync`, `UserExplicitReposBatch` requests only repo IDs from `User.permissionsInfo.repositories(source: API)`. A focused traced batch for one user with 19 explicit repos showed per-user fanout: @@ -319,6 +362,15 @@ Important requirements: Expected benefit: replace hundreds or thousands of per-repo resolver SQL spans per request with one indexed `user_repo_permissions` join per user batch. +The stress profile also needs attention on the write path. A 10k-user × +1k-repo full set planned about 10.23 million grants. The apply path then spent +about 67.8 minutes before a live/deleted-repo race interrupted the run, with +`pg_stat_statements` showing `permsStore.upsertUserRepoPermissions-range1` at +472 calls, 1,178.8 seconds total, and 2.5 seconds mean. A purpose-built bulk +overwrite API that accepts many repo/user edges at once, streams or stages the +input server-side, and avoids repeated per-repo permission reconciliation would +make worst-case full syncs much safer. + ## Copy/paste request Title: Add a bulk GraphQL read path for explicit repository permissions @@ -352,3 +404,6 @@ Acceptance criteria: latency visible. - `src-auth-perms-sync` can replace its aliased `User.permissionsInfo.repositories(source: API)` calls with this API. +- Follow-up: evaluate a bulk overwrite API for large full-set applies. The + stress run planned about 10.23 million grants and observed + `permsStore.upsertUserRepoPermissions-range1` averaging about 2.5s per call. diff --git a/maps-example.yaml b/maps-example.yaml index e5dfb7e..49afbd6 100644 --- a/maps-example.yaml +++ b/maps-example.yaml @@ -1,38 +1,48 @@ -# Auth provider → code host connection mapping rules -# Maintain this file using auth-providers.yaml and code-hosts.yaml as references. -# Those files are generated under src-auth-perms-sync-runs//. -# -# These examples cover every supported filter field: -# - users.authProvider: clientID, configID, displayName, samlGroup, serviceID, type -# - users.emails (verified email addresses) -# - users.usernames -# - repos.codeHostConnection: config, displayName, id, kind, url -# - repos.names -# - repos.regexes +# User → Repo permission mapping rules -maps: +# Maintain your maps.yaml file, using the values from auth-providers.yaml and code-hosts.yaml, +# which are created by the --get command, under `src-auth-perms-sync-runs//` -- name: SAML group users get all repos synced from one service account - users: - authProvider: - configID: okta - samlGroup: LOB1-GROUP1 - type: saml - repos: - codeHostConnection: - config: - username: LOB1-SA1 +# Schema details: +# maps: list[map] +# - name: string +# users: map +# authProvider: map +# type: string +# serviceID: string +# clientID: string +# displayName: string +# configID: string +# samlGroup: string +# emails: list[string] # exact verified email addresses +# emailRegexes: list[string] # Python regexes for verified email addresses +# usernames: list[string] # exact Sourcegraph usernames +# usernameRegexes: list[string] # Python regexes for Sourcegraph usernames +# repos: map +# codeHostConnection: map +# displayName: string +# kind: string +# url: string +# username: string +# names: list[string] # exact Sourcegraph repo names +# nameRegexes: list[string] # Python regexes for Sourcegraph repo names -- name: Users from one exact auth provider get repos from one exact code host connection +# Filter scopes: +# - Children of lists are ORed together (casting a wider net) +# - Children of maps are ANDed together (casting a narrower net) + +maps: + +# Widest net +- name: All users get all repos users: - authProvider: - clientID: sourcegraph - displayName: Okta SAML - serviceID: https://idp.example.com/saml + usernameRegexes: + - '.*' repos: - codeHostConnection: - id: 12 + nameRegexes: + - '.*' +# Wide net - name: All Okta SAML users get access to all Bitbucket repos users: authProvider: @@ -42,36 +52,79 @@ maps: codeHostConnection: kind: BITBUCKETSERVER -- name: All builtin users get repos from the GitHub Cloud connection +# Medium net +- name: | + Members of samlGroup LOB1-GROUP1, from any auth provider + get any repos cloned using username LOB1-SA1, from any code host users: authProvider: - type: builtin + samlGroup: LOB1-GROUP1 repos: codeHostConnection: - displayName: GitHub Cloud + username: LOB1-SA1 -- name: All builtin users get repos from the GitHub URL connection +# Narrower net +- name: | + Members of samlGroup LOB1-GROUP1 from the okta saml provider + get repos cloned from a specific Bitbucket code host connection users: authProvider: - type: builtin + configID: okta + samlGroup: LOB1-GROUP1 + type: saml repos: codeHostConnection: - url: https://github.com/ + displayName: 'BITBUCKETSERVER #1' + kind: BITBUCKETSERVER + url: https://bitbucket.example.com/ + username: LOB1-SA1 -- name: Exact user gets named repos +# Even narrower net +- name: | + Alice and Bob get access to bitbucket.example.com/example/private-repo, + if they are members of LOB1-GROUP1 from okta saml users: + authProvider: + configID: okta + samlGroup: LOB1-GROUP1 + type: saml emails: - alice@example.com - bob@example.com repos: + codeHostConnection: + displayName: Bitbucket + kind: BITBUCKETSERVER + url: https://bitbucket.example.com/ + username: LOB1-SA1 names: - - github.com/example/private-repo + - bitbucket.example.com/example/private-repo -- name: All builtin users get access to all repos under the github.com/example org, from any code host connection +# Narrowest net +- name: Alice gets private-repo repo, if all stars align users: authProvider: - type: builtin + clientID: https://sourcegraph.example.com/.auth/saml/metadata + configID: okta + displayName: Okta + samlGroup: LOB1-GROUP1 + serviceID: http://www.okta.com/example123 + type: saml + emails: + - alice@example.com + emailRegexes: + - '@example\.com$' + usernames: + - alice + usernameRegexes: + - '^alice$' repos: - regexes: - - ^github\.com/example/.* - - ^gitlab\.com/example/.* + codeHostConnection: + displayName: 'BITBUCKETSERVER #1' + kind: BITBUCKETSERVER + url: https://bitbucket.example.com/ + username: LOB1-SA1 + names: + - bitbucket.example.com/example/private-repo + nameRegexes: + - '^bitbucket\.example\.com/example/private-repo$' diff --git a/src/src_auth_perms_sync/__init__.py b/src/src_auth_perms_sync/__init__.py index cfcdeaf..8b16d8e 100644 --- a/src/src_auth_perms_sync/__init__.py +++ b/src/src_auth_perms_sync/__init__.py @@ -1 +1,11 @@ -"""Project package for src-auth-perms-sync.""" +"""Importable API for src-auth-perms-sync.""" + +from .cli import Config, Get, Restore, Set, SyncSamlOrgs + +__all__ = [ + "Config", + "Get", + "Restore", + "Set", + "SyncSamlOrgs", +] diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 7e25143..a01b2c9 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -9,15 +9,18 @@ from __future__ import annotations +import argparse import logging import os import sys +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path from typing import Literal, NoReturn, TypeAlias import src_py_lib as src +from src_py_lib.utils import config as config_utils from .orgs import command as organizations_command from .permissions import command as permissions_command @@ -28,6 +31,13 @@ log = logging.getLogger(__name__) CommandName: TypeAlias = Literal["get", "set", "restore", "sync_saml_orgs"] +CLI_COMMAND_NAMES: tuple[str, ...] = ("get", "set", "restore", "sync-saml-orgs") +CLI_COMMAND_NAME_BY_ARGUMENT: dict[str, CommandName] = { + "get": "get", + "set": "set", + "restore": "restore", + "sync-saml-orgs": "sync_saml_orgs", +} LogCommandName: TypeAlias = Literal[ "get", "set_full", @@ -77,33 +87,42 @@ class ResolvedCommand: @property def set_mode(self) -> permission_types.SetCommandMode | None: - """Return the concrete `--set` mode when this is a set command.""" + """Return the concrete set mode when this is a set command.""" if self.set_options is None: return None return self.set_options.mode -class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfig): +@dataclass(frozen=True) +class CliInput: + """Parsed CLI command and runtime config.""" + + command_name: CommandName + config: Config + + +class Config(src.SourcegraphClientConfig, src.LoggingConfig): """Config values loaded from defaults, .env, environment, and CLI flags.""" - get: bool = src.config_field( - default=False, - env_var="SRC_AUTH_PERMS_SYNC_GET", - cli_flag="--get", - cli_action="store_true", - help="Query the SG instance and write/refresh auth-providers.yaml and code-hosts.yaml", + maps_path: Path = src.config_field( + default=Path("maps.yaml"), + env_var="SRC_AUTH_PERMS_SYNC_MAPS_PATH", + cli_flag="--maps-path", + metavar="FILE", + help=( + "Maps YAML file for the set command.\n" + "Defaults to maps.yaml under src-auth-perms-sync-runs//.\n" + "Relative / short paths are resolved from that directory." + ), ) - set_path: Path | None = src.config_field( + restore_path: Path | None = src.config_field( default=None, - env_var="SRC_AUTH_PERMS_SYNC_SET", - cli_flag="--set", - cli_nargs="?", - cli_const="maps.yaml", + env_var="SRC_AUTH_PERMS_SYNC_RESTORE_PATH", + cli_flag="--restore-path", metavar="FILE", help=( - "Read the YAML config file and execute the mapping rules.\n" - "Defaults to maps.yaml under src-auth-perms-sync-runs//.\n" - "Relative paths are resolved from that path." + "Snapshot JSON file for the restore command.\n" + "Relative paths are resolved under 'src-auth-perms-sync-runs//.'" ), ) full: bool = src.config_field( @@ -111,7 +130,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi env_var="SRC_AUTH_PERMS_SYNC_FULL", cli_flag="--full", cli_action="store_true", - help="With --set: run the full overwrite reconciliation mode (default)", + help="With the set command: run the full overwrite reconciliation mode (default)", ) user: str | None = src.config_field( default=None, @@ -135,16 +154,6 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi pattern=r"^\d{4}-\d{2}-\d{2}$", help="Process Sourcegraph users created on or after this date", ) - restore_path: Path | None = src.config_field( - default=None, - env_var="SRC_AUTH_PERMS_SYNC_RESTORE", - cli_flag="--restore", - metavar="FILE", - help=( - "Restore explicit-permissions state to match the given snapshot JSON file.\n" - "Relative paths are resolved under 'src-auth-perms-sync-runs//.'" - ), - ) sync_saml_organizations: bool = src.config_field( default=False, env_var="SRC_AUTH_PERMS_SYNC_SYNC_SAML_ORGS", @@ -223,55 +232,53 @@ def config_error(message: str) -> NoReturn: raise SystemExit(2) -def validate_config(config: SrcAuthPermissionsSyncConfig) -> None: +def validate_config(command_name: CommandName, config: Config) -> None: """Validate cross-field CLI/config constraints.""" - validate_command_selection(config) - validate_user_filter_selection(config) - validate_set_mode_selection(config) + validate_command_options(command_name, config) + validate_user_filter_selection(command_name, config) + validate_set_mode_selection(command_name, config) -def validate_command_selection(config: SrcAuthPermissionsSyncConfig) -> None: - """Validate compatible top-level command flags.""" - if sum((config.get, config.set_path is not None, config.restore_path is not None)) > 1: - config_error("choose only one of --get, --set, or --restore") - if config.restore_path is not None and config.sync_saml_organizations: - config_error("--sync-saml-orgs can run by itself or with --get or --set") +def validate_command_options(command_name: CommandName, config: Config) -> None: + """Validate options that only make sense with specific commands.""" + if config.sync_saml_organizations and command_name not in {"get", "set"}: + config_error("--sync-saml-orgs can only be combined with get or set") + if command_name == "restore" and config.restore_path is None: + config_error("restore requires --restore-path") + if config.restore_path is not None and command_name != "restore": + config_error("--restore-path requires the restore command") -def validate_user_filter_selection(config: SrcAuthPermissionsSyncConfig) -> None: +def validate_user_filter_selection(command_name: CommandName, config: Config) -> None: """Validate user-scope filters and their compatible commands.""" user_identifier_filters = sum((config.user is not None, config.users_without_explicit_perms)) if user_identifier_filters > 1: config_error("choose only one of --user or --users-without-explicit-perms") user_filter_selected = user_identifier_filters > 0 or config.created_after is not None - user_filter_allowed = ( - config.get - or config.set_path is not None - or (config.restore_path is None and not config.sync_saml_organizations) - ) + user_filter_allowed = command_name in {"get", "set"} if user_filter_selected and not user_filter_allowed: config_error( - "--user, --users-without-explicit-perms, and --created-after require --get or --set" + "--user, --users-without-explicit-perms, and --created-after require get or set" ) -def validate_set_mode_selection(config: SrcAuthPermissionsSyncConfig) -> None: - """Validate `--set` mode flags.""" - if config.full and config.set_path is None: - config_error("--full requires --set") +def validate_set_mode_selection(command_name: CommandName, config: Config) -> None: + """Validate set command mode flags.""" + if config.full and command_name != "set": + config_error("--full requires the set command") - if config.set_path is None: + if command_name != "set": return if sum((config.full, config.user is not None, config.users_without_explicit_perms)) > 1: config_error( - "with --set, choose at most one of --full, --user, or --users-without-explicit-perms" + "with set, choose at most one of --full, --user, or --users-without-explicit-perms" ) -def set_command_options(config: SrcAuthPermissionsSyncConfig) -> permission_types.SetCommandOptions: - """Return the validated `--set` mode options.""" +def set_command_options(config: Config) -> permission_types.SetCommandOptions: + """Return the validated set mode options.""" if config.user is not None: return permission_types.SetCommandOptions( mode="user", @@ -289,38 +296,36 @@ def set_command_options(config: SrcAuthPermissionsSyncConfig) -> permission_type ) -def resolve_command(config: SrcAuthPermissionsSyncConfig) -> ResolvedCommand: +def resolve_command(command_name: CommandName, config: Config) -> ResolvedCommand: """Return the command execution plan derived from config.""" run_mode = "apply" if config.apply else "dry-run" - if config.set_path is not None: + if command_name == "set": return resolve_set_command(config, run_mode) - if config.restore_path is not None: + if command_name == "restore": return ResolvedCommand( name="restore", log_name="restore", artifact_name=f"restore-{run_mode}", ) - if config.get and config.sync_saml_organizations: + if command_name == "get" and config.sync_saml_organizations: return ResolvedCommand( name="get", log_name="get_sync_saml_orgs", artifact_name=f"get-sync-saml-orgs-{run_mode}", sync_saml_organizations=True, ) - if config.get: + if command_name == "get": return ResolvedCommand(name="get", log_name="get", artifact_name="get") - if config.sync_saml_organizations: - return ResolvedCommand( - name="sync_saml_orgs", - log_name="sync_saml_orgs", - artifact_name=f"sync-saml-orgs-{run_mode}", - sync_saml_organizations=True, - ) - return ResolvedCommand(name="get", log_name="get", artifact_name="get") + return ResolvedCommand( + name="sync_saml_orgs", + log_name="sync_saml_orgs", + artifact_name=f"sync-saml-orgs-{run_mode}", + sync_saml_organizations=True, + ) -def resolve_set_command(config: SrcAuthPermissionsSyncConfig, run_mode: str) -> ResolvedCommand: - """Return resolved metadata for the selected `--set` command mode.""" +def resolve_set_command(config: Config, run_mode: str) -> ResolvedCommand: + """Return resolved metadata for the selected set command mode.""" set_options = set_command_options(config) log_names = ( SYNC_SET_COMMAND_LOG_NAMES if config.sync_saml_organizations else SET_COMMAND_LOG_NAMES @@ -339,24 +344,35 @@ def resolve_set_command(config: SrcAuthPermissionsSyncConfig, run_mode: str) -> ) -def load_config() -> SrcAuthPermissionsSyncConfig: - """Parse and validate CLI/environment config.""" - config = src.parse_args( - SrcAuthPermissionsSyncConfig, - description=__doc__, - base_dir=Path("."), +def load_cli(argv: Sequence[str] | None = None) -> CliInput: + """Parse and validate the CLI command plus environment/config options.""" + parser = argparse.ArgumentParser( + description=__doc__.strip() if __doc__ is not None else None, + formatter_class=argparse.RawDescriptionHelpFormatter, + usage="%(prog)s {get,set,restore,sync-saml-orgs} [options]", ) - validate_config(config) - return config + parser.add_argument("command", choices=CLI_COMMAND_NAMES, help="Command to run") + config_utils.add_config_arguments(parser, Config) + arguments = parser.parse_args(argv) + try: + config = config_utils.load_config_from_args( + Config, + arguments, + base_dir=Path("."), + resolve_op_refs=True, + ) + except src.ConfigError as exception: + parser.error(str(exception)) + command_name = CLI_COMMAND_NAME_BY_ARGUMENT[arguments.command] + validate_config(command_name, config) + return CliInput(command_name=command_name, config=config) -def endpoint_scoped_config( - config: SrcAuthPermissionsSyncConfig, endpoint: str -) -> SrcAuthPermissionsSyncConfig: +def endpoint_scoped_config(command_name: CommandName, config: Config, endpoint: str) -> Config: """Return config with relative operator artifact paths scoped to this endpoint.""" updates: dict[str, object] = {} - if config.set_path is not None: - updates["set_path"] = backups.endpoint_artifact_path(endpoint, config.set_path) + if command_name == "set": + updates["maps_path"] = backups.endpoint_artifact_path(endpoint, config.maps_path) if config.restore_path is not None: updates["restore_path"] = backups.endpoint_artifact_path(endpoint, config.restore_path) if not updates: @@ -364,25 +380,21 @@ def endpoint_scoped_config( return config.model_copy(update=updates) -def require_set_input_file(config: SrcAuthPermissionsSyncConfig) -> None: +def require_set_input_file(config: Config) -> None: """Exit with a clear error if the selected maps file is missing.""" - if config.set_path is None: - return - if config.set_path.is_file(): + if config.maps_path.is_file(): return - if config.set_path.exists(): - raise SystemExit(f"--set input path is not a file: {config.set_path}") + if config.maps_path.exists(): + raise SystemExit(f"set input path is not a file: {config.maps_path}") raise SystemExit( - "--set input file does not exist: " - f"{config.set_path}\n" - "Run `uv run src-auth-perms-sync --get` to create the default maps.yaml, " + "set input file does not exist: " + f"{config.maps_path}\n" + "Run `uv run src-auth-perms-sync get` to create the default maps.yaml, " "or pass a path to an existing maps file." ) -def run_fields( - config: SrcAuthPermissionsSyncConfig, command: ResolvedCommand, endpoint: str -) -> dict[str, object]: +def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[str, object]: """Return run-level fields for structured logging.""" return { "cli_cmd": command.log_name, @@ -406,7 +418,7 @@ def run_fields( def run_with_client( - config: SrcAuthPermissionsSyncConfig, + config: Config, command: ResolvedCommand, endpoint: str, worker_pool: ThreadPoolExecutor, @@ -431,7 +443,7 @@ def run_with_client( def run_command( - config: SrcAuthPermissionsSyncConfig, + config: Config, command: ResolvedCommand, client: src.SourcegraphClient, worker_pool: ThreadPoolExecutor, @@ -466,19 +478,18 @@ def run_command( def run_set( - config: SrcAuthPermissionsSyncConfig, + config: Config, command: ResolvedCommand, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, worker_pool: ThreadPoolExecutor, ) -> run_context.CommandData: """Run the selected repo-permission sync command.""" - assert config.set_path is not None assert command.set_options is not None require_set_input_file(config) return permissions_command.cmd_set( client, - config.set_path, + config.maps_path, command.set_options, dry_run=not config.apply, parallelism=config.parallelism, @@ -494,7 +505,7 @@ def run_set( def run_restore( - config: SrcAuthPermissionsSyncConfig, + config: Config, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, worker_pool: ThreadPoolExecutor, @@ -514,7 +525,7 @@ def run_restore( def run_sync_saml_organizations( - config: SrcAuthPermissionsSyncConfig, + config: Config, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, command_data: run_context.CommandData, @@ -535,7 +546,7 @@ def run_sync_saml_organizations( def run_get( - config: SrcAuthPermissionsSyncConfig, + config: Config, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, worker_pool: ThreadPoolExecutor, @@ -577,14 +588,47 @@ def reraise_system_exit_with_logged_error(exception: SystemExit) -> NoReturn: raise exception -def main() -> None: - config = load_config() - command = resolve_command(config) +def Get(config: Config) -> bool: + """Run repository permission discovery and return whether it succeeded.""" + return _run("get", config) + + +def Set(config: Config) -> bool: + """Run repository permission reconciliation and return whether it succeeded.""" + return _run("set", config) + + +def Restore(config: Config) -> bool: + """Run repository permission restore and return whether it succeeded.""" + return _run("restore", config) + + +def SyncSamlOrgs(config: Config) -> bool: + """Run SAML organization sync and return whether it succeeded.""" + return _run("sync_saml_orgs", config) + + +def _run(command_name: CommandName, config: Config) -> bool: + """Run a command and return whether it completed successfully.""" + try: + _run_or_raise(command_name, config) + except SystemExit as exception: + return exception.code in (None, 0) + except Exception: + log.exception("src-auth-perms-sync run failed.") + return False + return True + + +def _run_or_raise(command_name: CommandName, config: Config) -> None: + """Run src-auth-perms-sync, preserving CLI-style exceptions.""" + validate_config(command_name, config) + command = resolve_command(command_name, config) try: endpoint = src.normalize_sourcegraph_endpoint(config.src_endpoint) except ValueError as error: config_error(str(error)) - config = endpoint_scoped_config(config, endpoint) + config = endpoint_scoped_config(command_name, config, endpoint) run_timestamp = backups.backup_timestamp() run_directory = backups.artifact_run_directory( run_timestamp, @@ -614,3 +658,8 @@ def main() -> None: run_with_client(config, command, endpoint, worker_pool) except SystemExit as exception: reraise_system_exit_with_logged_error(exception) + + +def main() -> None: + cli_input = load_cli() + _run_or_raise(cli_input.command_name, cli_input.config) diff --git a/src/src_auth_perms_sync/permissions/apply.py b/src/src_auth_perms_sync/permissions/apply.py index 14969ba..302cef6 100644 --- a/src/src_auth_perms_sync/permissions/apply.py +++ b/src/src_auth_perms_sync/permissions/apply.py @@ -26,6 +26,12 @@ log = logging.getLogger(__name__) +MISSING_MUTATION_RESOURCE_TERMS = ( + "repo", + "repository", + "user", +) + @dataclass class CircuitBreaker: @@ -179,6 +185,21 @@ def _mutate_repo_permission_for_user( ) +def is_missing_mutation_resource_error(exception: BaseException) -> bool: + """Return whether a mutation failed because its repo/user disappeared. + + Sourcegraph instances are live systems: users and repos can be deleted + between discovery/planning and the eventual mutation. Those races should + be logged and skipped, not treated as backend-health failures. + """ + if not isinstance(exception, src.GraphQLError): + return False + message = str(exception).lower() + if not any(term in message for term in MISSING_MUTATION_RESOURCE_TERMS): + return False + return "not found" in message or "could not resolve" in message + + def _apply_permission_changes( client: src.SourcegraphClient, changes: Sequence[PermissionChange], @@ -198,6 +219,7 @@ def _apply_permission_changes( succeeded = 0 failed = 0 canceled = 0 + skipped = 0 breaker = CircuitBreaker() with run_context.thread_pool(parallelism, worker_pool) as executor: futures = { @@ -228,6 +250,17 @@ def _apply_permission_changes( canceled += 1 continue except Exception as exception: + if is_missing_mutation_resource_error(exception): + skipped += 1 + log.warning( + " SKIP %s %s → %s (id=%d): repo/user no longer exists: %s", + action, + change.username, + change.repo_name, + src.decode_repository_id(change.repo_id), + exception, + ) + continue failed += 1 breaker.record(success=False) log.error( @@ -246,11 +279,13 @@ def _apply_permission_changes( batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled + batch_event["skipped"] = skipped batch_event["circuit_broken"] = breaker.is_open() return shared_types.MutationCounts( succeeded=succeeded, failed=failed, canceled=canceled, + skipped=skipped, ) @@ -312,6 +347,7 @@ def _apply_repo_overwrite_plans( succeeded = 0 failed = 0 canceled = 0 + skipped = 0 submitted_count = 0 submissions_stopped = False breaker = CircuitBreaker() @@ -371,6 +407,15 @@ def _stop_submissions() -> None: canceled += 1 continue except Exception as exception: + if is_missing_mutation_resource_error(exception): + skipped += 1 + log.warning( + " SKIP %s (id=%d): repo/user no longer exists: %s", + overwrite.repository_name, + src.decode_repository_id(overwrite.repository_id), + exception, + ) + continue failed += 1 breaker.record(success=False) log.error( @@ -395,12 +440,14 @@ def _stop_submissions() -> None: batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled + batch_event["skipped"] = skipped batch_event["circuit_broken"] = breaker.is_open() batch_event["submitted"] = submitted_count return shared_types.MutationCounts( succeeded=succeeded, failed=failed, canceled=canceled, + skipped=skipped, ) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index fbdc00c..0cf2830 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import Any, cast +from typing import Any import src_py_lib as src @@ -43,7 +43,7 @@ class _ResolvedMapping: index: int name: str - users_section: dict[str, object] + user_selector: permission_types.UserSelector repos: list[permission_types.Repository] @@ -52,9 +52,9 @@ def resolve_additive_mappings(context: permission_types.MappingContext) -> list[ resolved: list[_ResolvedMapping] = [] for mapping_index, mapping in enumerate(context.mapping_rules, start=1): name = mapping.get("name", f"") - repos_section = cast(dict[str, object], mapping["repos"]) + repository_selector = mapping["repos"] matched_repos = permissions_mapping.resolve_repos( - repos_section, + repository_selector, context.services_by_id, context.repos_by_external_service_id, context.all_repos_by_id, @@ -72,7 +72,7 @@ def resolve_additive_mappings(context: permission_types.MappingContext) -> list[ _ResolvedMapping( index=mapping_index, name=name, - users_section=cast(dict[str, object], mapping["users"]), + user_selector=mapping["users"], repos=matched_repos, ) ) @@ -317,7 +317,7 @@ def cmd_set( retain_saml_group_users: bool = False, worker_pool: ThreadPoolExecutor | None = None, ) -> run_context.CommandData: - """Dispatch the selected `--set` mode.""" + """Dispatch the selected set mode.""" if options.mode == "full": return permissions_full_set.cmd_set_full( client, @@ -550,8 +550,8 @@ def _plan_additions_for_user( """Return missing additive permission edges for one user.""" desired_repos: dict[str, permission_types.Repository] = {} for resolved_mapping in resolved_mappings: - if not permissions_mapping.user_matches_users_section( - resolved_mapping.users_section, + if not permissions_mapping.user_matches_user_selector( + resolved_mapping.user_selector, user, context.providers, context.saml_groups_attribute_names, @@ -677,8 +677,9 @@ def _apply_additive_permissions( worker_pool=worker_pool, ) log.info( - "Additive apply done. %d succeeded, %d failed, %d canceled.", + "Additive apply done. %d succeeded, %d skipped, %d failed, %d canceled.", mutations.succeeded, + mutations.skipped, mutations.failed, mutations.canceled, ) diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 6b79458..a3534d8 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import Any, cast +from typing import Any import src_py_lib as src @@ -279,11 +279,11 @@ def plan_full_set_permissions( name = mapping.get("name", f"") log.info("=== Mapping %d / %d: %s ===", mapping_index, len(context.mapping_rules), name) - users_section = cast(dict[str, object], mapping["users"]) - repos_section = cast(dict[str, object], mapping["repos"]) + user_selector = mapping["users"] + repository_selector = mapping["repos"] matched_users = permissions_mapping.resolve_users( - users_section, + user_selector, users, context.providers, context.saml_groups_attribute_names, @@ -294,7 +294,7 @@ def plan_full_set_permissions( continue matched_repos = permissions_mapping.resolve_repos( - repos_section, + repository_selector, context.services_by_id, context.repos_by_external_service_id, context.all_repos_by_id, @@ -504,8 +504,9 @@ def _apply_full_set_plans( worker_pool=worker_pool, ) log.info( - "Apply done. %d succeeded, %d failed, %d canceled.", + "Apply done. %d succeeded, %d skipped, %d failed, %d canceled.", mutations.succeeded, + mutations.skipped, mutations.failed, mutations.canceled, ) @@ -525,6 +526,7 @@ def _record_full_set_event_fields( command_event["repo_count"] = len(plan.expected_users) command_event["total_grants"] = plan.total_grants command_event["mutations_succeeded"] = apply_result.mutations.succeeded + command_event["mutations_skipped"] = apply_result.mutations.skipped command_event["mutations_failed"] = apply_result.mutations.failed command_event["mutations_canceled"] = apply_result.mutations.canceled command_event["full_short_circuit"] = apply_result.full_short_circuit @@ -586,7 +588,7 @@ def _finish_full_set_apply_with_backup( log.info( "To roll back the explicit-permissions state captured in " "the before-snapshot, run:\n" - " uv run src-auth-perms-sync --restore %s --apply", + " uv run src-auth-perms-sync restore --restore-path %s --apply", before_path, ) @@ -600,7 +602,7 @@ def _raise_for_failed_full_set_apply( log.error( "RUN FAILED: %d mutation(s) failed, %d canceled by circuit " "breaker (out of %d planned). Review the log file and the " - "before/after snapshots for details, then re-run --set --apply " + "before/after snapshots for details, then re-run set --apply " "(after addressing the underlying cause) to retry the " "remaining work.", apply_result.mutations.failed, diff --git a/src/src_auth_perms_sync/permissions/mapping.py b/src/src_auth_perms_sync/permissions/mapping.py index 33a9fb0..69d78f1 100644 --- a/src/src_auth_perms_sync/permissions/mapping.py +++ b/src/src_auth_perms_sync/permissions/mapping.py @@ -1,16 +1,15 @@ """Permission mapping resolution: validate rules and match users/repos. -Each mapping rule has a `users:` section and a `repos:` section, each -containing one or more matchers. Within a matcher, the supplied keys -AND together against the discovered auth-provider / external-service -entries. Across sibling matchers, results intersect. Across mapping -rules, `cmd_set` unions the per-repo user sets at apply time — see +Each mapping rule has a `users:` section and a `repos:` section. Top-level +selectors under each section AND together to keep each rule restrictive. +Values inside each supplied selector list OR together. Across mapping rules, +`cmd_set` unions the per-repo user sets at apply time — see `src/src_auth_perms_sync/permissions/types.py` for the rationale. Adding a new matcher type: 1. Add the TypedDict in `src/src_auth_perms_sync/permissions/types.py`. - 2. Add it as a sibling key on `UsersFilter` or `ReposFilter`. + 2. Add it as a sibling key on `UserSelector` or `RepositorySelector`. 3. Add a branch in `resolve_users` / `resolve_repos` below. 4. Add structural validation in `validate_mapping_rules`. 5. Add an example rule using the new matcher to `maps-example.yaml`. @@ -20,7 +19,7 @@ import logging import re -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any, cast import json5 @@ -50,7 +49,7 @@ "configID", "samlGroup", } -CODE_HOST_MATCHER_FIELDS: set[str] = {"id", "kind", "displayName", "url", "config"} +CODE_HOST_MATCHER_FIELDS: set[str] = {"kind", "displayName", "url", "username"} AUTH_PROVIDER_VALUE_MATCHES: tuple[tuple[str, str], ...] = ( ("type", "serviceType"), ("serviceID", "serviceID"), @@ -58,7 +57,15 @@ ("displayName", "displayName"), ("configID", "configID"), ) -CODE_HOST_VALUE_MATCHES: tuple[str, ...] = ("kind", "displayName", "url") +CODE_HOST_DIRECT_VALUE_MATCHES: tuple[str, ...] = ("kind", "displayName", "url") +USER_SELECTOR_FIELDS: set[str] = { + "authProvider", + "emails", + "emailRegexes", + "usernames", + "usernameRegexes", +} +REPOSITORY_SELECTOR_FIELDS: set[str] = {"codeHostConnection", "names", "nameRegexes"} # --------------------------------------------------------------------------- @@ -66,7 +73,7 @@ # --------------------------------------------------------------------------- -def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: +def validate_mapping_rules(rules: Sequence[object]) -> None: """Fail fast on structural problems in the YAML before doing any work. Catches operator typos that would otherwise produce confusing partial @@ -81,20 +88,37 @@ def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: bugs. """ errors: list[str] = [] - for rule_index, rule in enumerate(rules, start=1): + for rule_index, rule_object in enumerate(rules, start=1): + if not isinstance(rule_object, dict): + errors.append( + f"mapping {rule_index}: each `maps:` entry must be a mapping " + f"(got {type(rule_object).__name__})" + ) + continue + + rule = cast(Mapping[str, object], rule_object) label = rule.get("name") or f"" prefix = f"mapping {rule_index} ({label!r})" - users_section = cast(dict[str, object], rule.get("users") or {}) - repos_section = cast(dict[str, object], rule.get("repos") or {}) - - if not users_section: - errors.append(f"{prefix}: `users:` section is empty (matches no users)") - if not repos_section: - errors.append(f"{prefix}: `repos:` section is empty (matches no repos)") - - errors.extend(_validate_users_section(users_section, prefix)) - errors.extend(_validate_repos_section(repos_section, prefix)) + errors.extend(_validate_mapping_name(rule.get("name"), prefix)) + errors.extend( + _validate_selector_section( + rule.get("users"), + prefix, + "users", + USER_SELECTOR_FIELDS, + _validate_user_selector, + ) + ) + errors.extend( + _validate_selector_section( + rule.get("repos"), + prefix, + "repos", + REPOSITORY_SELECTOR_FIELDS, + _validate_repository_selector, + ) + ) if errors: bullet = "\n - " @@ -103,40 +127,116 @@ def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: ) -_KNOWN_USER_MATCHERS: set[str] = {"authProvider", "emails", "usernames"} -_KNOWN_REPO_MATCHERS: set[str] = {"codeHostConnection", "names", "regexes"} - - def mapping_rules_need_user_emails(mapping_rules: list[permission_types.MappingRule]) -> bool: """Return whether any mapping rule filters users by verified email.""" - return any("emails" in mapping["users"] for mapping in mapping_rules) + return any( + "emails" in mapping["users"] or "emailRegexes" in mapping["users"] + for mapping in mapping_rules + ) + +def _validate_mapping_name(value: object, prefix: str) -> list[str]: + """Validate the required human-readable mapping name.""" + if value is None: + return [f"{prefix}: `name:` is missing"] + if not isinstance(value, str): + return [f"{prefix}: `name:` must be a string (got {type(value).__name__})"] + if not value: + return [f"{prefix}: `name:` is empty"] + return [] + + +def _validate_selector_section( + value: object, + prefix: str, + section_name: str, + known_fields: set[str], + validate_selector: Callable[[dict[str, object], str, str], list[str]], +) -> list[str]: + """Validate a top-level user or repo selector mapping.""" + if value is None: + return [f"{prefix}: `{section_name}:` section is missing"] + if not isinstance(value, dict): + return [ + f"{prefix}: `{section_name}:` must be a selector mapping (got {type(value).__name__})" + ] -def _validate_users_section(section: dict[str, object], prefix: str) -> list[str]: - """Reject unknown matcher keys and validate each matcher's shape.""" + selector = cast(dict[str, object], value) errors: list[str] = [] - for key in section: - if key not in _KNOWN_USER_MATCHERS: - errors.append(f"{prefix}: unknown users matcher {key!r}") - auth_provider = cast(dict[str, object] | None, section.get("authProvider")) + if not selector: + errors.append(f"{prefix}: `{section_name}:` section is empty (matches nothing)") + return errors + + for field_name in sorted(set(selector) - known_fields): + errors.append(f"{prefix}: unknown {section_name} field {field_name!r}") + errors.extend(validate_selector(selector, prefix, section_name)) + return errors + + +def _validate_user_selector( + selector: dict[str, object], prefix: str, selector_path: str +) -> list[str]: + """Validate one user selector's ANDed matcher fields.""" + errors: list[str] = [] + auth_provider = selector.get("authProvider") if auth_provider is not None: - unknown = set(auth_provider) - AUTH_PROVIDER_MATCHER_FIELDS - for field_name in sorted(unknown): - errors.append(f"{prefix}: unknown authProvider field {field_name!r}") - if not auth_provider: - errors.append( - f"{prefix}: authProvider is empty (would match every provider on the instance)" + errors.extend(_validate_auth_provider_matcher(auth_provider, prefix, selector_path)) + if "emails" in selector: + errors.extend(_validate_string_list(selector["emails"], prefix, f"{selector_path}.emails")) + if "emailRegexes" in selector: + errors.extend( + _validate_regexes(selector["emailRegexes"], prefix, f"{selector_path}.emailRegexes") + ) + if "usernames" in selector: + errors.extend( + _validate_string_list(selector["usernames"], prefix, f"{selector_path}.usernames") + ) + if "usernameRegexes" in selector: + errors.extend( + _validate_regexes( + selector["usernameRegexes"], prefix, f"{selector_path}.usernameRegexes" ) - if "samlGroup" in auth_provider: - errors.extend(_validate_saml_group(auth_provider, prefix)) - if "emails" in section: - errors.extend(_validate_string_list(section["emails"], prefix, "users.emails")) - if "usernames" in section: - errors.extend(_validate_string_list(section["usernames"], prefix, "users.usernames")) + ) + return errors + + +def _validate_repository_selector( + selector: dict[str, object], prefix: str, selector_path: str +) -> list[str]: + """Validate one repository selector's ANDed matcher fields.""" + errors: list[str] = [] + code_host_connection = selector.get("codeHostConnection") + if code_host_connection is not None: + errors.extend( + _validate_code_host_connection_matcher(code_host_connection, prefix, selector_path) + ) + if "names" in selector: + errors.extend(_validate_string_list(selector["names"], prefix, f"{selector_path}.names")) + if "nameRegexes" in selector: + errors.extend( + _validate_regexes(selector["nameRegexes"], prefix, f"{selector_path}.nameRegexes") + ) return errors -def _validate_saml_group(auth_provider: dict[str, object], prefix: str) -> list[str]: +def _validate_auth_provider_matcher(value: object, prefix: str, selector_path: str) -> list[str]: + """Validate an `authProvider:` matcher.""" + path = f"{selector_path}.authProvider" + if not isinstance(value, dict): + return [f"{prefix}: {path} must be a mapping (got {type(value).__name__})"] + + auth_provider = cast(dict[str, object], value) + errors: list[str] = [] + for field_name in sorted(set(auth_provider) - AUTH_PROVIDER_MATCHER_FIELDS): + errors.append(f"{prefix}: unknown {path} field {field_name!r}") + if not auth_provider: + errors.append(f"{prefix}: {path} is empty (would match every provider on the instance)") + if "samlGroup" in auth_provider: + errors.extend(_validate_saml_group(auth_provider, prefix, path)) + return errors + + +def _validate_saml_group(auth_provider: dict[str, object], prefix: str, path: str) -> list[str]: """`authProvider.samlGroup`, if present, must be a non-empty string and incompatible with a non-SAML `type:` (the rule could never match). """ @@ -144,12 +244,12 @@ def _validate_saml_group(auth_provider: dict[str, object], prefix: str) -> list[ value = auth_provider["samlGroup"] if not isinstance(value, str): errors.append( - f"{prefix}: authProvider.samlGroup must be a single group-name " + f"{prefix}: {path}.samlGroup must be a single group-name " f"string (got {type(value).__name__} {value!r}); to OR multiple " - f"groups, write multiple rules" + f"groups, add multiple top-level maps entries" ) elif not value: - errors.append(f"{prefix}: authProvider.samlGroup is an empty string") + errors.append(f"{prefix}: {path}.samlGroup is an empty string") declared_type = auth_provider.get("type") if ( isinstance(declared_type, str) @@ -157,60 +257,44 @@ def _validate_saml_group(auth_provider: dict[str, object], prefix: str) -> list[ and declared_type != saml_groups.SAML_SERVICE_TYPE ): errors.append( - f"{prefix}: authProvider.samlGroup is set but authProvider.type " + f"{prefix}: {path}.samlGroup is set but {path}.type " f"is {declared_type!r}; only SAML providers carry group claims" ) return errors -def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str]: - """Reject unknown matcher keys and validate `codeHostConnection:` shape.""" +def _validate_code_host_connection_matcher( + value: object, prefix: str, selector_path: str +) -> list[str]: + """Validate a `codeHostConnection:` matcher.""" + path = f"{selector_path}.codeHostConnection" + if not isinstance(value, dict): + return [f"{prefix}: {path} must be a mapping (got {type(value).__name__})"] + + code_host_section = cast(dict[str, object], value) errors: list[str] = [] - for key in section: - if key not in _KNOWN_REPO_MATCHERS: - errors.append(f"{prefix}: unknown repos matcher {key!r}") - code_host_section = cast(dict[str, object] | None, section.get("codeHostConnection")) - if code_host_section is not None: - unknown = set(code_host_section) - CODE_HOST_MATCHER_FIELDS - for field_name in sorted(unknown): - errors.append(f"{prefix}: unknown codeHostConnection field {field_name!r}") - if not (set(code_host_section) & CODE_HOST_MATCHER_FIELDS): - errors.append( - f"{prefix}: codeHostConnection is empty (would match every " - f"external service on the instance); supply at least one of " - f"{sorted(CODE_HOST_MATCHER_FIELDS)}" - ) - if "id" in code_host_section: - external_service_id = code_host_section["id"] - if external_service_id is None or external_service_id == "": - errors.append( - f"{prefix}: codeHostConnection.id, if supplied, must be " - f"a non-empty integer (e.g. `id: 5`)" - ) - elif not isinstance(external_service_id, int) or isinstance(external_service_id, bool): - errors.append( - f"{prefix}: codeHostConnection.id must be an integer " - f"(got {type(external_service_id).__name__} {external_service_id!r}); " - f"the YAML config holds the decoded DB primary key, not the " - f"opaque base64 GraphQL Node ID" - ) - if "config" in code_host_section and not isinstance(code_host_section["config"], dict): + for field_name in sorted(set(code_host_section) - CODE_HOST_MATCHER_FIELDS): + errors.append(f"{prefix}: unknown {path} field {field_name!r}") + if not code_host_section: + errors.append( + f"{prefix}: {path} is empty (would match every external service on " + f"the instance); supply at least one of {sorted(CODE_HOST_MATCHER_FIELDS)}" + ) + for field_name in sorted(CODE_HOST_MATCHER_FIELDS & set(code_host_section)): + field_value = code_host_section[field_name] + if not isinstance(field_value, str): errors.append( - f"{prefix}: codeHostConnection.config must be a mapping of " - f"key/value pairs to deep-subset-match against the service's " - f"parsed config (got {type(code_host_section['config']).__name__})" + f"{prefix}: {path}.{field_name} must be a string " + f"(got {type(field_value).__name__} {field_value!r})" ) - if "names" in section: - errors.extend(_validate_string_list(section["names"], prefix, "repos.names")) - regexes = section.get("regexes") - if regexes is not None: - errors.extend(_validate_regexes(regexes, prefix)) + elif not field_value: + errors.append(f"{prefix}: {path}.{field_name} is an empty string") return errors -def _validate_regexes(value: object, prefix: str) -> list[str]: +def _validate_regexes(value: object, prefix: str, path: str) -> list[str]: """Validate list-based regex filters.""" - errors = _validate_string_list(value, prefix, "repos.regexes") + errors = _validate_string_list(value, prefix, path) if errors: return errors @@ -218,9 +302,7 @@ def _validate_regexes(value: object, prefix: str) -> list[str]: try: re.compile(pattern) except re.error as exception: - errors.append( - f"{prefix}: repos.regexes[{index}] is not a valid Python regex: {exception}" - ) + errors.append(f"{prefix}: {path}[{index}] is not a valid Python regex: {exception}") return errors @@ -249,12 +331,12 @@ def _validate_string_list(value: object, prefix: str, path: str) -> list[str]: def resolve_users( - section: dict[str, object], + selector: permission_types.UserSelector, all_users: list[shared_types.User], all_providers: list[shared_types.AuthProvider], saml_groups_attribute_names: saml_groups.SamlGroupsAttributeNameByProvider | None = None, ) -> list[shared_types.User]: - """Return users matching ALL matchers under `users:` (intersection). + """Return users matching ALL top-level selectors under `users:`. `saml_groups_attribute_names` overrides the default `"groups"` SAML assertion attribute name per (serviceID, clientID) — see @@ -262,110 +344,180 @@ def resolve_users( `None`, every SAML provider falls back to the default. Only consulted by the `authProvider.samlGroup` sub-field. - Empty section returns an empty user set — `validate_mapping_rules` + Empty sections return an empty user set — `validate_mapping_rules` rejects this at config-load time, so this branch only fires for programmatic callers. """ - if not section: + if not selector: return [] - users_by_id: dict[str, shared_types.User] = {user["id"]: user for user in all_users} - matched_ids: set[str] | None = None - for key, matcher in section.items(): - if key == "authProvider": - current_ids = { + selector_matches: list[set[str]] = [] + auth_provider = selector.get("authProvider") + if auth_provider is not None: + selector_matches.append( + { user["id"] for user in _users_matching_auth_provider( - cast(permission_types.AuthProviderMatcher, matcher), + auth_provider, all_users, all_providers, saml_groups_attribute_names, ) } - elif key == "emails": - current_ids = { - user["id"] for user in _users_matching_emails(cast(list[str], matcher), all_users) - } - elif key == "usernames": - current_ids = { - user["id"] - for user in _users_matching_usernames(cast(list[str], matcher), all_users) - } - else: - # validate_mapping_rules catches this earlier with a clearer - # message; this only fires for programmatic callers. - raise ValueError(f"unknown users matcher {key!r}") - matched_ids = current_ids if matched_ids is None else matched_ids & current_ids + ) + + emails = selector.get("emails") + if emails is not None: + selector_matches.append( + {user["id"] for user in _users_matching_email_values(emails, all_users)} + ) + + email_regexes = selector.get("emailRegexes") + if email_regexes is not None: + selector_matches.append( + {user["id"] for user in _users_matching_email_regexes(email_regexes, all_users)} + ) + + usernames = selector.get("usernames") + if usernames is not None: + selector_matches.append( + {user["id"] for user in _users_matching_username_values(usernames, all_users)} + ) + + username_regexes = selector.get("usernameRegexes") + if username_regexes is not None: + selector_matches.append( + {user["id"] for user in _users_matching_username_regexes(username_regexes, all_users)} + ) + + if not selector_matches: + return [] + + matched_ids = selector_matches[0] + for current_ids in selector_matches[1:]: + matched_ids &= current_ids if not matched_ids: return [] - assert matched_ids is not None - return [users_by_id[user_id] for user_id in matched_ids] + return [user for user in all_users if user["id"] in matched_ids] -def user_matches_users_section( - section: dict[str, object], +def user_matches_user_selector( + selector: permission_types.UserSelector, user: shared_types.User, all_providers: list[shared_types.AuthProvider], saml_groups_attribute_names: saml_groups.SamlGroupsAttributeNameByProvider | None = None, ) -> bool: - """Return whether one user matches ALL matchers under `users:`.""" - if not section: + """Return whether one user matches ALL top-level selectors under `users:`.""" + if not selector: return False - for key, matcher in section.items(): - if key == "authProvider": - if not _user_matches_auth_provider( - cast(permission_types.AuthProviderMatcher, matcher), - user, - all_providers, - saml_groups_attribute_names, - ): - return False - elif key == "emails": - if not _user_matches_emails(user, cast(list[str], matcher)): - return False - elif key == "usernames": - if user["username"] not in cast(list[str], matcher): - return False - else: - # validate_mapping_rules catches this earlier with a clearer - # message; this only fires for programmatic callers. - raise ValueError(f"unknown users matcher {key!r}") - return True + auth_provider = selector.get("authProvider") + if auth_provider is not None and not _user_matches_auth_provider( + auth_provider, + user, + all_providers, + saml_groups_attribute_names, + ): + return False + emails = selector.get("emails") + if emails is not None and not _user_matches_email(user, set(emails), []): + return False -def _users_matching_emails( + email_regexes = selector.get("emailRegexes") + if email_regexes is not None and not _user_matches_email( + user, set(), _compiled_regexes(email_regexes) + ): + return False + + usernames = selector.get("usernames") + if usernames is not None and not _text_matches(user["username"], set(usernames), []): + return False + + username_regexes = selector.get("usernameRegexes") + if username_regexes is None: + return True + return _text_matches(user["username"], set(), _compiled_regexes(username_regexes)) + + +def _users_matching_email_values( emails: list[str], all_users: list[shared_types.User] ) -> list[shared_types.User]: - """Return users with at least one verified email in `emails`.""" - matched = [user for user in all_users if _user_matches_emails(user, emails)] - log.info(" emails → %d user(s) matched %d email(s)", len(matched), len(set(emails))) + """Return users with at least one verified email equal to a listed email.""" + exact_values = set(emails) + matched = [user for user in all_users if _user_matches_email(user, exact_values, [])] + log.info( + " emails → %d user(s) matched %d email selector(s)", + len(matched), + len(exact_values), + ) + return matched + + +def _users_matching_email_regexes( + email_regexes: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users with at least one verified email matching a listed regex.""" + patterns = _compiled_regexes(email_regexes) + matched = [user for user in all_users if _user_matches_email(user, set(), patterns)] + log.info( + " emailRegexes → %d user(s) matched %d email regex selector(s)", + len(matched), + len(set(email_regexes)), + ) return matched -def _user_matches_emails(user: shared_types.User, emails: list[str]) -> bool: +def _user_matches_email( + user: shared_types.User, exact_values: set[str], patterns: list[re.Pattern[str]] +) -> bool: """Match only verified emails, mirroring Sourcegraph's `user(email:)` lookup.""" - email_set = set(emails) return any( - user_email["verified"] and user_email["email"] in email_set + user_email["verified"] and _text_matches(user_email["email"], exact_values, patterns) for user_email in user.get("emails", []) ) -def _users_matching_usernames( +def _users_matching_username_values( usernames: list[str], all_users: list[shared_types.User] ) -> list[shared_types.User]: - """Return users whose Sourcegraph username is listed exactly.""" - username_set = set(usernames) - matched = [user for user in all_users if user["username"] in username_set] + """Return users whose Sourcegraph username equals a listed username.""" + exact_values = set(usernames) + matched = [user for user in all_users if _text_matches(user["username"], exact_values, [])] log.info( - " usernames → %d user(s) matched %d username(s)", + " usernames → %d user(s) matched %d username selector(s)", len(matched), - len(username_set), + len(exact_values), ) return matched +def _users_matching_username_regexes( + username_regexes: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users whose Sourcegraph username matches a listed regex.""" + patterns = _compiled_regexes(username_regexes) + matched = [user for user in all_users if _text_matches(user["username"], set(), patterns)] + log.info( + " usernameRegexes → %d user(s) matched %d username regex selector(s)", + len(matched), + len(set(username_regexes)), + ) + return matched + + +def _compiled_regexes(regexes: list[str]) -> list[re.Pattern[str]]: + """Return compiled regexes.""" + return [re.compile(pattern) for pattern in regexes] + + +def _text_matches(value: str, exact_values: set[str], patterns: list[re.Pattern[str]]) -> bool: + """Return whether text matches exact values or any regex.""" + if value in exact_values: + return True + return any(pattern.search(value) for pattern in patterns) + + def _users_matching_auth_provider( matcher: permission_types.AuthProviderMatcher, all_users: list[shared_types.User], @@ -520,65 +672,77 @@ def _user_has_saml_group_in_provider( def resolve_repos( - section: dict[str, object], + selector: permission_types.RepositorySelector, services_by_id: dict[int, permission_types.ExternalService], repos_by_external_service_id: dict[int, list[permission_types.Repository]], all_repos_by_id: dict[str, permission_types.Repository], ) -> list[permission_types.Repository]: - """Return repos matching ALL matchers under `repos:` (intersection). + """Return repos matching ALL top-level selectors under `repos:`. - Empty section returns an empty repo set; `validate_mapping_rules` + Empty sections return an empty repo set; `validate_mapping_rules` rejects this at config-load time. """ - if not section: + if not selector: return [] - matched_ids: set[str] | None = None - repo_index: dict[str, permission_types.Repository] = {} - ordered_keys = [key for key in ("codeHostConnection", "names", "regexes") if key in section] - for key in ordered_keys: - matcher = section[key] - if key == "codeHostConnection": - repos = _repos_matching_code_host_connection( - cast(permission_types.CodeHostConnectionMatcher, matcher), - services_by_id, - repos_by_external_service_id, - ) - elif key == "names": - candidate_repos = ( - [repo_index[repo_id] for repo_id in matched_ids] - if matched_ids is not None - else list(all_repos_by_id.values()) - ) - repos = _repos_matching_names(cast(list[str], matcher), candidate_repos) - elif key == "regexes": - candidate_repos = ( - [repo_index[repo_id] for repo_id in matched_ids] - if matched_ids is not None - else list(all_repos_by_id.values()) - ) - repos = _repos_matching_regexes(cast(list[str], matcher), candidate_repos) - else: - # validate_mapping_rules catches this earlier with a clearer - # message; this only fires for programmatic callers. - raise ValueError(f"unknown repos matcher {key!r}") - current_ids = {repo["id"] for repo in repos} - for repo in repos: - repo_index[repo["id"]] = repo - matched_ids = current_ids if matched_ids is None else matched_ids & current_ids + selector_matches: list[set[str]] = [] + repo_index = dict(all_repos_by_id) + candidate_repos = list(all_repos_by_id.values()) + code_host_connection = selector.get("codeHostConnection") + if code_host_connection is not None: + repos = _repos_matching_code_host_connection( + code_host_connection, + services_by_id, + repos_by_external_service_id, + ) + repo_index.update({repo["id"]: repo for repo in repos}) + candidate_repos = repos + selector_matches.append({repo["id"] for repo in repos}) + + names = selector.get("names") + if names is not None: + selector_matches.append(_repo_ids_matching_names(names, candidate_repos)) + + name_regexes = selector.get("nameRegexes") + if name_regexes is not None: + selector_matches.append(_repo_ids_matching_name_regexes(name_regexes, candidate_repos)) + + if not selector_matches: + return [] + + matched_ids = selector_matches[0] + for current_ids in selector_matches[1:]: + matched_ids &= current_ids if not matched_ids: return [] - assert matched_ids is not None - return [repo_index[repo_id] for repo_id in matched_ids] + return [repo for repo in repo_index.values() if repo["id"] in matched_ids] -def _repos_matching_names( +def _repo_ids_matching_names( names: list[str], repos: list[permission_types.Repository] -) -> list[permission_types.Repository]: - """Return repos whose Sourcegraph name is listed exactly.""" - name_set = set(names) - matched = [repo for repo in repos if repo["name"] in name_set] - log.info(" names → %d repo(s) matched %d name(s)", len(matched), len(name_set)) +) -> set[str]: + """Return repo IDs whose Sourcegraph name equals a listed name.""" + exact_values = set(names) + matched = {repo["id"] for repo in repos if _repo_name_matches(repo["name"], exact_values, [])} + log.info( + " names → %d repo(s) matched %d name selector(s)", + len(matched), + len(exact_values), + ) + return matched + + +def _repo_ids_matching_name_regexes( + name_regexes: list[str], repos: list[permission_types.Repository] +) -> set[str]: + """Return repo IDs whose Sourcegraph name matches a listed regex.""" + patterns = _compiled_regexes(name_regexes) + matched = {repo["id"] for repo in repos if _repo_name_matches(repo["name"], set(), patterns)} + log.info( + " nameRegexes → %d repo(s) matched %d name regex selector(s)", + len(matched), + len(set(name_regexes)), + ) return matched @@ -608,74 +772,47 @@ def _repos_matching_code_host_connection( return list(matched_repos.values()) -def _repos_matching_regexes( - patterns: list[str], repos: list[permission_types.Repository] -) -> list[permission_types.Repository]: - """Return repos whose name matches any pattern using Python `re`. +def _repo_name_matches( + repository_name: str, exact_values: set[str], patterns: list[re.Pattern[str]] +) -> bool: + """Return whether a repo name matches exact values or regexes. Sourcegraph repo names usually omit the URL scheme (for example - `github.com/example/repo`). To keep URL-looking operator patterns - useful, also test `https://`. + `github.com/example/repo`). To keep URL-looking operator regexes useful, + also test `https://` for regex matches. Exact matches remain + exact Sourcegraph repo names. """ - compiled_patterns = [re.compile(pattern) for pattern in patterns] - matched = [ - repo - for repo in repos - if any( - compiled_pattern.search(repo["name"]) - or compiled_pattern.search(f"https://{repo['name']}") - for compiled_pattern in compiled_patterns - ) - ] - log.info(" regexes → %d repo(s) matched %d pattern(s)", len(matched), len(patterns)) - return matched + if repository_name in exact_values: + return True + return any( + pattern.search(repository_name) or pattern.search(f"https://{repository_name}") + for pattern in patterns + ) def _services_matching( services_by_id: dict[int, permission_types.ExternalService], matcher: permission_types.CodeHostConnectionMatcher, ) -> list[permission_types.ExternalService]: - """AND across the supplied matcher fields. If `id` is supplied we - short-circuit to a single candidate; remaining fields then act as a - defensive cross-check against an ES recreated/renamed under the - same id. Without `id`, every other supplied field is a primary - discriminator across the full service list. - """ - if "id" in matcher: - single_service = services_by_id.get(matcher["id"]) - if single_service is None: - return [] - candidates = [single_service] - else: - candidates = list(services_by_id.values()) - + """AND across the supplied human-readable code-host matcher fields.""" matched: list[permission_types.ExternalService] = [] matcher_values = cast(Mapping[str, object], matcher) - for service in candidates: + for service in services_by_id.values(): service_values = cast(Mapping[str, object], service) if not all( field_name not in matcher_values or matcher_values[field_name] == service_values[field_name] - for field_name in CODE_HOST_VALUE_MATCHES + for field_name in CODE_HOST_DIRECT_VALUE_MATCHES ): continue - if "config" in matcher and not _config_subset_matches( - matcher["config"], _parsed_service_config(service) - ): + if "username" in matcher and matcher["username"] != _service_username(service): continue matched.append(service) return matched def _parsed_service_config(service: permission_types.ExternalService) -> dict[str, Any]: - """Best-effort parse of `ExternalService.config` (JSONC string). - - Returns an empty dict if the config is missing or unparseable — - callers treat that as "no keys to match against", so a `config:` - matcher against such a service simply fails to match instead of - raising. Sourcegraph's resolver returns a JSON object string, so - parse failures here are anomalies worth not crashing on. - """ + """Best-effort parse of `ExternalService.config` (JSONC string).""" raw_config = service.get("config") if not raw_config: return {} @@ -688,46 +825,10 @@ def _parsed_service_config(service: permission_types.ExternalService) -> dict[st return cast(dict[str, Any], parsed) -def _config_subset_matches(matcher_config: dict[str, Any], service_config: dict[str, Any]) -> bool: - """True iff every key in `matcher_config` is present in `service_config` - with a matching value. Nested dicts are matched recursively - (subset semantics); lists and scalars are matched by equality. - - Sourcegraph's `REDACTED` sentinel is left as-is on the service side: - a matcher that names a redacted key (e.g. `token`) compares against - the literal `"REDACTED"` string and almost certainly fails to - match — exactly the semantics we want, since the operator can't - have known the real secret value. - """ - for key, expected in matcher_config.items(): - if key not in service_config: - return False - actual = service_config[key] - if isinstance(expected, dict) and isinstance(actual, dict): - if not _config_subset_matches( - cast(dict[str, Any], expected), cast(dict[str, Any], actual) - ): - return False - continue - if expected != actual: - return False - return True - - -def referenced_external_service_ids(rules: list[permission_types.MappingRule]) -> set[int]: - """Collect all external_service IDs referenced by the mapping rules. - - Returns integer DB primary keys (the YAML-facing form). Used by - `cmd_set` to pre-flight-warn about any IDs that the live instance - doesn't know about, before per-mapping resolution runs. - """ - referenced: set[int] = set() - for rule in rules: - repos_section = rule.get("repos") or {} - code_host_section = repos_section.get("codeHostConnection") - if code_host_section and "id" in code_host_section: - referenced.add(code_host_section["id"]) - return referenced +def _service_username(service: permission_types.ExternalService) -> str | None: + """Return the code-host username from `ExternalService.config`, if present.""" + username = _parsed_service_config(service).get("username") + return username if isinstance(username, str) else None def _format_matcher(matcher: dict[str, object]) -> str: diff --git a/src/src_auth_perms_sync/permissions/maps.py b/src/src_auth_perms_sync/permissions/maps.py index 98ff4c6..d6cb86d 100644 --- a/src/src_auth_perms_sync/permissions/maps.py +++ b/src/src_auth_perms_sync/permissions/maps.py @@ -133,26 +133,21 @@ def count_users_per_provider( def external_service_to_yaml(service: permission_types.ExternalService) -> dict[str, Any]: """Render an external service for the YAML config. - Keys mirror Sourcegraph GraphQL `ExternalService` field names directly - (camelCase). Every scalar field exposed by the GraphQL schema is - surfaced here, including the JSONC `config` blob (parsed and emitted - as a nested mapping). Sourcegraph's `config` resolver redacts secrets - by replacing their values with the literal string `"REDACTED"`; we - strip those keys recursively via `_strip_redacted` so the YAML - contains no useless redaction placeholders. Nested arrays - (e.g. `webhooks[]`, `exclude[]`) are walked too. - - `id` is the decoded integer DB primary key, NOT the opaque base64 - GraphQL Node ID — operators copy this into mapping rules' `repos. - codeHostConnection.id` field, where the integer form is much - friendlier than `RXh0ZXJuYWxTZXJ2aWNlOjU=`. + Keys mirror the human-readable Sourcegraph GraphQL `ExternalService` + fields that maps can match. The opaque GraphQL `id` is omitted; + maps should identify code host connections with `kind`, `displayName`, + `url`, and/or `username`. + + The JSONC `config` blob is parsed only to lift its top-level + `username` into the read-only discovery YAML. The rest of `config` + is intentionally omitted because maps no longer support matching + code-host connections by arbitrary config subtrees. Optional / nullable fields are omitted when null/empty so the YAML stays readable. Booleans are always emitted (true or false) so the discovered state is explicit. """ rendered: dict[str, Any] = { - "id": src.decode_external_service_id(service["id"]), "kind": service["kind"], "displayName": service["displayName"], "url": service["url"], @@ -181,21 +176,22 @@ def external_service_to_yaml(service: permission_types.ExternalService) -> dict[ raw_config = service.get("config") if raw_config: try: - parsed_config = cast(dict[str, Any], json5.loads(raw_config)) + parsed_config = cast(Any, json5.loads(raw_config)) except ValueError: - # Unparsable JSONC: surface the raw string verbatim so the - # operator can still see what's there. Stripping doesn't - # apply since we have no structure to walk. - rendered["config"] = raw_config + pass else: - rendered["config"] = _strip_redacted(parsed_config) + if isinstance(parsed_config, dict): + config_values = cast(dict[str, Any], parsed_config) + username = config_values.get("username") + if isinstance(username, str) and username: + rendered["username"] = username return rendered def dump_auth_providers_yaml(path: Path, providers: list[dict[str, Any]]) -> None: header = ( "# Sourcegraph auth provider configs.\n" - "# Generated/refreshed by: src-auth-perms-sync --get\n" + "# Generated/refreshed by: src-auth-perms-sync get\n" "# Use these values when writing maps.yaml rules under `users.authProvider`.\n" "# This file is read-only reference data; edit maps.yaml, not this file.\n" ) @@ -205,9 +201,9 @@ def dump_auth_providers_yaml(path: Path, providers: list[dict[str, Any]]) -> Non def dump_code_hosts_yaml(path: Path, code_hosts: list[dict[str, Any]]) -> None: header = ( "# Sourcegraph code host connection configs.\n" - "# Generated/refreshed by: src-auth-perms-sync --get\n" + "# Generated/refreshed by: src-auth-perms-sync get\n" "# Use these values when writing maps.yaml rules under `repos.codeHostConnection`.\n" - "# Secrets from ExternalService.config are stripped.\n" + "# ExternalService.config.username is surfaced as top-level `username` when present.\n" "# This file is read-only reference data; edit maps.yaml, not this file.\n" ) _dump_readonly_discovery_yaml(path, header, "codeHostConnections", code_hosts) diff --git a/src/src_auth_perms_sync/permissions/restore.py b/src/src_auth_perms_sync/permissions/restore.py index da7ab1a..5ab1ece 100644 --- a/src/src_auth_perms_sync/permissions/restore.py +++ b/src/src_auth_perms_sync/permissions/restore.py @@ -489,6 +489,12 @@ def _log_user_scoped_restore_done(mutations: _UserScopedRestoreMutationResult) - mutations.additions.succeeded, mutations.removals.succeeded, ) + skipped = mutations.additions.skipped + mutations.removals.skipped + if skipped: + log.warning( + "Scoped restore skipped %d vanished repo/user mutation(s); the next run will re-plan.", + skipped, + ) def _restore_command_name(dry_run: bool) -> str: @@ -725,8 +731,9 @@ def _apply_restore_overwrites( worker_pool=worker_pool, ) log.info( - "Restore done. %d succeeded, %d failed, %d canceled.", + "Restore done. %d succeeded, %d skipped, %d failed, %d canceled.", mutations.succeeded, + mutations.skipped, mutations.failed, mutations.canceled, ) @@ -744,6 +751,7 @@ def _record_restore_event_fields( command_event["repos_short_circuited"] = plan.skipped_repo_count command_event["snapshot_grants"] = snapshot_state.target_snapshot["stats"]["total_grants"] command_event["mutations_succeeded"] = mutations.succeeded + command_event["mutations_skipped"] = mutations.skipped command_event["mutations_failed"] = mutations.failed command_event["mutations_canceled"] = mutations.canceled diff --git a/src/src_auth_perms_sync/permissions/types.py b/src/src_auth_perms_sync/permissions/types.py index a320892..07da01a 100644 --- a/src/src_auth_perms_sync/permissions/types.py +++ b/src/src_auth_perms_sync/permissions/types.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal, NotRequired, TypeAlias, TypedDict +from typing import Any, Literal, TypeAlias, TypedDict from ..shared import types as shared_types @@ -16,7 +16,7 @@ @dataclass(frozen=True) class SetCommandOptions: - """Operator-selected mode for `--set`.""" + """Operator-selected mode for the set command.""" mode: SetCommandMode user_identifier: str | None = None @@ -76,29 +76,34 @@ class AuthProviderMatcher(TypedDict, total=False): class CodeHostConnectionMatcher(TypedDict, total=False): """Match repos by Sourcegraph code-host connection discovery fields.""" - id: int kind: str displayName: str url: str - config: dict[str, Any] + username: str + +class UserSelector(TypedDict, total=False): + """User selectors. Fields AND together; values inside each field OR.""" -class UsersFilter(TypedDict, total=False): authProvider: AuthProviderMatcher emails: list[str] + emailRegexes: list[str] usernames: list[str] + usernameRegexes: list[str] -class ReposFilter(TypedDict, total=False): +class RepositorySelector(TypedDict, total=False): + """Repository selectors. Fields AND together; values inside each field OR.""" + codeHostConnection: CodeHostConnectionMatcher names: list[str] - regexes: list[str] + nameRegexes: list[str] class MappingRule(TypedDict): - name: NotRequired[str] - users: UsersFilter - repos: ReposFilter + name: str + users: UserSelector + repos: RepositorySelector class ConfigFile(TypedDict, total=False): diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index d882d95..7f0ab66 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -30,9 +30,9 @@ def load_discovery( dict[tuple[str, str], str], ]: """Fetch auth providers + external services and resolve the SAML attribute - names map, with consistent logging. Shared by --get and --set; returns the - raw lists so each caller can transform them as needed (YAML form for --get, - keyed-by-id dict for --set). + names map, with consistent logging. Shared by get and set; returns the + raw lists so each caller can transform them as needed (YAML form for get, + keyed-by-id dict for set). Both commands need exactly the same instance state to do their work, so centralizing this avoids drift in which providers/services are considered @@ -143,7 +143,6 @@ def load_mapping_context_for_rules( len(all_repos_by_id), len(services_by_id), ) - warn_unknown_external_services(mapping_rules, services_by_id) return permission_types.MappingContext( mapping_rules=mapping_rules, providers=providers, @@ -154,23 +153,6 @@ def load_mapping_context_for_rules( ) -def warn_unknown_external_services( - mapping_rules: list[permission_types.MappingRule], - services_by_id: dict[int, permission_types.ExternalService], -) -> None: - """Warn when maps reference code-host connection IDs absent on the instance.""" - for external_service_id in sorted( - permissions_mapping.referenced_external_service_ids(mapping_rules) - ): - if external_service_id not in services_by_id: - log.warning( - "External service id %s is referenced by the maps but " - "is not present on the instance — rules using it will " - "resolve to zero repos.", - external_service_id, - ) - - def snapshot_path( input_path: Path, timestamp: str, diff --git a/src/src_auth_perms_sync/shared/types.py b/src/src_auth_perms_sync/shared/types.py index 7ac9096..46f2032 100644 --- a/src/src_auth_perms_sync/shared/types.py +++ b/src/src_auth_perms_sync/shared/types.py @@ -54,6 +54,7 @@ class MutationCounts: succeeded: int = 0 failed: int = 0 canceled: int = 0 + skipped: int = 0 @dataclass(frozen=True, slots=True) diff --git a/tests/integration/test_cli_entrypoint.py b/tests/integration/test_cli_entrypoint.py index a936982..a0802e9 100644 --- a/tests/integration/test_cli_entrypoint.py +++ b/tests/integration/test_cli_entrypoint.py @@ -15,6 +15,10 @@ def test_module_help_prints_usage(self) -> None: ) self.assertIn("src-auth-perms-sync", completed_process.stdout) - self.assertIn("--set", completed_process.stdout) + self.assertIn("set:\n- Explicit repo permissions", completed_process.stdout) + self.assertIn("Organizations and memberships\n\nSee", completed_process.stdout) + self.assertIn("{get,set,restore,sync-saml-orgs}", completed_process.stdout) + self.assertIn("--maps-path", completed_process.stdout) + self.assertIn("--restore-path", completed_process.stdout) self.assertIn("--sync-saml-orgs", completed_process.stdout) self.assertEqual("", completed_process.stderr) diff --git a/tests/unit/test_apply.py b/tests/unit/test_apply.py new file mode 100644 index 0000000..59326a0 --- /dev/null +++ b/tests/unit/test_apply.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import unittest +from typing import Any, cast + +import src_py_lib as src + +from src_auth_perms_sync.permissions import apply +from src_auth_perms_sync.permissions import types as permission_types + + +class _FakeSourcegraphClient: + def __init__(self, exception: BaseException | None = None) -> None: + self.exception = exception + self.calls: list[tuple[str, dict[str, Any]]] = [] + + def graphql(self, query: str, variables: src.JSONDict) -> dict[str, Any]: + self.calls.append((query, dict(variables))) + if self.exception is not None: + raise self.exception + return {} + + +class ApplyTests(unittest.TestCase): + def test_repo_not_found_overwrite_is_skipped_not_failed(self) -> None: + client = _FakeSourcegraphClient( + src.GraphQLError( + "Sourcegraph GraphQL errors: [{'message': 'repo not found: id=264'}]", + is_application_error=True, + ) + ) + counts = apply.apply_username_overwrites( + cast(src.SourcegraphClient, client), + [ + permission_types.RepositoryUsernameOverwrite( + repository_id=src.encode_repository_id(264), + repository_name="test-repo-0241", + usernames=("alice",), + ) + ], + parallelism=1, + ) + + self.assertEqual(0, counts.succeeded) + self.assertEqual(1, counts.skipped) + self.assertEqual(0, counts.failed) + self.assertEqual(0, counts.canceled) + self.assertEqual(1, len(client.calls)) + + def test_user_not_found_addition_is_skipped_not_failed(self) -> None: + client = _FakeSourcegraphClient( + src.GraphQLError( + "Sourcegraph GraphQL errors: [{'message': 'user not found: id=123'}]", + is_application_error=True, + ) + ) + counts = apply.apply_additions( + cast(src.SourcegraphClient, client), + [ + apply.PermissionAddition( + user_id="VXNlcjoxMjM=", + username="deleted-user", + repo_id=src.encode_repository_id(264), + repo_name="test-repo-0241", + ) + ], + parallelism=1, + ) + + self.assertEqual(0, counts.succeeded) + self.assertEqual(1, counts.skipped) + self.assertEqual(0, counts.failed) + self.assertEqual(0, counts.canceled) + self.assertEqual(1, len(client.calls)) + + def test_non_missing_graphql_error_is_failed(self) -> None: + client = _FakeSourcegraphClient( + src.GraphQLError( + "Sourcegraph GraphQL errors: [{'message': 'permission denied'}]", + is_application_error=True, + ) + ) + counts = apply.apply_username_overwrites( + cast(src.SourcegraphClient, client), + [ + permission_types.RepositoryUsernameOverwrite( + repository_id=src.encode_repository_id(264), + repository_name="test-repo-0241", + usernames=("alice",), + ) + ], + parallelism=1, + ) + + self.assertEqual(0, counts.succeeded) + self.assertEqual(0, counts.skipped) + self.assertEqual(1, counts.failed) + self.assertEqual(0, counts.canceled) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index 7c47f77..7f3ae0c 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -2,6 +2,7 @@ import contextlib import io +import os import tempfile import unittest from concurrent.futures import ThreadPoolExecutor @@ -12,21 +13,22 @@ import src_py_lib as src from src_py_lib.utils import config as shared_config +import src_auth_perms_sync from src_auth_perms_sync import cli from src_auth_perms_sync.shared import backups -def make_config(**updates: object) -> cli.SrcAuthPermissionsSyncConfig: - base_config = cli.SrcAuthPermissionsSyncConfig( +def make_config(**updates: object) -> cli.Config: + base_config = cli.Config( src_endpoint="https://sourcegraph.example.com", src_access_token="secret", ) return base_config.model_copy(update=updates) -def load_config_from_env(**env: str) -> cli.SrcAuthPermissionsSyncConfig: +def load_config_from_env(**env: str) -> cli.Config: return shared_config.load_config( - cli.SrcAuthPermissionsSyncConfig, + cli.Config, env_file=None, env={ "SRC_ENDPOINT": "https://sourcegraph.example.com", @@ -38,40 +40,75 @@ def load_config_from_env(**env: str) -> cli.SrcAuthPermissionsSyncConfig: class CliConfigTests(unittest.TestCase): - def test_resolve_command_defaults_to_get(self) -> None: - command = cli.resolve_command(make_config()) + def test_resolve_command_uses_explicit_command_name(self) -> None: + command = cli.resolve_command("get", make_config()) self.assertEqual("get", command.name) self.assertEqual("get", command.log_name) self.assertEqual("get", command.artifact_name) - - def test_resolve_command_prefers_explicit_commands(self) -> None: self.assertEqual( - "set", cli.resolve_command(make_config(set_path=Path("maps.yaml"), full=True)).name + "set", + cli.resolve_command("set", make_config(maps_path=Path("maps.yaml"), full=True)).name, ) self.assertEqual( - "restore", cli.resolve_command(make_config(restore_path=Path("snapshot.json"))).name + "restore", + cli.resolve_command("restore", make_config(restore_path=Path("snapshot.json"))).name, ) self.assertEqual( - "sync_saml_orgs", cli.resolve_command(make_config(sync_saml_organizations=True)).name + "sync_saml_orgs", + cli.resolve_command("sync_saml_orgs", make_config()).name, ) + def test_maps_path_does_not_select_set_command(self) -> None: + command = cli.resolve_command("get", make_config(maps_path=Path("custom-maps.yaml"))) + + self.assertEqual("get", command.name) + + def test_load_cli_returns_command_and_config_options(self) -> None: + with ( + tempfile.TemporaryDirectory() as directory, + mock.patch.dict( + os.environ, + { + "SRC_ENDPOINT": "https://sourcegraph.example.com", + "SRC_ACCESS_TOKEN": "secret", + }, + clear=True, + ), + ): + env_file = Path(directory) / ".env" + env_file.write_text("") + cli_input = cli.load_cli( + ["set", "--env-file", str(env_file), "--maps-path", "custom-maps.yaml"] + ) + + self.assertEqual("set", cli_input.command_name) + self.assertEqual(Path("custom-maps.yaml"), cli_input.config.maps_path) + + def test_restore_path_config_loads_without_selecting_a_command(self) -> None: + config = load_config_from_env(SRC_AUTH_PERMS_SYNC_RESTORE_PATH="snapshot.json") + + self.assertEqual(Path.cwd() / "snapshot.json", config.restore_path) + def test_set_command_options_match_each_incremental_mode(self) -> None: self.assertEqual( - "full", cli.set_command_options(make_config(set_path=Path("maps.yaml"))).mode + "full", + cli.set_command_options(make_config(maps_path=Path("maps.yaml"))).mode, ) self.assertEqual( ("user", "alice"), ( - cli.set_command_options(make_config(set_path=Path("maps.yaml"), user="alice")).mode, cli.set_command_options( - make_config(set_path=Path("maps.yaml"), user="alice") + make_config(maps_path=Path("maps.yaml"), user="alice") + ).mode, + cli.set_command_options( + make_config(maps_path=Path("maps.yaml"), user="alice") ).user_identifier, ), ) users_without_permissions = cli.set_command_options( make_config( - set_path=Path("maps.yaml"), + maps_path=Path("maps.yaml"), users_without_explicit_perms=True, created_after="2026-01-01", ) @@ -79,16 +116,17 @@ def test_set_command_options_match_each_incremental_mode(self) -> None: self.assertEqual("users_without_explicit_perms", users_without_permissions.mode) self.assertEqual("2026-01-01", users_without_permissions.user_created_after) filtered_full = cli.set_command_options( - make_config(set_path=Path("maps.yaml"), created_after="2026-01-01") + make_config(maps_path=Path("maps.yaml"), created_after="2026-01-01") ) self.assertEqual("full", filtered_full.mode) self.assertEqual("2026-01-01", filtered_full.user_created_after) def test_resolve_command_includes_set_mode_names(self) -> None: user_command = cli.resolve_command( - make_config(set_path=Path("maps.yaml"), user="alice", apply=True) + "set", + make_config(maps_path=Path("maps.yaml"), user="alice", apply=True), ) - full_command = cli.resolve_command(make_config(set_path=Path("maps.yaml"))) + full_command = cli.resolve_command("set", make_config(maps_path=Path("maps.yaml"))) self.assertEqual("set_user", user_command.log_name) self.assertEqual("set-add-user-apply", user_command.artifact_name) @@ -97,9 +135,14 @@ def test_resolve_command_includes_set_mode_names(self) -> None: self.assertEqual("set-dry-run", full_command.artifact_name) def test_resolve_command_includes_combined_sync_names(self) -> None: - get_command = cli.resolve_command(make_config(get=True, sync_saml_organizations=True)) + get_command = cli.resolve_command("get", make_config(sync_saml_organizations=True)) set_command = cli.resolve_command( - make_config(set_path=Path("maps.yaml"), apply=True, sync_saml_organizations=True) + "set", + make_config( + maps_path=Path("maps.yaml"), + apply=True, + sync_saml_organizations=True, + ), ) self.assertEqual("get", get_command.name) @@ -111,56 +154,72 @@ def test_resolve_command_includes_combined_sync_names(self) -> None: self.assertEqual("set-sync-saml-orgs-apply", set_command.artifact_name) self.assertTrue(set_command.sync_saml_organizations) - def test_validate_config_rejects_multiple_commands(self) -> None: + def test_validate_config_allows_sync_saml_orgs_with_get_or_set(self) -> None: + cli.validate_config("get", make_config(sync_saml_organizations=True)) + cli.validate_config( + "set", + make_config(maps_path=Path("maps.yaml"), sync_saml_organizations=True), + ) + + def test_validate_config_rejects_sync_saml_orgs_with_restore_or_sync_command(self) -> None: self.assert_config_error( - make_config(get=True, set_path=Path("maps.yaml"), full=True), - "choose only one", + "restore", + make_config(restore_path=Path("snapshot.json"), sync_saml_organizations=True), + "can only be combined with get or set", + ) + self.assert_config_error( + "sync_saml_orgs", + make_config(sync_saml_organizations=True), + "can only be combined with get or set", ) - def test_validate_config_allows_sync_saml_orgs_with_get_or_set(self) -> None: - cli.validate_config(make_config(get=True, sync_saml_organizations=True)) - cli.validate_config(make_config(set_path=Path("maps.yaml"), sync_saml_organizations=True)) + def test_validate_config_rejects_restore_without_restore_path(self) -> None: + self.assert_config_error("restore", make_config(), "restore requires --restore-path") - def test_validate_config_rejects_sync_saml_orgs_with_restore(self) -> None: + def test_validate_config_rejects_restore_path_without_restore(self) -> None: self.assert_config_error( - make_config(restore_path=Path("snapshot.json"), sync_saml_organizations=True), - "with --get or --set", + "get", + make_config(restore_path=Path("snapshot.json")), + "--restore-path requires the restore command", ) def test_validate_config_rejects_set_modes_without_set(self) -> None: - self.assert_config_error(make_config(full=True), "requires --set") + self.assert_config_error("get", make_config(full=True), "requires the set command") def test_validate_config_allows_get_user_filters_without_set(self) -> None: - cli.validate_config(make_config(user="alice")) - cli.validate_config(make_config(users_without_explicit_perms=True)) - cli.validate_config(make_config(created_after="2026-01-01")) + cli.validate_config("get", make_config(user="alice")) + cli.validate_config("get", make_config(users_without_explicit_perms=True)) + cli.validate_config("get", make_config(created_after="2026-01-01")) def test_validate_config_rejects_get_user_filter_conflicts(self) -> None: self.assert_config_error( + "get", make_config(user="alice", users_without_explicit_perms=True), "choose only one of --user or --users-without-explicit-perms", ) def test_validate_config_rejects_user_filters_on_non_get_set_commands(self) -> None: self.assert_config_error( + "restore", make_config(restore_path=Path("snapshot.json"), user="alice"), - "require --get or --set", + "require get or set", ) def test_validate_config_allows_set_without_explicit_mode(self) -> None: - cli.validate_config(make_config(set_path=Path("maps.yaml"))) + cli.validate_config("set", make_config(maps_path=Path("maps.yaml"))) def test_created_after_config_accepts_yyyy_mm_dd_date_arguments(self) -> None: config = load_config_from_env(SRC_AUTH_PERMS_SYNC_CREATED_AFTER="2026-01-01") self.assertEqual("2026-01-01", config.created_after) - cli.validate_config(make_config(get=True, created_after="2026-01-01")) + cli.validate_config("get", make_config(created_after="2026-01-01")) cli.validate_config( + "set", make_config( - set_path=Path("maps.yaml"), + maps_path=Path("maps.yaml"), user="alice", created_after="2026-01-01", - ) + ), ) def test_created_after_config_rejects_values_outside_yyyy_mm_dd_shape(self) -> None: @@ -196,11 +255,11 @@ def test_trace_config_is_loaded_from_env(self) -> None: def test_run_with_client_enables_sourcegraph_trace_collection(self) -> None: configuration = make_config(trace=True) - command = cli.resolve_command(configuration) + command = cli.resolve_command("get", configuration) captured_clients: list[src.SourcegraphClient] = [] def capture_client( - _config: cli.SrcAuthPermissionsSyncConfig, + _config: cli.Config, _command: cli.ResolvedCommand, client: src.SourcegraphClient, _worker_pool: ThreadPoolExecutor, @@ -223,11 +282,11 @@ def capture_client( def test_run_with_client_uses_configured_http_timeout(self) -> None: configuration = make_config(http_timeout_seconds=75.0) - command = cli.resolve_command(configuration) + command = cli.resolve_command("get", configuration) captured_clients: list[src.SourcegraphClient] = [] def capture_client( - _config: cli.SrcAuthPermissionsSyncConfig, + _config: cli.Config, _command: cli.ResolvedCommand, client: src.SourcegraphClient, _worker_pool: ThreadPoolExecutor, @@ -250,7 +309,8 @@ def capture_client( def test_validate_config_rejects_multiple_set_modes(self) -> None: self.assert_config_error( - make_config(set_path=Path("maps.yaml"), full=True, user="alice"), + "set", + make_config(maps_path=Path("maps.yaml"), full=True, user="alice"), "choose at most one", ) @@ -258,24 +318,35 @@ def test_require_set_input_file_reports_missing_maps_file(self) -> None: with tempfile.TemporaryDirectory() as directory: existing_path = Path(directory) / "maps.yaml" existing_path.write_text("maps: []\n") - cli.require_set_input_file(make_config(set_path=existing_path)) + cli.require_set_input_file(make_config(maps_path=existing_path)) with self.assertRaises(SystemExit) as exit_context: - cli.require_set_input_file(make_config(set_path=Path(directory) / "missing.yaml")) - self.assertIn("--set input file does not exist", str(exit_context.exception)) + cli.require_set_input_file(make_config(maps_path=Path(directory) / "missing.yaml")) + self.assertIn("set input file does not exist", str(exit_context.exception)) def test_endpoint_scoped_config_rewrites_relative_artifact_paths(self) -> None: - scoped_config = cli.endpoint_scoped_config( - make_config(set_path=Path("maps.yaml"), restore_path=Path("snapshot.json")), + scoped_set_config = cli.endpoint_scoped_config( + "set", + make_config(maps_path=Path("maps.yaml")), "https://sourcegraph.example.com", ) endpoint_directory = Path.cwd() / backups.ARTIFACTS_DIR_NAME / "sourcegraph.example.com" - self.assertEqual(endpoint_directory / "maps.yaml", scoped_config.set_path) - self.assertEqual(endpoint_directory / "snapshot.json", scoped_config.restore_path) + self.assertEqual(endpoint_directory / "maps.yaml", scoped_set_config.maps_path) + + scoped_restore_config = cli.endpoint_scoped_config( + "restore", + make_config(restore_path=Path("snapshot.json")), + "https://sourcegraph.example.com", + ) + self.assertEqual(endpoint_directory / "snapshot.json", scoped_restore_config.restore_path) def test_run_fields_include_concrete_command(self) -> None: - configuration = make_config(set_path=Path("maps.yaml"), user="alice", apply=True) - command = cli.resolve_command(configuration) + configuration = make_config( + maps_path=Path("maps.yaml"), + user="alice", + apply=True, + ) + command = cli.resolve_command("set", configuration) fields = cli.run_fields(configuration, command, "https://sourcegraph.example.com") @@ -288,8 +359,8 @@ def test_run_fields_include_concrete_command(self) -> None: self.assertEqual(60.0, fields["http_timeout_seconds"]) def test_run_command_passes_primary_data_to_combined_sync(self) -> None: - configuration = make_config(get=True, sync_saml_organizations=True) - command = cli.resolve_command(configuration) + configuration = make_config(sync_saml_organizations=True) + command = cli.resolve_command("get", configuration) client = cast(src.SourcegraphClient, object()) sourcegraph_site_config = object() command_data = cli.run_context.CommandData() @@ -320,14 +391,71 @@ def test_run_command_passes_primary_data_to_combined_sync(self) -> None: worker_pool, ) + def test_package_exports_programmatic_runner_and_config(self) -> None: + self.assertIs(src_auth_perms_sync.Config, cli.Config) + self.assertIs(src_auth_perms_sync.Get, cli.Get) + self.assertIs(src_auth_perms_sync.Set, cli.Set) + self.assertIs(src_auth_perms_sync.Restore, cli.Restore) + self.assertIs(src_auth_perms_sync.SyncSamlOrgs, cli.SyncSamlOrgs) + self.assertEqual( + ["Config", "Get", "Restore", "Set", "SyncSamlOrgs"], + src_auth_perms_sync.__all__, + ) + + def test_programmatic_runner_uses_supplied_config(self) -> None: + configuration = make_config(parallelism=1, sample_interval=0) + captured: list[tuple[cli.Config, cli.ResolvedCommand, str]] = [] + + def capture_run( + scoped_config: cli.Config, + command: cli.ResolvedCommand, + endpoint: str, + _worker_pool: ThreadPoolExecutor, + ) -> None: + captured.append((scoped_config, command, endpoint)) + + with ( + mock.patch.object(cli, "run_with_client", side_effect=capture_run), + mock.patch.object( + cli.src, + "logging_settings_from_config", + return_value=object(), + ), + mock.patch.object(cli.src, "logging", return_value=contextlib.nullcontext(None)), + ): + self.assertTrue(src_auth_perms_sync.Get(configuration)) + + self.assertEqual(1, len(captured)) + scoped_config, command, endpoint = captured[0] + self.assertIs(configuration, scoped_config) + self.assertEqual("get", command.name) + self.assertEqual("https://sourcegraph.example.com", endpoint) + + def test_programmatic_runner_returns_false_on_failure(self) -> None: + configuration = make_config(parallelism=1, sample_interval=0) + + with ( + mock.patch.object(cli, "run_with_client", side_effect=SystemExit(1)), + mock.patch.object( + cli.src, + "logging_settings_from_config", + return_value=object(), + ), + mock.patch.object(cli.src, "logging", return_value=contextlib.nullcontext(None)), + ): + self.assertFalse(src_auth_perms_sync.Get(configuration)) + def assert_config_error( - self, config: cli.SrcAuthPermissionsSyncConfig, expected_message: str + self, + command_name: cli.CommandName, + config: cli.Config, + expected_message: str, ) -> None: captured_stderr = io.StringIO() with ( contextlib.redirect_stderr(captured_stderr), self.assertRaises(SystemExit) as exit_context, ): - cli.validate_config(config) + cli.validate_config(command_name, config) self.assertEqual(2, exit_context.exception.code) self.assertIn(expected_message, captured_stderr.getvalue()) diff --git a/tests/unit/test_maps.py b/tests/unit/test_maps.py index 3dcbe77..d0d74f3 100644 --- a/tests/unit/test_maps.py +++ b/tests/unit/test_maps.py @@ -2,6 +2,7 @@ import base64 import itertools +import json import tempfile import unittest from pathlib import Path @@ -80,6 +81,33 @@ def test_count_users_per_provider_counts_each_user_once_per_provider(self) -> No self.assertEqual(1, counts[("saml", "https://idp.example.com", "sourcegraph")]) self.assertEqual(1, counts[("github", "https://github.com/", "github-client")]) + def test_external_service_to_yaml_lifts_username_without_config(self) -> None: + service: permission_types.ExternalService = { + "id": "RXh0ZXJuYWxTZXJ2aWNlOjE=", + "kind": "BITBUCKETSERVER", + "displayName": "Bitbucket LOB1", + "url": "https://bitbucket.example.com/", + "repoCount": 0, + "createdAt": "2026-05-30T00:00:00Z", + "updatedAt": "2026-05-30T00:00:00Z", + "lastSyncAt": None, + "nextSyncAt": None, + "lastSyncError": None, + "warning": None, + "unrestricted": False, + "suspended": False, + "hasConnectionCheck": False, + "supportsRepoExclusion": False, + "creator": None, + "lastUpdater": None, + "config": json.dumps({"username": "LOB1-SA1", "token": "REDACTED"}), + } + + rendered = maps.external_service_to_yaml(service) + + self.assertEqual("LOB1-SA1", rendered["username"]) + self.assertNotIn("config", rendered) + class MappingTests(unittest.TestCase): def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: @@ -87,6 +115,7 @@ def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: list[permission_types.MappingRule], [ { + "name": "username only", "users": {"usernames": ["alice"]}, "repos": {"names": ["github.com/example/private-repo"]}, } @@ -96,6 +125,7 @@ def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: list[permission_types.MappingRule], [ { + "name": "email only", "users": {"emails": ["alice@example.com"]}, "repos": {"names": ["github.com/example/private-repo"]}, } @@ -122,23 +152,30 @@ def test_user_filter_matchers_intersect_without_expanding_selection(self) -> Non self.make_user("user-3", "carol", True, "carol@example.com", False), self.make_user("user-4", "dana", False, "dana@example.com", True), ] - user_filters: dict[str, object] = { + user_fields: dict[str, object] = { "authProvider": {"type": "builtin"}, "emails": ["alice@example.com", "carol@example.com", "dana@example.com"], + "emailRegexes": [r"^(alice|bob|carol)@example\.com$"], "usernames": ["alice", "bob", "carol"], + "usernameRegexes": [r"^(alice|dana)$"], } single_filter_usernames = { name: self.usernames_for( - mapping.resolve_users({name: matcher}, users, providers), + mapping.resolve_users( + cast(permission_types.UserSelector, {name: matcher}), users, providers + ), ) - for name, matcher in user_filters.items() + for name, matcher in user_fields.items() } - for filter_count in range(2, len(user_filters) + 1): - for filter_names in itertools.combinations(user_filters, filter_count): + for filter_count in range(2, len(user_fields) + 1): + for filter_names in itertools.combinations(user_fields, filter_count): matched_usernames = self.usernames_for( mapping.resolve_users( - {name: user_filters[name] for name in filter_names}, + cast( + permission_types.UserSelector, + {name: user_fields[name] for name in filter_names}, + ), users, providers, ) @@ -151,7 +188,11 @@ def test_user_filter_matchers_intersect_without_expanding_selection(self) -> Non self.assertEqual( {"alice"}, - self.usernames_for(mapping.resolve_users(user_filters, users, providers)), + self.usernames_for( + mapping.resolve_users( + cast(permission_types.UserSelector, user_fields), users, providers + ) + ), ) def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> None: @@ -166,41 +207,41 @@ def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> Non example_public_repo["id"]: example_public_repo, } services_by_id = { - 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise"), - 2: self.make_external_service(2, "GITHUB", "GitHub Cloud"), + 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise", "enterprise-sync"), + 2: self.make_external_service(2, "GITHUB", "GitHub Cloud", "cloud-sync"), } repos_by_external_service_id = { 1: [sourcegraph_repo, example_private_repo, gitlab_repo], 2: [example_public_repo], } - repo_filters: dict[str, object] = { - "codeHostConnection": {"id": 1}, + repository_fields: dict[str, object] = { + "codeHostConnection": {"username": "enterprise-sync"}, "names": [ "github.com/example/private-repo", "gitlab.com/example/private-repo", ], - "regexes": [ - r"^github\.com/example/", - r"^gitlab\.com/example/", - ], + "nameRegexes": [r"^github\.com/example/"], } single_filter_repo_names = { name: self.repo_names_for( mapping.resolve_repos( - {name: matcher}, + cast(permission_types.RepositorySelector, {name: matcher}), services_by_id, repos_by_external_service_id, all_repos, ) ) - for name, matcher in repo_filters.items() + for name, matcher in repository_fields.items() } - for filter_count in range(2, len(repo_filters) + 1): - for filter_names in itertools.combinations(repo_filters, filter_count): + for filter_count in range(2, len(repository_fields) + 1): + for filter_names in itertools.combinations(repository_fields, filter_count): matched_repo_names = self.repo_names_for( mapping.resolve_repos( - {name: repo_filters[name] for name in filter_names}, + cast( + permission_types.RepositorySelector, + {name: repository_fields[name] for name in filter_names}, + ), services_by_id, repos_by_external_service_id, all_repos, @@ -213,10 +254,10 @@ def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> Non self.assertLessEqual(matched_repo_names, single_filter_repo_names[name]) self.assertEqual( - {"github.com/example/private-repo", "gitlab.com/example/private-repo"}, + {"github.com/example/private-repo"}, self.repo_names_for( mapping.resolve_repos( - repo_filters, + cast(permission_types.RepositorySelector, repository_fields), services_by_id, repos_by_external_service_id, all_repos, @@ -224,26 +265,29 @@ def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> Non ), ) - def test_validate_mapping_rules_accepts_string_list_filters(self) -> None: + def test_validate_mapping_rules_accepts_flat_text_selector_lists(self) -> None: mapping.validate_mapping_rules( cast( list[permission_types.MappingRule], [ { + "name": "flat selector lists", "users": { "emails": ["alice@example.com"], + "emailRegexes": [r"^team-.*@example\.com$"], "usernames": ["alice"], + "usernameRegexes": [r"^team-.*"], }, "repos": { "names": ["github.com/example/private-repo"], - "regexes": [r"^github\.com/example/"], + "nameRegexes": [r"^github\.com/example/"], }, } ], ) ) - def test_repos_regexes_match_any_pattern(self) -> None: + def test_repository_name_matches_any_pattern(self) -> None: sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") github_repo = self.make_repo("repo-2", "github.com/example/private-repo") gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") @@ -255,7 +299,7 @@ def test_repos_regexes_match_any_pattern(self) -> None: matched_repos = mapping.resolve_repos( { - "regexes": [ + "nameRegexes": [ r"^github\.com/example/", r"^gitlab\.com/example/", ], @@ -270,25 +314,56 @@ def test_repos_regexes_match_any_pattern(self) -> None: self.repo_names_for(matched_repos), ) - def test_validate_mapping_rules_rejects_non_string_list_filters(self) -> None: + def test_username_matches_any_pattern(self) -> None: + providers: list[shared_types.AuthProvider] = [] + users = [ + self.make_user("user-1", "alice", True, "alice@example.com", True), + self.make_user("user-2", "test_user_00001", True, "one@example.com", True), + self.make_user("user-3", "test_user_00100", True, "hundred@example.com", True), + self.make_user("user-4", "service-account", True, "service@example.com", True), + ] + + matched_users = mapping.resolve_users( + {"usernameRegexes": [r"^(alice|test_user_00[0-9]{3})$"]}, + users, + providers, + ) + + self.assertEqual( + {"alice", "test_user_00001", "test_user_00100"}, + self.usernames_for(matched_users), + ) + + def test_validate_mapping_rules_rejects_invalid_text_matchers(self) -> None: with self.assertRaises(SystemExit) as raised: mapping.validate_mapping_rules( cast( list[permission_types.MappingRule], [ { + "name": "invalid flat selector lists", "users": { "emails": "alice@example.com", "usernames": [""], }, + "repos": {"names": [123], "nameRegexes": ["["]}, + }, + { + "name": "invalid code host field", + "users": {"usernames": ["alice"]}, "repos": { - "names": [123], - "regexes": ["["], + "codeHostConnection": {"config": {"username": "old"}, "id": 1}, + "regex": r"^github\.com/example/", }, }, + { + "name": "invalid username regex", + "users": {"usernameRegexes": ["["]}, + "repos": {"names": ["github.com/example/private-repo"]}, + }, { "users": {"usernames": ["alice"]}, - "repos": {"regex": r"^github\.com/example/"}, + "repos": {"names": ["github.com/example/private-repo"]}, }, ], ) @@ -298,8 +373,12 @@ def test_validate_mapping_rules_rejects_non_string_list_filters(self) -> None: self.assertIn("users.emails must be a list of strings", message) self.assertIn("users.usernames[0] is an empty string", message) self.assertIn("repos.names[0] must be a string", message) - self.assertIn("repos.regexes[0] is not a valid Python regex", message) - self.assertIn("unknown repos matcher 'regex'", message) + self.assertIn("repos.nameRegexes[0] is not a valid Python regex", message) + self.assertIn("users.usernameRegexes[0] is not a valid Python regex", message) + self.assertIn("unknown repos field 'regex'", message) + self.assertIn("unknown repos.codeHostConnection field 'config'", message) + self.assertIn("unknown repos.codeHostConnection field 'id'", message) + self.assertIn("`name:` is missing", message) def make_user( self, @@ -325,6 +404,7 @@ def make_external_service( external_service_id: int, kind: str, display_name: str, + username: str | None = None, ) -> permission_types.ExternalService: graphql_id = base64.b64encode(f"ExternalService:{external_service_id}".encode()).decode() return { @@ -345,7 +425,7 @@ def make_external_service( "supportsRepoExclusion": False, "creator": None, "lastUpdater": None, - "config": "{}", + "config": json.dumps({"username": username} if username else {}), } def usernames_for(self, users: list[shared_types.User]) -> set[str]: @@ -373,6 +453,7 @@ def test_full_set_plan_reuses_user_tuple_for_non_overlapping_repos(self) -> None context = self.make_context( [ { + "name": "alice and bob get example repos", "users": {"usernames": ["alice", "bob"]}, "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, } @@ -401,10 +482,12 @@ def test_full_set_plan_unions_only_overlapping_repos(self) -> None: context = self.make_context( [ { + "name": "alice and bob get first repos", "users": {"usernames": ["alice", "bob"]}, "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, }, { + "name": "bob and chris get second repos", "users": {"usernames": ["bob", "chris"]}, "repos": {"names": ["github.com/example/two", "github.com/example/three"]}, },