From 14b387132c11c3dbd2d5c487a9a92a7998053406 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:01:46 -0600 Subject: [PATCH 01/17] Tidy permission sync logs Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- dev/TODO.md | 4 - src/src_auth_perms_sync/cli.py | 26 ++-- .../permissions/command.py | 76 ++++++------ .../permissions/full_set.py | 10 +- .../permissions/restore.py | 12 +- .../permissions/snapshot.py | 115 ++++++++++-------- .../permissions/workflow.py | 6 +- src/src_auth_perms_sync/shared/site_config.py | 17 +-- tests/integration/test_cli_entrypoint.py | 2 +- tests/unit/test_cli_config.py | 102 ++++++++++++++-- tests/unit/test_restore.py | 12 +- tests/unit/test_snapshot.py | 37 +++--- 12 files changed, 265 insertions(+), 154 deletions(-) diff --git a/dev/TODO.md b/dev/TODO.md index 72d9cb7..f46d1f0 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -6,10 +6,6 @@ - Additive modes, to add new users’ perms quickly, without the extraneous load on the database of a full sync -- Take a list of usernames and/or email addresses as input, - query users on the instance for these, - then trigger a perms sync for found users -- Query the instance for all new users, which do not yet have explicit perms - Query the instance for all new repos, which do not yet have explicit perms ### Full: Overwrite all perms diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 5dc018c..96157c1 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -47,6 +47,7 @@ "users", "users_without_explicit_perms", "created_after", + "no_backup", "explicit_permissions_batch_size", *COMMON_CONFIG_FIELDS, ) @@ -336,8 +337,6 @@ def validate_command_options(command_name: CommandName, config: Config) -> None: """Validate options that only make sense with specific commands.""" if command_name == "get" and config.apply: config_error("--apply cannot be used with the read-only get command") - if command_name == "get" and config.no_backup: - config_error("--no-backup cannot be used with the read-only get command") if config.sync_saml_organizations and command_name != "set": config_error("--sync-saml-orgs can only be combined with set") if command_name == "restore" and config.restore_path is None: @@ -507,12 +506,7 @@ def require_set_input_file(maps_path: Path) -> None: def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[str, object]: """Return run-level fields for structured logging.""" - return { - "cli_cmd": command.log_name, - "base_cmd": command.name, - "set_mode": command.set_mode, - "sync_saml_orgs_flag": command.sync_saml_organizations, - "apply_flag": config.apply, + fields: dict[str, object] = { "endpoint": endpoint, "parallelism": config.parallelism, "explicit_permissions_batch_size": config.explicit_permissions_batch_size, @@ -520,13 +514,22 @@ def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[ "open_telemetry": config.open_telemetry, "max_attempts": config.max_attempts, "http_timeout_seconds": config.http_timeout_seconds, - "no_backup": config.no_backup, "sample_interval": config.sample_interval, - "user_created_after": config.created_after, "artifacts_dir": str(backups.endpoint_artifacts_directory(endpoint)), "python_version": sys.version.split()[0], "pid": os.getpid(), } + if command.name != "get": + fields["apply"] = config.apply + if config.no_backup: + fields["no_backup"] = True + if command.set_mode is not None: + fields["set_mode"] = command.set_mode + if command.sync_saml_organizations: + fields["sync_saml_orgs"] = True + if config.created_after is not None: + fields["created_after"] = config.created_after + return fields def run_with_client( @@ -690,6 +693,7 @@ def run_get( sourcegraph_site_config.saml_groups_attribute_name_by_config_id ), auth_providers_by_config_id=sourcegraph_site_config.auth_providers_by_config_id, + do_backup=not config.no_backup, retain_saml_group_users=False, worker_pool=worker_pool, ) @@ -767,7 +771,7 @@ def _run_or_raise(command_name: CommandName, config: Config) -> None: backups.run_artifacts_context(run_directory, run_timestamp), src.logging( config, - command=command.log_name, + command=command.name, git_cwd=__file__, logging_config=logging_settings, run_fields=run_fields(config, command, endpoint), diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 1d4e37b..7c1863c 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -93,6 +93,7 @@ def cmd_get( bind_id_mode: str, saml_groups_attribute_name_by_config_id: dict[str, str], auth_providers_by_config_id: dict[str, dict[str, Any]], + do_backup: bool, retain_saml_group_users: bool = False, worker_pool: ThreadPoolExecutor | None = None, ) -> run_context.CommandData: @@ -115,16 +116,17 @@ def cmd_get( `auth-providers.yaml` alongside the GraphQL-discovered fields. Providers without an explicit `configID` get only the GraphQL-derived view. """ - with src.span( - "cmd_get", - code_hosts_path=str(code_hosts_path), - auth_providers_path=str(auth_providers_path), - maps_path=str(maps_path), - user_identifiers=user_identifiers, - users_without_explicit_perms=users_without_explicit_perms, - user_created_after=user_created_after, - parallelism=parallelism, - ) as cmd_event: + cmd_fields: dict[str, Any] = {} + if user_identifiers: + cmd_fields["user_identifiers"] = user_identifiers + if users_without_explicit_perms: + cmd_fields["users_without_explicit_perms"] = True + if user_created_after is not None: + cmd_fields["created_after"] = user_created_after + if not do_backup: + cmd_fields["backup"] = False + + with src.span("cmd_get", **cmd_fields) as cmd_event: raw_providers, raw_services, attribute_names_by_provider = load_discovery( client, saml_groups_attribute_name_by_config_id ) @@ -147,7 +149,7 @@ def cmd_get( saml_group_counts = saml_groups.count_users_per_saml_group( users, attribute_names_by_provider ) - cmd_event["user_count"] = len(users) + cmd_event["selected_user_count"] = len(users) cmd_event["saml_providers_with_groups"] = len(saml_group_counts) providers = [ @@ -175,31 +177,37 @@ def cmd_get( permissions_maps.dump_code_hosts_yaml(code_hosts_path, services) permissions_maps.dump_auth_providers_yaml(auth_providers_path, providers) + cmd_event["code_hosts_path"] = str(code_hosts_path) + cmd_event["auth_providers_path"] = str(auth_providers_path) + cmd_event["maps_path"] = str(maps_path) log.info("Wrote %s and %s", code_hosts_path, auth_providers_path) - timestamp = backups.backup_timestamp() - before_snapshot = permission_snapshot.build_snapshot( - client, - users, - parallelism, - bind_id_mode, - maps_path, - total_users=len(users), - explicit_permissions_batch_size=explicit_permissions_batch_size, - worker_pool=worker_pool, - ) - before_path = snapshot_path(maps_path, timestamp, client.endpoint, "get", "before") - permission_snapshot.write_snapshot(before_path, before_snapshot) - cmd_event["beforesnapshot_path"] = str(before_path) - maps_backup_path = write_maps_backup(maps_path, timestamp, client.endpoint, "get") - if maps_backup_path is not None: - cmd_event["maps_backup_path"] = str(maps_backup_path) - log.info( - "Wrote before-snapshot: %s (%d repo(s) with explicit grants, %d total grant(s)).", - before_path, - before_snapshot["stats"]["repos_with_explicit_grants"], - before_snapshot["stats"]["total_grants"], - ) + if do_backup: + timestamp = backups.backup_timestamp() + before_snapshot = permission_snapshot.build_snapshot( + client, + users, + parallelism, + bind_id_mode, + maps_path, + expected_user_count=len(users), + explicit_permissions_batch_size=explicit_permissions_batch_size, + worker_pool=worker_pool, + ) + before_path = snapshot_path(maps_path, timestamp, client.endpoint, "get", "before") + permission_snapshot.write_snapshot(before_path, before_snapshot) + cmd_event["before_snapshot_path"] = str(before_path) + maps_backup_path = write_maps_backup(maps_path, timestamp, client.endpoint, "get") + if maps_backup_path is not None: + cmd_event["maps_backup_path"] = str(maps_backup_path) + log.info( + "Wrote before-snapshot: %s (%d repo(s) with explicit grants, %d total grant(s)).", + before_path, + before_snapshot["stats"]["repos_with_explicit_grants"], + before_snapshot["stats"]["total_grants"], + ) + else: + log.info("Skipping get before-snapshot and maps backup because --no-backup is set.") saml_group_users = ( saml_groups.compact_saml_group_users( users, diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index b027383..00857fb 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -101,11 +101,11 @@ def _capture_full_set_snapshot_state( include_user_emails: bool = False, ) -> _FullSetUserState: """Load users while capturing the before-snapshot.""" - total_users = shared_sourcegraph.count_users(client) + expected_user_count = shared_sourcegraph.count_users(client) users: list[shared_types.User] = [] log.info( "Streaming %d users from %s while capturing before-snapshot in parallel ...", - total_users, + expected_user_count, client.endpoint, ) before_timestamp = backups.backup_timestamp() @@ -119,7 +119,7 @@ def _capture_full_set_snapshot_state( parallelism, bind_id_mode, input_path, - total_users=total_users, + expected_user_count=expected_user_count, explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, ) @@ -426,7 +426,7 @@ def _filter_full_set_plans( pending_overwrites: list[permission_types.RepositoryUsernameOverwrite] = [] for overwrite in overwrites: current_repo = before_repos_map.get(overwrite.repository_id) - current_usernames = current_repo["explicit_permissions_users"] if current_repo else [] + current_usernames = current_repo["users"] if current_repo else [] expected_list = list(overwrite.usernames) if current_usernames == expected_list or sorted(current_usernames) == expected_list: skipped_repo_ids.add(overwrite.repository_id) @@ -557,7 +557,7 @@ def _finish_full_set_apply_with_backup( parallelism, bind_id_mode, input_path, - total_users=len(snapshot_state.users), + expected_user_count=len(snapshot_state.users), explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, ) diff --git a/src/src_auth_perms_sync/permissions/restore.py b/src/src_auth_perms_sync/permissions/restore.py index 26902c9..b671b00 100644 --- a/src/src_auth_perms_sync/permissions/restore.py +++ b/src/src_auth_perms_sync/permissions/restore.py @@ -544,11 +544,11 @@ def _capture_restore_snapshot_state( worker_pool: ThreadPoolExecutor | None = None, ) -> RestoreSnapshotState: """Capture the live full-instance state needed to plan a restore.""" - total_users = shared_sourcegraph.count_users(client) + expected_user_count = shared_sourcegraph.count_users(client) log.info( "Streaming %d users from %s and capturing current explicit-permissions " "state in parallel ...", - total_users, + expected_user_count, client.endpoint, ) users: list[shared_types.User] = [] @@ -558,7 +558,7 @@ def _capture_restore_snapshot_state( parallelism, bind_id_mode, snapshot_path, - total_users=total_users, + expected_user_count=expected_user_count, explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, ) @@ -583,9 +583,9 @@ def plan_full_restore(snapshot_state: RestoreSnapshotState) -> RestorePlan: overwrites: list[permission_types.RepositoryUsernameOverwrite] = [] skipped_repo_count = 0 for repo_id, repo_snapshot in target_repos.items(): - target_usernames = repo_snapshot["explicit_permissions_users"] + target_usernames = repo_snapshot["users"] current_repo = current_repos.get(repo_id) - current_usernames = current_repo["explicit_permissions_users"] if current_repo else [] + current_usernames = current_repo["users"] if current_repo else [] if current_usernames == target_usernames or sorted(current_usernames) == target_usernames: skipped_repo_count += 1 continue @@ -774,7 +774,7 @@ def _finish_restore_apply_with_backup( parallelism, bind_id_mode, snapshot_path, - total_users=len(snapshot_state.users), + expected_user_count=len(snapshot_state.users), explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, ) diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index 63cd2d7..3357413 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -25,7 +25,7 @@ class RepoSnapshot(TypedDict): name: str - explicit_permissions_users: list[str] + users: list[str] class SnapshotStats(TypedDict): @@ -152,7 +152,7 @@ class UserScopedSnapshotDiff(TypedDict): users: list[UserScopedSnapshotDiffEntry] -SNAPSHOT_SCHEMA_VERSION: int = 3 +SNAPSHOT_SCHEMA_VERSION: int = 4 USER_SCOPED_SNAPSHOT_KIND = "user_scope" SNAPSHOT_DIFF_SCHEMA_VERSION: int = 1 @@ -162,7 +162,7 @@ def capture_explicit_grants( users: Iterable[SnapshotUserInput], parallelism: int, explicit_permissions_batch_size: int, - total_users: int | None = None, + expected_user_count: int | None = None, worker_pool: ThreadPoolExecutor | None = None, ) -> tuple[dict[str, RepoSnapshot], int]: """Build the per-repo inverse index of explicit-API grants. @@ -178,17 +178,17 @@ def capture_explicit_grants( scale this overlaps the entire ListUsers pagination time with capture work, removing it from the critical path. - `total_users`, when supplied, enables percentage + ETA in the + `expected_user_count`, when supplied, enables percentage + ETA in the progress log lines. Callers that have already paid for `count_users()` (e.g. `cmd_set` / `cmd_restore` in their --apply branches) should pass it through; otherwise progress reports just show running counts and - rate. Reports fire at every ~10% of `total_users` (or every 1000 + rate. Reports fire at every ~10% of `expected_user_count` (or every 1000 completed when total is unknown). Sourcegraph only supports READ repository permissions, so snapshots store only the usernames that have explicit repository grants. - Returns `(repos, user_count)` so callers (e.g. `build_snapshot`) + Returns `(repos, scanned_user_count)` so callers (e.g. `build_snapshot`) that need the user-count statistic don't have to materialize the iterator twice or measure it themselves. """ @@ -209,7 +209,7 @@ def _fetch( "user_explicit_repos_batch_fetch", level="DEBUG", omit_success_status=True, - user_count=len(batch_users), + batch_user_count=len(batch_users), ) as fetch_event: try: repository_ids_by_user_id = permissions_sourcegraph.list_users_explicit_repo_ids( @@ -233,7 +233,8 @@ def _fetch( fetch_event["fetched_grant_count"] = sum( len(repository_ids) for repository_ids in repository_ids_by_username.values() ) - fetch_event["per_user_failures"] = failures + if failures: + fetch_event["user_permission_lookup_failures"] = failures return repository_ids_by_username, failures def _fetch_one_user_at_a_time( @@ -259,32 +260,33 @@ def _fetch_one_user_at_a_time( repository_ids_by_user_id[user["id"]] = [] return repository_ids_by_user_id, failures - with src.span( - "capture_explicit_grants", - total_users=total_users, - explicit_permissions_batch_size=explicit_permissions_batch_size, - ) as capture_event: + span_fields: dict[str, Any] = {} + if expected_user_count is not None: + span_fields["expected_user_count"] = expected_user_count + + with src.span("capture_explicit_grants", **span_fields) as capture_event: capture_failures = 0 futures: dict[Any, list[SnapshotUserInput]] = {} - submitted_user_count = 0 + scanned_user_count = 0 max_pending_batches = max(1, parallelism * 2) + src.debug("capture_explicit_grants_queue", max_pending_batches=max_pending_batches) def _submit_batch( executor: ThreadPoolExecutor, batch_users: list[SnapshotUserInput], ) -> None: - nonlocal submitted_user_count + nonlocal scanned_user_count if not batch_users: return submitted_batch = list(batch_users) - submitted_user_count += len(submitted_batch) + scanned_user_count += len(submitted_batch) future = src.submit_with_log_context(executor, _fetch, submitted_batch) futures[future] = submitted_batch # Progress reporting: every 10% when total is known (max 10 # lines), every 1000 otherwise. Avoids drowning the operator on # tiny instances and gives steady feedback on large ones. - progress_step = max(1, total_users // 10) if total_users else 1000 + progress_step = max(1, expected_user_count // 10) if expected_user_count else 1000 # Start the timer BEFORE submission. The submit-while-iterating # loop blocks on ListUsers pagination, but workers process # already-submitted tasks during those blocks — so by the time @@ -321,19 +323,19 @@ def _record_completed_futures(done_futures: Iterable[Any]) -> None: ) if completed >= next_progress_report or ( - all_users_submitted and completed == submitted_user_count + all_users_submitted and completed == scanned_user_count ): elapsed = time.perf_counter() - progress_started rate = completed / elapsed if elapsed > 0 else 0.0 - if total_users: - remaining = max(total_users - completed, 0) + if expected_user_count: + remaining = max(expected_user_count - completed, 0) eta_seconds = remaining / rate if rate > 0 else 0.0 log.info( "Captured explicit permissions for %d / %d users (%.0f%%) " "in %.0fs (%.0f users/sec, ETA %.0fs).", completed, - total_users, - 100.0 * completed / total_users, + expected_user_count, + 100.0 * completed / expected_user_count, elapsed, rate, eta_seconds, @@ -367,32 +369,34 @@ def _record_completed_futures(done_futures: Iterable[Any]) -> None: while futures: done_futures, _ = wait(futures, return_when=FIRST_COMPLETED) _record_completed_futures(done_futures) - capture_event["user_count"] = submitted_user_count - capture_event["per_user_failures"] = capture_failures - capture_event["max_pending_batches"] = max_pending_batches + capture_event["scanned_user_count"] = scanned_user_count + if capture_failures: + capture_event["user_permission_lookup_failures"] = capture_failures # Stable sort: users alphabetical within each repo. for usernames in usernames_by_repository_id.values(): usernames.sort() + repository_count = len(usernames_by_repository_id) with src.span( - "hydrate_explicit_repository_names", - repository_count=len(usernames_by_repository_id), + "hydrate_explicit_repository_names", repository_count=repository_count ) as hydrate_event: repositories_by_id = permissions_sourcegraph.list_repositories_by_ids( client, usernames_by_repository_id.keys(), ) - hydrate_event["hydrated_repository_count"] = len(repositories_by_id) + missing_repository_count = repository_count - len(repositories_by_id) + if missing_repository_count: + hydrate_event["missing_repository_count"] = missing_repository_count repos_out: dict[str, RepoSnapshot] = {} for repository_id, usernames in usernames_by_repository_id.items(): repos_out[repository_id] = { "name": _snapshot_repository_name(repositories_by_id, repository_id), - "explicit_permissions_users": usernames, + "users": usernames, } - return repos_out, submitted_user_count + return repos_out, scanned_user_count def _snapshot_repository_name( @@ -416,7 +420,7 @@ def build_snapshot( bind_id_mode: str, config_path: Path | None = None, *, - total_users: int | None = None, + expected_user_count: int | None = None, explicit_permissions_batch_size: int, worker_pool: ThreadPoolExecutor | None = None, ) -> Snapshot: @@ -427,16 +431,16 @@ def build_snapshot( batched work as the iterator yields, so ListUsers pagination overlaps with UserExplicitRepos work. - `total_users`, when known, drives percentage + ETA in the per-batch - progress log lines emitted by `capture_explicit_grants`. + `expected_user_count`, when known, drives percentage + ETA in the + per-batch progress log lines emitted by `capture_explicit_grants`. """ - with src.span("build_snapshot", bind_id_mode=bind_id_mode) as build_event: - repos, user_count = capture_explicit_grants( + with src.span("build_snapshot") as build_event: + repos, scanned_user_count = capture_explicit_grants( client, users, parallelism, explicit_permissions_batch_size, - total_users=total_users, + expected_user_count=expected_user_count, worker_pool=worker_pool, ) pending = permissions_sourcegraph.list_pending_bind_ids(client) @@ -448,14 +452,15 @@ def build_snapshot( distinct_users: set[str] = set() total_grants = 0 for repo in repos.values(): - for username in repo["explicit_permissions_users"]: + for username in repo["users"]: distinct_users.add(username) total_grants += 1 - build_event["user_count"] = user_count + build_event["scanned_user_count"] = scanned_user_count build_event["repos_with_explicit_grants"] = len(repos) build_event["users_with_explicit_grants"] = len(distinct_users) build_event["total_grants"] = total_grants - build_event["pending_bindIDs_count"] = len(pending) + if pending: + build_event["pending_bindIDs_count"] = len(pending) return { "schema_version": SNAPSHOT_SCHEMA_VERSION, @@ -466,7 +471,7 @@ def build_snapshot( "config_sha256": config_sha, "pending_bindIDs": sorted(pending), "stats": { - "total_users_scanned": user_count, + "total_users_scanned": scanned_user_count, "users_with_explicit_grants": len(distinct_users), "repos_with_explicit_grants": len(repos), "total_grants": total_grants, @@ -517,7 +522,7 @@ def _fetch(user: SnapshotUser) -> tuple[SnapshotUser, list[permission_types.Repo "id": fetched_user["id"], "explicit_repositories": sorted(repos, key=lambda repo: repo["name"]), } - capture_event["user_count"] = len(scoped_users) + capture_event["scanned_user_count"] = len(scoped_users) capture_event["total_grants"] = sum( len(user_snapshot["explicit_repositories"]) for user_snapshot in scoped_users.values() ) @@ -533,7 +538,7 @@ def build_user_scoped_snapshot( worker_pool: ThreadPoolExecutor | None = None, ) -> UserScopedSnapshot: """Capture a reversible snapshot for only the supplied users.""" - with src.span("build_user_scoped_snapshot", bind_id_mode=bind_id_mode) as build_event: + with src.span("build_user_scoped_snapshot") as build_event: scoped_users = capture_user_scoped_explicit_grants( client, users, @@ -550,7 +555,7 @@ def build_user_scoped_snapshot( users_with_explicit_grants = sum( 1 for user_snapshot in scoped_users.values() if user_snapshot["explicit_repositories"] ) - build_event["user_count"] = len(scoped_users) + build_event["scanned_user_count"] = len(scoped_users) build_event["users_with_explicit_grants"] = users_with_explicit_grants build_event["total_grants"] = total_grants @@ -612,8 +617,8 @@ def _write_repo_snapshot_value(output: TextIO, repo: RepoSnapshot, indent: int) output.write(f'{field_indent}"name": ') json.dump(repo["name"], output) output.write(",\n") - output.write(f'{field_indent}"explicit_permissions_users": ') - _write_string_list(output, repo["explicit_permissions_users"], indent + 2) + output.write(f'{field_indent}"users": ') + _write_string_list(output, repo["users"], indent + 2) output.write("\n" + " " * indent + "}") @@ -807,13 +812,25 @@ def _validate_snapshot_schema_version(path: Path, version: object) -> None: ) +def _encode_repo_snapshot_raw(path: Path, repo_id: str, raw_repo: dict[str, Any]) -> RepoSnapshot: + raw_usernames = raw_repo.get("users") + if not isinstance(raw_usernames, list): + raise SystemExit(f"{path}: repo {repo_id} is missing a users list.") + usernames = cast(list[object], raw_usernames) + return { + "name": cast(str, raw_repo["name"]), + "users": [str(username) for username in usernames], + } + + def _encode_full_snapshot_raw(path: Path, raw: dict[str, Any]) -> Snapshot: _validate_snapshot_schema_version(path, raw.get("schema_version")) if raw.get("snapshot_kind") == USER_SCOPED_SNAPSHOT_KIND: raise SystemExit(f"{path}: snapshot_kind is 'user_scope', expected full repo snapshot.") - on_disk_repos = cast(dict[str, RepoSnapshot], raw.get("repos", {})) + on_disk_repos = cast(dict[str, dict[str, Any]], raw.get("repos", {})) raw["repos"] = { - src.encode_repository_id(int(repo_id)): repo for repo_id, repo in on_disk_repos.items() + src.encode_repository_id(int(repo_id)): _encode_repo_snapshot_raw(path, repo_id, repo) + for repo_id, repo in on_disk_repos.items() } return cast(Snapshot, raw) @@ -891,7 +908,7 @@ def _sorted_usernames(values: Sequence[str]) -> Sequence[str]: def _repo_usernames(repo: RepoSnapshot | None) -> Sequence[str]: if repo is None: return () - return repo["explicit_permissions_users"] + return repo["users"] def _sorted_username_diff_counts( @@ -1485,7 +1502,7 @@ def _repositories_by_id( def _permission_count(repo_snapshot: RepoSnapshot | None) -> int: if repo_snapshot is None: return 0 - return len(repo_snapshot["explicit_permissions_users"]) + return len(repo_snapshot["users"]) def _snapshot_diff_side(snapshot: Snapshot | UserScopedSnapshot) -> SnapshotDiffSide: diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index 3e42cc6..6e0ea27 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -282,7 +282,7 @@ def projected_snapshot_repo_for_id( return None return { "name": repo_names[repo_id], - "explicit_permissions_users": list(usernames), + "users": list(usernames), } return before_snapshot["repos"].get(repo_id) @@ -316,7 +316,7 @@ def projected_snapshot_stats( if repo_id in expected_users: continue repo_count += 1 - usernames = repo["explicit_permissions_users"] + usernames = repo["users"] users_with_explicit_grants.update(usernames) total_grants += len(usernames) for usernames in expected_users.values(): @@ -450,7 +450,7 @@ def validate_post_apply( for repo_id in mutated_repo_ids: expected = list(expected_users.get(repo_id, ())) actual_repo = after["repos"].get(repo_id) - actual = actual_repo["explicit_permissions_users"] if actual_repo else [] + actual = actual_repo["users"] if actual_repo else [] if expected == actual: continue expected_set = set(expected) diff --git a/src/src_auth_perms_sync/shared/site_config.py b/src/src_auth_perms_sync/shared/site_config.py index 133d46a..560bb51 100644 --- a/src/src_auth_perms_sync/shared/site_config.py +++ b/src/src_auth_perms_sync/shared/site_config.py @@ -98,13 +98,6 @@ def validate_site_config(client: src.SourcegraphClient) -> SiteConfig: enabled = bool(user_mapping.get("enabled", False)) enable_username_changes = bool(contents.get("auth.enableUsernameChanges", False)) - log.info( - "Site config: permissions.userMapping.enabled=%s bindID=%s auth.enableUsernameChanges=%s", - enabled, - bind_id_enum, - enable_username_changes, - ) - safety_errors: list[str] = [] if not enabled: @@ -156,6 +149,16 @@ def validate_site_config(client: src.SourcegraphClient) -> SiteConfig: ) raise SystemExit("FATAL: " + "\n\n".join(message_sections)) + log.info( + "Site config validation passed: " + "permissions.userMapping.enabled=%s " + "bindID=%s " + "auth.enableUsernameChanges=%s", + enabled, + bind_id_enum, + enable_username_changes, + ) + return SiteConfig( bind_id_mode=bind_id_enum, auth_providers_by_config_id=_extract_auth_providers_by_config_id(contents), diff --git a/tests/integration/test_cli_entrypoint.py b/tests/integration/test_cli_entrypoint.py index efc968a..499a5fa 100644 --- a/tests/integration/test_cli_entrypoint.py +++ b/tests/integration/test_cli_entrypoint.py @@ -53,7 +53,7 @@ def test_command_help_prints_command_specific_options(self) -> None: ) self.assertNotIn("--apply", get_help.stdout) - self.assertNotIn("--no-backup", get_help.stdout) + self.assertIn("--no-backup", get_help.stdout) self.assertNotIn("--sync-saml-orgs", get_help.stdout) self.assertIn("--users USERS", get_help.stdout) self.assertNotIn("--user USER", get_help.stdout) diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index b6bafb7..53c099a 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -7,6 +7,7 @@ import unittest from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from types import SimpleNamespace from typing import cast from unittest import mock @@ -15,6 +16,7 @@ import src_auth_perms_sync from src_auth_perms_sync import cli +from src_auth_perms_sync.permissions import command as permissions_command from src_auth_perms_sync.shared import backups @@ -217,17 +219,15 @@ def test_validate_config_rejects_sync_saml_orgs_without_set(self) -> None: "can only be combined with set", ) - def test_validate_config_rejects_mutating_options_with_get(self) -> None: + def test_validate_config_rejects_apply_with_get(self) -> None: self.assert_config_error( "get", make_config(apply=True), "--apply cannot be used with the read-only get command", ) - self.assert_config_error( - "get", - make_config(no_backup=True), - "--no-backup cannot be used with the read-only get command", - ) + + def test_validate_config_allows_get_no_backup(self) -> None: + cli.validate_config("get", make_config(no_backup=True)) def test_validate_config_rejects_restore_without_restore_path(self) -> None: self.assert_config_error("restore", make_config(), "restore requires --restore-path") @@ -410,7 +410,7 @@ def test_config_with_default_paths_only_defaults_omitted_maps_path(self) -> None ) self.assertEqual(Path("snapshot.json"), restore_config.restore_path) - def test_run_fields_include_concrete_command(self) -> None: + def test_run_fields_include_command_arguments_without_command_duplicates(self) -> None: configuration = make_config( maps_path=Path("maps.yaml"), users=("alice",), @@ -420,15 +420,97 @@ def test_run_fields_include_concrete_command(self) -> None: fields = cli.run_fields(configuration, command, "https://sourcegraph.example.com") - self.assertEqual("set_users", fields["cli_cmd"]) - self.assertEqual("set", fields["base_cmd"]) self.assertEqual("users", fields["set_mode"]) - self.assertEqual(True, fields["apply_flag"]) + self.assertEqual(True, fields["apply"]) + self.assertNotIn("cli_cmd", fields) + self.assertNotIn("base_cmd", fields) self.assertEqual(25, fields["explicit_permissions_batch_size"]) self.assertEqual(False, fields["fetch_sg_traces"]) self.assertEqual(False, fields["open_telemetry"]) self.assertEqual(60.0, fields["http_timeout_seconds"]) + def test_run_fields_omit_irrelevant_false_flags(self) -> None: + configuration = make_config() + command = cli.resolve_command("get", configuration) + + fields = cli.run_fields(configuration, command, "https://sourcegraph.example.com") + + self.assertNotIn("apply", fields) + self.assertNotIn("no_backup", fields) + self.assertNotIn("set_mode", fields) + self.assertNotIn("sync_saml_orgs", fields) + self.assertNotIn("created_after", fields) + + def test_run_fields_include_no_backup_only_when_set(self) -> None: + configuration = make_config(no_backup=True) + command = cli.resolve_command("get", configuration) + + fields = cli.run_fields(configuration, command, "https://sourcegraph.example.com") + + self.assertEqual(True, fields["no_backup"]) + + def test_run_get_passes_no_backup_to_permission_command(self) -> None: + configuration = make_config(no_backup=True) + client = cast( + src.SourcegraphClient, + SimpleNamespace(endpoint="https://sourcegraph.example.com"), + ) + sourcegraph_site_config = cli.site_config.SiteConfig( + bind_id_mode="USERNAME", + auth_providers_by_config_id={}, + saml_groups_attribute_name_by_config_id={}, + ) + worker_pool = cast(ThreadPoolExecutor, object()) + + with ( + mock.patch.object( + cli.permissions_maps, "create_maps_yaml_if_missing", return_value=False + ), + mock.patch.object( + cli.permissions_command, + "cmd_get", + return_value=cli.run_context.CommandData(), + ) as cmd_get, + ): + cli.run_get(configuration, client, sourcegraph_site_config, worker_pool) + + self.assertFalse(cmd_get.call_args.kwargs["do_backup"]) + + def test_cmd_get_no_backup_skips_snapshot_artifacts(self) -> None: + client = cast( + src.SourcegraphClient, + SimpleNamespace(endpoint="https://sourcegraph.example.com"), + ) + + with ( + mock.patch.object(permissions_command, "load_discovery", return_value=([], [], {})), + mock.patch.object(permissions_command, "_load_get_users", return_value=[]), + mock.patch.object(permissions_command.permissions_maps, "dump_code_hosts_yaml"), + mock.patch.object(permissions_command.permissions_maps, "dump_auth_providers_yaml"), + mock.patch.object( + permissions_command.permission_snapshot, "build_snapshot" + ) as build_snapshot, + mock.patch.object(permissions_command, "write_maps_backup") as write_maps_backup, + ): + permissions_command.cmd_get( + client, + Path("code-hosts.yaml"), + Path("auth-providers.yaml"), + Path("maps.yaml"), + user_identifiers=(), + users_without_explicit_perms=False, + user_created_after=None, + parallelism=1, + explicit_permissions_batch_size=25, + bind_id_mode="USERNAME", + saml_groups_attribute_name_by_config_id={}, + auth_providers_by_config_id={}, + do_backup=False, + ) + + build_snapshot.assert_not_called() + write_maps_backup.assert_not_called() + def test_run_command_passes_set_data_to_combined_sync(self) -> None: configuration = make_config(sync_saml_organizations=True) command = cli.resolve_command("set", configuration) diff --git a/tests/unit/test_restore.py b/tests/unit/test_restore.py index ea4f70c..bb909e8 100644 --- a/tests/unit/test_restore.py +++ b/tests/unit/test_restore.py @@ -70,24 +70,20 @@ def test_plan_full_restore_skips_repos_that_already_match(self) -> None: def make_repo_snapshot( self, name: str, - explicit_permissions_users: list[str], + users: list[str], ) -> permission_snapshot.RepoSnapshot: return { "name": name, - "explicit_permissions_users": explicit_permissions_users, + "users": users, } def make_snapshot( self, repos: dict[str, permission_snapshot.RepoSnapshot], ) -> permission_snapshot.Snapshot: - total_grants = sum( - len(repo_snapshot["explicit_permissions_users"]) for repo_snapshot in repos.values() - ) + total_grants = sum(len(repo_snapshot["users"]) for repo_snapshot in repos.values()) users_with_explicit_grants = { - username - for repo_snapshot in repos.values() - for username in repo_snapshot["explicit_permissions_users"] + username for repo_snapshot in repos.values() for username in repo_snapshot["users"] } return { "schema_version": permission_snapshot.SNAPSHOT_SCHEMA_VERSION, diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index bcb72f9..f8e195e 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -66,25 +66,25 @@ def list_repositories_by_ids( side_effect=list_repositories_by_ids, ), ): - repos, user_count = permission_snapshot.capture_explicit_grants( + repos, scanned_user_count = permission_snapshot.capture_explicit_grants( cast(src.SourcegraphClient, object()), users, parallelism=1, explicit_permissions_batch_size=25, - total_users=len(users), + expected_user_count=len(users), ) - self.assertEqual(3, user_count) + self.assertEqual(3, scanned_user_count) self.assertEqual([repo_one_id, repo_two_id], hydrated_repository_ids) self.assertEqual( { repo_one_id: { "name": "github.com/sourcegraph/one", - "explicit_permissions_users": ["alice", "carol"], + "users": ["alice", "carol"], }, repo_two_id: { "name": "github.com/sourcegraph/two", - "explicit_permissions_users": ["carol"], + "users": ["carol"], }, }, repos, @@ -123,15 +123,15 @@ def list_repo_ids( ), patch.object(permission_snapshot, "wait", side_effect=recording_wait), ): - _, user_count = permission_snapshot.capture_explicit_grants( + _, scanned_user_count = permission_snapshot.capture_explicit_grants( cast(src.SourcegraphClient, object()), users, parallelism=2, explicit_permissions_batch_size=1, - total_users=len(users), + expected_user_count=len(users), ) - self.assertEqual(9, user_count) + self.assertEqual(9, scanned_user_count) self.assertTrue(pending_counts) self.assertLessEqual(max(pending_counts), 4) @@ -225,7 +225,7 @@ def graphql( self.assertEqual(repo_two["id"], calls[2][1]["repo1"]) self.assertEqual(repo_three["id"], calls[2][1]["repo2"]) - def test_write_snapshot_uses_username_list_for_explicit_permissions(self) -> None: + def test_write_snapshot_uses_short_users_key_for_explicit_permissions(self) -> None: snapshot = self.make_snapshot() with tempfile.TemporaryDirectory() as directory_name: @@ -233,18 +233,23 @@ def test_write_snapshot_uses_username_list_for_explicit_permissions(self) -> Non permission_snapshot.write_snapshot(snapshot_path, snapshot) on_disk = json.loads(snapshot_path.read_text()) + loaded_snapshot = permission_snapshot.read_snapshot(snapshot_path) self.assertEqual( ["alice", "bob"], - on_disk["repos"]["1"]["explicit_permissions_users"], + on_disk["repos"]["1"]["users"], ) - self.assertNotIn("explicit_user_permissions", on_disk["repos"]["1"]) + self.assertEqual( + ["alice", "bob"], + loaded_snapshot["repos"][src.encode_repository_id(1)]["users"], + ) + self.assertEqual({"name", "users"}, set(on_disk["repos"]["1"])) def test_snapshot_diff_omits_unchanged_users(self) -> None: before = self.make_snapshot() after = self.make_snapshot() repo_id = src.encode_repository_id(1) - after["repos"][repo_id]["explicit_permissions_users"] = ["alice", "carol"] + after["repos"][repo_id]["users"] = ["alice", "carol"] diff = permission_snapshot.build_snapshot_diff(before, after) @@ -292,11 +297,11 @@ def test_write_projected_snapshot_keeps_after_repos_out_of_memory(self) -> None: self.assertEqual({}, after["repos"]) self.assertEqual( ["alice", "carol"], - after_on_disk["repos"]["1"]["explicit_permissions_users"], + after_on_disk["repos"]["1"]["users"], ) self.assertEqual( ["dana"], - after_on_disk["repos"]["2"]["explicit_permissions_users"], + after_on_disk["repos"]["2"]["users"], ) self.assertEqual(2, diff_on_disk["summary"]["repos_changed"]) self.assertEqual(2, diff_on_disk["summary"]["grants_added"]) @@ -330,7 +335,7 @@ def test_read_snapshot_rejects_old_schema_versions(self) -> None: with self.assertRaises(SystemExit) as exit_context: permission_snapshot.read_snapshot(snapshot_path) - self.assertIn("expected 3", str(exit_context.exception)) + self.assertIn("expected 4", str(exit_context.exception)) def make_snapshot(self) -> permission_snapshot.Snapshot: return { @@ -350,7 +355,7 @@ def make_snapshot(self) -> permission_snapshot.Snapshot: "repos": { src.encode_repository_id(1): { "name": "github.com/sourcegraph/example", - "explicit_permissions_users": ["alice", "bob"], + "users": ["alice", "bob"], } }, } From ebf02cee8d027ac0862ec8cd7bbf1b654cae5f45 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:07:03 -0600 Subject: [PATCH 02/17] Move queue limit into startup config log Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- src/src_auth_perms_sync/cli.py | 9 ++++++++- src/src_auth_perms_sync/permissions/snapshot.py | 1 - tests/unit/test_cli_config.py | 8 ++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 96157c1..f82fcb1 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -532,6 +532,13 @@ def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[ return fields +def startup_config_fields(config: Config) -> dict[str, object]: + """Return the startup config snapshot plus derived runtime limits.""" + fields = src.config_snapshot(config) + fields["SRC_AUTH_PERMS_SYNC_MAX_PENDING_BATCHES"] = max(1, config.parallelism * 2) + return fields + + def run_with_client( config: Config, command: ResolvedCommand, @@ -770,7 +777,7 @@ def _run_or_raise(command_name: CommandName, config: Config) -> None: with ( backups.run_artifacts_context(run_directory, run_timestamp), src.logging( - config, + startup_config_fields(config), command=command.name, git_cwd=__file__, logging_config=logging_settings, diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index 3357413..86c8d94 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -269,7 +269,6 @@ def _fetch_one_user_at_a_time( futures: dict[Any, list[SnapshotUserInput]] = {} scanned_user_count = 0 max_pending_batches = max(1, parallelism * 2) - src.debug("capture_explicit_grants_queue", max_pending_batches=max_pending_batches) def _submit_batch( executor: ThreadPoolExecutor, diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index 53c099a..cd05d85 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -449,6 +449,14 @@ def test_run_fields_include_no_backup_only_when_set(self) -> None: self.assertEqual(True, fields["no_backup"]) + def test_startup_config_fields_include_derived_queue_limit(self) -> None: + configuration = make_config(parallelism=3) + + fields = cli.startup_config_fields(configuration) + + self.assertEqual(6, fields["SRC_AUTH_PERMS_SYNC_MAX_PENDING_BATCHES"]) + self.assertEqual("provided", fields["SRC_ACCESS_TOKEN"]) + def test_run_get_passes_no_backup_to_permission_command(self) -> None: configuration = make_config(no_backup=True) client = cast( From 520cb8c060ed9412d383ab12ff3da6f779ed59c3 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:07:31 -0600 Subject: [PATCH 03/17] Rename external services -> code hosts --- dev/memory-efficiency-generate.py | 2 +- src/src_auth_perms_sync/permissions/maps.py | 2 +- src/src_auth_perms_sync/permissions/workflow.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/memory-efficiency-generate.py b/dev/memory-efficiency-generate.py index bcbad6b..7dc6dbc 100755 --- a/dev/memory-efficiency-generate.py +++ b/dev/memory-efficiency-generate.py @@ -551,7 +551,7 @@ def list_external_services(client: src.SourcegraphClient) -> list[ExternalServic ) ) if not services: - raise SystemExit("No external services found on the Sourcegraph instance") + raise SystemExit("No code hosts found on the Sourcegraph instance") return services diff --git a/src/src_auth_perms_sync/permissions/maps.py b/src/src_auth_perms_sync/permissions/maps.py index 8f37bb0..f723d15 100644 --- a/src/src_auth_perms_sync/permissions/maps.py +++ b/src/src_auth_perms_sync/permissions/maps.py @@ -131,7 +131,7 @@ def count_users_per_provider( def external_service_to_yaml(service: permission_types.ExternalService) -> dict[str, Any]: - """Render an external service for the YAML config. + """Render a code host for the YAML config. Keys mirror the human-readable Sourcegraph GraphQL `ExternalService` fields that maps can match. The opaque GraphQL `id` is omitted; diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index 6e0ea27..20057dd 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -29,7 +29,7 @@ def load_discovery( list[permission_types.ExternalService], dict[tuple[str, str], str], ]: - """Fetch auth providers + external services and resolve the SAML attribute + """Fetch auth providers + code hosts and resolve the SAML attribute names map, with consistent logging. Shared by get and set; returns the raw lists so each caller can transform them as needed (YAML form for get, keyed-by-id dict for set). @@ -43,9 +43,9 @@ def load_discovery( providers = shared_sourcegraph.list_auth_providers(client) log.info("Received %d auth providers.", len(providers)) - log.info("Loading external services from %s ...", client.endpoint) + log.info("Loading code hosts from %s ...", client.endpoint) services = permissions_sourcegraph.list_external_services(client) - log.info("Received %d external services.", len(services)) + log.info("Received %d code hosts.", len(services)) saml_attribute_names = saml_groups.attribute_names_by_provider_key( providers, saml_groups_attribute_name_by_config_id From 2da98e9e4541bb2ad2ae6f837dd3118ebca96bd2 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:10:16 -0600 Subject: [PATCH 04/17] Stop logging derived queue limit Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- src/src_auth_perms_sync/cli.py | 9 +-------- tests/unit/test_cli_config.py | 8 -------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index f82fcb1..96157c1 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -532,13 +532,6 @@ def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[ return fields -def startup_config_fields(config: Config) -> dict[str, object]: - """Return the startup config snapshot plus derived runtime limits.""" - fields = src.config_snapshot(config) - fields["SRC_AUTH_PERMS_SYNC_MAX_PENDING_BATCHES"] = max(1, config.parallelism * 2) - return fields - - def run_with_client( config: Config, command: ResolvedCommand, @@ -777,7 +770,7 @@ def _run_or_raise(command_name: CommandName, config: Config) -> None: with ( backups.run_artifacts_context(run_directory, run_timestamp), src.logging( - startup_config_fields(config), + config, command=command.name, git_cwd=__file__, logging_config=logging_settings, diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index cd05d85..53c099a 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -449,14 +449,6 @@ def test_run_fields_include_no_backup_only_when_set(self) -> None: self.assertEqual(True, fields["no_backup"]) - def test_startup_config_fields_include_derived_queue_limit(self) -> None: - configuration = make_config(parallelism=3) - - fields = cli.startup_config_fields(configuration) - - self.assertEqual(6, fields["SRC_AUTH_PERMS_SYNC_MAX_PENDING_BATCHES"]) - self.assertEqual("provided", fields["SRC_ACCESS_TOKEN"]) - def test_run_get_passes_no_backup_to_permission_command(self) -> None: configuration = make_config(no_backup=True) client = cast( From ebabaef98c7d33823c921d7ba8a54b1c8fd3a27a Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:10:39 -0600 Subject: [PATCH 05/17] Simplify log lines --- src/src_auth_perms_sync/permissions/command.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 7c1863c..f6b83d6 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -258,7 +258,7 @@ def _load_get_users( parse_cli_date(user_created_after, "--created-after") ) candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) - log.info("Received %d non-deleted user candidate(s).", len(candidates)) + log.info("Received %d user(s)", len(candidates)) users: list[shared_types.User] = [] for candidate in candidates: @@ -479,7 +479,7 @@ def cmd_set_additive_users_without_explicit_perms( ) resolved_mappings = resolve_additive_mappings(context) candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) - log.info("Received %d non-deleted user candidate(s).", len(candidates)) + log.info("Received %d user(s)", len(candidates)) users: list[shared_types.User] = [] additions: list[permissions_apply.PermissionAddition] = [] From 082197bc530f6f873617527be534a21b55566d6b Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:42:16 -0600 Subject: [PATCH 06/17] Speed up user-candidate permission scans Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- .../permissions/command.py | 197 ++++++++++++++---- .../permissions/queries.py | 14 -- .../permissions/sourcegraph.py | 193 +++++++++++++---- src/src_auth_perms_sync/shared/run_context.py | 96 ++++++++- tests/unit/test_permissions_sourcegraph.py | 142 +++++++++++++ 5 files changed, 556 insertions(+), 86 deletions(-) create mode 100644 tests/unit/test_permissions_sourcegraph.py diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index f6b83d6..27536ac 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -139,6 +139,9 @@ def cmd_get( user_identifiers=user_identifiers, users_without_explicit_perms=users_without_explicit_perms, user_created_after=user_created_after, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + worker_pool=worker_pool, ) counts = permissions_maps.count_users_per_provider(users) # SAML-only: tally distinct users per (serviceID, clientID, group) @@ -232,6 +235,9 @@ def _load_get_users( user_identifiers: tuple[str, ...], users_without_explicit_perms: bool, user_created_after: str | None, + parallelism: int, + explicit_permissions_batch_size: int, + worker_pool: ThreadPoolExecutor | None, ) -> list[shared_types.User]: """Load the Sourcegraph users selected by get/set-compatible user filters.""" if user_identifiers: @@ -257,23 +263,43 @@ def _load_get_users( created_after_filter = sourcegraph_datetime_filter( parse_cli_date(user_created_after, "--created-after") ) - candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) - log.info("Received %d user(s)", len(candidates)) - - users: list[shared_types.User] = [] - for candidate in candidates: - if users_without_explicit_perms and permissions_sourcegraph.user_has_explicit_repos( - client, candidate["id"] - ): - continue - user = permissions_sourcegraph.get_user_by_id(client, candidate["id"]) - if user is None: - log.warning( - "Skipping user candidate %s: user no longer exists.", - candidate["username"], - ) - continue - users.append(user) + candidates = permissions_sourcegraph.list_site_user_candidates( + client, + created_after_filter, + parallelism=parallelism, + worker_pool=worker_pool, + ) + log.info("Loaded %d active user candidate(s).", len(candidates)) + if users_without_explicit_perms: + log.info( + "Checking %d active user candidate(s) for existing explicit repo permissions " + "in batches of %d ...", + len(candidates), + explicit_permissions_batch_size, + ) + explicit_user_ids = permissions_sourcegraph.user_ids_with_explicit_repos( + client, + [candidate["id"] for candidate in candidates], + batch_size=explicit_permissions_batch_size, + parallelism=parallelism, + worker_pool=worker_pool, + ) + candidates = [ + candidate for candidate in candidates if candidate["id"] not in explicit_user_ids + ] + log.info( + "Selected %d active user candidate(s) without explicit repo permissions; " + "skipped %d with existing explicit permissions.", + len(candidates), + len(explicit_user_ids), + ) + + users = _hydrate_site_user_candidates( + client, + candidates, + parallelism=parallelism, + worker_pool=worker_pool, + ) log.info("Selected %d user(s) for get output.", len(users)) return users @@ -315,6 +341,73 @@ def _load_all_get_users(client: src.SourcegraphClient) -> list[shared_types.User return users +def _hydrate_site_user_candidates( + client: src.SourcegraphClient, + candidates: list[shared_types.SiteUserCandidate], + *, + include_emails: bool = False, + parallelism: int, + worker_pool: ThreadPoolExecutor | None, +) -> list[shared_types.User]: + """Hydrate filtered site-user candidates into full user metadata.""" + if not candidates: + return [] + + log.info( + "Hydrating Sourcegraph metadata for %d selected user candidate(s) with parallelism=%d ...", + len(candidates), + parallelism, + ) + + def hydrate_user(candidate: shared_types.SiteUserCandidate) -> shared_types.User | None: + return permissions_sourcegraph.get_user_by_id( + client, + candidate["id"], + include_emails=include_emails, + ) + + hydrated_users = run_context.parallel_map( + hydrate_user, + candidates, + parallelism=parallelism, + worker_pool=worker_pool, + progress_label="Hydrated selected Sourcegraph user metadata", + ) + users = [user for user in hydrated_users if user is not None] + missing_user_count = len(hydrated_users) - len(users) + if missing_user_count: + log.warning( + "Skipped %d selected user candidate(s) that no longer exist.", + missing_user_count, + ) + log.info("Hydrated metadata for %d selected user(s).", len(users)) + return users + + +def _log_user_planning_progress( + completed: int, + total_count: int, + started: float, + *, + grant_count: int, +) -> None: + elapsed = time.perf_counter() - started + rate = completed / elapsed if elapsed > 0 else 0.0 + remaining = max(total_count - completed, 0) + eta_seconds = remaining / rate if rate > 0 else 0.0 + log.info( + "Planned additive grants for %d / %d selected user(s) (%.0f%%) " + "in %.0fs (%.0f users/sec, ETA %.0fs): grant_count=%d.", + completed, + total_count, + 100.0 * completed / total_count, + elapsed, + rate, + eta_seconds, + grant_count, + ) + + def cmd_set( client: src.SourcegraphClient, input_path: Path, @@ -364,6 +457,7 @@ def cmd_set( options.user_created_after, dry_run, parallelism, + explicit_permissions_batch_size, bind_id_mode, saml_groups_attribute_name_by_config_id, do_backup, @@ -452,6 +546,7 @@ def cmd_set_additive_users_without_explicit_perms( user_created_after: str | None, dry_run: bool, parallelism: int, + explicit_permissions_batch_size: int, bind_id_mode: str, saml_groups_attribute_name_by_config_id: dict[str, str], do_backup: bool, @@ -478,25 +573,49 @@ def cmd_set_additive_users_without_explicit_perms( context.mapping_rules ) resolved_mappings = resolve_additive_mappings(context) - candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) - log.info("Received %d user(s)", len(candidates)) + candidates = permissions_sourcegraph.list_site_user_candidates( + client, + created_after_filter, + parallelism=parallelism, + worker_pool=worker_pool, + ) + log.info("Loaded %d active user candidate(s).", len(candidates)) + log.info( + "Checking %d active user candidate(s) for existing explicit repo permissions, " + "in batches of %d ...", + len(candidates), + explicit_permissions_batch_size, + ) + explicit_user_ids = permissions_sourcegraph.user_ids_with_explicit_repos( + client, + [candidate["id"] for candidate in candidates], + batch_size=explicit_permissions_batch_size, + parallelism=parallelism, + worker_pool=worker_pool, + ) + candidates = [ + candidate for candidate in candidates if candidate["id"] not in explicit_user_ids + ] + log.info( + "Selected %d active user candidate(s) without explicit repo permissions; " + "skipped %d with existing explicit permissions.", + len(candidates), + len(explicit_user_ids), + ) - users: list[shared_types.User] = [] + users = _hydrate_site_user_candidates( + client, + candidates, + include_emails=include_user_emails, + parallelism=parallelism, + worker_pool=worker_pool, + ) additions: list[permissions_apply.PermissionAddition] = [] - for candidate in candidates: - if permissions_sourcegraph.user_has_explicit_repos(client, candidate["id"]): - continue - user = permissions_sourcegraph.get_user_by_id( - client, - candidate["id"], - include_emails=include_user_emails, - ) - if user is None: - log.warning( - "Skipping user candidate %s: user no longer exists.", - candidate["username"], - ) - continue + started = time.perf_counter() + progress_step = max(1, len(users) // 10) if users else 1 + next_progress_report = progress_step + log.info("Planning additive grants for %d selected user(s) ...", len(users)) + for completed, user in enumerate(users, start=1): user_additions = _plan_additions_for_user( client, context, @@ -504,8 +623,16 @@ def cmd_set_additive_users_without_explicit_perms( user, existing_repo_ids=set(), ) - users.append(user) additions.extend(user_additions) + if completed >= next_progress_report or completed == len(users): + _log_user_planning_progress( + completed, + len(users), + started, + grant_count=len(additions), + ) + while next_progress_report <= completed: + next_progress_report += progress_step log.info( "Planned additive grants for %d user(s) with no explicit grants.", diff --git a/src/src_auth_perms_sync/permissions/queries.py b/src/src_auth_perms_sync/permissions/queries.py index 7a7e2ec..71f7a58 100644 --- a/src/src_auth_perms_sync/permissions/queries.py +++ b/src/src_auth_perms_sync/permissions/queries.py @@ -179,20 +179,6 @@ def query_user_by_id(*, include_emails: bool = False) -> str: } """ -QUERY_USER_EXPLICIT_REPO_EXISTS = """ -query UserExplicitRepoExists($id: ID!) { - node(id: $id) { - ... on User { - permissionsInfo { - repositories(source: API, first: 1) { - nodes { id } - } - } - } - } -} -""" - # Used as part of post-apply validation: any of OUR bindIDs appearing in # this list means the bindID didn't resolve to a real user (typically a # username typo or a recent rename — would fail for our case since we diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index ed4e32b..f7355d6 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -3,17 +3,21 @@ from __future__ import annotations import logging +import time from collections.abc import Iterable, Iterator, Sequence +from concurrent.futures import ThreadPoolExecutor from typing import Any, cast import src_py_lib as src +from ..shared import run_context from ..shared import sourcegraph as shared_sourcegraph from ..shared import types as shared_types from . import queries from . import types as permission_types log = logging.getLogger(__name__) +SITE_USER_CANDIDATE_PAGE_SIZE = 1000 def list_external_services(client: src.SourcegraphClient) -> list[permission_types.ExternalService]: @@ -95,52 +99,171 @@ def get_user_by_id( def list_site_user_candidates( client: src.SourcegraphClient, created_after: str | None, + *, + parallelism: int = 1, + worker_pool: ThreadPoolExecutor | None = None, ) -> list[shared_types.SiteUserCandidate]: """Return non-deleted site users, optionally filtered by creation time.""" - candidates: list[shared_types.SiteUserCandidate] = [] - offset = 0 created_filter = {"gte": created_after} if created_after is not None else None - while True: - data = cast( - dict[str, Any], - client.graphql( - queries.QUERY_SITE_USERS, - cast( - src.JSONDict, - { - "limit": shared_sourcegraph.DEFAULT_PAGE_SIZE, - "offset": offset, - "createdAt": created_filter, - }, - ), - ), + created_filter_label = f" created on or after {created_after}" if created_after else "" + log.info("Querying active Sourcegraph user candidates%s ...", created_filter_label) + started = time.perf_counter() + first_page, total_count = _site_user_candidate_page( + client, + created_filter, + offset=0, + page_size=SITE_USER_CANDIDATE_PAGE_SIZE, + ) + if not first_page or len(first_page) >= total_count: + return first_page + + # If the server caps `nodes(limit:)` below our requested page size, use + # the observed first-page width so parallel offset requests do not skip + # rows. + page_size = len(first_page) + page_count = (total_count + page_size - 1) // page_size + log.info( + "Loading %d active Sourcegraph user candidate(s)%s across %d page(s) " + "of %d users/page with parallelism=%d ...", + total_count, + created_filter_label, + page_count, + page_size, + parallelism, + ) + pages: list[tuple[int, list[shared_types.SiteUserCandidate]]] = [(0, first_page)] + + def fetch_page(offset: int) -> tuple[int, list[shared_types.SiteUserCandidate]]: + nodes, _ = _site_user_candidate_page( + client, + created_filter, + offset=offset, + page_size=SITE_USER_CANDIDATE_PAGE_SIZE, + ) + return offset, nodes + + pages.extend( + run_context.parallel_map( + fetch_page, + range(page_size, total_count, page_size), + parallelism=parallelism, + worker_pool=worker_pool, + progress_label="Loaded active Sourcegraph user candidate pages", ) - site_users = cast(dict[str, Any], data["site"]["users"]) - total_count = int(cast(float, site_users["totalCount"])) - nodes = cast(list[shared_types.SiteUserCandidate], site_users["nodes"]) - candidates.extend(nodes) - if not nodes or len(candidates) >= total_count: - return candidates - offset += len(nodes) + ) + candidates = _dedupe_site_user_candidate_pages(pages) + _log_user_candidate_load_progress(len(candidates), total_count, started) + return candidates -def user_has_explicit_repos(client: src.SourcegraphClient, user_id: str) -> bool: - """Return whether the user has any explicit API repository grant.""" +def _site_user_candidate_page( + client: src.SourcegraphClient, + created_filter: dict[str, str] | None, + *, + offset: int, + page_size: int, +) -> tuple[list[shared_types.SiteUserCandidate], int]: data = cast( dict[str, Any], client.graphql( - queries.QUERY_USER_EXPLICIT_REPO_EXISTS, - cast(src.JSONDict, {"id": user_id}), + queries.QUERY_SITE_USERS, + cast( + src.JSONDict, + { + "limit": page_size, + "offset": offset, + "createdAt": created_filter, + }, + ), ), ) - node = cast(dict[str, Any] | None, data.get("node")) - if node is None: - return False - permissions_info = cast(dict[str, Any] | None, node.get("permissionsInfo")) - if permissions_info is None: - return False - repositories = cast(dict[str, Any], permissions_info["repositories"]) - return bool(src.json_list(repositories.get("nodes"))) + site_users = cast(dict[str, Any], data["site"]["users"]) + total_count = int(cast(float, site_users["totalCount"])) + nodes = cast(list[shared_types.SiteUserCandidate], site_users["nodes"]) + return nodes, total_count + + +def _dedupe_site_user_candidate_pages( + pages: Iterable[tuple[int, Sequence[shared_types.SiteUserCandidate]]], +) -> list[shared_types.SiteUserCandidate]: + candidates: list[shared_types.SiteUserCandidate] = [] + seen_user_ids: set[str] = set() + for _, page_candidates in sorted(pages, key=lambda page: page[0]): + for candidate in page_candidates: + user_id = candidate["id"] + if user_id in seen_user_ids: + continue + seen_user_ids.add(user_id) + candidates.append(candidate) + return candidates + + +def _log_user_candidate_load_progress(completed: int, total_count: int, started: float) -> None: + elapsed = time.perf_counter() - started + rate = completed / elapsed if elapsed > 0 else 0.0 + remaining = max(total_count - completed, 0) + eta_seconds = remaining / rate if rate > 0 else 0.0 + log.info( + "Loaded %d / %d active Sourcegraph user candidate(s) (%.0f%%) " + "in %.0fs (%.0f users/sec, ETA %.0fs).", + completed, + total_count, + 100.0 * completed / total_count, + elapsed, + rate, + eta_seconds, + ) + + +def user_ids_with_explicit_repos( + client: src.SourcegraphClient, + user_ids: Sequence[str], + *, + batch_size: int, + parallelism: int, + worker_pool: ThreadPoolExecutor | None = None, +) -> set[str]: + """Return user IDs that have at least one explicit API repository grant.""" + batches = list(_batches(tuple(dict.fromkeys(user_ids)), batch_size)) + + def fetch_batch(batch: Sequence[str]) -> set[str]: + return _user_ids_with_explicit_repos_batch(client, batch) + + explicit_user_ids: set[str] = set() + for batch_result in run_context.parallel_map( + fetch_batch, + batches, + parallelism=parallelism, + worker_pool=worker_pool, + progress_label="Checked explicit repo permissions for user batches", + ): + explicit_user_ids.update(batch_result) + return explicit_user_ids + + +def _user_ids_with_explicit_repos_batch( + client: src.SourcegraphClient, + user_ids: Sequence[str], +) -> set[str]: + data = client.graphql( + _user_explicit_repos_batch_query(len(user_ids)), + _user_explicit_repo_exists_batch_variables(user_ids), + follow_pages=False, + ) + explicit_user_ids: set[str] = set() + for index, user_id in enumerate(user_ids): + connection = _user_explicit_repos_connection(data, index) + if connection is not None and src.json_list(connection.get("nodes")): + explicit_user_ids.add(user_id) + return explicit_user_ids + + +def _user_explicit_repo_exists_batch_variables(user_ids: Sequence[str]) -> src.JSONDict: + variables: src.JSONDict = {"first": 1} + for index, user_id in enumerate(user_ids): + variables[f"user{index}"] = user_id + variables[f"after{index}"] = None + return variables def list_user_explicit_repos( diff --git a/src/src_auth_perms_sync/shared/run_context.py b/src/src_auth_perms_sync/shared/run_context.py index 3034327..77efa1f 100644 --- a/src/src_auth_perms_sync/shared/run_context.py +++ b/src/src_auth_perms_sync/shared/run_context.py @@ -2,13 +2,22 @@ from __future__ import annotations -from collections.abc import Generator -from concurrent.futures import ThreadPoolExecutor +import logging +import time +from collections.abc import Callable, Generator, Iterable +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import contextmanager from dataclasses import dataclass +from typing import TypeVar, cast + +import src_py_lib as src from . import types as shared_types +log = logging.getLogger(__name__) +Input = TypeVar("Input") +Output = TypeVar("Output") + @dataclass(frozen=True) class CommandData: @@ -32,3 +41,86 @@ def thread_pool( max_workers=parallelism, thread_name_prefix="sg-worker" ) as created_pool: yield created_pool + + +def parallel_map( + function: Callable[[Input], Output], + items: Iterable[Input], + *, + parallelism: int, + worker_pool: ThreadPoolExecutor | None = None, + progress_label: str | None = None, +) -> list[Output]: + """Map `function` over `items` using the run worker pool, preserving order.""" + values = list(items) + total_count = len(values) + if total_count == 0: + return [] + + started = time.perf_counter() + if parallelism <= 1: + results: list[Output] = [] + for completed, value in enumerate(values, start=1): + results.append(function(value)) + if progress_label is not None and _parallel_progress_due(completed, total_count): + _log_parallel_progress(progress_label, completed, total_count, started) + return results + + results_by_index: dict[int, Output] = {} + pending_futures: dict[Future[Output], int] = {} + next_index = 0 + completed = 0 + max_pending = max(1, parallelism * 2) + + def submit_next(executor: ThreadPoolExecutor) -> None: + nonlocal next_index + while next_index < total_count and len(pending_futures) < max_pending: + value = values[next_index] + future = cast( + Future[Output], + src.submit_with_log_context(executor, function, value), + ) + pending_futures[future] = next_index + next_index += 1 + + with thread_pool(parallelism, worker_pool) as executor: + submit_next(executor) + while pending_futures: + done_futures, _ = wait(pending_futures, return_when=FIRST_COMPLETED) + for future in done_futures: + index = pending_futures.pop(future) + results_by_index[index] = future.result() + completed += 1 + if progress_label is not None and _parallel_progress_due( + completed, + total_count, + ): + _log_parallel_progress(progress_label, completed, total_count, started) + submit_next(executor) + return [results_by_index[index] for index in range(total_count)] + + +def _parallel_progress_due(completed: int, total_count: int) -> bool: + return completed == total_count or completed % max(1, total_count // 10) == 0 + + +def _log_parallel_progress( + progress_label: str, + completed: int, + total_count: int, + started: float, +) -> None: + elapsed = time.perf_counter() - started + rate = completed / elapsed if elapsed > 0 else 0.0 + remaining = max(total_count - completed, 0) + eta_seconds = remaining / rate if rate > 0 else 0.0 + log.info( + "%s: %d / %d complete (%.0f%%) in %.0fs (%.0f/sec, ETA %.0fs).", + progress_label, + completed, + total_count, + 100.0 * completed / total_count, + elapsed, + rate, + eta_seconds, + ) diff --git a/tests/unit/test_permissions_sourcegraph.py b/tests/unit/test_permissions_sourcegraph.py new file mode 100644 index 0000000..990a5d1 --- /dev/null +++ b/tests/unit/test_permissions_sourcegraph.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import threading +import unittest +from typing import cast + +import src_py_lib as src + +from src_auth_perms_sync.permissions import sourcegraph as permissions_sourcegraph + + +class _SiteUsersClient: + def __init__(self, total_count: int) -> None: + self.total_count = total_count + self.calls: list[src.JSONDict] = [] + self.lock = threading.Lock() + + def graphql( + self, + query: str, + variables: src.JSONDict | None = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + del query, follow_pages + if variables is None: + raise AssertionError("expected SiteUsers variables") + with self.lock: + self.calls.append(dict(variables)) + + limit_value = variables.get("limit") + offset_value = variables.get("offset") + if not isinstance(limit_value, int) or not isinstance(offset_value, int): + raise AssertionError("expected integer limit and offset") + + page_nodes: list[dict[str, object]] = [] + for user_number in range(offset_value, min(offset_value + limit_value, self.total_count)): + page_nodes.append( + { + "id": f"user-{user_number}", + "username": f"user-{user_number}", + "email": None, + "createdAt": "2026-06-09T00:00:00Z", + "deletedAt": None, + } + ) + return cast( + src.JSONDict, + {"site": {"users": {"totalCount": self.total_count, "nodes": page_nodes}}}, + ) + + +class _ExplicitReposClient: + def __init__(self, explicit_user_ids: set[str]) -> None: + self.explicit_user_ids = explicit_user_ids + self.calls: list[src.JSONDict] = [] + + def graphql( + self, + query: str, + variables: src.JSONDict | None = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + del query + if variables is None: + raise AssertionError("expected explicit-repo variables") + if follow_pages: + raise AssertionError("existence batch should not ask the client to follow pages") + self.calls.append(dict(variables)) + + response: dict[str, object] = {} + for variable_name, variable_value in variables.items(): + if not variable_name.startswith("user"): + continue + if not isinstance(variable_value, str): + raise AssertionError("expected user ID variable") + user_index = int(variable_name.removeprefix("user")) + permission_nodes: list[dict[str, str]] = [] + if variable_value in self.explicit_user_ids: + permission_nodes.append({"id": "repo-with-explicit-grant"}) + response[f"user{user_index}"] = { + "permissionsInfo": { + "repositories": { + "nodes": permission_nodes, + "pageInfo": {"hasNextPage": False, "endCursor": None}, + } + } + } + return cast(src.JSONDict, response) + + +class PermissionsSourcegraphTests(unittest.TestCase): + def test_list_site_user_candidates_uses_larger_pages(self) -> None: + client = _SiteUsersClient(total_count=2500) + + candidates = permissions_sourcegraph.list_site_user_candidates( + cast(src.SourcegraphClient, client), + None, + parallelism=4, + ) + + self.assertEqual(len(candidates), 2500) + self.assertEqual(candidates[0]["id"], "user-0") + self.assertEqual(candidates[-1]["id"], "user-2499") + limits, offsets = _site_users_call_page_args(client.calls) + self.assertEqual(set(limits), {1000}) + self.assertEqual(sorted(offsets), [0, 1000, 2000]) + + def test_user_ids_with_explicit_repos_batches_existence_checks(self) -> None: + client = _ExplicitReposClient({"user-2", "user-3"}) + + explicit_user_ids = permissions_sourcegraph.user_ids_with_explicit_repos( + cast(src.SourcegraphClient, client), + ["user-1", "user-2", "user-3"], + batch_size=2, + parallelism=1, + ) + + self.assertEqual(explicit_user_ids, {"user-2", "user-3"}) + self.assertEqual([call["first"] for call in client.calls], [1, 1]) + self.assertEqual( + [[call.get("user0"), call.get("user1")] for call in client.calls], + [["user-1", "user-2"], ["user-3", None]], + ) + + +def _site_users_call_page_args(calls: list[src.JSONDict]) -> tuple[list[int], list[int]]: + limits: list[int] = [] + offsets: list[int] = [] + for call in calls: + limit_value = call.get("limit") + offset_value = call.get("offset") + if not isinstance(limit_value, int) or not isinstance(offset_value, int): + raise AssertionError("expected integer limit and offset") + limits.append(limit_value) + offsets.append(offset_value) + return limits, offsets + + +if __name__ == "__main__": + unittest.main() From 4066550d0c950e24de83cb1456acadd2640d35e4 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:51:41 -0600 Subject: [PATCH 07/17] Deduplicate bounded parallel worker loops Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- src/src_auth_perms_sync/orgs/sync.py | 233 ++++++++------- src/src_auth_perms_sync/permissions/apply.py | 268 ++++++++---------- .../permissions/snapshot.py | 218 +++++++------- src/src_auth_perms_sync/shared/run_context.py | 213 +++++++++++--- tests/unit/test_snapshot.py | 4 +- 5 files changed, 531 insertions(+), 405 deletions(-) diff --git a/src/src_auth_perms_sync/orgs/sync.py b/src/src_auth_perms_sync/orgs/sync.py index 1b5819a..8fd228f 100644 --- a/src/src_auth_perms_sync/orgs/sync.py +++ b/src/src_auth_perms_sync/orgs/sync.py @@ -8,7 +8,7 @@ import re import time from collections.abc import Iterable -from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed +from concurrent.futures import CancelledError, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path from typing import Any, cast @@ -550,41 +550,44 @@ def _load_current_organization_states( lookup_batch_count=len(name_batches), member_page_size=ORGANIZATION_MEMBER_PAGE_SIZE, ) as load_event: - with run_context.thread_pool(parallelism, worker_pool) as executor: - futures = { - src.submit_with_log_context( - executor, _fetch_organization_batch, client, batch - ): batch - for batch in name_batches - } - for future in as_completed(futures): - result = future.result() - batch_current_user = result["current_user"] - if current_user is None: - current_user = batch_current_user - elif current_user["id"] != batch_current_user["id"]: - raise RuntimeError( - "currentUser changed between organization lookup batches " - f"({current_user['username']} vs {batch_current_user['username']})" - ) - states.update(result["states"]) - - existing_states = [state for state in states.values() if state.id is not None] - load_event["existing_organizations_needing_member_pages"] = len(existing_states) - if existing_states: - member_futures = { - src.submit_with_log_context( - executor, - _fetch_all_members, - client, - state, - ): state - for state in existing_states - } - for future in as_completed(member_futures): - state = member_futures[future] - for member in future.result(): - state.members_by_id[member["id"]] = member + + def fetch_organization_batch( + batch: list[str], + ) -> organization_types.OrganizationBatchLookup: + return _fetch_organization_batch(client, batch) + + for result in run_context.parallel_map( + fetch_organization_batch, + name_batches, + parallelism=parallelism, + worker_pool=worker_pool, + ): + batch_current_user = result["current_user"] + if current_user is None: + current_user = batch_current_user + elif current_user["id"] != batch_current_user["id"]: + raise RuntimeError( + "currentUser changed between organization lookup batches " + f"({current_user['username']} vs {batch_current_user['username']})" + ) + states.update(result["states"]) + + existing_states = [state for state in states.values() if state.id is not None] + load_event["existing_organizations_needing_member_pages"] = len(existing_states) + + def fetch_members( + state: organization_types.OrganizationState, + ) -> tuple[organization_types.OrganizationState, list[organization_types.OrgMember]]: + return state, _fetch_all_members(client, state) + + for state, members in run_context.parallel_map( + fetch_members, + existing_states, + parallelism=parallelism, + worker_pool=worker_pool, + ): + for member in members: + state.members_by_id[member["id"]] = member load_event["existing_organizations"] = sum(1 for state in states.values() if state.id) load_event["total_current_members"] = sum( len(state.members_by_id) for state in states.values() @@ -774,36 +777,41 @@ def _apply_create_organizations( succeeded = 0 failed = 0 canceled = 0 - with run_context.thread_pool(parallelism, worker_pool) as executor: - futures = { - src.submit_with_log_context( - executor, - _create_organization, - client, - organization_name, - current_user, - ): organization_name - for organization_name in organization_names - } - for future in as_completed(futures): - organization_name = futures[future] - try: - state = future.result() - current_states[organization_name] = state - succeeded += 1 - breaker.record(success=True) - log.info(" OK create org %s.", organization_name) - except CancelledError: - canceled += 1 - continue - except Exception as exception: - failed += 1 - breaker.record(success=False) - log.error(" FAIL create org %s: %s", organization_name, exception) - if breaker.is_open(): - for pending_future in futures: - if not pending_future.done(): - pending_future.cancel() + + def create_organization(organization_name: str) -> organization_types.OrganizationState: + return _create_organization(client, organization_name, current_user) + + def record_result( + result: run_context.ParallelResult[str, organization_types.OrganizationState], + ) -> None: + nonlocal succeeded, failed, canceled + organization_name = result.item + if result.exception is None: + state = result.value + if state is None: + raise RuntimeError(f"create org {organization_name} returned no state") + current_states[organization_name] = state + succeeded += 1 + breaker.record(success=True) + log.info(" OK create org %s.", organization_name) + return + if isinstance(result.exception, CancelledError): + canceled += 1 + return + failed += 1 + breaker.record(success=False) + log.error(" FAIL create org %s: %s", organization_name, result.exception) + + summary = run_context.parallel_process( + create_organization, + organization_names, + parallelism=parallelism, + worker_pool=worker_pool, + handle_result=record_result, + should_stop=breaker.is_open, + ) + if breaker.is_open(): + canceled += summary.unsubmitted_count batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled @@ -878,49 +886,58 @@ def _apply_user_changes( succeeded = 0 failed = 0 canceled = 0 - with run_context.thread_pool(parallelism, worker_pool) as executor: - futures = { - src.submit_with_log_context( - executor, - _apply_user_change, - client, - change, - current_states[change.organization_name], + + def apply_change(change: organization_types.OrganizationUserChange) -> None: + _apply_user_change( + client, + change, + current_states[change.organization_name], + change_kind, + ) + + def record_result( + result: run_context.ParallelResult[ + organization_types.OrganizationUserChange, + None, + ], + ) -> None: + nonlocal succeeded, failed, canceled + change = result.item + if result.exception is None: + succeeded += 1 + breaker.record(success=True) + log.info( + " OK %s %s %s org %s.", change_kind, - ): change - for change in changes - } - for future in as_completed(futures): - change = futures[future] - try: - future.result() - succeeded += 1 - breaker.record(success=True) - log.info( - " OK %s %s %s org %s.", - change_kind, - change.username, - "to" if change_kind == "add" else "from", - change.organization_name, - ) - except CancelledError: - canceled += 1 - continue - except Exception as exception: - failed += 1 - breaker.record(success=False) - log.error( - " FAIL %s %s %s org %s: %s", - change_kind, - change.username, - "to" if change_kind == "add" else "from", - change.organization_name, - exception, - ) - if breaker.is_open(): - for pending_future in futures: - if not pending_future.done(): - pending_future.cancel() + change.username, + "to" if change_kind == "add" else "from", + change.organization_name, + ) + return + if isinstance(result.exception, CancelledError): + canceled += 1 + return + failed += 1 + breaker.record(success=False) + log.error( + " FAIL %s %s %s org %s: %s", + change_kind, + change.username, + "to" if change_kind == "add" else "from", + change.organization_name, + result.exception, + ) + + summary = run_context.parallel_process( + apply_change, + changes, + parallelism=parallelism, + worker_pool=worker_pool, + handle_result=record_result, + should_stop=breaker.is_open, + ) + if breaker.is_open(): + canceled += summary.unsubmitted_count batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled diff --git a/src/src_auth_perms_sync/permissions/apply.py b/src/src_auth_perms_sync/permissions/apply.py index 1dc8dd3..77a1938 100644 --- a/src/src_auth_perms_sync/permissions/apply.py +++ b/src/src_auth_perms_sync/permissions/apply.py @@ -6,14 +6,7 @@ import threading from collections import deque from collections.abc import Sequence -from concurrent.futures import ( - FIRST_COMPLETED, - CancelledError, - Future, - ThreadPoolExecutor, - as_completed, - wait, -) +from concurrent.futures import CancelledError, ThreadPoolExecutor from dataclasses import dataclass, field from typing import TypeAlias, cast @@ -221,61 +214,65 @@ def _apply_permission_changes( canceled = 0 skipped = 0 breaker = CircuitBreaker() - with run_context.thread_pool(parallelism, worker_pool) as executor: - futures = { - src.submit_with_log_context( - executor, - _mutate_repo_permission_for_user, - client, - change, - mutation, - event_name, - ): change - for change in changes - } - for future in as_completed(futures): - change = futures[future] - try: - future.result() - succeeded += 1 - breaker.record(success=True) - log.info( - " OK %s %s → %s (id=%d).", - action, - change.username, - change.repo_name, - src.decode_repository_id(change.repo_id), - ) - except CancelledError: - canceled += 1 - continue - except Exception as exception: - if is_missing_mutation_resource_error(exception): - skipped += 1 - log.warning( - " SKIP %s %s → %s (id=%d): repo/user no longer exists: %s", - action, - change.username, - change.repo_name, - src.decode_repository_id(change.repo_id), - exception, - ) - continue - failed += 1 - breaker.record(success=False) - log.error( - " FAIL %s %s → %s (id=%d): %s", - action, - change.username, - change.repo_name, - src.decode_repository_id(change.repo_id), - exception, - ) - - if breaker.is_open(): - for pending_future in futures: - if not pending_future.done(): - pending_future.cancel() + + def mutate_change(change: PermissionChange) -> None: + _mutate_repo_permission_for_user( + client, + change, + mutation, + event_name, + ) + + def record_result(result: run_context.ParallelResult[PermissionChange, None]) -> None: + nonlocal succeeded, failed, canceled, skipped + change = result.item + exception = result.exception + if exception is None: + succeeded += 1 + breaker.record(success=True) + log.info( + " OK %s %s → %s (id=%d).", + action, + change.username, + change.repo_name, + src.decode_repository_id(change.repo_id), + ) + return + if isinstance(exception, CancelledError): + canceled += 1 + return + if is_missing_mutation_resource_error(exception): + skipped += 1 + log.warning( + " SKIP %s %s → %s (id=%d): repo/user no longer exists: %s", + action, + change.username, + change.repo_name, + src.decode_repository_id(change.repo_id), + exception, + ) + return + failed += 1 + breaker.record(success=False) + log.error( + " FAIL %s %s → %s (id=%d): %s", + action, + change.username, + change.repo_name, + src.decode_repository_id(change.repo_id), + exception, + ) + + summary = run_context.parallel_process( + mutate_change, + changes, + parallelism=parallelism, + worker_pool=worker_pool, + handle_result=record_result, + should_stop=breaker.is_open, + ) + if breaker.is_open(): + canceled += summary.unsubmitted_count batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled @@ -348,101 +345,74 @@ def _apply_repo_overwrite_plans( failed = 0 canceled = 0 skipped = 0 - submitted_count = 0 - submissions_stopped = False breaker = CircuitBreaker() - overwrite_iterator = iter(overwrites) - futures: dict[Future[None], permission_types.RepositoryUsernameOverwrite] = {} - - def _submit_next(executor: ThreadPoolExecutor) -> bool: - nonlocal submitted_count - try: - overwrite = next(overwrite_iterator) - except StopIteration: - return False - future = cast( - Future[None], - src.submit_with_log_context( - executor, - set_repo_permissions_for_usernames, - client, - overwrite.repository_id, - overwrite.usernames, - ), + + def apply_overwrite(overwrite: permission_types.RepositoryUsernameOverwrite) -> None: + set_repo_permissions_for_usernames( + client, + overwrite.repository_id, + overwrite.usernames, ) - futures[future] = overwrite - submitted_count += 1 - return True - def _stop_submissions() -> None: - nonlocal submissions_stopped - if submissions_stopped: + def record_result( + result: run_context.ParallelResult[ + permission_types.RepositoryUsernameOverwrite, + None, + ], + ) -> None: + nonlocal succeeded, failed, canceled, skipped + overwrite = result.item + exception = result.exception + if exception is None: + succeeded += 1 + breaker.record(success=True) + log.info( + " OK %s (id=%d) — %d users.", + overwrite.repository_name, + src.decode_repository_id(overwrite.repository_id), + len(overwrite.usernames), + ) + return + if isinstance(exception, CancelledError): + # Cancelled by the breaker; not counted as a failure because + # we never gave the server a chance to apply it. + canceled += 1 return - submissions_stopped = True - for pending_future in futures: - if not pending_future.done(): - pending_future.cancel() - - with run_context.thread_pool(parallelism, worker_pool) as executor: - while len(futures) < max_pending_futures and _submit_next(executor): - pass - - while futures: - done_futures, _ = wait(futures, return_when=FIRST_COMPLETED) - for future in done_futures: - overwrite = futures.pop(future) - try: - future.result() - succeeded += 1 - breaker.record(success=True) - log.info( - " OK %s (id=%d) — %d users.", - overwrite.repository_name, - src.decode_repository_id(overwrite.repository_id), - len(overwrite.usernames), - ) - except CancelledError: - # Cancelled by the breaker; not counted as a failure - # because we never gave the server a chance to apply it. - canceled += 1 - continue - except Exception as exception: - if is_missing_mutation_resource_error(exception): - skipped += 1 - log.warning( - " SKIP %s (id=%d): repo/user no longer exists: %s", - overwrite.repository_name, - src.decode_repository_id(overwrite.repository_id), - exception, - ) - continue - failed += 1 - breaker.record(success=False) - log.error( - " FAIL %s (id=%d): %s", - overwrite.repository_name, - src.decode_repository_id(overwrite.repository_id), - exception, - ) - - if breaker.is_open(): - _stop_submissions() - - while ( - not submissions_stopped - and len(futures) < max_pending_futures - and _submit_next(executor) - ): - pass - - if submissions_stopped: - canceled += len(overwrites) - submitted_count + if is_missing_mutation_resource_error(exception): + skipped += 1 + log.warning( + " SKIP %s (id=%d): repo/user no longer exists: %s", + overwrite.repository_name, + src.decode_repository_id(overwrite.repository_id), + exception, + ) + return + failed += 1 + breaker.record(success=False) + log.error( + " FAIL %s (id=%d): %s", + overwrite.repository_name, + src.decode_repository_id(overwrite.repository_id), + exception, + ) + + summary = run_context.parallel_process( + apply_overwrite, + overwrites, + parallelism=parallelism, + worker_pool=worker_pool, + handle_result=record_result, + should_stop=breaker.is_open, + max_pending=max_pending_futures, + ) + if breaker.is_open(): + canceled += summary.unsubmitted_count batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled batch_event["skipped"] = skipped batch_event["circuit_broken"] = breaker.is_open() - batch_event["submitted"] = submitted_count + batch_event["submitted"] = summary.submitted_count return shared_types.MutationCounts( succeeded=succeeded, failed=failed, diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index 86c8d94..61d602a 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -8,7 +8,7 @@ import logging import time from collections.abc import Callable, Iterable, Sequence -from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, as_completed, wait +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path from typing import Any, Literal, TextIO, TypeAlias, TypedDict, cast @@ -266,108 +266,108 @@ def _fetch_one_user_at_a_time( with src.span("capture_explicit_grants", **span_fields) as capture_event: capture_failures = 0 - futures: dict[Any, list[SnapshotUserInput]] = {} scanned_user_count = 0 max_pending_batches = max(1, parallelism * 2) - def _submit_batch( - executor: ThreadPoolExecutor, - batch_users: list[SnapshotUserInput], - ) -> None: - nonlocal scanned_user_count - if not batch_users: - return - submitted_batch = list(batch_users) - scanned_user_count += len(submitted_batch) - future = src.submit_with_log_context(executor, _fetch, submitted_batch) - futures[future] = submitted_batch - # Progress reporting: every 10% when total is known (max 10 # lines), every 1000 otherwise. Avoids drowning the operator on # tiny instances and gives steady feedback on large ones. progress_step = max(1, expected_user_count // 10) if expected_user_count else 1000 - # Start the timer BEFORE submission. The submit-while-iterating - # loop blocks on ListUsers pagination, but workers process - # already-submitted tasks during those blocks — so by the time - # the submit loop finishes, many futures may already be done. - # Anchoring `progress_started` here means the first progress - # line shows real wall-clock work time, not zero. + # Start the timer BEFORE submission. Iterating `users` may block on + # ListUsers pagination, but workers process already-submitted tasks + # during those blocks — so progress reflects real wall-clock work. progress_started = time.perf_counter() completed = 0 next_progress_report = progress_step - all_users_submitted = False - - def _record_completed_futures(done_futures: Iterable[Any]) -> None: - nonlocal capture_failures, completed, next_progress_report - for future in done_futures: - submitted_batch = futures.pop(future) - completed += len(submitted_batch) - try: - repository_ids_by_username, failures = future.result() - capture_failures += failures - for username, repository_ids in repository_ids_by_username.items(): - for repository_id in repository_ids: - usernames_by_repository_id.setdefault( - repository_id, - [], - ).append(username) - except Exception as exception: - # Don't blow up the whole capture; warn so the operator - # can see the users whose grants were treated as empty. - capture_failures += len(submitted_batch) - log.warning( - "Failed to fetch explicit grants for %d user(s): %s", - len(submitted_batch), - exception, - ) + last_reported_completed = 0 - if completed >= next_progress_report or ( - all_users_submitted and completed == scanned_user_count - ): - elapsed = time.perf_counter() - progress_started - rate = completed / elapsed if elapsed > 0 else 0.0 - if expected_user_count: - remaining = max(expected_user_count - completed, 0) - eta_seconds = remaining / rate if rate > 0 else 0.0 - log.info( - "Captured explicit permissions for %d / %d users (%.0f%%) " - "in %.0fs (%.0f users/sec, ETA %.0fs).", - completed, - expected_user_count, - 100.0 * completed / expected_user_count, - elapsed, - rate, - eta_seconds, - ) - else: - log.info( - "Captured explicit permissions for %d users in %.0fs (%.0f users/sec).", - completed, - elapsed, - rate, - ) - while next_progress_report <= completed: - next_progress_report += progress_step - - # Submit-while-iterating. Iterating `users` may block on each - # ListUsers page when a streaming iterator is passed; during those - # blocks, workers continue processing already-submitted tasks. - with run_context.thread_pool(parallelism, worker_pool) as executor: + def _user_batches() -> Iterable[list[SnapshotUserInput]]: batch_users: list[SnapshotUserInput] = [] for user in users: batch_users.append(user) if len(batch_users) >= explicit_permissions_batch_size: - _submit_batch(executor, batch_users) + yield list(batch_users) batch_users = [] - if len(futures) >= max_pending_batches: - done_futures, _ = wait(futures, return_when=FIRST_COMPLETED) - _record_completed_futures(done_futures) - _submit_batch(executor, batch_users) - all_users_submitted = True - - while futures: - done_futures, _ = wait(futures, return_when=FIRST_COMPLETED) - _record_completed_futures(done_futures) + if batch_users: + yield list(batch_users) + + def _log_progress(*, force: bool = False) -> None: + nonlocal last_reported_completed, next_progress_report + if completed == 0 or (not force and completed < next_progress_report): + return + if completed == last_reported_completed: + return + elapsed = time.perf_counter() - progress_started + rate = completed / elapsed if elapsed > 0 else 0.0 + if expected_user_count: + remaining = max(expected_user_count - completed, 0) + eta_seconds = remaining / rate if rate > 0 else 0.0 + log.info( + "Captured explicit permissions for %d / %d users (%.0f%%) " + "in %.0fs (%.0f users/sec, ETA %.0fs).", + completed, + expected_user_count, + 100.0 * completed / expected_user_count, + elapsed, + rate, + eta_seconds, + ) + else: + log.info( + "Captured explicit permissions for %d users in %.0fs (%.0f users/sec).", + completed, + elapsed, + rate, + ) + last_reported_completed = completed + while next_progress_report <= completed: + next_progress_report += progress_step + + def _record_completed_batch( + result: run_context.ParallelResult[ + list[SnapshotUserInput], + tuple[dict[str, list[str]], int], + ], + ) -> None: + nonlocal capture_failures, completed, scanned_user_count + submitted_batch = result.item + completed += len(submitted_batch) + scanned_user_count += len(submitted_batch) + if result.exception is not None: + # Don't blow up the whole capture; warn so the operator can + # see the users whose grants were treated as empty. + capture_failures += len(submitted_batch) + log.warning( + "Failed to fetch explicit grants for %d user(s): %s", + len(submitted_batch), + result.exception, + ) + _log_progress() + return + if result.value is None: + raise RuntimeError("explicit-grant batch fetch returned no result") + repository_ids_by_username, failures = result.value + capture_failures += failures + for username, repository_ids in repository_ids_by_username.items(): + for repository_id in repository_ids: + usernames_by_repository_id.setdefault( + repository_id, + [], + ).append(username) + _log_progress() + + # Submit-while-iterating. Iterating `users` may block on each + # ListUsers page when a streaming iterator is passed; during those + # blocks, workers continue processing already-submitted tasks. + run_context.parallel_process( + _fetch, + _user_batches(), + parallelism=parallelism, + worker_pool=worker_pool, + handle_result=_record_completed_batch, + max_pending=max_pending_batches, + ) + _log_progress(force=True) capture_event["scanned_user_count"] = scanned_user_count if capture_failures: capture_event["user_permission_lookup_failures"] = capture_failures @@ -499,28 +499,30 @@ def _fetch(user: SnapshotUser) -> tuple[SnapshotUser, list[permission_types.Repo fetch_event["repo_count"] = len(repos) return user, repos + def _fetch_or_empty( + user: SnapshotUser, + ) -> tuple[SnapshotUser, list[permission_types.Repository]]: + try: + return _fetch(user) + except Exception as exception: + log.warning( + "Failed to fetch scoped explicit grants for user=%s: %s", + user["username"], + exception, + ) + return user, [] + with src.span("capture_user_scoped_explicit_grants") as capture_event: - futures: dict[Any, SnapshotUser] = {} - with run_context.thread_pool(parallelism, worker_pool) as executor: - for user in users: - futures[src.submit_with_log_context(executor, _fetch, user)] = user - for future in as_completed(futures): - user = futures[future] - fetched_user: SnapshotUser - repos: list[permission_types.Repository] - try: - fetched_user, repos = future.result() - except Exception as exception: - log.warning( - "Failed to fetch scoped explicit grants for user=%s: %s", - user["username"], - exception, - ) - fetched_user, repos = user, [] - scoped_users[fetched_user["username"]] = { - "id": fetched_user["id"], - "explicit_repositories": sorted(repos, key=lambda repo: repo["name"]), - } + for fetched_user, repos in run_context.parallel_map( + _fetch_or_empty, + users, + parallelism=parallelism, + worker_pool=worker_pool, + ): + scoped_users[fetched_user["username"]] = { + "id": fetched_user["id"], + "explicit_repositories": sorted(repos, key=lambda repo: repo["name"]), + } capture_event["scanned_user_count"] = len(scoped_users) capture_event["total_grants"] = sum( len(user_snapshot["explicit_repositories"]) for user_snapshot in scoped_users.values() diff --git a/src/src_auth_perms_sync/shared/run_context.py b/src/src_auth_perms_sync/shared/run_context.py index 77efa1f..817e269 100644 --- a/src/src_auth_perms_sync/shared/run_context.py +++ b/src/src_auth_perms_sync/shared/run_context.py @@ -4,19 +4,19 @@ import logging import time -from collections.abc import Callable, Generator, Iterable +from collections.abc import Callable, Generator, Iterable, Sized from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import contextmanager from dataclasses import dataclass -from typing import TypeVar, cast +from typing import Generic, TypeVar, cast import src_py_lib as src from . import types as shared_types log = logging.getLogger(__name__) -Input = TypeVar("Input") -Output = TypeVar("Output") +InputValue = TypeVar("InputValue") +OutputValue = TypeVar("OutputValue") @dataclass(frozen=True) @@ -27,6 +27,23 @@ class CommandData: saml_group_users: list[shared_types.SamlGroupUser] | None = None +@dataclass(frozen=True) +class ParallelResult(Generic[InputValue, OutputValue]): + """One completed parallel item, carrying either a value or an exception.""" + + item: InputValue + value: OutputValue | None = None + exception: Exception | None = None + + +@dataclass(frozen=True) +class ParallelSummary: + """Submission counts from a bounded parallel run.""" + + submitted_count: int + unsubmitted_count: int + + @contextmanager def thread_pool( parallelism: int, @@ -44,13 +61,13 @@ def thread_pool( def parallel_map( - function: Callable[[Input], Output], - items: Iterable[Input], + function: Callable[[InputValue], OutputValue], + items: Iterable[InputValue], *, parallelism: int, worker_pool: ThreadPoolExecutor | None = None, progress_label: str | None = None, -) -> list[Output]: +) -> list[OutputValue]: """Map `function` over `items` using the run worker pool, preserving order.""" values = list(items) total_count = len(values) @@ -58,46 +75,166 @@ def parallel_map( return [] started = time.perf_counter() - if parallelism <= 1: - results: list[Output] = [] - for completed, value in enumerate(values, start=1): - results.append(function(value)) - if progress_label is not None and _parallel_progress_due(completed, total_count): - _log_parallel_progress(progress_label, completed, total_count, started) - return results - - results_by_index: dict[int, Output] = {} - pending_futures: dict[Future[Output], int] = {} - next_index = 0 completed = 0 - max_pending = max(1, parallelism * 2) + results_by_index: dict[int, OutputValue] = {} + + def run_indexed(indexed_value: tuple[int, InputValue]) -> tuple[int, OutputValue]: + index, value = indexed_value + return index, function(value) + + def record_result( + result: ParallelResult[tuple[int, InputValue], tuple[int, OutputValue]], + ) -> None: + nonlocal completed + if result.exception is not None: + raise result.exception + if result.value is None: + raise RuntimeError("parallel map item returned no result") + index, value = result.value + results_by_index[index] = value + completed += 1 + if progress_label is not None and _parallel_progress_due(completed, total_count): + _log_parallel_progress(progress_label, completed, total_count, started) + + parallel_process( + run_indexed, + list(enumerate(values)), + parallelism=parallelism, + worker_pool=worker_pool, + handle_result=record_result, + ) + return [results_by_index[index] for index in range(total_count)] + + +def parallel_process( + function: Callable[[InputValue], OutputValue], + items: Iterable[InputValue], + *, + parallelism: int, + worker_pool: ThreadPoolExecutor | None = None, + handle_result: Callable[[ParallelResult[InputValue, OutputValue]], None], + should_stop: Callable[[], bool] | None = None, + max_pending: int | None = None, +) -> ParallelSummary: + """Process items in parallel, letting callers handle each success or failure. + + Unlike `parallel_map`, this does not raise the first worker exception by + default. The caller receives every completed item and can decide how to + count, log, or stop after it. Work is bounded to avoid queueing thousands + of futures at once. + """ + known_total_count = len(items) if isinstance(items, Sized) else None + if known_total_count == 0: + return ParallelSummary(submitted_count=0, unsubmitted_count=0) + + item_iterator = iter(items) + + if parallelism <= 1: + return _process_sequentially( + function, + item_iterator, + known_total_count=known_total_count, + handle_result=handle_result, + should_stop=should_stop, + ) + + submitted_count = 0 + input_exhausted = False + stop_submissions = False + pending_futures: dict[Future[OutputValue], InputValue] = {} + pending_limit = max_pending or max(1, parallelism * 2) + + def stop_requested() -> bool: + return bool(stop_submissions or (should_stop is not None and should_stop())) + + def cancel_pending() -> None: + for pending_future in pending_futures: + if not pending_future.done(): + pending_future.cancel() def submit_next(executor: ThreadPoolExecutor) -> None: - nonlocal next_index - while next_index < total_count and len(pending_futures) < max_pending: - value = values[next_index] + nonlocal input_exhausted, submitted_count, stop_submissions + while not input_exhausted and len(pending_futures) < pending_limit and not stop_requested(): + try: + value = next(item_iterator) + except StopIteration: + input_exhausted = True + return future = cast( - Future[Output], + Future[OutputValue], src.submit_with_log_context(executor, function, value), ) - pending_futures[future] = next_index - next_index += 1 + pending_futures[future] = value + submitted_count += 1 + if stop_requested(): + stop_submissions = True + cancel_pending() with thread_pool(parallelism, worker_pool) as executor: - submit_next(executor) - while pending_futures: - done_futures, _ = wait(pending_futures, return_when=FIRST_COMPLETED) - for future in done_futures: - index = pending_futures.pop(future) - results_by_index[index] = future.result() - completed += 1 - if progress_label is not None and _parallel_progress_due( - completed, - total_count, - ): - _log_parallel_progress(progress_label, completed, total_count, started) + try: submit_next(executor) - return [results_by_index[index] for index in range(total_count)] + while pending_futures: + done_futures, _ = wait(pending_futures, return_when=FIRST_COMPLETED) + for future in done_futures: + value = pending_futures.pop(future) + try: + result = ParallelResult[InputValue, OutputValue]( + item=value, + value=future.result(), + ) + except Exception as exception: + result = ParallelResult[InputValue, OutputValue]( + item=value, + exception=exception, + ) + handle_result(result) + if should_stop is not None and should_stop(): + stop_submissions = True + cancel_pending() + submit_next(executor) + except BaseException: + cancel_pending() + raise + return ParallelSummary( + submitted_count=submitted_count, + unsubmitted_count=_unsubmitted_count(known_total_count, submitted_count), + ) + + +def _process_sequentially( + function: Callable[[InputValue], OutputValue], + values: Iterable[InputValue], + *, + known_total_count: int | None, + handle_result: Callable[[ParallelResult[InputValue, OutputValue]], None], + should_stop: Callable[[], bool] | None, +) -> ParallelSummary: + submitted_count = 0 + for value in values: + if should_stop is not None and should_stop(): + break + submitted_count += 1 + try: + result = ParallelResult[InputValue, OutputValue]( + item=value, + value=function(value), + ) + except Exception as exception: + result = ParallelResult[InputValue, OutputValue]( + item=value, + exception=exception, + ) + handle_result(result) + return ParallelSummary( + submitted_count=submitted_count, + unsubmitted_count=_unsubmitted_count(known_total_count, submitted_count), + ) + + +def _unsubmitted_count(known_total_count: int | None, submitted_count: int) -> int: + if known_total_count is None: + return 0 + return max(known_total_count - submitted_count, 0) def _parallel_progress_due(completed: int, total_count: int) -> bool: diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index f8e195e..da28a85 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -95,7 +95,7 @@ def test_capture_explicit_grants_bounds_pending_batches(self) -> None: {"id": f"user-{index}", "username": f"user-{index}"} for index in range(9) ] pending_counts: list[int] = [] - real_wait = permission_snapshot.wait + real_wait = permission_snapshot.run_context.wait def recording_wait(futures: Iterable[Future[Any]], **kwargs: Any) -> Any: futures_list = list(futures) @@ -121,7 +121,7 @@ def list_repo_ids( "list_repositories_by_ids", return_value={}, ), - patch.object(permission_snapshot, "wait", side_effect=recording_wait), + patch.object(permission_snapshot.run_context, "wait", side_effect=recording_wait), ): _, scanned_user_count = permission_snapshot.capture_explicit_grants( cast(src.SourcegraphClient, object()), From 4c7f0f0bcb939a2a79c8721e582797ffee571485 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:59:44 -0600 Subject: [PATCH 08/17] Slim explicit permission existence query Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- .../permissions/sourcegraph.py | 29 +++++++++++++++++-- tests/unit/test_permissions_sourcegraph.py | 13 +++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index f7355d6..18c5970 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -246,7 +246,7 @@ def _user_ids_with_explicit_repos_batch( user_ids: Sequence[str], ) -> set[str]: data = client.graphql( - _user_explicit_repos_batch_query(len(user_ids)), + _user_explicit_repo_exists_batch_query(len(user_ids)), _user_explicit_repo_exists_batch_variables(user_ids), follow_pages=False, ) @@ -259,10 +259,9 @@ def _user_ids_with_explicit_repos_batch( def _user_explicit_repo_exists_batch_variables(user_ids: Sequence[str]) -> src.JSONDict: - variables: src.JSONDict = {"first": 1} + variables: src.JSONDict = {} for index, user_id in enumerate(user_ids): variables[f"user{index}"] = user_id - variables[f"after{index}"] = None return variables @@ -411,6 +410,30 @@ def _user_explicit_repos_batch_query(batch_size: int) -> str: return "query UserExplicitReposBatch(" + ", ".join(variables) + ") {" + "".join(fields) + "\n}" +def _user_explicit_repo_exists_batch_query(batch_size: int) -> str: + variables = [f"$user{index}: ID!" for index in range(batch_size)] + fields = [ + f""" + user{index}: node(id: $user{index}) {{ + ... on User {{ + permissionsInfo {{ + repositories(source: API, first: 1) {{ + nodes {{ id }} + }} + }} + }} + }}""" + for index in range(batch_size) + ] + return ( + "query UserExplicitRepoExistsBatch(" + + ", ".join(variables) + + ") {" + + "".join(fields) + + "\n}" + ) + + def _user_explicit_repos_batch_variables( batch: Sequence[tuple[str, str | None]], ) -> src.JSONDict: diff --git a/tests/unit/test_permissions_sourcegraph.py b/tests/unit/test_permissions_sourcegraph.py index 990a5d1..cc5f971 100644 --- a/tests/unit/test_permissions_sourcegraph.py +++ b/tests/unit/test_permissions_sourcegraph.py @@ -54,6 +54,7 @@ class _ExplicitReposClient: def __init__(self, explicit_user_ids: set[str]) -> None: self.explicit_user_ids = explicit_user_ids self.calls: list[src.JSONDict] = [] + self.queries: list[str] = [] def graphql( self, @@ -62,12 +63,12 @@ def graphql( *, follow_pages: bool = True, ) -> src.JSONDict: - del query if variables is None: raise AssertionError("expected explicit-repo variables") if follow_pages: raise AssertionError("existence batch should not ask the client to follow pages") self.calls.append(dict(variables)) + self.queries.append(query) response: dict[str, object] = {} for variable_name, variable_value in variables.items(): @@ -83,7 +84,6 @@ def graphql( "permissionsInfo": { "repositories": { "nodes": permission_nodes, - "pageInfo": {"hasNextPage": False, "endCursor": None}, } } } @@ -118,11 +118,18 @@ def test_user_ids_with_explicit_repos_batches_existence_checks(self) -> None: ) self.assertEqual(explicit_user_ids, {"user-2", "user-3"}) - self.assertEqual([call["first"] for call in client.calls], [1, 1]) + for query in client.queries: + self.assertIn("query UserExplicitRepoExistsBatch", query) + self.assertIn("repositories(source: API, first: 1)", query) + self.assertNotIn("pageInfo", query) + self.assertNotIn("after", query) self.assertEqual( [[call.get("user0"), call.get("user1")] for call in client.calls], [["user-1", "user-2"], ["user-3", None]], ) + for call in client.calls: + self.assertNotIn("first", call) + self.assertFalse(any(variable_name.startswith("after") for variable_name in call)) def _site_users_call_page_args(calls: list[src.JSONDict]) -> tuple[list[int], list[int]]: From caa60792f09615c116d177e019298ecbf42ff408 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:09:13 -0600 Subject: [PATCH 09/17] Update log message --- src/src_auth_perms_sync/permissions/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index 20057dd..c973d9e 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -43,7 +43,7 @@ def load_discovery( providers = shared_sourcegraph.list_auth_providers(client) log.info("Received %d auth providers.", len(providers)) - log.info("Loading code hosts from %s ...", client.endpoint) + log.info("Querying code hosts from %s ...", client.endpoint) services = permissions_sourcegraph.list_external_services(client) log.info("Received %d code hosts.", len(services)) From 68aa137ac58a8b2c9915ca77ae8bd28e903934ec Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:09:42 -0600 Subject: [PATCH 10/17] Expand Sourcegraph bulk permission API request Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- dev/memory-efficiency.md | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/dev/memory-efficiency.md b/dev/memory-efficiency.md index aaf90a8..dbee6be 100644 --- a/dev/memory-efficiency.md +++ b/dev/memory-efficiency.md @@ -298,6 +298,39 @@ Important requirements: Expected benefit: replace hundreds or thousands of per-repo resolver SQL spans per request with one indexed `user_repo_permissions` join per user batch. +The `get --users-without-explicit-perms` path also needs a cheaper presence +check. Today it has to ask +`User.permissionsInfo.repositories(source: API, first: 1)` for every candidate +user, in aliased batches. Recent test runs show the client can parallelize +those batches, but the Sourcegraph frontend / load balancer can still return +502/503s under that resolver load. Add one or both direct APIs: + +```graphql +type ExplicitRepositoryPermissionPresence { + userID: ID! + hasExplicitRepositoryPermissions: Boolean! +} + +extend type Query { + explicitRepositoryPermissionPresenceForUsers( + userIDs: [ID!]! + source: PermissionSource = API + ): [ExplicitRepositoryPermissionPresence!]! + + usersWithoutExplicitRepositoryPermissions( + createdAt: DateTimeFilter + source: PermissionSource = API + first: Int + after: String + ): UserConnection! +} +``` + +Expected benefit: `src-auth-perms-sync get --users-without-explicit-perms` +can either check explicit-permission presence for candidate users in one indexed +batch query, or ask Sourcegraph for the filtered user set directly instead of +probing every user through the expensive permissions connection resolver. + The stress profile also needs attention on the write path. A purpose-built bulk overwrite API that accepts many repo/user edges at once, streams or stages the input server-side, and avoids repeated per-repo permission reconciliation @@ -324,7 +357,10 @@ Request: add a bulk explicit-permissions read API that accepts many user IDs and returns compact permission edges (`userID`, `repositoryID`, `repositoryName`, `updatedAt`) for `source: API`, without resolving full `Repository` GraphQL objects. A single indexed query over `user_repo_permissions` joined to `repo` -should be enough for each user batch. +should be enough for each user batch. Also add a cheaper presence/filter path +for `get --users-without-explicit-perms`: either `userID -> has explicit API +repo permissions` for many users, or a direct query for users without explicit +API repo permissions, optionally filtered by `createdAt`. Acceptance criteria: @@ -336,6 +372,9 @@ Acceptance criteria: latency visible. - `src-auth-perms-sync` can replace its aliased `User.permissionsInfo.repositories(source: API)` calls with this API. +- `src-auth-perms-sync get --users-without-explicit-perms` can stop probing + every candidate user through `User.permissionsInfo.repositories(source: API, + first: 1)`. - Follow-up: evaluate a bulk overwrite API for large full-set applies. The stress run planned roughly 10 million grants and observed `permsStore.upsertUserRepoPermissions-range1` averaging about 2.5s per call. From 234a67f4eb8852c73866c8e374b171c7bc45ef11 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:24:45 -0600 Subject: [PATCH 11/17] Pipeline explicit permission candidate checks Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- .../permissions/command.py | 72 ++--- .../permissions/sourcegraph.py | 251 +++++++++++++++++- src/src_auth_perms_sync/shared/run_context.py | 10 +- tests/unit/test_permissions_sourcegraph.py | 90 +++++++ 4 files changed, 373 insertions(+), 50 deletions(-) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 27536ac..750d2a5 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -263,36 +263,31 @@ def _load_get_users( created_after_filter = sourcegraph_datetime_filter( parse_cli_date(user_created_after, "--created-after") ) - candidates = permissions_sourcegraph.list_site_user_candidates( - client, - created_after_filter, - parallelism=parallelism, - worker_pool=worker_pool, - ) - log.info("Loaded %d active user candidate(s).", len(candidates)) if users_without_explicit_perms: + candidate_selection = ( + permissions_sourcegraph.list_site_user_candidates_without_explicit_repos( + client, + created_after_filter, + batch_size=explicit_permissions_batch_size, + parallelism=parallelism, + worker_pool=worker_pool, + ) + ) + candidates = candidate_selection.candidates log.info( - "Checking %d active user candidate(s) for existing explicit repo permissions " - "in batches of %d ...", + "Selected %d active user candidate(s) without explicit repo permissions; " + "skipped %d with existing explicit permissions.", len(candidates), - explicit_permissions_batch_size, + candidate_selection.explicit_user_count, ) - explicit_user_ids = permissions_sourcegraph.user_ids_with_explicit_repos( + else: + candidates = permissions_sourcegraph.list_site_user_candidates( client, - [candidate["id"] for candidate in candidates], - batch_size=explicit_permissions_batch_size, + created_after_filter, parallelism=parallelism, worker_pool=worker_pool, ) - candidates = [ - candidate for candidate in candidates if candidate["id"] not in explicit_user_ids - ] - log.info( - "Selected %d active user candidate(s) without explicit repo permissions; " - "skipped %d with existing explicit permissions.", - len(candidates), - len(explicit_user_ids), - ) + log.info("Loaded %d active user candidate(s).", len(candidates)) users = _hydrate_site_user_candidates( client, @@ -573,34 +568,21 @@ def cmd_set_additive_users_without_explicit_perms( context.mapping_rules ) resolved_mappings = resolve_additive_mappings(context) - candidates = permissions_sourcegraph.list_site_user_candidates( - client, - created_after_filter, - parallelism=parallelism, - worker_pool=worker_pool, - ) - log.info("Loaded %d active user candidate(s).", len(candidates)) - log.info( - "Checking %d active user candidate(s) for existing explicit repo permissions, " - "in batches of %d ...", - len(candidates), - explicit_permissions_batch_size, - ) - explicit_user_ids = permissions_sourcegraph.user_ids_with_explicit_repos( - client, - [candidate["id"] for candidate in candidates], - batch_size=explicit_permissions_batch_size, - parallelism=parallelism, - worker_pool=worker_pool, + candidate_selection = ( + permissions_sourcegraph.list_site_user_candidates_without_explicit_repos( + client, + created_after_filter, + batch_size=explicit_permissions_batch_size, + parallelism=parallelism, + worker_pool=worker_pool, + ) ) - candidates = [ - candidate for candidate in candidates if candidate["id"] not in explicit_user_ids - ] + candidates = candidate_selection.candidates log.info( "Selected %d active user candidate(s) without explicit repo permissions; " "skipped %d with existing explicit permissions.", len(candidates), - len(explicit_user_ids), + candidate_selection.explicit_user_count, ) users = _hydrate_site_user_candidates( diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index 18c5970..6f51c46 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -4,8 +4,10 @@ import logging import time +from collections import deque from collections.abc import Iterable, Iterator, Sequence -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from dataclasses import dataclass from typing import Any, cast import src_py_lib as src @@ -20,6 +22,20 @@ SITE_USER_CANDIDATE_PAGE_SIZE = 1000 +@dataclass(frozen=True) +class SiteUserCandidateSelection: + """Active user candidates after filtering explicit repo-permission owners.""" + + candidates: list[shared_types.SiteUserCandidate] + explicit_user_count: int + + +@dataclass(frozen=True) +class _SiteUserCandidatePage: + offset: int + candidates: list[shared_types.SiteUserCandidate] + + def list_external_services(client: src.SourcegraphClient) -> list[permission_types.ExternalService]: return [ cast(permission_types.ExternalService, node) @@ -156,6 +172,93 @@ def fetch_page(offset: int) -> tuple[int, list[shared_types.SiteUserCandidate]]: return candidates +def list_site_user_candidates_without_explicit_repos( + client: src.SourcegraphClient, + created_after: str | None, + *, + batch_size: int, + parallelism: int, + worker_pool: ThreadPoolExecutor | None = None, +) -> SiteUserCandidateSelection: + """Return active site users that do not already have explicit API grants. + + Candidate pages and explicit-permission checks are pipelined so the slow + permission checks can start as soon as the first candidate page yields a + full batch of users. + """ + if batch_size < 1: + raise ValueError("batch_size must be at least 1") + + created_filter = {"gte": created_after} if created_after is not None else None + created_filter_label = f" created on or after {created_after}" if created_after else "" + log.info("Querying active Sourcegraph user candidates%s ...", created_filter_label) + started = time.perf_counter() + first_page, total_count = _site_user_candidate_page( + client, + created_filter, + offset=0, + page_size=SITE_USER_CANDIDATE_PAGE_SIZE, + ) + if not first_page: + return SiteUserCandidateSelection(candidates=[], explicit_user_count=0) + + if len(first_page) >= total_count or parallelism <= 1: + _log_user_candidate_load_progress(len(first_page), total_count, started) + log.info( + "Checking %d active user candidate(s)%s for existing explicit repo permissions " + "in batches of %d ...", + len(first_page), + created_filter_label, + batch_size, + ) + explicit_user_ids = user_ids_with_explicit_repos( + client, + [candidate["id"] for candidate in first_page], + batch_size=batch_size, + parallelism=parallelism, + worker_pool=worker_pool, + ) + return SiteUserCandidateSelection( + candidates=[ + candidate for candidate in first_page if candidate["id"] not in explicit_user_ids + ], + explicit_user_count=len(explicit_user_ids), + ) + + page_size = len(first_page) + page_count = (total_count + page_size - 1) // page_size + log.info( + "Loading %d active Sourcegraph user candidate(s)%s across %d page(s) " + "of %d users/page, while checking explicit repo permissions in batches " + "of %d with parallelism=%d ...", + total_count, + created_filter_label, + page_count, + page_size, + batch_size, + parallelism, + ) + pages, explicit_user_ids = _load_candidate_pages_and_explicit_user_ids( + client, + created_filter, + first_page, + total_count=total_count, + page_size=page_size, + batch_size=batch_size, + parallelism=parallelism, + worker_pool=worker_pool, + started=started, + ) + candidates = _dedupe_site_user_candidate_pages(pages) + _log_user_candidate_load_progress(len(candidates), total_count, started) + return SiteUserCandidateSelection( + candidates=[ + candidate for candidate in candidates if candidate["id"] not in explicit_user_ids + ], + explicit_user_count=len(explicit_user_ids), + ) + + def _site_user_candidate_page( client: src.SourcegraphClient, created_filter: dict[str, str] | None, @@ -183,6 +286,152 @@ def _site_user_candidate_page( return nodes, total_count +def _load_candidate_pages_and_explicit_user_ids( + client: src.SourcegraphClient, + created_filter: dict[str, str] | None, + first_page: list[shared_types.SiteUserCandidate], + *, + total_count: int, + page_size: int, + batch_size: int, + parallelism: int, + worker_pool: ThreadPoolExecutor | None, + started: float, +) -> tuple[list[tuple[int, list[shared_types.SiteUserCandidate]]], set[str]]: + pages: list[tuple[int, list[shared_types.SiteUserCandidate]]] = [(0, first_page)] + explicit_user_ids: set[str] = set() + queued_user_ids: set[str] = set() + candidate_batch_buffer: list[str] = [] + ready_user_batches = deque[tuple[str, ...]]() + page_offsets = iter(range(page_size, total_count, page_size)) + page_count = (total_count + page_size - 1) // page_size + total_batch_count = (total_count + batch_size - 1) // batch_size + completed_page_count = 1 + completed_batch_count = 0 + pages_exhausted = False + page_pending_limit = max(1, parallelism // 2) + early_permission_pending_limit = max(1, parallelism - page_pending_limit) + pending_page_futures: dict[Future[_SiteUserCandidatePage], int] = {} + pending_permission_futures: dict[Future[set[str]], tuple[str, ...]] = {} + + def queue_user_batches(candidates: Sequence[shared_types.SiteUserCandidate]) -> None: + for candidate in candidates: + user_id = candidate["id"] + if user_id in queued_user_ids: + continue + queued_user_ids.add(user_id) + candidate_batch_buffer.append(user_id) + if len(candidate_batch_buffer) == batch_size: + ready_user_batches.append(tuple(candidate_batch_buffer)) + candidate_batch_buffer.clear() + + def fetch_page(offset: int) -> _SiteUserCandidatePage: + candidates, _ = _site_user_candidate_page( + client, + created_filter, + offset=offset, + page_size=SITE_USER_CANDIDATE_PAGE_SIZE, + ) + return _SiteUserCandidatePage(offset=offset, candidates=candidates) + + def submit_candidate_pages(executor: ThreadPoolExecutor) -> None: + nonlocal pages_exhausted + while not pages_exhausted and len(pending_page_futures) < page_pending_limit: + try: + offset = next(page_offsets) + except StopIteration: + pages_exhausted = True + return + future = cast( + Future[_SiteUserCandidatePage], + src.submit_with_log_context(executor, fetch_page, offset), + ) + pending_page_futures[future] = offset + + def flush_final_user_batch() -> None: + if pages_exhausted and not pending_page_futures and candidate_batch_buffer: + ready_user_batches.append(tuple(candidate_batch_buffer)) + candidate_batch_buffer.clear() + + def permission_pending_limit() -> int: + if pages_exhausted and not pending_page_futures: + return parallelism + return early_permission_pending_limit + + def submit_permission_batches(executor: ThreadPoolExecutor) -> None: + while ready_user_batches and len(pending_permission_futures) < permission_pending_limit(): + user_batch = ready_user_batches.popleft() + future = cast( + Future[set[str]], + src.submit_with_log_context( + executor, + _user_ids_with_explicit_repos_batch, + client, + user_batch, + ), + ) + pending_permission_futures[future] = user_batch + + def cancel_pending_futures() -> None: + for future in list(pending_page_futures) + list(pending_permission_futures): + future.cancel() + + queue_user_batches(first_page) + with run_context.thread_pool(parallelism, worker_pool) as executor: + try: + submit_candidate_pages(executor) + flush_final_user_batch() + submit_permission_batches(executor) + while pending_page_futures or pending_permission_futures: + pending_futures: set[Future[object]] = { + cast(Future[object], future) for future in pending_page_futures + } + pending_futures.update( + cast(Future[object], future) for future in pending_permission_futures + ) + completed_futures, _ = wait( + pending_futures, + return_when=FIRST_COMPLETED, + ) + for completed_future in completed_futures: + page_future = cast(Future[_SiteUserCandidatePage], completed_future) + if page_future in pending_page_futures: + page = page_future.result() + pending_page_futures.pop(page_future) + pages.append((page.offset, page.candidates)) + completed_page_count += 1 + if run_context.parallel_progress_due(completed_page_count, page_count): + run_context.log_parallel_progress( + "Loaded active Sourcegraph user candidate pages", + completed_page_count, + page_count, + started, + ) + queue_user_batches(page.candidates) + else: + permission_future = cast(Future[set[str]], completed_future) + pending_permission_futures.pop(permission_future) + explicit_user_ids.update(permission_future.result()) + completed_batch_count += 1 + if run_context.parallel_progress_due( + completed_batch_count, + total_batch_count, + ): + run_context.log_parallel_progress( + "Checked explicit repo permissions for user batches", + completed_batch_count, + total_batch_count, + started, + ) + submit_candidate_pages(executor) + flush_final_user_batch() + submit_permission_batches(executor) + except BaseException: + cancel_pending_futures() + raise + return pages, explicit_user_ids + + def _dedupe_site_user_candidate_pages( pages: Iterable[tuple[int, Sequence[shared_types.SiteUserCandidate]]], ) -> list[shared_types.SiteUserCandidate]: diff --git a/src/src_auth_perms_sync/shared/run_context.py b/src/src_auth_perms_sync/shared/run_context.py index 817e269..45369e2 100644 --- a/src/src_auth_perms_sync/shared/run_context.py +++ b/src/src_auth_perms_sync/shared/run_context.py @@ -93,8 +93,8 @@ def record_result( index, value = result.value results_by_index[index] = value completed += 1 - if progress_label is not None and _parallel_progress_due(completed, total_count): - _log_parallel_progress(progress_label, completed, total_count, started) + if progress_label is not None and parallel_progress_due(completed, total_count): + log_parallel_progress(progress_label, completed, total_count, started) parallel_process( run_indexed, @@ -237,11 +237,13 @@ def _unsubmitted_count(known_total_count: int | None, submitted_count: int) -> i return max(known_total_count - submitted_count, 0) -def _parallel_progress_due(completed: int, total_count: int) -> bool: +def parallel_progress_due(completed: int, total_count: int) -> bool: + """Return whether a bounded parallel run should log progress now.""" + return completed == total_count or completed % max(1, total_count // 10) == 0 -def _log_parallel_progress( +def log_parallel_progress( progress_label: str, completed: int, total_count: int, diff --git a/tests/unit/test_permissions_sourcegraph.py b/tests/unit/test_permissions_sourcegraph.py index cc5f971..a4849c5 100644 --- a/tests/unit/test_permissions_sourcegraph.py +++ b/tests/unit/test_permissions_sourcegraph.py @@ -90,6 +90,81 @@ def graphql( return cast(src.JSONDict, response) +class _PipelinedCandidateClient: + def __init__(self) -> None: + self.total_count = 1001 + self.explicit_user_ids = {"user-10"} + self.release_second_page = threading.Event() + self.second_page_returned = threading.Event() + self.explicit_started_before_second_page_returned = False + + def graphql( + self, + query: str, + variables: src.JSONDict | None = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + if variables is None: + raise AssertionError("expected variables") + if "query SiteUsers" in query: + return self._site_users(variables) + if "query UserExplicitRepoExistsBatch" in query: + if not follow_pages: + self.explicit_started_before_second_page_returned = bool( + self.explicit_started_before_second_page_returned + or not self.second_page_returned.is_set() + ) + self.release_second_page.set() + return self._explicit_repos(variables) + raise AssertionError("existence batch should not ask the client to follow pages") + raise AssertionError(f"unexpected query: {query[:80]}") + + def _site_users(self, variables: src.JSONDict) -> src.JSONDict: + limit_value = variables.get("limit") + offset_value = variables.get("offset") + if not isinstance(limit_value, int) or not isinstance(offset_value, int): + raise AssertionError("expected integer limit and offset") + if offset_value == 1000: + if not self.release_second_page.wait(timeout=5): + raise AssertionError( + "explicit permission lookup did not start before page load finished" + ) + self.second_page_returned.set() + + page_nodes: list[dict[str, object]] = [] + for user_number in range(offset_value, min(offset_value + limit_value, self.total_count)): + page_nodes.append( + { + "id": f"user-{user_number}", + "username": f"user-{user_number}", + "email": None, + "createdAt": "2026-06-09T00:00:00Z", + "deletedAt": None, + } + ) + return cast( + src.JSONDict, + {"site": {"users": {"totalCount": self.total_count, "nodes": page_nodes}}}, + ) + + def _explicit_repos(self, variables: src.JSONDict) -> src.JSONDict: + response: dict[str, object] = {} + for variable_name, variable_value in variables.items(): + if not variable_name.startswith("user"): + continue + if not isinstance(variable_value, str): + raise AssertionError("expected user ID variable") + user_index = int(variable_name.removeprefix("user")) + permission_nodes: list[dict[str, str]] = [] + if variable_value in self.explicit_user_ids: + permission_nodes.append({"id": "repo-with-explicit-grant"}) + response[f"user{user_index}"] = { + "permissionsInfo": {"repositories": {"nodes": permission_nodes}} + } + return cast(src.JSONDict, response) + + class PermissionsSourcegraphTests(unittest.TestCase): def test_list_site_user_candidates_uses_larger_pages(self) -> None: client = _SiteUsersClient(total_count=2500) @@ -131,6 +206,21 @@ def test_user_ids_with_explicit_repos_batches_existence_checks(self) -> None: self.assertNotIn("first", call) self.assertFalse(any(variable_name.startswith("after") for variable_name in call)) + def test_candidates_without_explicit_repos_pipelines_checks_after_first_page(self) -> None: + client = _PipelinedCandidateClient() + + selection = permissions_sourcegraph.list_site_user_candidates_without_explicit_repos( + cast(src.SourcegraphClient, client), + None, + batch_size=1000, + parallelism=2, + ) + + self.assertTrue(client.explicit_started_before_second_page_returned) + self.assertEqual(selection.explicit_user_count, 1) + self.assertEqual(len(selection.candidates), 1000) + self.assertNotIn("user-10", {candidate["id"] for candidate in selection.candidates}) + def _site_users_call_page_args(calls: list[src.JSONDict]) -> tuple[list[int], list[int]]: limits: list[int] = [] From 22d0430e5b674939fedadd3f4bb8fac70d56dcc3 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:35:20 -0600 Subject: [PATCH 12/17] Batch selected user hydration Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- .../permissions/command.py | 26 +++++--- .../permissions/sourcegraph.py | 41 +++++++++++++ tests/unit/test_permissions_sourcegraph.py | 60 +++++++++++++++++++ 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 750d2a5..a7ff7bc 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -348,26 +348,36 @@ def _hydrate_site_user_candidates( if not candidates: return [] + batch_size = shared_sourcegraph.DEFAULT_PAGE_SIZE log.info( - "Hydrating Sourcegraph metadata for %d selected user candidate(s) with parallelism=%d ...", + "Hydrating Sourcegraph metadata for %d selected user candidate(s) " + "in batches of %d with parallelism=%d ...", len(candidates), + batch_size, parallelism, ) + candidate_batches = [ + candidates[start_index : start_index + batch_size] + for start_index in range(0, len(candidates), batch_size) + ] - def hydrate_user(candidate: shared_types.SiteUserCandidate) -> shared_types.User | None: - return permissions_sourcegraph.get_user_by_id( + def hydrate_users( + candidate_batch: list[shared_types.SiteUserCandidate], + ) -> list[shared_types.User | None]: + return permissions_sourcegraph.get_users_by_ids( client, - candidate["id"], + [candidate["id"] for candidate in candidate_batch], include_emails=include_emails, ) - hydrated_users = run_context.parallel_map( - hydrate_user, - candidates, + hydrated_user_batches = run_context.parallel_map( + hydrate_users, + candidate_batches, parallelism=parallelism, worker_pool=worker_pool, - progress_label="Hydrated selected Sourcegraph user metadata", + progress_label="Hydrated selected Sourcegraph user metadata batches", ) + hydrated_users = [user for batch in hydrated_user_batches for user in batch] users = [user for user in hydrated_users if user is not None] missing_user_count = len(hydrated_users) - len(users) if missing_user_count: diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index 6f51c46..57051b4 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -112,6 +112,25 @@ def get_user_by_id( return cast(shared_types.User | None, data.get("node")) +def get_users_by_ids( + client: src.SourcegraphClient, + user_ids: Sequence[str], + *, + include_emails: bool = False, +) -> list[shared_types.User | None]: + """Hydrate User nodes by GraphQL ID, preserving caller order.""" + if not user_ids: + return [] + data = client.graphql( + _users_by_id_batch_query(len(user_ids), include_emails=include_emails), + _users_by_id_batch_variables(user_ids), + follow_pages=False, + ) + return [ + cast(shared_types.User | None, data.get(f"user{index}")) for index in range(len(user_ids)) + ] + + def list_site_user_candidates( client: src.SourcegraphClient, created_after: str | None, @@ -636,6 +655,28 @@ def _batches(values: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]: yield values[start_index : start_index + batch_size] +def _users_by_id_batch_query(batch_size: int, *, include_emails: bool = False) -> str: + variables = [f"$user{index}: ID!" for index in range(batch_size)] + user_fields = queries.user_fields(include_emails=include_emails) + fields = [ + f""" + user{index}: node(id: $user{index}) {{ + ... on User {{ + {user_fields} + }} + }}""" + for index in range(batch_size) + ] + return "query UsersByIDBatch(" + ", ".join(variables) + ") {" + "".join(fields) + "\n}" + + +def _users_by_id_batch_variables(user_ids: Sequence[str]) -> src.JSONDict: + variables: src.JSONDict = {} + for index, user_id in enumerate(user_ids): + variables[f"user{index}"] = user_id + return variables + + def _user_explicit_repos_batch_query(batch_size: int) -> str: variables = ["$first: Int!"] fields: list[str] = [] diff --git a/tests/unit/test_permissions_sourcegraph.py b/tests/unit/test_permissions_sourcegraph.py index a4849c5..f5b665c 100644 --- a/tests/unit/test_permissions_sourcegraph.py +++ b/tests/unit/test_permissions_sourcegraph.py @@ -90,6 +90,45 @@ def graphql( return cast(src.JSONDict, response) +class _UsersByIDClient: + def __init__(self, missing_user_ids: set[str] | None = None) -> None: + self.missing_user_ids = missing_user_ids or set() + self.calls: list[src.JSONDict] = [] + self.queries: list[str] = [] + + def graphql( + self, + query: str, + variables: src.JSONDict | None = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + if variables is None: + raise AssertionError("expected user variables") + if follow_pages: + raise AssertionError("user batch should not ask the client to follow pages") + self.calls.append(dict(variables)) + self.queries.append(query) + response: dict[str, object] = {} + for variable_name, variable_value in variables.items(): + if not variable_name.startswith("user"): + continue + if not isinstance(variable_value, str): + raise AssertionError("expected user ID variable") + user_index = int(variable_name.removeprefix("user")) + response[f"user{user_index}"] = ( + None + if variable_value in self.missing_user_ids + else { + "id": variable_value, + "username": variable_value.replace("user-", "username-"), + "builtinAuth": False, + "externalAccounts": {"nodes": []}, + } + ) + return cast(src.JSONDict, response) + + class _PipelinedCandidateClient: def __init__(self) -> None: self.total_count = 1001 @@ -206,6 +245,27 @@ def test_user_ids_with_explicit_repos_batches_existence_checks(self) -> None: self.assertNotIn("first", call) self.assertFalse(any(variable_name.startswith("after") for variable_name in call)) + def test_get_users_by_ids_batches_user_hydration(self) -> None: + client = _UsersByIDClient(missing_user_ids={"user-2"}) + + users = permissions_sourcegraph.get_users_by_ids( + cast(src.SourcegraphClient, client), + ["user-1", "user-2", "user-3"], + ) + + self.assertEqual( + [user["id"] if user else None for user in users], + ["user-1", None, "user-3"], + ) + self.assertEqual( + client.calls, + [{"user0": "user-1", "user1": "user-2", "user2": "user-3"}], + ) + self.assertEqual(len(client.queries), 1) + self.assertIn("query UsersByIDBatch", client.queries[0]) + self.assertIn("externalAccounts(first: 50)", client.queries[0]) + self.assertNotIn("emails {", client.queries[0]) + def test_candidates_without_explicit_repos_pipelines_checks_after_first_page(self) -> None: client = _PipelinedCandidateClient() From 18a69f28aa541d29d7b7f75635885cffb8ce82ba Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:44:19 -0600 Subject: [PATCH 13/17] Revert "Batch selected user hydration" This reverts commit 22d0430e5b674939fedadd3f4bb8fac70d56dcc3. --- .../permissions/command.py | 26 +++----- .../permissions/sourcegraph.py | 41 ------------- tests/unit/test_permissions_sourcegraph.py | 60 ------------------- 3 files changed, 8 insertions(+), 119 deletions(-) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index a7ff7bc..750d2a5 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -348,36 +348,26 @@ def _hydrate_site_user_candidates( if not candidates: return [] - batch_size = shared_sourcegraph.DEFAULT_PAGE_SIZE log.info( - "Hydrating Sourcegraph metadata for %d selected user candidate(s) " - "in batches of %d with parallelism=%d ...", + "Hydrating Sourcegraph metadata for %d selected user candidate(s) with parallelism=%d ...", len(candidates), - batch_size, parallelism, ) - candidate_batches = [ - candidates[start_index : start_index + batch_size] - for start_index in range(0, len(candidates), batch_size) - ] - def hydrate_users( - candidate_batch: list[shared_types.SiteUserCandidate], - ) -> list[shared_types.User | None]: - return permissions_sourcegraph.get_users_by_ids( + def hydrate_user(candidate: shared_types.SiteUserCandidate) -> shared_types.User | None: + return permissions_sourcegraph.get_user_by_id( client, - [candidate["id"] for candidate in candidate_batch], + candidate["id"], include_emails=include_emails, ) - hydrated_user_batches = run_context.parallel_map( - hydrate_users, - candidate_batches, + hydrated_users = run_context.parallel_map( + hydrate_user, + candidates, parallelism=parallelism, worker_pool=worker_pool, - progress_label="Hydrated selected Sourcegraph user metadata batches", + progress_label="Hydrated selected Sourcegraph user metadata", ) - hydrated_users = [user for batch in hydrated_user_batches for user in batch] users = [user for user in hydrated_users if user is not None] missing_user_count = len(hydrated_users) - len(users) if missing_user_count: diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index 57051b4..6f51c46 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -112,25 +112,6 @@ def get_user_by_id( return cast(shared_types.User | None, data.get("node")) -def get_users_by_ids( - client: src.SourcegraphClient, - user_ids: Sequence[str], - *, - include_emails: bool = False, -) -> list[shared_types.User | None]: - """Hydrate User nodes by GraphQL ID, preserving caller order.""" - if not user_ids: - return [] - data = client.graphql( - _users_by_id_batch_query(len(user_ids), include_emails=include_emails), - _users_by_id_batch_variables(user_ids), - follow_pages=False, - ) - return [ - cast(shared_types.User | None, data.get(f"user{index}")) for index in range(len(user_ids)) - ] - - def list_site_user_candidates( client: src.SourcegraphClient, created_after: str | None, @@ -655,28 +636,6 @@ def _batches(values: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]: yield values[start_index : start_index + batch_size] -def _users_by_id_batch_query(batch_size: int, *, include_emails: bool = False) -> str: - variables = [f"$user{index}: ID!" for index in range(batch_size)] - user_fields = queries.user_fields(include_emails=include_emails) - fields = [ - f""" - user{index}: node(id: $user{index}) {{ - ... on User {{ - {user_fields} - }} - }}""" - for index in range(batch_size) - ] - return "query UsersByIDBatch(" + ", ".join(variables) + ") {" + "".join(fields) + "\n}" - - -def _users_by_id_batch_variables(user_ids: Sequence[str]) -> src.JSONDict: - variables: src.JSONDict = {} - for index, user_id in enumerate(user_ids): - variables[f"user{index}"] = user_id - return variables - - def _user_explicit_repos_batch_query(batch_size: int) -> str: variables = ["$first: Int!"] fields: list[str] = [] diff --git a/tests/unit/test_permissions_sourcegraph.py b/tests/unit/test_permissions_sourcegraph.py index f5b665c..a4849c5 100644 --- a/tests/unit/test_permissions_sourcegraph.py +++ b/tests/unit/test_permissions_sourcegraph.py @@ -90,45 +90,6 @@ def graphql( return cast(src.JSONDict, response) -class _UsersByIDClient: - def __init__(self, missing_user_ids: set[str] | None = None) -> None: - self.missing_user_ids = missing_user_ids or set() - self.calls: list[src.JSONDict] = [] - self.queries: list[str] = [] - - def graphql( - self, - query: str, - variables: src.JSONDict | None = None, - *, - follow_pages: bool = True, - ) -> src.JSONDict: - if variables is None: - raise AssertionError("expected user variables") - if follow_pages: - raise AssertionError("user batch should not ask the client to follow pages") - self.calls.append(dict(variables)) - self.queries.append(query) - response: dict[str, object] = {} - for variable_name, variable_value in variables.items(): - if not variable_name.startswith("user"): - continue - if not isinstance(variable_value, str): - raise AssertionError("expected user ID variable") - user_index = int(variable_name.removeprefix("user")) - response[f"user{user_index}"] = ( - None - if variable_value in self.missing_user_ids - else { - "id": variable_value, - "username": variable_value.replace("user-", "username-"), - "builtinAuth": False, - "externalAccounts": {"nodes": []}, - } - ) - return cast(src.JSONDict, response) - - class _PipelinedCandidateClient: def __init__(self) -> None: self.total_count = 1001 @@ -245,27 +206,6 @@ def test_user_ids_with_explicit_repos_batches_existence_checks(self) -> None: self.assertNotIn("first", call) self.assertFalse(any(variable_name.startswith("after") for variable_name in call)) - def test_get_users_by_ids_batches_user_hydration(self) -> None: - client = _UsersByIDClient(missing_user_ids={"user-2"}) - - users = permissions_sourcegraph.get_users_by_ids( - cast(src.SourcegraphClient, client), - ["user-1", "user-2", "user-3"], - ) - - self.assertEqual( - [user["id"] if user else None for user in users], - ["user-1", None, "user-3"], - ) - self.assertEqual( - client.calls, - [{"user0": "user-1", "user1": "user-2", "user2": "user-3"}], - ) - self.assertEqual(len(client.queries), 1) - self.assertIn("query UsersByIDBatch", client.queries[0]) - self.assertIn("externalAccounts(first: 50)", client.queries[0]) - self.assertNotIn("emails {", client.queries[0]) - def test_candidates_without_explicit_repos_pipelines_checks_after_first_page(self) -> None: client = _PipelinedCandidateClient() From afc5c3eef4b0c334e4d225885ad52eee72132e32 Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 04:14:52 -0600 Subject: [PATCH 14/17] Make queries much more efficient --- dev/TODO.md | 2 - src/src_auth_perms_sync/orgs/sync.py | 5 +- .../permissions/command.py | 284 +++++++++++++++--- .../permissions/full_set.py | 11 + .../permissions/mapping.py | 29 ++ .../permissions/queries.py | 46 ++- .../permissions/restore.py | 13 +- .../permissions/snapshot.py | 70 ++++- .../permissions/sourcegraph.py | 21 +- .../permissions/workflow.py | 57 ++++ src/src_auth_perms_sync/shared/queries.py | 26 +- src/src_auth_perms_sync/shared/saml_groups.py | 4 +- src/src_auth_perms_sync/shared/sourcegraph.py | 12 +- tests/unit/test_command_additive.py | 280 +++++++++++++++++ tests/unit/test_maps.py | 72 +++++ tests/unit/test_permissions_sourcegraph.py | 33 ++ tests/unit/test_snapshot.py | 44 ++- 17 files changed, 919 insertions(+), 90 deletions(-) create mode 100644 tests/unit/test_command_additive.py diff --git a/dev/TODO.md b/dev/TODO.md index f46d1f0..dfae3b3 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -4,8 +4,6 @@ ### Fast -- Additive modes, to add new users’ perms quickly, - without the extraneous load on the database of a full sync - Query the instance for all new repos, which do not yet have explicit perms ### Full: Overwrite all perms diff --git a/src/src_auth_perms_sync/orgs/sync.py b/src/src_auth_perms_sync/orgs/sync.py index 8fd228f..e52a466 100644 --- a/src/src_auth_perms_sync/orgs/sync.py +++ b/src/src_auth_perms_sync/orgs/sync.py @@ -403,7 +403,10 @@ def _collect_target_organizations( progress_step = 1000 if saml_group_users is None: log.info("Streaming users once and extracting SAML group memberships ...") - for completed, user in enumerate(shared_sourcegraph.list_users_streaming(client), start=1): + for completed, user in enumerate( + shared_sourcegraph.list_users_streaming(client, include_account_data=True), + start=1, + ): compact_user = saml_groups.compact_saml_group_user( user, providers_by_account_key, diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 750d2a5..1fb4055 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -25,7 +25,9 @@ from . import types as permission_types from .workflow import ( load_discovery, - load_mapping_context, + load_mapping_context_discovery, + load_mapping_rules, + load_repos_for_mapping_context, parse_cli_date, snapshot_path, sourcegraph_datetime_filter, @@ -79,6 +81,54 @@ def resolve_additive_mappings(context: permission_types.MappingContext) -> list[ return resolved +def _mapping_context_with_rules( + context: permission_types.MappingContext, + mapping_rules: list[permission_types.MappingRule], +) -> permission_types.MappingContext: + return permission_types.MappingContext( + mapping_rules=mapping_rules, + providers=context.providers, + saml_groups_attribute_names=context.saml_groups_attribute_names, + services_by_id=context.services_by_id, + repos_by_external_service_id=context.repos_by_external_service_id, + all_repos_by_id=context.all_repos_by_id, + ) + + +def _mapping_rules_matching_selected_users( + context: permission_types.MappingContext, + users: list[shared_types.User], +) -> list[permission_types.MappingRule]: + matching_rules: list[permission_types.MappingRule] = [] + for mapping_rule in context.mapping_rules: + if any( + permissions_mapping.user_matches_user_selector( + mapping_rule["users"], + user, + context.providers, + context.saml_groups_attribute_names, + ) + for user in users + ): + matching_rules.append(mapping_rule) + return matching_rules + + +def _service_ids_required_by_mapping_rules( + context: permission_types.MappingContext, + mapping_rules: list[permission_types.MappingRule], +) -> set[int]: + return permissions_mapping.service_ids_required_by_repository_selectors( + context.services_by_id, + [mapping_rule["repos"] for mapping_rule in mapping_rules], + ) + + +def _providers_need_saml_account_data(providers: list[shared_types.AuthProvider]) -> bool: + """Return whether output needs SAML accountData-derived group counts.""" + return any(provider["serviceType"] == saml_groups.SAML_SERVICE_TYPE for provider in providers) + + def cmd_get( client: src.SourcegraphClient, code_hosts_path: Path, @@ -133,6 +183,7 @@ def cmd_get( services = [permissions_maps.external_service_to_yaml(service) for service in raw_services] cmd_event["auth_provider_count"] = len(raw_providers) cmd_event["external_service_count"] = len(services) + include_user_account_data = _providers_need_saml_account_data(raw_providers) users = _load_get_users( client, @@ -141,6 +192,7 @@ def cmd_get( user_created_after=user_created_after, parallelism=parallelism, explicit_permissions_batch_size=explicit_permissions_batch_size, + include_account_data=include_user_account_data, worker_pool=worker_pool, ) counts = permissions_maps.count_users_per_provider(users) @@ -237,11 +289,16 @@ def _load_get_users( user_created_after: str | None, parallelism: int, explicit_permissions_batch_size: int, + include_account_data: bool, worker_pool: ThreadPoolExecutor | None, ) -> list[shared_types.User]: """Load the Sourcegraph users selected by get/set-compatible user filters.""" if user_identifiers: - users = _resolve_user_identifiers(client, user_identifiers) + users = _resolve_user_identifiers( + client, + user_identifiers, + include_account_data=include_account_data, + ) if user_created_after is None: return users candidate_user_ids = user_ids_created_on_or_after(client, user_created_after) @@ -292,16 +349,21 @@ def _load_get_users( users = _hydrate_site_user_candidates( client, candidates, + include_account_data=include_account_data, parallelism=parallelism, worker_pool=worker_pool, ) log.info("Selected %d user(s) for get output.", len(users)) return users - return _load_all_get_users(client) + return _load_all_get_users(client, include_account_data=include_account_data) -def _load_all_get_users(client: src.SourcegraphClient) -> list[shared_types.User]: +def _load_all_get_users( + client: src.SourcegraphClient, + *, + include_account_data: bool, +) -> list[shared_types.User]: """Load all users for get output, with progress logs for large instances.""" total_users = shared_sourcegraph.count_users(client) page_count = ( @@ -316,7 +378,13 @@ def _load_all_get_users(client: src.SourcegraphClient) -> list[shared_types.User users: list[shared_types.User] = [] load_started = time.perf_counter() progress_step = max(1, total_users // 10) - for completed, user in enumerate(shared_sourcegraph.list_users_streaming(client), start=1): + for completed, user in enumerate( + shared_sourcegraph.list_users_streaming( + client, + include_account_data=include_account_data, + ), + start=1, + ): users.append(user) if completed % progress_step == 0 or completed == total_users: elapsed = time.perf_counter() - load_started @@ -341,6 +409,7 @@ def _hydrate_site_user_candidates( candidates: list[shared_types.SiteUserCandidate], *, include_emails: bool = False, + include_account_data: bool = True, parallelism: int, worker_pool: ThreadPoolExecutor | None, ) -> list[shared_types.User]: @@ -359,6 +428,7 @@ def hydrate_user(candidate: shared_types.SiteUserCandidate) -> shared_types.User client, candidate["id"], include_emails=include_emails, + include_account_data=include_account_data, ) hydrated_users = run_context.parallel_map( @@ -483,16 +553,24 @@ def cmd_set_additive_users( parallelism=parallelism, do_backup=do_backup, ): - context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) - if context is None: + mapping_rules = load_mapping_rules(input_path) + if not mapping_rules: + log.warning("No maps defined in %s — nothing to do.", input_path) return run_context.CommandData() - include_user_emails = permissions_mapping.mapping_rules_need_user_emails( - context.mapping_rules + include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) + include_user_account_data = permissions_mapping.mapping_rules_need_saml_account_data( + mapping_rules ) users = _resolve_user_identifiers( client, user_identifiers, include_emails=include_user_emails, + include_account_data=include_user_account_data, + ) + context = load_mapping_context_discovery( + client, + mapping_rules, + saml_groups_attribute_name_by_config_id, ) if user_created_after is not None: candidate_user_ids = user_ids_created_on_or_after(client, user_created_after) @@ -509,15 +587,57 @@ def cmd_set_additive_users( users = selected_users if not users: return run_context.CommandData(auth_providers=context.providers) + + matching_rules = _mapping_rules_matching_selected_users(context, users) + log.info( + "%d / %d mapping rule(s) match the selected user(s).", + len(matching_rules), + len(context.mapping_rules), + ) + if not matching_rules: + _run_additive_apply( + client, + input_path, + users, + [], + dry_run=dry_run, + parallelism=parallelism, + bind_id_mode=bind_id_mode, + do_backup=do_backup, + command_name="set-add-users", + worker_pool=worker_pool, + ) + return run_context.CommandData(auth_providers=context.providers) + + service_ids = _service_ids_required_by_mapping_rules(context, matching_rules) + log.info( + "Selected mapping rule(s) require repo scans for %d / %d code host connection(s).", + len(service_ids), + len(context.services_by_id), + ) + context = load_repos_for_mapping_context( + client, + _mapping_context_with_rules(context, matching_rules), + service_ids, + ) resolved_mappings = resolve_additive_mappings(context) additions: list[permissions_apply.PermissionAddition] = [] + existing_repos_by_user_id = ( + _load_selected_user_explicit_repos(client, users) if do_backup else None + ) for user in users: + existing_repo_ids = None + if existing_repos_by_user_id is not None: + existing_repo_ids = { + repository["id"] for repository in existing_repos_by_user_id[user["id"]] + } additions.extend( _plan_additions_for_user( client, context, resolved_mappings, user, + existing_repo_ids=existing_repo_ids, ) ) _run_additive_apply( @@ -530,6 +650,7 @@ def cmd_set_additive_users( bind_id_mode=bind_id_mode, do_backup=do_backup, command_name="set-add-users", + existing_repos_by_user_id=existing_repos_by_user_id, worker_pool=worker_pool, ) return run_context.CommandData(auth_providers=context.providers) @@ -561,13 +682,19 @@ def cmd_set_additive_users_without_explicit_perms( parallelism=parallelism, do_backup=do_backup, ): - context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) - if context is None: + mapping_rules = load_mapping_rules(input_path) + if not mapping_rules: + log.warning("No maps defined in %s — nothing to do.", input_path) return run_context.CommandData() - include_user_emails = permissions_mapping.mapping_rules_need_user_emails( - context.mapping_rules + context = load_mapping_context_discovery( + client, + mapping_rules, + saml_groups_attribute_name_by_config_id, + ) + include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) + include_user_account_data = permissions_mapping.mapping_rules_need_saml_account_data( + mapping_rules ) - resolved_mappings = resolve_additive_mappings(context) candidate_selection = ( permissions_sourcegraph.list_site_user_candidates_without_explicit_repos( client, @@ -589,9 +716,58 @@ def cmd_set_additive_users_without_explicit_perms( client, candidates, include_emails=include_user_emails, + include_account_data=include_user_account_data, parallelism=parallelism, worker_pool=worker_pool, ) + if not users: + _run_additive_apply( + client, + input_path, + users, + [], + dry_run=dry_run, + parallelism=parallelism, + bind_id_mode=bind_id_mode, + do_backup=do_backup, + command_name="set-add-users-without-explicit-perms", + worker_pool=worker_pool, + ) + return run_context.CommandData(auth_providers=context.providers) + + matching_rules = _mapping_rules_matching_selected_users(context, users) + log.info( + "%d / %d mapping rule(s) match the selected user(s).", + len(matching_rules), + len(context.mapping_rules), + ) + if not matching_rules: + _run_additive_apply( + client, + input_path, + users, + [], + dry_run=dry_run, + parallelism=parallelism, + bind_id_mode=bind_id_mode, + do_backup=do_backup, + command_name="set-add-users-without-explicit-perms", + worker_pool=worker_pool, + ) + return run_context.CommandData(auth_providers=context.providers) + + service_ids = _service_ids_required_by_mapping_rules(context, matching_rules) + log.info( + "Selected mapping rule(s) require repo scans for %d / %d code host connection(s).", + len(service_ids), + len(context.services_by_id), + ) + context = load_repos_for_mapping_context( + client, + _mapping_context_with_rules(context, matching_rules), + service_ids, + ) + resolved_mappings = resolve_additive_mappings(context) additions: list[permissions_apply.PermissionAddition] = [] started = time.perf_counter() progress_step = max(1, len(users) // 10) if users else 1 @@ -640,6 +816,7 @@ def _resolve_user_identifiers( user_identifiers: tuple[str, ...], *, include_emails: bool = False, + include_account_data: bool = True, ) -> list[shared_types.User]: """Resolve username/email inputs to distinct Sourcegraph users in caller order.""" users: list[shared_types.User] = [] @@ -649,6 +826,7 @@ def _resolve_user_identifiers( client, user_identifier, include_emails=include_emails, + include_account_data=include_account_data, ) if user["id"] in seen_user_ids: continue @@ -662,6 +840,7 @@ def _resolve_user_identifier( user_identifier: str, *, include_emails: bool = False, + include_account_data: bool = True, ) -> shared_types.User: """Resolve username/email input to one Sourcegraph user.""" user: shared_types.User | None @@ -670,20 +849,24 @@ def _resolve_user_identifier( client, user_identifier, include_emails=include_emails, + include_account_data=include_account_data, ) or permissions_sourcegraph.get_user_by_username( client, user_identifier, include_emails=include_emails, + include_account_data=include_account_data, ) else: user = permissions_sourcegraph.get_user_by_username( client, user_identifier, include_emails=include_emails, + include_account_data=include_account_data, ) or permissions_sourcegraph.get_user_by_email( client, user_identifier, include_emails=include_emails, + include_account_data=include_account_data, ) if user is None: raise SystemExit(f"No Sourcegraph user found for {user_identifier!r}.") @@ -692,6 +875,23 @@ def _resolve_user_identifier( return user +def _load_selected_user_explicit_repos( + client: src.SourcegraphClient, + users: list[shared_types.User], +) -> dict[str, list[permission_types.Repository]]: + """Fetch selected users' explicit repos once for planning and snapshots.""" + with src.span("load_selected_user_explicit_repos", user_count=len(users)) as load_event: + repos_by_user_id = { + user["id"]: permissions_sourcegraph.list_user_explicit_repos( + client, + user["id"], + ) + for user in users + } + load_event["total_grants"] = sum(len(repos) for repos in repos_by_user_id.values()) + return repos_by_user_id + + def _plan_additions_for_user( client: src.SourcegraphClient, context: permission_types.MappingContext, @@ -712,6 +912,10 @@ def _plan_additions_for_user( for repository in resolved_mapping.repos: desired_repos[repository["id"]] = repository + if not desired_repos: + log.info("User %s: no desired repo grants.", user["username"]) + return [] + if existing_repo_ids is None: existing_repo_ids = set( permissions_sourcegraph.list_user_explicit_repo_ids(client, user["id"]) @@ -752,17 +956,27 @@ def _write_additive_initial_artifacts( parallelism: int, bind_id_mode: str, command_name: str, + existing_repos_by_user_id: dict[str, list[permission_types.Repository]] | None = None, worker_pool: ThreadPoolExecutor | None = None, ) -> permission_snapshot.UserScopedSnapshot: """Capture before-snapshot and write dry-run/no-op additive artifacts.""" - before_snapshot = permission_snapshot.build_user_scoped_snapshot( - client, - snapshot_users, - parallelism, - bind_id_mode, - input_path, - worker_pool=worker_pool, - ) + if existing_repos_by_user_id is None: + before_snapshot = permission_snapshot.build_user_scoped_snapshot( + client, + snapshot_users, + parallelism, + bind_id_mode, + input_path, + worker_pool=worker_pool, + ) + else: + before_snapshot = permission_snapshot.build_user_scoped_snapshot_from_repos( + client, + snapshot_users, + existing_repos_by_user_id, + bind_id_mode, + input_path, + ) run_label = _additive_run_label(command_name, dry_run) before_path = snapshot_path(input_path, timestamp, client.endpoint, run_label, "before") after_path = snapshot_path(input_path, timestamp, client.endpoint, run_label, "after") @@ -906,6 +1120,7 @@ def _run_additive_apply( bind_id_mode: str, do_backup: bool, command_name: str, + existing_repos_by_user_id: dict[str, list[permission_types.Repository]] | None = None, worker_pool: ThreadPoolExecutor | None = None, ) -> None: """Snapshot, dry-run, apply, and validate an additive permission plan.""" @@ -914,9 +1129,10 @@ def _run_additive_apply( return snapshot_users = _snapshot_users_from_users(users) - timestamp = backups.backup_timestamp() before_snapshot: permission_snapshot.UserScopedSnapshot | None = None - if dry_run or do_backup: + timestamp: str | None = None + if do_backup: + timestamp = backups.backup_timestamp() before_snapshot = _write_additive_initial_artifacts( client, input_path, @@ -927,6 +1143,7 @@ def _run_additive_apply( parallelism=parallelism, bind_id_mode=bind_id_mode, command_name=command_name, + existing_repos_by_user_id=existing_repos_by_user_id, worker_pool=worker_pool, ) @@ -943,6 +1160,7 @@ def _run_additive_apply( if do_backup: assert before_snapshot is not None + assert timestamp is not None _finish_additive_apply_with_backup( client, input_path, @@ -979,13 +1197,11 @@ def _user_scoped_snapshot_with_additions( for addition in additions: user_snapshot = users.setdefault( addition.username, - {"id": addition.user_id, "explicit_repositories": []}, + {"id": addition.user_id, "repos": []}, ) - repositories = { - repository["id"]: repository for repository in user_snapshot["explicit_repositories"] - } + repositories = {repository["id"]: repository for repository in user_snapshot["repos"]} repositories[addition.repo_id] = {"id": addition.repo_id, "name": addition.repo_name} - user_snapshot["explicit_repositories"] = sorted( + user_snapshot["repos"] = sorted( repositories.values(), key=lambda repository: repository["name"], ) @@ -998,7 +1214,7 @@ def _copy_user_scoped_users( return { username: { "id": user_snapshot["id"], - "explicit_repositories": list(user_snapshot["explicit_repositories"]), + "repos": list(user_snapshot["repos"]), } for username, user_snapshot in snapshot["users"].items() } @@ -1008,9 +1224,7 @@ def _copy_user_scoped_snapshot_with_users( snapshot: permission_snapshot.UserScopedSnapshot, users: dict[str, permission_snapshot.UserScopedUserSnapshot], ) -> permission_snapshot.UserScopedSnapshot: - total_grants = sum( - len(user_snapshot["explicit_repositories"]) for user_snapshot in users.values() - ) + total_grants = sum(len(user_snapshot["repos"]) for user_snapshot in users.values()) return { "schema_version": snapshot["schema_version"], "snapshot_kind": snapshot["snapshot_kind"], @@ -1022,7 +1236,7 @@ def _copy_user_scoped_snapshot_with_users( "stats": { "total_users_scanned": len(users), "users_with_explicit_grants": sum( - 1 for user_snapshot in users.values() if user_snapshot["explicit_repositories"] + 1 for user_snapshot in users.values() if user_snapshot["repos"] ), "total_grants": total_grants, }, @@ -1037,7 +1251,7 @@ def _validate_additive_after( """Validate that every requested additive edge exists after apply.""" missing: list[permissions_apply.PermissionAddition] = [] repos_by_username = { - username: {repository["id"] for repository in user_snapshot["explicit_repositories"]} + username: {repository["id"] for repository in user_snapshot["repos"]} for username, user_snapshot in after_snapshot["users"].items() } for addition in additions: diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 00857fb..207c948 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -99,6 +99,7 @@ def _capture_full_set_snapshot_state( bind_id_mode: str, worker_pool: ThreadPoolExecutor | None = None, include_user_emails: bool = False, + include_user_account_data: bool = True, ) -> _FullSetUserState: """Load users while capturing the before-snapshot.""" expected_user_count = shared_sourcegraph.count_users(client) @@ -115,6 +116,7 @@ def _capture_full_set_snapshot_state( client, collect_into=users, include_emails=include_user_emails, + include_account_data=include_user_account_data, ), parallelism, bind_id_mode, @@ -146,6 +148,7 @@ def _load_full_set_snapshot_state( capture_before: bool, worker_pool: ThreadPoolExecutor | None = None, include_user_emails: bool = False, + include_user_account_data: bool = True, ) -> _FullSetUserState: """Load all users, optionally with a before-snapshot.""" if capture_before: @@ -157,12 +160,14 @@ def _load_full_set_snapshot_state( bind_id_mode, worker_pool, include_user_emails=include_user_emails, + include_user_account_data=include_user_account_data, ) log.info("Loading users from %s ...", client.endpoint) users = shared_sourcegraph.list_users_with_accounts( client, include_emails=include_user_emails, + include_account_data=include_user_account_data, ) log.info("Received %d total users.", len(users)) return _FullSetUserState(users=users) @@ -656,6 +661,7 @@ def _finish_empty_full_set_mapping_rules( explicit_permissions_batch_size, bind_id_mode, worker_pool, + include_user_account_data=False, ) _write_noop_full_set_artifacts( input_path, @@ -683,6 +689,10 @@ def _load_full_set_plan( worker_pool: ThreadPoolExecutor | None = None, ) -> _FullSetLoadedPlan: include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) + include_user_account_data = ( + permissions_mapping.mapping_rules_need_saml_account_data(mapping_rules) + or retain_saml_group_users + ) user_state = _load_full_set_snapshot_state( client, input_path, @@ -692,6 +702,7 @@ def _load_full_set_plan( capture_before=capture_before, worker_pool=worker_pool, include_user_emails=include_user_emails, + include_user_account_data=include_user_account_data, ) before_path: Path | None = None if capture_before: diff --git a/src/src_auth_perms_sync/permissions/mapping.py b/src/src_auth_perms_sync/permissions/mapping.py index 69d78f1..e176f2a 100644 --- a/src/src_auth_perms_sync/permissions/mapping.py +++ b/src/src_auth_perms_sync/permissions/mapping.py @@ -135,6 +135,15 @@ def mapping_rules_need_user_emails(mapping_rules: list[permission_types.MappingR ) +def mapping_rules_need_saml_account_data( + mapping_rules: list[permission_types.MappingRule], +) -> bool: + """Return whether any mapping rule filters users by SAML group claims.""" + return any( + bool(mapping["users"].get("authProvider", {}).get("samlGroup")) for mapping in mapping_rules + ) + + def _validate_mapping_name(value: object, prefix: str) -> list[str]: """Validate the required human-readable mapping name.""" if value is None: @@ -772,6 +781,26 @@ def _repos_matching_code_host_connection( return list(matched_repos.values()) +def service_ids_required_by_repository_selectors( + services_by_id: dict[int, permission_types.ExternalService], + selectors: Sequence[permission_types.RepositorySelector], +) -> set[int]: + """Return code-host service IDs whose repos may match the selectors. + + A selector without `codeHostConnection` can match any code host, so the + caller must load every service. Selectors with `codeHostConnection` narrow + the repo scan to only services matching that matcher. + """ + required_service_ids: set[int] = set() + for selector in selectors: + matcher = selector.get("codeHostConnection") + if matcher is None: + return set(services_by_id) + for service in _services_matching(services_by_id, matcher): + required_service_ids.add(src.decode_external_service_id(service["id"])) + return required_service_ids + + def _repo_name_matches( repository_name: str, exact_values: set[str], patterns: list[re.Pattern[str]] ) -> bool: diff --git a/src/src_auth_perms_sync/permissions/queries.py b/src/src_auth_perms_sync/permissions/queries.py index 71f7a58..e643090 100644 --- a/src/src_auth_perms_sync/permissions/queries.py +++ b/src/src_auth_perms_sync/permissions/queries.py @@ -75,7 +75,7 @@ } """ -USER_FIELDS = """ +USER_BASE_FIELDS = """ id username builtinAuth @@ -84,11 +84,13 @@ serviceType serviceID clientID - accountData +__ACCOUNT_DATA_FIELD__ } } """ +USER_ACCOUNT_DATA_FIELD = " accountData" + USER_EMAIL_FIELDS = """ emails { email @@ -97,39 +99,59 @@ """ -def user_fields(*, include_emails: bool = False) -> str: - """Return user fields, adding emails only when downstream matching needs them.""" +def user_fields( + *, + include_emails: bool = False, + include_account_data: bool = True, +) -> str: + """Return user fields, adding heavier fields only when downstream needs them.""" + fields = USER_BASE_FIELDS.replace( + "__ACCOUNT_DATA_FIELD__", + USER_ACCOUNT_DATA_FIELD if include_account_data else "", + ) if include_emails: - return f"{USER_FIELDS}\n{USER_EMAIL_FIELDS}" - return USER_FIELDS + return f"{fields}\n{USER_EMAIL_FIELDS}" + return fields -def query_user_by_username(*, include_emails: bool = False) -> str: +def query_user_by_username( + *, + include_emails: bool = False, + include_account_data: bool = True, +) -> str: return f""" query UserByUsername($username: String!) {{ user(username: $username) {{ - {user_fields(include_emails=include_emails)} + {user_fields(include_emails=include_emails, include_account_data=include_account_data)} }} }} """ -def query_user_by_email(*, include_emails: bool = False) -> str: +def query_user_by_email( + *, + include_emails: bool = False, + include_account_data: bool = True, +) -> str: return f""" query UserByEmail($email: String!) {{ user(email: $email) {{ - {user_fields(include_emails=include_emails)} + {user_fields(include_emails=include_emails, include_account_data=include_account_data)} }} }} """ -def query_user_by_id(*, include_emails: bool = False) -> str: +def query_user_by_id( + *, + include_emails: bool = False, + include_account_data: bool = True, +) -> str: return f""" query UserByID($id: ID!) {{ node(id: $id) {{ ... on User {{ - {user_fields(include_emails=include_emails)} + {user_fields(include_emails=include_emails, include_account_data=include_account_data)} }} }} }} diff --git a/src/src_auth_perms_sync/permissions/restore.py b/src/src_auth_perms_sync/permissions/restore.py index b671b00..541102d 100644 --- a/src/src_auth_perms_sync/permissions/restore.py +++ b/src/src_auth_perms_sync/permissions/restore.py @@ -176,12 +176,9 @@ def _plan_user_scoped_restore( current_user = current_snapshot["users"].get(username) current_repos = { repository["id"]: repository["name"] - for repository in (current_user["explicit_repositories"] if current_user else []) - } - target_repos = { - repository["id"]: repository["name"] - for repository in target_user["explicit_repositories"] + for repository in (current_user["repos"] if current_user else []) } + target_repos = {repository["id"]: repository["name"] for repository in target_user["repos"]} for repo_id in sorted( set(target_repos) - set(current_repos), key=lambda value: target_repos[value], @@ -554,7 +551,11 @@ def _capture_restore_snapshot_state( users: list[shared_types.User] = [] current_snapshot = permission_snapshot.build_snapshot( client, - shared_sourcegraph.list_users_streaming(client, collect_into=users), + shared_sourcegraph.list_users_streaming( + client, + collect_into=users, + include_account_data=False, + ), parallelism, bind_id_mode, snapshot_path, diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index 61d602a..365bfb6 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -62,7 +62,7 @@ def compact_snapshot_users(users: Iterable[shared_types.User]) -> list[SnapshotU class UserScopedUserSnapshot(TypedDict): id: str - explicit_repositories: list[permission_types.Repository] + repos: list[permission_types.Repository] class UserScopedSnapshotStats(TypedDict): @@ -152,7 +152,7 @@ class UserScopedSnapshotDiff(TypedDict): users: list[UserScopedSnapshotDiffEntry] -SNAPSHOT_SCHEMA_VERSION: int = 4 +SNAPSHOT_SCHEMA_VERSION: int = 5 USER_SCOPED_SNAPSHOT_KIND = "user_scope" SNAPSHOT_DIFF_SCHEMA_VERSION: int = 1 @@ -521,11 +521,11 @@ def _fetch_or_empty( ): scoped_users[fetched_user["username"]] = { "id": fetched_user["id"], - "explicit_repositories": sorted(repos, key=lambda repo: repo["name"]), + "repos": sorted(repos, key=lambda repo: repo["name"]), } capture_event["scanned_user_count"] = len(scoped_users) capture_event["total_grants"] = sum( - len(user_snapshot["explicit_repositories"]) for user_snapshot in scoped_users.values() + len(user_snapshot["repos"]) for user_snapshot in scoped_users.values() ) return dict(sorted(scoped_users.items())) @@ -550,11 +550,9 @@ def build_user_scoped_snapshot( if config_path is not None and config_path.exists(): config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest() - total_grants = sum( - len(user_snapshot["explicit_repositories"]) for user_snapshot in scoped_users.values() - ) + total_grants = sum(len(user_snapshot["repos"]) for user_snapshot in scoped_users.values()) users_with_explicit_grants = sum( - 1 for user_snapshot in scoped_users.values() if user_snapshot["explicit_repositories"] + 1 for user_snapshot in scoped_users.values() if user_snapshot["repos"] ) build_event["scanned_user_count"] = len(scoped_users) build_event["users_with_explicit_grants"] = users_with_explicit_grants @@ -577,6 +575,49 @@ def build_user_scoped_snapshot( } +def build_user_scoped_snapshot_from_repos( + client: src.SourcegraphClient, + users: Iterable[SnapshotUser], + repos_by_user_id: dict[str, list[permission_types.Repository]], + bind_id_mode: str, + config_path: Path | None = None, +) -> UserScopedSnapshot: + """Build a user-scoped snapshot from explicit repos already fetched.""" + scoped_users: dict[str, UserScopedUserSnapshot] = { + user["username"]: { + "id": user["id"], + "repos": sorted( + repos_by_user_id.get(user["id"], []), + key=lambda repo: repo["name"], + ), + } + for user in users + } + config_sha: str | None = None + if config_path is not None and config_path.exists(): + config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest() + + total_grants = sum(len(user_snapshot["repos"]) for user_snapshot in scoped_users.values()) + users_with_explicit_grants = sum( + 1 for user_snapshot in scoped_users.values() if user_snapshot["repos"] + ) + return { + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "snapshot_kind": USER_SCOPED_SNAPSHOT_KIND, + "captured_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "endpoint": client.endpoint, + "bindID_mode": bind_id_mode, + "config_file": str(config_path.resolve()) if config_path else None, + "config_sha256": config_sha, + "stats": { + "total_users_scanned": len(scoped_users), + "users_with_explicit_grants": users_with_explicit_grants, + "total_grants": total_grants, + }, + "users": dict(sorted(scoped_users.items())), + } + + def _write_pretty_json(path: Path, value: Any) -> int: """Write pretty JSON without materializing the encoded string first.""" with path.open("w", encoding="utf-8") as output: @@ -660,8 +701,8 @@ def _write_user_scoped_snapshot_value( output.write(f'{field_indent}"id": ') json.dump(user_snapshot["id"], output) output.write(",\n") - output.write(f'{field_indent}"explicit_repositories": ') - _write_repository_list(output, user_snapshot["explicit_repositories"], indent + 2) + output.write(f'{field_indent}"repos": ') + _write_repository_list(output, user_snapshot["repos"], indent + 2) output.write("\n" + " " * indent + "}") @@ -846,12 +887,12 @@ def _encode_user_scoped_snapshot_raw(path: Path, raw: dict[str, Any]) -> UserSco raw["users"] = { username: { "id": user_snapshot["id"], - "explicit_repositories": [ + "repos": [ { "id": src.encode_repository_id(int(repo["id"])), "name": cast(str, repo["name"]), } - for repo in cast(list[dict[str, Any]], user_snapshot["explicit_repositories"]) + for repo in cast(list[dict[str, Any]], user_snapshot["repos"]) ], } for username, user_snapshot in on_disk_users.items() @@ -1494,10 +1535,7 @@ def _repositories_by_id( ) -> dict[str, str]: if user_snapshot is None: return {} - return { - repository["id"]: repository["name"] - for repository in user_snapshot["explicit_repositories"] - } + return {repository["id"]: repository["name"] for repository in user_snapshot["repos"]} def _permission_count(repo_snapshot: RepoSnapshot | None) -> int: diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index 6f51c46..2f01660 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -20,6 +20,7 @@ log = logging.getLogger(__name__) SITE_USER_CANDIDATE_PAGE_SIZE = 1000 +REPOSITORY_PAGE_SIZE = 1000 @dataclass(frozen=True) @@ -56,7 +57,7 @@ def list_repos_for_external_service( queries.QUERY_REPOS_BY_EXTERNAL_SERVICE, {"esID": external_service_id}, connection_path=("repositories",), - page_size=shared_sourcegraph.DEFAULT_PAGE_SIZE, + page_size=REPOSITORY_PAGE_SIZE, ) ] @@ -66,12 +67,16 @@ def get_user_by_username( username: str, *, include_emails: bool = False, + include_account_data: bool = True, ) -> shared_types.User | None: """Return the exact Sourcegraph user for `username`, if it exists.""" data = cast( dict[str, Any], client.graphql( - queries.query_user_by_username(include_emails=include_emails), + queries.query_user_by_username( + include_emails=include_emails, + include_account_data=include_account_data, + ), cast(src.JSONDict, {"username": username}), ), ) @@ -83,12 +88,16 @@ def get_user_by_email( email: str, *, include_emails: bool = False, + include_account_data: bool = True, ) -> shared_types.User | None: """Return the user owning the verified email address, if it exists.""" data = cast( dict[str, Any], client.graphql( - queries.query_user_by_email(include_emails=include_emails), + queries.query_user_by_email( + include_emails=include_emails, + include_account_data=include_account_data, + ), cast(src.JSONDict, {"email": email}), ), ) @@ -100,12 +109,16 @@ def get_user_by_id( user_id: str, *, include_emails: bool = False, + include_account_data: bool = True, ) -> shared_types.User | None: """Hydrate a User node by GraphQL ID.""" data = cast( dict[str, Any], client.graphql( - queries.query_user_by_id(include_emails=include_emails), + queries.query_user_by_id( + include_emails=include_emails, + include_account_data=include_account_data, + ), cast(src.JSONDict, {"id": user_id}), ), ) diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index c973d9e..f49e4e5 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -153,6 +153,63 @@ def load_mapping_context_for_rules( ) +def load_mapping_context_discovery( + client: src.SourcegraphClient, + mapping_rules: list[permission_types.MappingRule], + saml_groups_attribute_name_by_config_id: dict[str, str], +) -> permission_types.MappingContext: + """Load provider and code-host metadata without scanning repos.""" + providers, services, saml_groups_attribute_names = load_discovery( + client, saml_groups_attribute_name_by_config_id + ) + services_by_id: dict[int, permission_types.ExternalService] = { + src.decode_external_service_id(service["id"]): service for service in services + } + return permission_types.MappingContext( + mapping_rules=mapping_rules, + providers=providers, + saml_groups_attribute_names=saml_groups_attribute_names, + services_by_id=services_by_id, + repos_by_external_service_id={}, + all_repos_by_id={}, + ) + + +def load_repos_for_mapping_context( + client: src.SourcegraphClient, + context: permission_types.MappingContext, + service_ids: set[int] | None = None, +) -> permission_types.MappingContext: + """Return context with repos loaded for all or selected code hosts.""" + services_by_id = ( + context.services_by_id + if service_ids is None + else { + service_id: context.services_by_id[service_id] + for service_id in sorted(service_ids) + if service_id in context.services_by_id + } + ) + repos_by_external_service_id = { + **context.repos_by_external_service_id, + **load_repos_by_external_service(client, services_by_id), + } + all_repos_by_id = index_repos_by_id(repos_by_external_service_id) + log.info( + "Received %d unique repo(s) across %d loaded code host connection(s).", + len(all_repos_by_id), + len(repos_by_external_service_id), + ) + return permission_types.MappingContext( + mapping_rules=context.mapping_rules, + providers=context.providers, + saml_groups_attribute_names=context.saml_groups_attribute_names, + services_by_id=context.services_by_id, + repos_by_external_service_id=repos_by_external_service_id, + all_repos_by_id=all_repos_by_id, + ) + + def snapshot_path( input_path: Path, timestamp: str, diff --git a/src/src_auth_perms_sync/shared/queries.py b/src/src_auth_perms_sync/shared/queries.py index c833e42..0a8126d 100644 --- a/src/src_auth_perms_sync/shared/queries.py +++ b/src/src_auth_perms_sync/shared/queries.py @@ -44,10 +44,24 @@ } """ +USER_ACCOUNT_DATA_FIELD = """ # accountData is the parsed gosaml2 + # AssertionInfo JSON for SAML + # accounts (used by saml_groups extraction). The server gates + # it on Site Admin for SAML/OIDC; we already require Site + # Admin. Returns null for serviceType where the resolver does + # not expose data (e.g. plain GitHub OAuth without SSO). + accountData +""" -def query_users(*, include_emails: bool = False) -> str: - """Return the users page query, adding email fields only when requested.""" + +def query_users( + *, + include_emails: bool = False, + include_account_data: bool = True, +) -> str: + """Return the users page query, adding heavier fields only when requested.""" email_fields = USER_EMAIL_FIELDS if include_emails else "" + account_data_field = USER_ACCOUNT_DATA_FIELD if include_account_data else "" return f""" query ListUsers($first: Int!, $after: String) {{ users(first: $first, after: $after) {{ @@ -60,13 +74,7 @@ def query_users(*, include_emails: bool = False) -> str: serviceType serviceID clientID - # accountData is the parsed gosaml2 AssertionInfo JSON for SAML - # accounts (used by saml_groups extraction). The server gates - # it on Site Admin for SAML/OIDC; we already require Site - # Admin. Returns null for serviceType where the resolver does - # not expose data (e.g. plain GitHub OAuth without SSO). - accountData - }} +{account_data_field} }} }} }} pageInfo {{ hasNextPage endCursor }} diff --git a/src/src_auth_perms_sync/shared/saml_groups.py b/src/src_auth_perms_sync/shared/saml_groups.py index cec3f74..1dcc1dd 100644 --- a/src/src_auth_perms_sync/shared/saml_groups.py +++ b/src/src_auth_perms_sync/shared/saml_groups.py @@ -6,8 +6,8 @@ attribute named by the provider's `groupsAttributeName` site config (default `"groups"`). -This module does NOT fetch — it only parses what `list_users_with_accounts` -already pulled. Two on-disk shapes are handled defensively: +This module does NOT fetch — it only parses user rows fetched with +`include_account_data=True`. Two on-disk shapes are handled defensively: 1. Raw `*saml2.AssertionInfo`: accountData["Assertions"][i]["AttributeStatement"]["Attributes"][j] diff --git a/src/src_auth_perms_sync/shared/sourcegraph.py b/src/src_auth_perms_sync/shared/sourcegraph.py index f138e2d..e0ba123 100644 --- a/src/src_auth_perms_sync/shared/sourcegraph.py +++ b/src/src_auth_perms_sync/shared/sourcegraph.py @@ -36,11 +36,15 @@ def list_users_with_accounts( client: src.SourcegraphClient, *, include_emails: bool = False, + include_account_data: bool = True, ) -> list[shared_types.User]: return [ cast(shared_types.User, node) for node in client.stream_connection_nodes( - queries.query_users(include_emails=include_emails), + queries.query_users( + include_emails=include_emails, + include_account_data=include_account_data, + ), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ) @@ -52,6 +56,7 @@ def list_users_streaming( collect_into: list[shared_types.User] | None = None, *, include_emails: bool = False, + include_account_data: bool = True, ) -> Iterator[shared_types.User]: """Stream ListUsers pages one at a time, yielding each User as it arrives. @@ -65,7 +70,10 @@ def list_users_streaming( streaming benefit in one pass — no double-pagination. """ for node in client.stream_connection_nodes( - queries.query_users(include_emails=include_emails), + queries.query_users( + include_emails=include_emails, + include_account_data=include_account_data, + ), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ): diff --git a/tests/unit/test_command_additive.py b/tests/unit/test_command_additive.py new file mode 100644 index 0000000..ceac1cf --- /dev/null +++ b/tests/unit/test_command_additive.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import base64 +import tempfile +import unittest +from pathlib import Path +from typing import Any, cast + +import src_py_lib as src + +from src_auth_perms_sync.permissions import command +from src_auth_perms_sync.permissions import types as permission_types +from src_auth_perms_sync.shared import backups +from src_auth_perms_sync.shared import types as shared_types + + +class _AdditiveCommandClient: + endpoint = "https://sourcegraph.example.com" + + def __init__( + self, + *, + services: list[permission_types.ExternalService], + repos_by_service_id: dict[str, list[permission_types.Repository]], + users_by_username: dict[str, shared_types.User], + explicit_repo_ids_by_user_id: dict[str, list[str]] | None = None, + ) -> None: + self.services = services + self.repos_by_service_id = repos_by_service_id + self.users_by_username = users_by_username + self.explicit_repo_ids_by_user_id = explicit_repo_ids_by_user_id or {} + self.repo_service_ids: list[str] = [] + self.explicit_repo_fetch_count = 0 + + def graphql( + self, + query: str, + variables: src.JSONDict | None = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + del follow_pages + if "authProviders" in query: + return cast(src.JSONDict, {"site": {"authProviders": {"nodes": []}}}) + if "query UserByUsername" in query: + if variables is None: + raise AssertionError("expected username variables") + username = variables.get("username") + return cast(src.JSONDict, {"user": self.users_by_username.get(str(username))}) + if "query UserByEmail" in query: + return cast(src.JSONDict, {"user": None}) + if "query RepositoryNamesByID" in query: + if variables is None: + raise AssertionError("expected repository variables") + return cast(src.JSONDict, self._repositories_by_alias(variables)) + raise AssertionError(f"unexpected query: {query[:80]}") + + def stream_connection_nodes( + self, + query: str, + variables: src.JSONDict | None = None, + *, + connection_path: tuple[str, ...], + page_size: int, + ) -> list[dict[str, Any]]: + del connection_path, page_size + if "externalServices" in query: + return cast(list[dict[str, Any]], self.services) + if "query ReposByExternalService" in query: + service_id_value = None if variables is None else variables.get("esID") + if not isinstance(service_id_value, str): + raise AssertionError("expected external service ID") + self.repo_service_ids.append(service_id_value) + return cast( + list[dict[str, Any]], + self.repos_by_service_id.get(service_id_value, []), + ) + if "query UserExplicitRepos" in query: + user_id_value = None if variables is None else variables.get("id") + if not isinstance(user_id_value, str): + raise AssertionError("expected user ID") + self.explicit_repo_fetch_count += 1 + return [ + {"id": repository_id} + for repository_id in self.explicit_repo_ids_by_user_id.get(user_id_value, []) + ] + raise AssertionError(f"unexpected stream query: {query[:80]}") + + def _repositories_by_alias(self, variables: src.JSONDict) -> dict[str, object]: + repos_by_id = { + repository["id"]: repository + for repositories in self.repos_by_service_id.values() + for repository in repositories + } + response: dict[str, object] = {} + for variable_name, repository_id in variables.items(): + if not variable_name.startswith("repo") or not isinstance(repository_id, str): + continue + response[variable_name] = repos_by_id.get(repository_id) + return response + + +class AdditiveCommandTests(unittest.TestCase): + def test_no_backup_dry_run_skips_artifacts_and_repo_load_when_no_rule_matches( + self, + ) -> None: + service = make_external_service(1, "GitHub Enterprise") + client = _AdditiveCommandClient( + services=[service], + repos_by_service_id={service["id"]: [make_repository(1, "github.com/example/repo")]}, + users_by_username={"marc": make_user("user-1", "marc")}, + ) + + with tempfile.TemporaryDirectory() as directory_name: + directory = Path(directory_name) + maps_path = directory / "maps.yaml" + maps_path.write_text( + """ +maps: + - name: alice repos + users: + usernames: [alice] + repos: + codeHostConnection: + displayName: GitHub Enterprise + names: [github.com/example/repo] +""".lstrip(), + encoding="utf-8", + ) + run_directory = directory / "run-artifacts" + + with backups.run_artifacts_context(run_directory, "2026-06-09-10-00-00"): + command.cmd_set_additive_users( + cast(src.SourcegraphClient, client), + maps_path, + ("marc",), + None, + dry_run=True, + parallelism=1, + bind_id_mode="USERNAME", + saml_groups_attribute_name_by_config_id={}, + do_backup=False, + ) + + self.assertFalse(run_directory.exists()) + self.assertEqual([], client.repo_service_ids) + self.assertEqual(0, client.explicit_repo_fetch_count) + + def test_additive_users_loads_only_referenced_code_hosts(self) -> None: + first_service = make_external_service(1, "GitHub Enterprise") + second_service = make_external_service(2, "GitLab") + client = _AdditiveCommandClient( + services=[first_service, second_service], + repos_by_service_id={ + first_service["id"]: [make_repository(1, "github.com/example/repo")], + second_service["id"]: [make_repository(2, "gitlab.example.com/example/repo")], + }, + users_by_username={"alice": make_user("user-1", "alice")}, + ) + + with tempfile.TemporaryDirectory() as directory_name: + maps_path = Path(directory_name) / "maps.yaml" + maps_path.write_text( + """ +maps: + - name: alice repos + users: + usernames: [alice] + repos: + codeHostConnection: + displayName: GitHub Enterprise + names: [github.com/example/repo] +""".lstrip(), + encoding="utf-8", + ) + + command.cmd_set_additive_users( + cast(src.SourcegraphClient, client), + maps_path, + ("alice",), + None, + dry_run=True, + parallelism=1, + bind_id_mode="USERNAME", + saml_groups_attribute_name_by_config_id={}, + do_backup=False, + ) + + self.assertEqual([first_service["id"]], client.repo_service_ids) + self.assertEqual(1, client.explicit_repo_fetch_count) + + def test_backup_dry_run_reuses_planning_explicit_repo_read_for_snapshot(self) -> None: + service = make_external_service(1, "GitHub Enterprise") + client = _AdditiveCommandClient( + services=[service], + repos_by_service_id={service["id"]: [make_repository(1, "github.com/example/repo")]}, + users_by_username={"alice": make_user("user-1", "alice")}, + ) + + with tempfile.TemporaryDirectory() as directory_name: + directory = Path(directory_name) + maps_path = directory / "maps.yaml" + maps_path.write_text( + """ +maps: + - name: alice repos + users: + usernames: [alice] + repos: + codeHostConnection: + displayName: GitHub Enterprise + names: [github.com/example/repo] +""".lstrip(), + encoding="utf-8", + ) + run_directory = directory / "run-artifacts" + + with backups.run_artifacts_context(run_directory, "2026-06-09-10-00-00"): + command.cmd_set_additive_users( + cast(src.SourcegraphClient, client), + maps_path, + ("alice",), + None, + dry_run=True, + parallelism=1, + bind_id_mode="USERNAME", + saml_groups_attribute_name_by_config_id={}, + do_backup=True, + ) + + self.assertTrue((run_directory / "before.json").exists()) + self.assertTrue((run_directory / "after.json").exists()) + self.assertTrue((run_directory / "diff.json").exists()) + self.assertTrue((run_directory / "maps.yaml").exists()) + + self.assertEqual(1, client.explicit_repo_fetch_count) + + +def make_graphql_id(kind: str, identifier: int) -> str: + return base64.b64encode(f"{kind}:{identifier}".encode()).decode() + + +def make_user(user_id: str, username: str) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": True, + "externalAccounts": {"nodes": []}, + } + + +def make_repository(identifier: int, name: str) -> permission_types.Repository: + return {"id": make_graphql_id("Repository", identifier), "name": name} + + +def make_external_service(identifier: int, display_name: str) -> permission_types.ExternalService: + return { + "id": make_graphql_id("ExternalService", identifier), + "kind": "GITHUB", + "displayName": display_name, + "url": f"https://code-host-{identifier}.example.com", + "repoCount": 1, + "createdAt": "2026-06-09T00:00:00Z", + "updatedAt": "2026-06-09T00:00:00Z", + "lastSyncAt": None, + "nextSyncAt": None, + "lastSyncError": None, + "warning": None, + "unrestricted": False, + "suspended": False, + "hasConnectionCheck": False, + "supportsRepoExclusion": False, + "creator": None, + "lastUpdater": None, + "config": "{}", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_maps.py b/tests/unit/test_maps.py index d0d74f3..a57e120 100644 --- a/tests/unit/test_maps.py +++ b/tests/unit/test_maps.py @@ -135,6 +135,33 @@ def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: self.assertFalse(mapping.mapping_rules_need_user_emails(rules_without_email_filters)) self.assertTrue(mapping.mapping_rules_need_user_emails(rules_with_email_filters)) + def test_mapping_rules_need_saml_account_data_tracks_saml_group_filters(self) -> None: + rules_without_saml_group_filters = cast( + list[permission_types.MappingRule], + [ + { + "name": "provider only", + "users": {"authProvider": {"type": "saml"}}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + rules_with_saml_group_filters = cast( + list[permission_types.MappingRule], + [ + { + "name": "saml group", + "users": {"authProvider": {"type": "saml", "samlGroup": "eng"}}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + + self.assertFalse( + mapping.mapping_rules_need_saml_account_data(rules_without_saml_group_filters) + ) + self.assertTrue(mapping.mapping_rules_need_saml_account_data(rules_with_saml_group_filters)) + def test_user_filter_matchers_intersect_without_expanding_selection(self) -> None: providers: list[shared_types.AuthProvider] = [ { @@ -265,6 +292,39 @@ def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> Non ), ) + def test_service_ids_required_by_repository_selectors_uses_code_host_filter(self) -> None: + services_by_id = { + 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise", "enterprise-sync"), + 2: self.make_external_service(2, "GITHUB", "GitHub Cloud", "cloud-sync"), + } + + service_ids = mapping.service_ids_required_by_repository_selectors( + services_by_id, + [ + cast( + permission_types.RepositorySelector, + {"codeHostConnection": {"displayName": "GitHub Enterprise"}}, + ) + ], + ) + + self.assertEqual({1}, service_ids) + + def test_service_ids_required_by_repository_selectors_loads_all_for_global_filter( + self, + ) -> None: + services_by_id = { + 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise"), + 2: self.make_external_service(2, "GITLAB", "GitLab"), + } + + service_ids = mapping.service_ids_required_by_repository_selectors( + services_by_id, + [cast(permission_types.RepositorySelector, {"nameRegexes": [".*"]})], + ) + + self.assertEqual({1, 2}, service_ids) + def test_validate_mapping_rules_accepts_flat_text_selector_lists(self) -> None: mapping.validate_mapping_rules( cast( @@ -538,3 +598,15 @@ def test_user_email_fields_are_opt_in(self) -> None: self.assertNotIn("emails {", permission_queries.QUERY_USER_BY_ID) self.assertNotIn("emails {", permission_queries.query_user_by_id()) self.assertIn("emails {", permission_queries.query_user_by_id(include_emails=True)) + + def test_account_data_fields_are_opt_out(self) -> None: + self.assertIn("accountData", shared_queries.QUERY_USERS) + self.assertIn("accountData", shared_queries.query_users()) + self.assertNotIn("accountData", shared_queries.query_users(include_account_data=False)) + + self.assertIn("accountData", permission_queries.QUERY_USER_BY_ID) + self.assertIn("accountData", permission_queries.query_user_by_id()) + self.assertNotIn( + "accountData", + permission_queries.query_user_by_id(include_account_data=False), + ) diff --git a/tests/unit/test_permissions_sourcegraph.py b/tests/unit/test_permissions_sourcegraph.py index a4849c5..0ec3ae8 100644 --- a/tests/unit/test_permissions_sourcegraph.py +++ b/tests/unit/test_permissions_sourcegraph.py @@ -90,6 +90,27 @@ def graphql( return cast(src.JSONDict, response) +class _RepoConnectionClient: + def __init__(self) -> None: + self.page_sizes: list[int] = [] + self.variables: list[src.JSONDict | None] = [] + + def stream_connection_nodes( + self, + query: str, + variables: src.JSONDict | None = None, + *, + connection_path: tuple[str, ...], + page_size: int, + ) -> list[dict[str, str]]: + del connection_path + if "query ReposByExternalService" not in query: + raise AssertionError(f"unexpected query: {query[:80]}") + self.page_sizes.append(page_size) + self.variables.append(variables) + return [{"id": "repo-1", "name": "github.com/example/repo"}] + + class _PipelinedCandidateClient: def __init__(self) -> None: self.total_count = 1001 @@ -166,6 +187,18 @@ def _explicit_repos(self, variables: src.JSONDict) -> src.JSONDict: class PermissionsSourcegraphTests(unittest.TestCase): + def test_list_repos_for_external_service_uses_larger_pages(self) -> None: + client = _RepoConnectionClient() + + repos = permissions_sourcegraph.list_repos_for_external_service( + cast(src.SourcegraphClient, client), + "external-service-1", + ) + + self.assertEqual(repos, [{"id": "repo-1", "name": "github.com/example/repo"}]) + self.assertEqual(client.page_sizes, [1000]) + self.assertEqual(client.variables, [{"esID": "external-service-1"}]) + def test_list_site_user_candidates_uses_larger_pages(self) -> None: client = _SiteUsersClient(total_count=2500) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index da28a85..3699426 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -245,6 +245,48 @@ def test_write_snapshot_uses_short_users_key_for_explicit_permissions(self) -> N ) self.assertEqual({"name", "users"}, set(on_disk["repos"]["1"])) + def test_write_user_scoped_snapshot_uses_short_repos_key(self) -> None: + repo_id = src.encode_repository_id(1) + snapshot: permission_snapshot.UserScopedSnapshot = { + "schema_version": permission_snapshot.SNAPSHOT_SCHEMA_VERSION, + "snapshot_kind": permission_snapshot.USER_SCOPED_SNAPSHOT_KIND, + "captured_at": "2026-05-26T00:00:00+00:00", + "endpoint": "https://sourcegraph.example.com", + "bindID_mode": "USERNAME", + "config_file": None, + "config_sha256": None, + "stats": { + "total_users_scanned": 1, + "users_with_explicit_grants": 1, + "total_grants": 1, + }, + "users": { + "alice": { + "id": "user-1", + "repos": [ + { + "id": repo_id, + "name": "github.com/sourcegraph/example", + } + ], + } + }, + } + + with tempfile.TemporaryDirectory() as directory_name: + snapshot_path = Path(directory_name) / "before.json" + + permission_snapshot.write_user_scoped_snapshot(snapshot_path, snapshot) + on_disk = json.loads(snapshot_path.read_text()) + loaded_snapshot = permission_snapshot.read_user_scoped_snapshot(snapshot_path) + + self.assertEqual( + [{"id": 1, "name": "github.com/sourcegraph/example"}], + on_disk["users"]["alice"]["repos"], + ) + self.assertNotIn("explicit_repositories", on_disk["users"]["alice"]) + self.assertEqual(repo_id, loaded_snapshot["users"]["alice"]["repos"][0]["id"]) + def test_snapshot_diff_omits_unchanged_users(self) -> None: before = self.make_snapshot() after = self.make_snapshot() @@ -335,7 +377,7 @@ def test_read_snapshot_rejects_old_schema_versions(self) -> None: with self.assertRaises(SystemExit) as exit_context: permission_snapshot.read_snapshot(snapshot_path) - self.assertIn("expected 4", str(exit_context.exception)) + self.assertIn("expected 5", str(exit_context.exception)) def make_snapshot(self) -> permission_snapshot.Snapshot: return { From 399862d4faeaa0ba66bba40de8e34dc3c44fd86e Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 04:16:46 -0600 Subject: [PATCH 15/17] Move verbose output to debug --- src/src_auth_perms_sync/permissions/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/src_auth_perms_sync/permissions/apply.py b/src/src_auth_perms_sync/permissions/apply.py index 77a1938..2951d61 100644 --- a/src/src_auth_perms_sync/permissions/apply.py +++ b/src/src_auth_perms_sync/permissions/apply.py @@ -230,7 +230,7 @@ def record_result(result: run_context.ParallelResult[PermissionChange, None]) -> if exception is None: succeeded += 1 breaker.record(success=True) - log.info( + log.debug( " OK %s %s → %s (id=%d).", action, change.username, From a0aced34a0218485ee2591a2e1c8d23026e8ab7b Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 04:24:16 -0600 Subject: [PATCH 16/17] Make --created-after additive instead of overwriting --- dev/TODO.md | 13 +- dev/test-end-to-end.py | 13 +- src/src_auth_perms_sync/cli.py | 23 ++- .../permissions/command.py | 151 ++++++++++++++++++ src/src_auth_perms_sync/permissions/types.py | 1 + tests/unit/test_cli_config.py | 23 ++- 6 files changed, 205 insertions(+), 19 deletions(-) diff --git a/dev/TODO.md b/dev/TODO.md index dfae3b3..d55f19a 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -1,15 +1,5 @@ # TODO -## High priority: Sync modes - -### Fast - -- Query the instance for all new repos, which do not yet have explicit perms - -### Full: Overwrite all perms - -- Separate full sync mode with an arg - ## High priority: Remote trigger on demand - Sourcegraph webhook for new user coming in v7.4.0 @@ -87,6 +77,9 @@ If/when we revisit: 3. Add a CLI flag (e.g. `--cross-check-capture`) gated behind a clear "this doubles capture cost" warning. +Also, create a fast mode to query the instance for all new repos, +which do not yet have explicit perms + ## Low priority: Grouped full-set plan if memory is still too high Phase 1 now avoids per-repo username sets for non-overlapping full-set maps. diff --git a/dev/test-end-to-end.py b/dev/test-end-to-end.py index 5060df9..b6952cd 100755 --- a/dev/test-end-to-end.py +++ b/dev/test-end-to-end.py @@ -1575,6 +1575,12 @@ def invalid_configuration_cases(config: EndToEndConfig) -> list[CommandCase]: expected_exit_code=2, must_contain=("choose at most one",), ), + CommandCase( + name="invalid-set-full-and-created-after", + arguments=("set", "--full", "--created-after", config.future_date), + expected_exit_code=2, + must_contain=("--full cannot be combined with --created-after",), + ), CommandCase( name="invalid-user-filter-conflict", arguments=("get", "--users", config.user, "--users-without-explicit-perms"), @@ -1682,10 +1688,9 @@ def read_only_cases(config: EndToEndConfig) -> list[CommandCase]: def run_safe_set_cases(config: EndToEndConfig, runner: CommandPermutationRunner) -> None: runner.run( CommandCase( - name="set-explicit-full-no-op-apply", + name="set-created-after-no-op-apply", arguments=( "set", - "--full", "--created-after", config.future_date, "--apply", @@ -1693,8 +1698,8 @@ def run_safe_set_cases(config: EndToEndConfig, runner: CommandPermutationRunner) "--parallelism", str(config.parallelism), ), - expected_log_command="set_full", - must_contain=("No repos resolved across any mapping",), + expected_log_command="set_created_after", + must_contain=("No users selected",), ) ) diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 96157c1..c2bed4b 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -80,27 +80,32 @@ "set_full", "set_users", "set_users_without_explicit_perms", + "set_created_after", "restore", "sync_saml_orgs", "set_full_sync_saml_orgs", "set_users_sync_saml_orgs", "set_users_without_explicit_perms_sync_saml_orgs", + "set_created_after_sync_saml_orgs", ] SET_COMMAND_LOG_NAMES: dict[permission_types.SetCommandMode, LogCommandName] = { "full": "set_full", "users": "set_users", "users_without_explicit_perms": "set_users_without_explicit_perms", + "created_after": "set_created_after", } SET_COMMAND_ARTIFACT_NAMES: dict[permission_types.SetCommandMode, str] = { "full": "set-{run_mode}", "users": "set-add-users-{run_mode}", "users_without_explicit_perms": "set-add-users-without-explicit-perms-{run_mode}", + "created_after": "set-add-users-created-after-{run_mode}", } SYNC_SET_COMMAND_LOG_NAMES: dict[permission_types.SetCommandMode, LogCommandName] = { "full": "set_full_sync_saml_orgs", "users": "set_users_sync_saml_orgs", "users_without_explicit_perms": "set_users_without_explicit_perms_sync_saml_orgs", + "created_after": "set_created_after_sync_saml_orgs", } SYNC_SET_COMMAND_ARTIFACT_NAMES: dict[permission_types.SetCommandMode, str] = { "full": "set-sync-saml-orgs-{run_mode}", @@ -108,6 +113,7 @@ "users_without_explicit_perms": ( "set-add-users-without-explicit-perms-sync-saml-orgs-{run_mode}" ), + "created_after": "set-add-users-created-after-sync-saml-orgs-{run_mode}", } @@ -179,7 +185,10 @@ class Config(src.SourcegraphClientConfig, src.LoggingConfig, src.OpenTelemetryCo env_var="SRC_AUTH_PERMS_SYNC_FULL", cli_flag="--full", cli_action="store_true", - help="With the set command: run the full overwrite reconciliation mode (default)", + help=( + "With the set command: run full overwrite reconciliation " + "(default only when no user filter is set)" + ), help_group="Permission sync", ) users: tuple[str, ...] = src.config_field( @@ -367,6 +376,12 @@ def validate_set_mode_selection(command_name: CommandName, config: Config) -> No if command_name != "set": return + if config.full and config.created_after is not None: + config_error( + "--full cannot be combined with --created-after because full mode " + "overwrites mapped repos; omit --full to add grants for new users" + ) + if sum((config.full, bool(config.users), config.users_without_explicit_perms)) > 1: config_error( "with set, choose at most one of --full, --users, or --users-without-explicit-perms" @@ -386,9 +401,13 @@ def set_command_options(config: Config) -> permission_types.SetCommandOptions: mode="users_without_explicit_perms", user_created_after=config.created_after, ) + if config.created_after is not None: + return permission_types.SetCommandOptions( + mode="created_after", + user_created_after=config.created_after, + ) return permission_types.SetCommandOptions( mode="full", - user_created_after=config.created_after, ) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 1fb4055..91932ba 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -528,6 +528,19 @@ def cmd_set( do_backup, worker_pool, ) + if options.mode == "created_after": + assert options.user_created_after is not None + return cmd_set_additive_created_after( + client, + input_path, + options.user_created_after, + dry_run, + parallelism, + bind_id_mode, + saml_groups_attribute_name_by_config_id, + do_backup, + worker_pool, + ) return run_context.CommandData() @@ -811,6 +824,144 @@ def cmd_set_additive_users_without_explicit_perms( return run_context.CommandData(auth_providers=context.providers) +def cmd_set_additive_created_after( + client: src.SourcegraphClient, + input_path: Path, + user_created_after: str, + dry_run: bool, + parallelism: int, + bind_id_mode: str, + saml_groups_attribute_name_by_config_id: dict[str, str], + do_backup: bool, + worker_pool: ThreadPoolExecutor | None = None, +) -> run_context.CommandData: + """Add missing mapped permissions for users created on or after a date.""" + created_after_filter = sourcegraph_datetime_filter( + parse_cli_date(user_created_after, "--created-after") + ) + with src.span( + "cmd_set_additive_created_after", + input_path=str(input_path), + user_created_after=user_created_after, + dry_run=dry_run, + parallelism=parallelism, + do_backup=do_backup, + ): + mapping_rules = load_mapping_rules(input_path) + if not mapping_rules: + log.warning("No maps defined in %s — nothing to do.", input_path) + return run_context.CommandData() + context = load_mapping_context_discovery( + client, + mapping_rules, + saml_groups_attribute_name_by_config_id, + ) + include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) + include_user_account_data = permissions_mapping.mapping_rules_need_saml_account_data( + mapping_rules + ) + candidates = permissions_sourcegraph.list_site_user_candidates( + client, + created_after_filter, + parallelism=parallelism, + worker_pool=worker_pool, + ) + log.info( + "Selected %d active user candidate(s) created on or after %s.", + len(candidates), + user_created_after, + ) + users = _hydrate_site_user_candidates( + client, + candidates, + include_emails=include_user_emails, + include_account_data=include_user_account_data, + parallelism=parallelism, + worker_pool=worker_pool, + ) + if not users: + _run_additive_apply( + client, + input_path, + users, + [], + dry_run=dry_run, + parallelism=parallelism, + bind_id_mode=bind_id_mode, + do_backup=do_backup, + command_name="set-add-users-created-after", + worker_pool=worker_pool, + ) + return run_context.CommandData(auth_providers=context.providers) + + matching_rules = _mapping_rules_matching_selected_users(context, users) + log.info( + "%d / %d mapping rule(s) match the selected user(s).", + len(matching_rules), + len(context.mapping_rules), + ) + if not matching_rules: + _run_additive_apply( + client, + input_path, + users, + [], + dry_run=dry_run, + parallelism=parallelism, + bind_id_mode=bind_id_mode, + do_backup=do_backup, + command_name="set-add-users-created-after", + worker_pool=worker_pool, + ) + return run_context.CommandData(auth_providers=context.providers) + + service_ids = _service_ids_required_by_mapping_rules(context, matching_rules) + log.info( + "Selected mapping rule(s) require repo scans for %d / %d code host connection(s).", + len(service_ids), + len(context.services_by_id), + ) + context = load_repos_for_mapping_context( + client, + _mapping_context_with_rules(context, matching_rules), + service_ids, + ) + resolved_mappings = resolve_additive_mappings(context) + additions: list[permissions_apply.PermissionAddition] = [] + existing_repos_by_user_id = ( + _load_selected_user_explicit_repos(client, users) if do_backup else None + ) + for user in users: + existing_repo_ids = None + if existing_repos_by_user_id is not None: + existing_repo_ids = { + repository["id"] for repository in existing_repos_by_user_id[user["id"]] + } + additions.extend( + _plan_additions_for_user( + client, + context, + resolved_mappings, + user, + existing_repo_ids=existing_repo_ids, + ) + ) + _run_additive_apply( + client, + input_path, + users, + additions, + dry_run=dry_run, + parallelism=parallelism, + bind_id_mode=bind_id_mode, + do_backup=do_backup, + command_name="set-add-users-created-after", + existing_repos_by_user_id=existing_repos_by_user_id, + worker_pool=worker_pool, + ) + return run_context.CommandData(auth_providers=context.providers) + + def _resolve_user_identifiers( client: src.SourcegraphClient, user_identifiers: tuple[str, ...], diff --git a/src/src_auth_perms_sync/permissions/types.py b/src/src_auth_perms_sync/permissions/types.py index 4f57eed..9f81885 100644 --- a/src/src_auth_perms_sync/permissions/types.py +++ b/src/src_auth_perms_sync/permissions/types.py @@ -11,6 +11,7 @@ "full", "users", "users_without_explicit_perms", + "created_after", ] diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index 53c099a..28e2063 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -162,11 +162,11 @@ def test_set_command_options_match_each_incremental_mode(self) -> None: ) self.assertEqual("users_without_explicit_perms", users_without_permissions.mode) self.assertEqual("2026-01-01", users_without_permissions.user_created_after) - filtered_full = cli.set_command_options( + created_after = cli.set_command_options( make_config(maps_path=Path("maps.yaml"), created_after="2026-01-01") ) - self.assertEqual("full", filtered_full.mode) - self.assertEqual("2026-01-01", filtered_full.user_created_after) + self.assertEqual("created_after", created_after.mode) + self.assertEqual("2026-01-01", created_after.user_created_after) def test_resolve_command_includes_set_mode_names(self) -> None: users_command = cli.resolve_command( @@ -174,12 +174,22 @@ def test_resolve_command_includes_set_mode_names(self) -> None: make_config(maps_path=Path("maps.yaml"), users=("alice",), apply=True), ) full_command = cli.resolve_command("set", make_config(maps_path=Path("maps.yaml"))) + created_after_command = cli.resolve_command( + "set", + make_config(maps_path=Path("maps.yaml"), created_after="2026-01-01"), + ) self.assertEqual("set_users", users_command.log_name) self.assertEqual("set-add-users-apply", users_command.artifact_name) self.assertEqual("users", users_command.set_mode) self.assertEqual("set_full", full_command.log_name) self.assertEqual("set-dry-run", full_command.artifact_name) + self.assertEqual("set_created_after", created_after_command.log_name) + self.assertEqual( + "set-add-users-created-after-dry-run", + created_after_command.artifact_name, + ) + self.assertEqual("created_after", created_after_command.set_mode) def test_resolve_command_includes_combined_set_sync_names(self) -> None: set_command = cli.resolve_command( @@ -376,6 +386,13 @@ def test_validate_config_rejects_multiple_set_modes(self) -> None: "choose at most one", ) + def test_validate_config_rejects_full_created_after(self) -> None: + self.assert_config_error( + "set", + make_config(maps_path=Path("maps.yaml"), full=True, created_after="2026-01-01"), + "--full cannot be combined with --created-after", + ) + def test_require_set_input_file_reports_missing_maps_file(self) -> None: with tempfile.TemporaryDirectory() as directory: existing_path = Path(directory) / "maps.yaml" From 95a71da89390e30757478bfde7c1cc34ef7bc20b Mon Sep 17 00:00:00 2001 From: Marc LeBlanc <7050295+marcleblanc2@users.noreply.github.com> Date: Tue, 9 Jun 2026 04:52:17 -0600 Subject: [PATCH 17/17] Add repo-scoped permission filters Amp-Thread-ID: https://ampcode.com/threads/T-019eab3a-95ea-74bb-8a8a-5950a3e3a9c1 Co-authored-by: Amp --- dev/TODO.md | 3 - src/src_auth_perms_sync/cli.py | 121 +++++++++++- .../permissions/command.py | 160 +++++++++++++++- .../permissions/full_set.py | 172 +++++++++++++++++- .../permissions/queries.py | 61 +++++++ .../permissions/snapshot.py | 45 ++++- .../permissions/sourcegraph.py | 84 +++++++++ src/src_auth_perms_sync/permissions/types.py | 5 + .../permissions/workflow.py | 66 +++++++ tests/unit/test_cli_config.py | 102 +++++++++++ tests/unit/test_snapshot.py | 19 ++ 11 files changed, 814 insertions(+), 24 deletions(-) diff --git a/dev/TODO.md b/dev/TODO.md index d55f19a..33914b2 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -77,9 +77,6 @@ If/when we revisit: 3. Add a CLI flag (e.g. `--cross-check-capture`) gated behind a clear "this doubles capture cost" warning. -Also, create a fast mode to query the instance for all new repos, -which do not yet have explicit perms - ## Low priority: Grouped full-set plan if memory is still too high Phase 1 now avoids per-repo username sets for non-overlapping full-set maps. diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index c2bed4b..f3b63b2 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -47,6 +47,9 @@ "users", "users_without_explicit_perms", "created_after", + "repos", + "repos_without_explicit_perms", + "repos_created_after", "no_backup", "explicit_permissions_batch_size", *COMMON_CONFIG_FIELDS, @@ -57,6 +60,9 @@ "users", "users_without_explicit_perms", "created_after", + "repos", + "repos_without_explicit_perms", + "repos_created_after", "sync_saml_organizations", "apply", "no_backup", @@ -81,12 +87,18 @@ "set_users", "set_users_without_explicit_perms", "set_created_after", + "set_repos", + "set_repos_without_explicit_perms", + "set_repos_created_after", "restore", "sync_saml_orgs", "set_full_sync_saml_orgs", "set_users_sync_saml_orgs", "set_users_without_explicit_perms_sync_saml_orgs", "set_created_after_sync_saml_orgs", + "set_repos_sync_saml_orgs", + "set_repos_without_explicit_perms_sync_saml_orgs", + "set_repos_created_after_sync_saml_orgs", ] SET_COMMAND_LOG_NAMES: dict[permission_types.SetCommandMode, LogCommandName] = { @@ -94,18 +106,27 @@ "users": "set_users", "users_without_explicit_perms": "set_users_without_explicit_perms", "created_after": "set_created_after", + "repos": "set_repos", + "repos_without_explicit_perms": "set_repos_without_explicit_perms", + "repos_created_after": "set_repos_created_after", } SET_COMMAND_ARTIFACT_NAMES: dict[permission_types.SetCommandMode, str] = { "full": "set-{run_mode}", "users": "set-add-users-{run_mode}", "users_without_explicit_perms": "set-add-users-without-explicit-perms-{run_mode}", "created_after": "set-add-users-created-after-{run_mode}", + "repos": "set-repos-{run_mode}", + "repos_without_explicit_perms": "set-repos-without-explicit-perms-{run_mode}", + "repos_created_after": "set-repos-created-after-{run_mode}", } SYNC_SET_COMMAND_LOG_NAMES: dict[permission_types.SetCommandMode, LogCommandName] = { "full": "set_full_sync_saml_orgs", "users": "set_users_sync_saml_orgs", "users_without_explicit_perms": "set_users_without_explicit_perms_sync_saml_orgs", "created_after": "set_created_after_sync_saml_orgs", + "repos": "set_repos_sync_saml_orgs", + "repos_without_explicit_perms": "set_repos_without_explicit_perms_sync_saml_orgs", + "repos_created_after": "set_repos_created_after_sync_saml_orgs", } SYNC_SET_COMMAND_ARTIFACT_NAMES: dict[permission_types.SetCommandMode, str] = { "full": "set-sync-saml-orgs-{run_mode}", @@ -114,6 +135,9 @@ "set-add-users-without-explicit-perms-sync-saml-orgs-{run_mode}" ), "created_after": "set-add-users-created-after-sync-saml-orgs-{run_mode}", + "repos": "set-repos-sync-saml-orgs-{run_mode}", + "repos_without_explicit_perms": ("set-repos-without-explicit-perms-sync-saml-orgs-{run_mode}"), + "repos_created_after": "set-repos-created-after-sync-saml-orgs-{run_mode}", } @@ -216,6 +240,31 @@ class Config(src.SourcegraphClientConfig, src.LoggingConfig, src.OpenTelemetryCo help="Process Sourcegraph users created on or after this date", help_group="User filters", ) + repos: tuple[str, ...] = src.config_field( + default=(), + env_var="SRC_AUTH_PERMS_SYNC_REPOS", + cli_flag="--repos", + metavar="REPOS", + help="Process comma-delimited Sourcegraph repository names", + help_group="Repo filters", + ) + repos_without_explicit_perms: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_REPOS_WITHOUT_EXPLICIT_PERMS", + cli_flag="--repos-without-explicit-perms", + cli_action="store_true", + help="Process Sourcegraph repositories without explicit permissions", + help_group="Repo filters", + ) + repos_created_after: str | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_REPOS_CREATED_AFTER", + cli_flag="--repos-created-after", + metavar="YYYY-MM-DD", + pattern=r"^\d{4}-\d{2}-\d{2}$", + help="Process Sourcegraph repositories created on or after this date", + help_group="Repo filters", + ) sync_saml_organizations: bool = src.config_field( default=False, env_var="SRC_AUTH_PERMS_SYNC_SYNC_SAML_ORGS", @@ -237,7 +286,7 @@ class Config(src.SourcegraphClientConfig, src.LoggingConfig, src.OpenTelemetryCo env_var="SRC_AUTH_PERMS_SYNC_NO_BACKUP", cli_flag="--no-backup", cli_action="store_true", - help="With mutating commands: skip before/after snapshots and validation", + help="Skip before/after snapshot artifacts and validation where supported", help_group="Mutation", ) parallelism: int = src.config_field( @@ -339,6 +388,7 @@ def validate_config(command_name: CommandName, config: Config) -> None: """Validate cross-field CLI/config constraints.""" validate_command_options(command_name, config) validate_user_filter_selection(command_name, config) + validate_repository_filter_selection(command_name, config) validate_set_mode_selection(command_name, config) @@ -368,6 +418,34 @@ def validate_user_filter_selection(command_name: CommandName, config: Config) -> ) +def validate_repository_filter_selection(command_name: CommandName, config: Config) -> None: + """Validate repo-scope filters and their compatible commands.""" + repository_filter_count = sum( + ( + bool(config.repos), + config.repos_without_explicit_perms, + config.repos_created_after is not None, + ) + ) + if repository_filter_count > 1: + config_error( + "choose only one of --repos, --repos-without-explicit-perms, or --repos-created-after" + ) + + repository_filter_selected = repository_filter_count > 0 + repository_filter_allowed = command_name in {"get", "set"} + if repository_filter_selected and not repository_filter_allowed: + config_error( + "--repos, --repos-without-explicit-perms, and --repos-created-after require get or set" + ) + + user_filter_selected = any( + (bool(config.users), config.users_without_explicit_perms, config.created_after is not None) + ) + if repository_filter_selected and user_filter_selected: + config_error("choose either user filters or repo filters, not both") + + def validate_set_mode_selection(command_name: CommandName, config: Config) -> None: """Validate set command mode flags.""" if config.full and command_name != "set": @@ -382,9 +460,23 @@ def validate_set_mode_selection(command_name: CommandName, config: Config) -> No "overwrites mapped repos; omit --full to add grants for new users" ) - if sum((config.full, bool(config.users), config.users_without_explicit_perms)) > 1: + if ( + sum( + ( + config.full, + bool(config.users), + config.users_without_explicit_perms, + bool(config.repos), + config.repos_without_explicit_perms, + config.repos_created_after is not None, + ) + ) + > 1 + ): config_error( - "with set, choose at most one of --full, --users, or --users-without-explicit-perms" + "with set, choose at most one of --full, --users, " + "--users-without-explicit-perms, --repos, " + "--repos-without-explicit-perms, or --repos-created-after" ) @@ -406,6 +498,20 @@ def set_command_options(config: Config) -> permission_types.SetCommandOptions: mode="created_after", user_created_after=config.created_after, ) + if config.repos: + return permission_types.SetCommandOptions( + mode="repos", + repository_names=config.repos, + ) + if config.repos_without_explicit_perms: + return permission_types.SetCommandOptions( + mode="repos_without_explicit_perms", + ) + if config.repos_created_after is not None: + return permission_types.SetCommandOptions( + mode="repos_created_after", + repository_created_after=config.repos_created_after, + ) return permission_types.SetCommandOptions( mode="full", ) @@ -548,6 +654,12 @@ def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[ fields["sync_saml_orgs"] = True if config.created_after is not None: fields["created_after"] = config.created_after + if config.repos: + fields["repos"] = config.repos + if config.repos_without_explicit_perms: + fields["repos_without_explicit_perms"] = True + if config.repos_created_after is not None: + fields["repos_created_after"] = config.repos_created_after return fields @@ -705,6 +817,9 @@ def run_get( user_identifiers=config.users, users_without_explicit_perms=config.users_without_explicit_perms, user_created_after=config.created_after, + repository_names=config.repos, + repositories_without_explicit_perms=config.repos_without_explicit_perms, + repository_created_after=config.repos_created_after, parallelism=config.parallelism, explicit_permissions_batch_size=config.explicit_permissions_batch_size, bind_id_mode=sourcegraph_site_config.bind_id_mode, diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 91932ba..42f4b28 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -28,6 +28,8 @@ load_mapping_context_discovery, load_mapping_rules, load_repos_for_mapping_context, + load_repository_candidates_by_names, + load_repository_candidates_created_on_or_after, parse_cli_date, snapshot_path, sourcegraph_datetime_filter, @@ -129,6 +131,67 @@ def _providers_need_saml_account_data(providers: list[shared_types.AuthProvider] return any(provider["serviceType"] == saml_groups.SAML_SERVICE_TYPE for provider in providers) +def _repository_filter_selected( + repository_names: tuple[str, ...], + repositories_without_explicit_perms: bool, + repository_created_after: str | None, +) -> bool: + return any( + ( + bool(repository_names), + repositories_without_explicit_perms, + repository_created_after is not None, + ) + ) + + +def _repository_ids(candidates: list[permissions_sourcegraph.RepositoryCandidate]) -> set[str]: + return {candidate.repository["id"] for candidate in candidates} + + +def _load_get_repository_filter_ids( + client: src.SourcegraphClient, + *, + repository_names: tuple[str, ...], + repository_created_after: str | None, +) -> set[str] | None: + """Return selected repo IDs for get snapshot filtering when known up front.""" + if repository_names: + return _repository_ids(load_repository_candidates_by_names(client, repository_names)) + if repository_created_after is not None: + return _repository_ids( + load_repository_candidates_created_on_or_after( + client, + repository_created_after, + "--repos-created-after", + ) + ) + return None + + +def _filter_get_snapshot_to_repositories_without_explicit_perms( + client: src.SourcegraphClient, + before_snapshot: permission_snapshot.Snapshot, +) -> permission_snapshot.Snapshot: + """Return a get snapshot scoped to repos with no explicit API grants.""" + candidates = permissions_sourcegraph.list_repository_candidates(client) + explicit_repository_ids = set(before_snapshot["repos"]) + selected_repository_ids = { + candidate.repository["id"] + for candidate in candidates + if candidate.repository["id"] not in explicit_repository_ids + } + log.info( + "Selected %d / %d repo(s) without explicit repo permissions.", + len(selected_repository_ids), + len(candidates), + ) + return permission_snapshot.snapshot_with_repository_filter( + before_snapshot, + selected_repository_ids, + ) + + def cmd_get( client: src.SourcegraphClient, code_hosts_path: Path, @@ -138,6 +201,9 @@ def cmd_get( user_identifiers: tuple[str, ...], users_without_explicit_perms: bool, user_created_after: str | None, + repository_names: tuple[str, ...], + repositories_without_explicit_perms: bool, + repository_created_after: str | None, parallelism: int, explicit_permissions_batch_size: int, bind_id_mode: str, @@ -173,6 +239,12 @@ def cmd_get( cmd_fields["users_without_explicit_perms"] = True if user_created_after is not None: cmd_fields["created_after"] = user_created_after + if repository_names: + cmd_fields["repositories"] = repository_names + if repositories_without_explicit_perms: + cmd_fields["repositories_without_explicit_perms"] = True + if repository_created_after is not None: + cmd_fields["repositories_created_after"] = repository_created_after if not do_backup: cmd_fields["backup"] = False @@ -238,6 +310,11 @@ def cmd_get( log.info("Wrote %s and %s", code_hosts_path, auth_providers_path) if do_backup: + selected_repository_ids = _load_get_repository_filter_ids( + client, + repository_names=repository_names, + repository_created_after=repository_created_after, + ) timestamp = backups.backup_timestamp() before_snapshot = permission_snapshot.build_snapshot( client, @@ -248,7 +325,13 @@ def cmd_get( expected_user_count=len(users), explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, + selected_repository_ids=selected_repository_ids, ) + if repositories_without_explicit_perms: + before_snapshot = _filter_get_snapshot_to_repositories_without_explicit_perms( + client, + before_snapshot, + ) before_path = snapshot_path(maps_path, timestamp, client.endpoint, "get", "before") permission_snapshot.write_snapshot(before_path, before_snapshot) cmd_event["before_snapshot_path"] = str(before_path) @@ -272,6 +355,11 @@ def cmd_get( if not user_identifiers and not users_without_explicit_perms and user_created_after is None + and not _repository_filter_selected( + repository_names, + repositories_without_explicit_perms, + repository_created_after, + ) and retain_saml_group_users else None ) @@ -492,14 +580,70 @@ def cmd_set( client, input_path, options.user_created_after, - dry_run, - parallelism, - explicit_permissions_batch_size, - bind_id_mode, - saml_groups_attribute_name_by_config_id, - do_backup, - retain_saml_group_users, - worker_pool, + repository_names=(), + repositories_without_explicit_perms=False, + repository_created_after=None, + dry_run=dry_run, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + bind_id_mode=bind_id_mode, + saml_groups_attribute_name_by_config_id=saml_groups_attribute_name_by_config_id, + do_backup=do_backup, + retain_saml_group_users=retain_saml_group_users, + worker_pool=worker_pool, + ) + if options.mode == "repos": + assert options.repository_names + return permissions_full_set.cmd_set_full( + client, + input_path, + None, + repository_names=options.repository_names, + repositories_without_explicit_perms=False, + repository_created_after=None, + dry_run=dry_run, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + bind_id_mode=bind_id_mode, + saml_groups_attribute_name_by_config_id=saml_groups_attribute_name_by_config_id, + do_backup=do_backup, + retain_saml_group_users=retain_saml_group_users, + worker_pool=worker_pool, + ) + if options.mode == "repos_without_explicit_perms": + return permissions_full_set.cmd_set_full( + client, + input_path, + None, + repository_names=(), + repositories_without_explicit_perms=True, + repository_created_after=None, + dry_run=dry_run, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + bind_id_mode=bind_id_mode, + saml_groups_attribute_name_by_config_id=saml_groups_attribute_name_by_config_id, + do_backup=do_backup, + retain_saml_group_users=retain_saml_group_users, + worker_pool=worker_pool, + ) + if options.mode == "repos_created_after": + assert options.repository_created_after is not None + return permissions_full_set.cmd_set_full( + client, + input_path, + None, + repository_names=(), + repositories_without_explicit_perms=False, + repository_created_after=options.repository_created_after, + dry_run=dry_run, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + bind_id_mode=bind_id_mode, + saml_groups_attribute_name_by_config_id=saml_groups_attribute_name_by_config_id, + do_backup=do_backup, + retain_saml_group_users=retain_saml_group_users, + worker_pool=worker_pool, ) if options.mode == "users": assert options.user_identifiers diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 207c948..c6945af 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -16,10 +16,15 @@ from . import apply as permissions_apply from . import mapping as permissions_mapping from . import snapshot as permission_snapshot +from . import sourcegraph as permissions_sourcegraph from . import types as permission_types from .workflow import ( + load_mapping_context_discovery, load_mapping_context_for_rules, load_mapping_rules, + load_repository_candidates_by_names, + load_repository_candidates_created_on_or_after, + mapping_context_with_repository_candidates, render_projected_snapshot_diff, snapshot_path, user_ids_created_on_or_after, @@ -50,6 +55,7 @@ class _FullSetSnapshotState: users: list[permission_snapshot.SnapshotUser] before_snapshot: permission_snapshot.Snapshot | None = None before_timestamp: str | None = None + selected_repository_ids: set[str] | None = None @dataclass(frozen=True) @@ -100,6 +106,7 @@ def _capture_full_set_snapshot_state( worker_pool: ThreadPoolExecutor | None = None, include_user_emails: bool = False, include_user_account_data: bool = True, + selected_repository_ids: set[str] | None = None, ) -> _FullSetUserState: """Load users while capturing the before-snapshot.""" expected_user_count = shared_sourcegraph.count_users(client) @@ -124,6 +131,7 @@ def _capture_full_set_snapshot_state( expected_user_count=expected_user_count, explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, + selected_repository_ids=selected_repository_ids, ) log.info( "Received %d total users; before-snapshot has %d repo(s) " @@ -149,6 +157,7 @@ def _load_full_set_snapshot_state( worker_pool: ThreadPoolExecutor | None = None, include_user_emails: bool = False, include_user_account_data: bool = True, + selected_repository_ids: set[str] | None = None, ) -> _FullSetUserState: """Load all users, optionally with a before-snapshot.""" if capture_before: @@ -161,6 +170,7 @@ def _load_full_set_snapshot_state( worker_pool, include_user_emails=include_user_emails, include_user_account_data=include_user_account_data, + selected_repository_ids=selected_repository_ids, ) log.info("Loading users from %s ...", client.endpoint) @@ -193,15 +203,78 @@ def _filter_full_set_users_by_created_at( return filtered_users +def _repository_ids( + candidates: list[permissions_sourcegraph.RepositoryCandidate], +) -> set[str]: + """Return Sourcegraph repository node IDs from candidates.""" + return {candidate.repository["id"] for candidate in candidates} + + +def _load_pre_snapshot_repository_candidates( + client: src.SourcegraphClient, + repository_names: tuple[str, ...], + repository_created_after: str | None, +) -> list[permissions_sourcegraph.RepositoryCandidate] | None: + """Load repo filters that do not depend on current explicit grants.""" + if repository_names: + return load_repository_candidates_by_names(client, repository_names) + if repository_created_after is not None: + return load_repository_candidates_created_on_or_after( + client, + repository_created_after, + "--repos-created-after", + ) + return None + + +def _load_repositories_without_explicit_permissions( + client: src.SourcegraphClient, + before_snapshot: permission_snapshot.Snapshot, +) -> list[permissions_sourcegraph.RepositoryCandidate]: + """Load repo candidates without any explicit API grants.""" + candidates = permissions_sourcegraph.list_repository_candidates(client) + explicit_repository_ids = set(before_snapshot["repos"]) + selected_candidates = [ + candidate + for candidate in candidates + if candidate.repository["id"] not in explicit_repository_ids + ] + log.info( + "Selected %d / %d repo(s) without explicit repo permissions.", + len(selected_candidates), + len(candidates), + ) + return selected_candidates + + +def _filter_full_set_user_state_snapshot( + snapshot_state: _FullSetUserState, + selected_repository_ids: set[str] | None, +) -> _FullSetUserState: + """Return user state with before-snapshot scoped to selected repos.""" + if snapshot_state.before_snapshot is None or selected_repository_ids is None: + return snapshot_state + return _FullSetUserState( + users=snapshot_state.users, + before_snapshot=permission_snapshot.snapshot_with_repository_filter( + snapshot_state.before_snapshot, + selected_repository_ids, + ), + before_timestamp=snapshot_state.before_timestamp, + ) + + def _compact_full_set_snapshot_state( snapshot_state: _FullSetUserState, users: list[shared_types.User], + selected_repository_ids: set[str] | None = None, ) -> _FullSetSnapshotState: """Return snapshot state with only fields needed for later capture.""" return _FullSetSnapshotState( users=permission_snapshot.compact_snapshot_users(users), before_snapshot=snapshot_state.before_snapshot, before_timestamp=snapshot_state.before_timestamp, + selected_repository_ids=selected_repository_ids, ) @@ -565,6 +638,7 @@ def _finish_full_set_apply_with_backup( expected_user_count=len(snapshot_state.users), explicit_permissions_batch_size=explicit_permissions_batch_size, worker_pool=worker_pool, + selected_repository_ids=snapshot_state.selected_repository_ids, ) after_path = snapshot_path(input_path, timestamp, client.endpoint, "set-apply", "after") @@ -643,6 +717,9 @@ def _finish_empty_full_set_mapping_rules( input_path: Path, command_name: str, dry_run: bool, + repository_names: tuple[str, ...], + repositories_without_explicit_perms: bool, + repository_created_after: str | None, parallelism: int, explicit_permissions_batch_size: int, bind_id_mode: str, @@ -654,6 +731,16 @@ def _finish_empty_full_set_mapping_rules( if not (dry_run or do_backup): return + selected_repository_candidates = _load_pre_snapshot_repository_candidates( + client, + repository_names, + repository_created_after, + ) + selected_repository_ids = ( + _repository_ids(selected_repository_candidates) + if selected_repository_candidates is not None + else None + ) snapshot_state = _capture_full_set_snapshot_state( client, input_path, @@ -662,7 +749,18 @@ def _finish_empty_full_set_mapping_rules( bind_id_mode, worker_pool, include_user_account_data=False, + selected_repository_ids=selected_repository_ids, ) + if repositories_without_explicit_perms: + before_snapshot, _ = _require_before_snapshot(snapshot_state) + selected_repository_candidates = _load_repositories_without_explicit_permissions( + client, + before_snapshot, + ) + snapshot_state = _filter_full_set_user_state_snapshot( + snapshot_state, + _repository_ids(selected_repository_candidates), + ) _write_noop_full_set_artifacts( input_path, client.endpoint, @@ -678,11 +776,15 @@ def _load_full_set_plan( input_path: Path, mapping_rules: list[permission_types.MappingRule], user_created_after: str | None, + repository_names: tuple[str, ...], + repositories_without_explicit_perms: bool, + repository_created_after: str | None, parallelism: int, explicit_permissions_batch_size: int, bind_id_mode: str, saml_groups_attribute_name_by_config_id: dict[str, str], capture_before: bool, + write_before_snapshot: bool, command_name: str, command_event: dict[str, Any], retain_saml_group_users: bool, @@ -693,6 +795,16 @@ def _load_full_set_plan( permissions_mapping.mapping_rules_need_saml_account_data(mapping_rules) or retain_saml_group_users ) + selected_repository_candidates = _load_pre_snapshot_repository_candidates( + client, + repository_names, + repository_created_after, + ) + selected_repository_ids = ( + _repository_ids(selected_repository_candidates) + if selected_repository_candidates is not None + else None + ) user_state = _load_full_set_snapshot_state( client, input_path, @@ -703,9 +815,22 @@ def _load_full_set_plan( worker_pool=worker_pool, include_user_emails=include_user_emails, include_user_account_data=include_user_account_data, + selected_repository_ids=selected_repository_ids, ) + if repositories_without_explicit_perms: + before_snapshot, _ = _require_before_snapshot(user_state) + selected_repository_candidates = _load_repositories_without_explicit_permissions( + client, + before_snapshot, + ) + selected_repository_ids = _repository_ids(selected_repository_candidates) + user_state = _filter_full_set_user_state_snapshot( + user_state, + selected_repository_ids, + ) + before_path: Path | None = None - if capture_before: + if write_before_snapshot: before_snapshot, before_timestamp = _require_before_snapshot(user_state) before_path = _write_full_set_before_snapshot( input_path, @@ -716,18 +841,36 @@ def _load_full_set_plan( command_event, ) - context = load_mapping_context_for_rules( - client, - mapping_rules, - saml_groups_attribute_name_by_config_id, - ) + if selected_repository_candidates is None: + context = load_mapping_context_for_rules( + client, + mapping_rules, + saml_groups_attribute_name_by_config_id, + ) + else: + context = mapping_context_with_repository_candidates( + load_mapping_context_discovery( + client, + mapping_rules, + saml_groups_attribute_name_by_config_id, + ), + selected_repository_candidates, + ) + + if selected_repository_ids is not None: + command_event["selected_repo_count"] = len(selected_repository_ids) + users = _filter_full_set_users_by_created_at( client, user_state.users, user_created_after, ) plan = plan_full_set_permissions(context, users) - snapshot_state = _compact_full_set_snapshot_state(user_state, users) + snapshot_state = _compact_full_set_snapshot_state( + user_state, + users, + selected_repository_ids, + ) saml_group_users = ( saml_groups.compact_saml_group_users( user_state.users, @@ -828,6 +971,9 @@ def cmd_set_full( client: src.SourcegraphClient, input_path: Path, user_created_after: str | None, + repository_names: tuple[str, ...], + repositories_without_explicit_perms: bool, + repository_created_after: str | None, dry_run: bool, parallelism: int, explicit_permissions_batch_size: int, @@ -842,6 +988,9 @@ def cmd_set_full( "cmd_set", input_path=str(input_path), user_created_after=user_created_after, + repository_names=repository_names or None, + repositories_without_explicit_perms=(True if repositories_without_explicit_perms else None), + repository_created_after=repository_created_after, dry_run=dry_run, parallelism=parallelism, do_backup=do_backup, @@ -854,6 +1003,9 @@ def cmd_set_full( input_path, command_name, dry_run, + repository_names, + repositories_without_explicit_perms, + repository_created_after, parallelism, explicit_permissions_batch_size, bind_id_mode, @@ -868,11 +1020,15 @@ def cmd_set_full( input_path, mapping_rules, user_created_after, + repository_names, + repositories_without_explicit_perms, + repository_created_after, parallelism, explicit_permissions_batch_size, bind_id_mode, saml_groups_attribute_name_by_config_id, - capture_before=dry_run or do_backup, + capture_before=dry_run or do_backup or repositories_without_explicit_perms, + write_before_snapshot=dry_run or do_backup, command_name=command_name, command_event=command_event, retain_saml_group_users=retain_saml_group_users, diff --git a/src/src_auth_perms_sync/permissions/queries.py b/src/src_auth_perms_sync/permissions/queries.py index e643090..a08d552 100644 --- a/src/src_auth_perms_sync/permissions/queries.py +++ b/src/src_auth_perms_sync/permissions/queries.py @@ -48,6 +48,67 @@ } """ +REPOSITORY_CANDIDATE_FIELDS = """ +id +name +createdAt +externalServices(first: 50) { + nodes { id } +} +""" + +QUERY_REPOSITORIES_BY_NAMES = f""" +query RepositoriesByNames($names: [String!]!, $first: Int!, $after: String) {{ + repositories( + names: $names + first: $first + after: $after + cloned: true + notCloned: true + ) {{ + nodes {{ + {REPOSITORY_CANDIDATE_FIELDS} + }} + pageInfo {{ hasNextPage endCursor }} + }} +}} +""" + +QUERY_REPOSITORY_CANDIDATES = f""" +query RepositoryCandidates($first: Int!, $after: String) {{ + repositories( + first: $first + after: $after + cloned: true + notCloned: true + orderBy: REPOSITORY_NAME + ) {{ + nodes {{ + {REPOSITORY_CANDIDATE_FIELDS} + }} + pageInfo {{ hasNextPage endCursor }} + }} +}} +""" + +QUERY_REPOSITORY_CANDIDATES_BY_CREATED_AT = f""" +query RepositoryCandidatesByCreatedAt($first: Int!, $after: String) {{ + repositories( + first: $first + after: $after + cloned: true + notCloned: true + orderBy: REPO_CREATED_AT + descending: true + ) {{ + nodes {{ + {REPOSITORY_CANDIDATE_FIELDS} + }} + pageInfo {{ hasNextPage endCursor }} + }} +}} +""" + MUTATION_SET_REPO_PERMISSIONS = """ mutation SetRepoPerms($repo: ID!, $userPerms: [UserPermissionInput!]!) { setRepositoryPermissionsForUsers(repository: $repo, userPermissions: $userPerms) { diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index 365bfb6..3b22b09 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -164,6 +164,7 @@ def capture_explicit_grants( explicit_permissions_batch_size: int, expected_user_count: int | None = None, worker_pool: ThreadPoolExecutor | None = None, + selected_repository_ids: set[str] | None = None, ) -> tuple[dict[str, RepoSnapshot], int]: """Build the per-repo inverse index of explicit-API grants. @@ -348,8 +349,13 @@ def _record_completed_batch( raise RuntimeError("explicit-grant batch fetch returned no result") repository_ids_by_username, failures = result.value capture_failures += failures - for username, repository_ids in repository_ids_by_username.items(): - for repository_id in repository_ids: + for username, user_repository_ids in repository_ids_by_username.items(): + for repository_id in user_repository_ids: + if ( + selected_repository_ids is not None + and repository_id not in selected_repository_ids + ): + continue usernames_by_repository_id.setdefault( repository_id, [], @@ -422,6 +428,7 @@ def build_snapshot( expected_user_count: int | None = None, explicit_permissions_batch_size: int, worker_pool: ThreadPoolExecutor | None = None, + selected_repository_ids: set[str] | None = None, ) -> Snapshot: """Capture a full Snapshot: explicit grants + pending-bindIDs + metadata. @@ -441,6 +448,7 @@ def build_snapshot( explicit_permissions_batch_size, expected_user_count=expected_user_count, worker_pool=worker_pool, + selected_repository_ids=selected_repository_ids, ) pending = permissions_sourcegraph.list_pending_bind_ids(client) @@ -479,6 +487,39 @@ def build_snapshot( } +def snapshot_with_repository_filter( + snapshot: Snapshot, + selected_repository_ids: set[str], +) -> Snapshot: + """Return a snapshot containing only selected repository entries.""" + repos = { + repository_id: repo + for repository_id, repo in snapshot["repos"].items() + if repository_id in selected_repository_ids + } + distinct_users: set[str] = set() + total_grants = 0 + for repo in repos.values(): + distinct_users.update(repo["users"]) + total_grants += len(repo["users"]) + return { + "schema_version": snapshot["schema_version"], + "captured_at": snapshot["captured_at"], + "endpoint": snapshot["endpoint"], + "bindID_mode": snapshot["bindID_mode"], + "config_file": snapshot["config_file"], + "config_sha256": snapshot["config_sha256"], + "pending_bindIDs": list(snapshot["pending_bindIDs"]), + "stats": { + "total_users_scanned": snapshot["stats"]["total_users_scanned"], + "users_with_explicit_grants": len(distinct_users), + "repos_with_explicit_grants": len(repos), + "total_grants": total_grants, + }, + "repos": dict(sorted(repos.items())), + } + + def capture_user_scoped_explicit_grants( client: src.SourcegraphClient, users: Iterable[SnapshotUser], diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index 2f01660..5c8edb3 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -2,6 +2,7 @@ from __future__ import annotations +import datetime import logging import time from collections import deque @@ -37,6 +38,15 @@ class _SiteUserCandidatePage: candidates: list[shared_types.SiteUserCandidate] +@dataclass(frozen=True) +class RepositoryCandidate: + """Repository selected for repo-scoped get/set work.""" + + repository: permission_types.Repository + created_at: str + external_service_ids: tuple[str, ...] + + def list_external_services(client: src.SourcegraphClient) -> list[permission_types.ExternalService]: return [ cast(permission_types.ExternalService, node) @@ -62,6 +72,80 @@ def list_repos_for_external_service( ] +def list_repository_candidates_by_names( + client: src.SourcegraphClient, + repository_names: Sequence[str], +) -> list[RepositoryCandidate]: + """Return repositories whose names exactly match `repository_names`.""" + unique_names = tuple(dict.fromkeys(repository_names)) + if not unique_names: + return [] + return [ + _repository_candidate_from_node(node) + for node in client.stream_connection_nodes( + queries.QUERY_REPOSITORIES_BY_NAMES, + {"names": list(unique_names)}, + connection_path=("repositories",), + page_size=REPOSITORY_PAGE_SIZE, + ) + ] + + +def list_repository_candidates(client: src.SourcegraphClient) -> list[RepositoryCandidate]: + """Return all repositories with enough metadata for repo filtering.""" + return [ + _repository_candidate_from_node(node) + for node in client.stream_connection_nodes( + queries.QUERY_REPOSITORY_CANDIDATES, + connection_path=("repositories",), + page_size=REPOSITORY_PAGE_SIZE, + ) + ] + + +def list_repository_candidates_created_on_or_after( + client: src.SourcegraphClient, + created_after: str, +) -> list[RepositoryCandidate]: + """Return repositories with Sourcegraph rows created on or after a timestamp.""" + threshold = _parse_sourcegraph_datetime(created_after) + candidates: list[RepositoryCandidate] = [] + for node in client.stream_connection_nodes( + queries.QUERY_REPOSITORY_CANDIDATES_BY_CREATED_AT, + connection_path=("repositories",), + page_size=REPOSITORY_PAGE_SIZE, + ): + candidate = _repository_candidate_from_node(node) + if _parse_sourcegraph_datetime(candidate.created_at) < threshold: + break + candidates.append(candidate) + return candidates + + +def _repository_candidate_from_node(node: dict[str, Any]) -> RepositoryCandidate: + repository_id = src.json_str(node, "id") + repository_name = src.json_str(node, "name") + created_at = src.json_str(node, "createdAt") + external_services = src.json_dict(node.get("externalServices")) + external_service_ids = tuple( + external_service_id + for external_service_id in ( + src.json_dict(external_service).get("id") + for external_service in src.json_list(external_services.get("nodes")) + ) + if isinstance(external_service_id, str) + ) + return RepositoryCandidate( + repository={"id": repository_id, "name": repository_name}, + created_at=created_at, + external_service_ids=external_service_ids, + ) + + +def _parse_sourcegraph_datetime(value: str) -> datetime.datetime: + return datetime.datetime.fromisoformat(value.replace("Z", "+00:00")) + + def get_user_by_username( client: src.SourcegraphClient, username: str, diff --git a/src/src_auth_perms_sync/permissions/types.py b/src/src_auth_perms_sync/permissions/types.py index 9f81885..c764717 100644 --- a/src/src_auth_perms_sync/permissions/types.py +++ b/src/src_auth_perms_sync/permissions/types.py @@ -12,6 +12,9 @@ "users", "users_without_explicit_perms", "created_after", + "repos", + "repos_without_explicit_perms", + "repos_created_after", ] @@ -22,6 +25,8 @@ class SetCommandOptions: mode: SetCommandMode user_identifiers: tuple[str, ...] = () user_created_after: str | None = None + repository_names: tuple[str, ...] = () + repository_created_after: str | None = None class UserRef(TypedDict): diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index f49e4e5..a1fc18d 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -210,6 +210,72 @@ def load_repos_for_mapping_context( ) +def mapping_context_with_repository_candidates( + context: permission_types.MappingContext, + candidates: list[permissions_sourcegraph.RepositoryCandidate], +) -> permission_types.MappingContext: + """Return context limited to selected repository candidates.""" + repos_by_external_service_id: dict[int, list[permission_types.Repository]] = {} + all_repos_by_id: dict[str, permission_types.Repository] = {} + for candidate in candidates: + repository = candidate.repository + all_repos_by_id[repository["id"]] = repository + for external_service_id in candidate.external_service_ids: + decoded_service_id = src.decode_external_service_id(external_service_id) + if decoded_service_id not in context.services_by_id: + continue + repos_by_external_service_id.setdefault(decoded_service_id, []).append(repository) + log.info( + "Selected %d repo(s) across %d code host connection(s).", + len(all_repos_by_id), + len(repos_by_external_service_id), + ) + return permission_types.MappingContext( + mapping_rules=context.mapping_rules, + providers=context.providers, + saml_groups_attribute_names=context.saml_groups_attribute_names, + services_by_id=context.services_by_id, + repos_by_external_service_id=repos_by_external_service_id, + all_repos_by_id=all_repos_by_id, + ) + + +def load_repository_candidates_by_names( + client: src.SourcegraphClient, + repository_names: tuple[str, ...], +) -> list[permissions_sourcegraph.RepositoryCandidate]: + """Load exact repository-name matches or exit with missing names.""" + candidates = permissions_sourcegraph.list_repository_candidates_by_names( + client, + repository_names, + ) + found_names = {candidate.repository["name"] for candidate in candidates} + missing_names = [name for name in repository_names if name not in found_names] + if missing_names: + raise SystemExit("No Sourcegraph repo found for: " + ", ".join(sorted(missing_names))) + log.info("Selected %d repo(s) by exact name.", len(candidates)) + return candidates + + +def load_repository_candidates_created_on_or_after( + client: src.SourcegraphClient, + value: str, + flag_name: str, +) -> list[permissions_sourcegraph.RepositoryCandidate]: + """Load repositories whose Sourcegraph row was created on or after a CLI date.""" + filter_value = sourcegraph_datetime_filter(parse_cli_date(value, flag_name)) + candidates = permissions_sourcegraph.list_repository_candidates_created_on_or_after( + client, + filter_value, + ) + log.info( + "Selected %d Sourcegraph repo(s) created on or after %s.", + len(candidates), + value, + ) + return candidates + + def snapshot_path( input_path: Path, timestamp: str, diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index 28e2063..c071505 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -17,6 +17,7 @@ import src_auth_perms_sync from src_auth_perms_sync import cli from src_auth_perms_sync.permissions import command as permissions_command +from src_auth_perms_sync.permissions import types as permission_types from src_auth_perms_sync.shared import backups @@ -137,6 +138,16 @@ def test_users_config_loads_comma_delimited_values(self) -> None: self.assertEqual(("alice", "bob@example.com", "carol"), config.users) + def test_repos_config_loads_comma_delimited_values(self) -> None: + config = load_config_from_env( + SRC_AUTH_PERMS_SYNC_REPOS="github.com/sourcegraph/one, github.com/sourcegraph/two" + ) + + self.assertEqual( + ("github.com/sourcegraph/one", "github.com/sourcegraph/two"), + config.repos, + ) + def test_set_command_options_match_each_incremental_mode(self) -> None: self.assertEqual( "full", @@ -167,6 +178,23 @@ def test_set_command_options_match_each_incremental_mode(self) -> None: ) self.assertEqual("created_after", created_after.mode) self.assertEqual("2026-01-01", created_after.user_created_after) + repos = cli.set_command_options( + make_config( + maps_path=Path("maps.yaml"), + repos=("github.com/sourcegraph/one",), + ) + ) + self.assertEqual("repos", repos.mode) + self.assertEqual(("github.com/sourcegraph/one",), repos.repository_names) + repos_without_permissions = cli.set_command_options( + make_config(maps_path=Path("maps.yaml"), repos_without_explicit_perms=True) + ) + self.assertEqual("repos_without_explicit_perms", repos_without_permissions.mode) + repos_created_after = cli.set_command_options( + make_config(maps_path=Path("maps.yaml"), repos_created_after="2026-01-01") + ) + self.assertEqual("repos_created_after", repos_created_after.mode) + self.assertEqual("2026-01-01", repos_created_after.repository_created_after) def test_resolve_command_includes_set_mode_names(self) -> None: users_command = cli.resolve_command( @@ -178,6 +206,13 @@ def test_resolve_command_includes_set_mode_names(self) -> None: "set", make_config(maps_path=Path("maps.yaml"), created_after="2026-01-01"), ) + repos_command = cli.resolve_command( + "set", + make_config( + maps_path=Path("maps.yaml"), + repos=("github.com/sourcegraph/one",), + ), + ) self.assertEqual("set_users", users_command.log_name) self.assertEqual("set-add-users-apply", users_command.artifact_name) @@ -190,6 +225,9 @@ def test_resolve_command_includes_set_mode_names(self) -> None: created_after_command.artifact_name, ) self.assertEqual("created_after", created_after_command.set_mode) + self.assertEqual("set_repos", repos_command.log_name) + self.assertEqual("set-repos-dry-run", repos_command.artifact_name) + self.assertEqual("repos", repos_command.set_mode) def test_resolve_command_includes_combined_set_sync_names(self) -> None: set_command = cli.resolve_command( @@ -257,6 +295,11 @@ def test_validate_config_allows_get_user_filters_without_set(self) -> None: cli.validate_config("get", make_config(users_without_explicit_perms=True)) cli.validate_config("get", make_config(created_after="2026-01-01")) + def test_validate_config_allows_get_repo_filters_without_set(self) -> None: + cli.validate_config("get", make_config(repos=("github.com/sourcegraph/one",))) + cli.validate_config("get", make_config(repos_without_explicit_perms=True)) + cli.validate_config("get", make_config(repos_created_after="2026-01-01")) + def test_validate_config_rejects_get_user_filter_conflicts(self) -> None: self.assert_config_error( "get", @@ -271,6 +314,31 @@ def test_validate_config_rejects_user_filters_on_non_get_set_commands(self) -> N "require get or set", ) + def test_validate_config_rejects_repo_filter_conflicts(self) -> None: + self.assert_config_error( + "get", + make_config( + repos=("github.com/sourcegraph/one",), + repos_without_explicit_perms=True, + ), + "choose only one of --repos", + ) + self.assert_config_error( + "get", + make_config(users=("alice",), repos=("github.com/sourcegraph/one",)), + "choose either user filters or repo filters", + ) + + def test_validate_config_rejects_repo_filters_on_non_get_set_commands(self) -> None: + self.assert_config_error( + "restore", + make_config( + restore_path=Path("snapshot.json"), + repos=("github.com/sourcegraph/one",), + ), + "require get or set", + ) + def test_validate_config_allows_set_without_explicit_mode(self) -> None: cli.validate_config("set", make_config(maps_path=Path("maps.yaml"))) @@ -517,6 +585,9 @@ def test_cmd_get_no_backup_skips_snapshot_artifacts(self) -> None: user_identifiers=(), users_without_explicit_perms=False, user_created_after=None, + repository_names=(), + repositories_without_explicit_perms=False, + repository_created_after=None, parallelism=1, explicit_permissions_batch_size=25, bind_id_mode="USERNAME", @@ -528,6 +599,37 @@ def test_cmd_get_no_backup_skips_snapshot_artifacts(self) -> None: build_snapshot.assert_not_called() write_maps_backup.assert_not_called() + def test_cmd_set_dispatches_repo_filters_to_full_set(self) -> None: + client = cast(src.SourcegraphClient, object()) + options = permission_types.SetCommandOptions( + mode="repos", + repository_names=("github.com/sourcegraph/one",), + ) + + with mock.patch.object( + permissions_command.permissions_full_set, + "cmd_set_full", + return_value=cli.run_context.CommandData(), + ) as cmd_set_full: + permissions_command.cmd_set( + client, + Path("maps.yaml"), + options, + dry_run=True, + parallelism=1, + explicit_permissions_batch_size=25, + bind_id_mode="USERNAME", + saml_groups_attribute_name_by_config_id={}, + do_backup=True, + ) + + self.assertEqual( + ("github.com/sourcegraph/one",), + cmd_set_full.call_args.kwargs["repository_names"], + ) + self.assertFalse(cmd_set_full.call_args.kwargs["repositories_without_explicit_perms"]) + self.assertIsNone(cmd_set_full.call_args.kwargs["repository_created_after"]) + def test_run_command_passes_set_data_to_combined_sync(self) -> None: configuration = make_config(sync_saml_organizations=True) command = cli.resolve_command("set", configuration) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 3699426..b8af9ce 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -379,6 +379,25 @@ def test_read_snapshot_rejects_old_schema_versions(self) -> None: self.assertIn("expected 5", str(exit_context.exception)) + def test_snapshot_with_repository_filter_recomputes_stats(self) -> None: + snapshot = self.make_snapshot() + second_repo_id = src.encode_repository_id(2) + snapshot["repos"][second_repo_id] = { + "name": "github.com/sourcegraph/second", + "users": ["alice", "carol"], + } + + filtered = permission_snapshot.snapshot_with_repository_filter( + snapshot, + {second_repo_id}, + ) + + self.assertEqual({second_repo_id}, set(filtered["repos"])) + self.assertEqual(2, filtered["stats"]["users_with_explicit_grants"]) + self.assertEqual(1, filtered["stats"]["repos_with_explicit_grants"]) + self.assertEqual(2, filtered["stats"]["total_grants"]) + self.assertEqual(2, filtered["stats"]["total_users_scanned"]) + def make_snapshot(self) -> permission_snapshot.Snapshot: return { "schema_version": permission_snapshot.SNAPSHOT_SCHEMA_VERSION,