diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cb083ef..43ca985 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,17 +6,17 @@ on: - "v*" workflow_dispatch: inputs: - tag: - description: "Existing release tag to publish, for example v0.1.0" + version: + description: "Package version to publish, for example 0.1.0 or v0.1.0" required: true type: string permissions: - contents: write + contents: read pull-requests: read concurrency: - group: release-${{ github.event.inputs.tag || github.ref_name }} + group: release-${{ github.event_name == 'workflow_dispatch' && inputs.version || github.ref_name }} cancel-in-progress: false defaults: @@ -24,15 +24,41 @@ defaults: shell: bash jobs: + release_ref: + name: Resolve release tag + runs-on: ubuntu-24.04 + outputs: + tag: ${{ steps.release.outputs.tag }} + version: ${{ steps.release.outputs.version }} + + steps: + - name: Resolve release tag + id: release + env: + RELEASE_INPUT: ${{ github.event_name == 'workflow_dispatch' && inputs.version || github.ref_name }} + run: | + release_input="${RELEASE_INPUT}" + if [[ ! "${release_input}" =~ ^v?[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "::error title=Invalid release version::Use MAJOR.MINOR.PATCH or vMAJOR.MINOR.PATCH, got '${release_input}'." + exit 1 + fi + + release_version="${release_input#v}" + release_tag="v${release_version}" + echo "tag=${release_tag}" >> "${GITHUB_OUTPUT}" + echo "version=${release_version}" >> "${GITHUB_OUTPUT}" + validate: name: Validate + needs: release_ref uses: ./.github/workflows/validate.yml with: - ref: ${{ github.event.inputs.tag || github.ref }} + ref: ${{ needs.release_ref.outputs.tag }} build-package: false wheelhouse: name: ${{ matrix.platform }}-py311 wheelhouse + needs: release_ref runs-on: ${{ matrix.runs_on }} strategy: fail-fast: false @@ -61,7 +87,7 @@ jobs: with: fetch-depth: 0 persist-credentials: false - ref: ${{ github.event.inputs.tag || github.ref }} + ref: ${{ needs.release_ref.outputs.tag }} - name: Set up Python uses: actions/setup-python@v6 @@ -81,8 +107,10 @@ jobs: - name: Validate release inputs id: release + env: + RELEASE_TAG: ${{ needs.release_ref.outputs.tag }} run: | - release_tag="${{ github.event.inputs.tag || github.ref_name }}" + release_tag="${RELEASE_TAG}" if [[ ! "${release_tag}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then echo "::error title=Invalid release tag::Use a vMAJOR.MINOR.PATCH tag, got '${release_tag}'." exit 1 @@ -100,19 +128,8 @@ jobs: exit 1 fi - project_version=$(uv run --frozen python - <<'PY' - import tomllib - - with open("pyproject.toml", "rb") as pyproject_file: - print(tomllib.load(pyproject_file)["project"]["version"]) - PY - ) - if [[ "v${project_version}" != "${release_tag}" ]]; then - echo "::error title=Version mismatch::pyproject.toml version '${project_version}' does not match tag '${release_tag}'." - exit 1 - fi - echo "tag=${release_tag}" >> "${GITHUB_OUTPUT}" + echo "version=${release_tag#v}" >> "${GITHUB_OUTPUT}" - name: Validate runner architecture run: | @@ -131,6 +148,7 @@ jobs: id: build run: | release_tag="${{ steps.release.outputs.tag }}" + release_version="${{ steps.release.outputs.version }}" release_dir="build/release/${ASSET_BASENAME}" wheelhouse_dir="${release_dir}/wheelhouse" dist_dir="build/release/dist" @@ -158,6 +176,22 @@ jobs: project_wheel_name="$(basename "${project_wheel_path}")" project_source_distribution_path="${project_source_distributions[0]}" project_source_distribution_name="$(basename "${project_source_distribution_path}")" + case "${project_wheel_name}" in + src_auth_perms_sync-"${release_version}"-*.whl) + ;; + *) + echo "::error title=Wheel version mismatch::Expected wheel version ${release_version}, got '${project_wheel_name}'." + exit 1 + ;; + esac + case "${project_source_distribution_name}" in + src_auth_perms_sync-"${release_version}".tar.gz) + ;; + *) + echo "::error title=Source distribution version mismatch::Expected source distribution version ${release_version}, got '${project_source_distribution_name}'." + exit 1 + ;; + esac project_wheel_checksum_path="${project_wheel_path}.sha256" project_source_distribution_checksum_path="${project_source_distribution_path}.sha256" if [[ ! -f "${project_wheel_path}" ]]; then @@ -165,18 +199,27 @@ jobs: exit 1 fi - uv export \ - --no-dev \ - --no-emit-project \ - --no-hashes \ - --no-header \ - --no-annotate \ - --frozen \ - --output-file "${requirements_file}" + dependency_metadata_dir="$(mktemp -d)" + git clone --no-hardlinks . "${dependency_metadata_dir}" >/dev/null + ( + cd "${dependency_metadata_dir}" + git checkout --detach "${release_tag}" >/dev/null + mkdir -p "$(dirname "${requirements_file}")" + uv export \ + --no-sources \ + --no-dev \ + --no-emit-project \ + --no-hashes \ + --no-header \ + --no-annotate \ + --output-file "${requirements_file}" + ) + cp "${dependency_metadata_dir}/${requirements_file}" "${requirements_file}" cp "${requirements_file}" "${runtime_requirements_file}" - if grep -q '^\./' "${runtime_requirements_file}"; then + if grep -Eq '(^-e[[:space:]]|^(\.\.?/)|(^|[[:space:]])file:| @ (\.\.?/|file:))' "${runtime_requirements_file}"; then echo "::error title=Unexpected local dependency::Runtime requirements must resolve from PyPI." + cat "${runtime_requirements_file}" exit 1 fi @@ -356,8 +399,10 @@ jobs: github-release: name: Publish GitHub release assets - needs: [validate, wheelhouse] + needs: [release_ref, validate, wheelhouse] runs-on: ubuntu-24.04 + permissions: + contents: write steps: - name: Download wheelhouse artifacts @@ -377,8 +422,9 @@ jobs: env: GH_TOKEN: ${{ github.token }} GH_REPO: ${{ github.repository }} + RELEASE_TAG: ${{ needs.release_ref.outputs.tag }} run: | - release_tag="${{ github.event.inputs.tag || github.ref_name }}" + release_tag="${RELEASE_TAG}" notes_path="$(find release-notes -name release-notes.md -print -quit)" mapfile -t release_assets < <(find release-assets -type f | sort) diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index f0e9b0f..6cbf417 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -157,6 +157,7 @@ jobs: - name: Check out code uses: actions/checkout@v6 with: + fetch-depth: 0 persist-credentials: false ref: ${{ inputs.ref || github.ref }} @@ -208,6 +209,7 @@ jobs: - name: Check out code uses: actions/checkout@v6 with: + fetch-depth: 0 persist-credentials: false ref: ${{ inputs.ref || github.ref }} @@ -227,8 +229,8 @@ jobs: - name: Install uv run: python -m pip install "uv==${UV_VERSION}" - - name: Build wheel - run: uv build --wheel --out-dir dist --no-create-gitignore + - name: Build distributions + run: uv build --wheel --sdist --out-dir dist --no-create-gitignore - name: Smoke test installed wheel run: | diff --git a/.gitignore b/.gitignore index 964fe48..078afe2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,9 +11,11 @@ __pycache__ *.gql *.py[cod] *.py[oc] +*.swp *.yaml build/ dist/ +logs*/ logs/ src-auth-perms-sync-runs/ wheels/ diff --git a/AGENTS.md b/AGENTS.md index c8e7acd..05440f8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -44,50 +44,27 @@ uv run src-auth-perms-sync --restore backups///before.json ## Release process -- The tagged source commit must already contain the package version it - releases. Do not make the customer release workflow edit `pyproject.toml`. -- Prepare the version bump on a branch. Set `VERSION`, then copy / paste: -- As part of every release bump, find old release-version literals in - `AGENTS.md`, `README.md`, and release snippets, and replace them with the - new version where they are meant to stay current. +- Package versions are derived from Git tags through `hatch-vcs`. +- `pyproject.toml` must use `dynamic = ["version"]`; do not add a hard-coded + `project.version` for releases. +- The release tag must be `vMAJOR.MINOR.PATCH` and point at a commit reachable + from `origin/main`. +- The release workflow builds from the tag and checks that wheel and source + distribution filenames match the tag version before publishing. +- Do not make the release workflow edit `pyproject.toml` or `uv.lock`. +- Validate the remote head of `main` before tagging it: ```bash set -euo pipefail -VERSION=0.2.1 -BRANCH="release-v${VERSION}" +VERSION_INPUT= +VERSION="${VERSION_INPUT#v}" +[[ "${VERSION_INPUT}" =~ ^v?[0-9]+\.[0-9]+\.[0-9]+$ ]] git fetch origin --tags --prune git switch main git pull --ff-only -git switch -c "${BRANCH}" - -uv run python - "${VERSION}" <<'PY' -from pathlib import Path -import re -import sys - -version = sys.argv[1] -path = Path("pyproject.toml") -text = path.read_text() -new_text = re.sub( - r'(?m)^version = "[^"]+"$', - f'version = "{version}"', - text, - count=1, -) -if new_text == text: - raise SystemExit("pyproject.toml version was not updated") -path.write_text(new_text) -PY - -uv lock -``` - -- Validate the release candidate before opening / merging the PR: - -```bash -set -euo pipefail +test "$(git rev-parse HEAD)" = "$(git rev-parse origin/main)" uv lock --check actionlint @@ -97,57 +74,24 @@ uv run pyright uv run python -m unittest discover -s tests uv run src-auth-perms-sync --help npx --yes markdownlint-cli2@0.22.1 -uv build --wheel --out-dir /tmp/src-auth-perms-sync-release-check --no-create-gitignore +uv build --wheel --sdist --out-dir /tmp/src-auth-perms-sync-release-check --no-create-gitignore rm -rf /tmp/src-auth-perms-sync-release-check ``` -- Commit, push, open the PR, wait for checks, then merge it. If review is - required, stop after `gh pr checks` and ask for review before merging. +- Tag the remote head of `main` directly: ```bash set -euo pipefail -VERSION=0.2.1 -BRANCH="release-v${VERSION}" +VERSION_INPUT= +VERSION="${VERSION_INPUT#v}" GH_REPO="sourcegraph/src-auth-perms-sync" -git add pyproject.toml uv.lock -git commit -m "Release v${VERSION}" -git push -u origin "${BRANCH}" - -gh pr create \ - --repo "${GH_REPO}" \ - --base main \ - --head "${BRANCH}" \ - --title "Release v${VERSION}" \ - --body "Bump src-auth-perms-sync package metadata to ${VERSION}." - -gh pr checks "${BRANCH}" --repo "${GH_REPO}" --watch --fail-fast -gh pr merge "${BRANCH}" --repo "${GH_REPO}" --squash --delete-branch -``` - -- Tag the merged `main` commit. Do not tag a feature branch commit. - -```bash -set -euo pipefail - -VERSION=0.2.1 - +[[ "${VERSION_INPUT}" =~ ^v?[0-9]+\.[0-9]+\.[0-9]+$ ]] git fetch origin --tags --prune -git switch main -git pull --ff-only -git tag "v${VERSION}" +MAIN_COMMIT="$(git rev-parse origin/main)" +git tag -a "v${VERSION}" "${MAIN_COMMIT}" -m "Release v${VERSION}" git push origin "v${VERSION}" -``` - -- Watch the customer release workflow and confirm the GitHub release assets - are uploaded: - -```bash -set -euo pipefail - -VERSION=0.2.1 -GH_REPO="sourcegraph/src-auth-perms-sync" RUN_ID="$( gh run list \ @@ -169,13 +113,13 @@ gh release view "v${VERSION}" --repo "${GH_REPO}" ```bash set -euo pipefail -VERSION=0.2.1 +VERSION_INPUT= +VERSION="${VERSION_INPUT#v}" GH_REPO="sourcegraph/src-auth-perms-sync" +[[ "${VERSION_INPUT}" =~ ^v?[0-9]+\.[0-9]+\.[0-9]+$ ]] git fetch origin --tags --prune -git switch main -git pull --ff-only -git tag -f "v${VERSION}" origin/main +git tag -f -a "v${VERSION}" origin/main -m "Release v${VERSION}" git push origin "refs/tags/v${VERSION}" --force RUN_ID="$( diff --git a/README.md b/README.md index bc31c32..49fbfdd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # src-auth-perms-sync - + src-auth-perms-sync automates Sourcegraph's Explicit Permissions GraphQL API, setting user-to-repo permissions based on mapping rules, for example: @@ -18,8 +18,8 @@ Feel free to open issues or PRs, but responses are best effort. - Release versions are `major.minor.patch` - Because this project is still major version 0: - - Minor version updates are breaking changes - - Patch version updates are not breaking changes + - Minor version updates are probably breaking changes + - Patch version updates are probably not breaking changes ## Principles @@ -31,7 +31,7 @@ Feel free to open issues or PRs, but responses are best effort. map casts a smaller net of users / repos. This can result in more maps, but they will be easier to understand and trust. -- Backup files are saved in `src-auth-perms-sync-runs//backups/`, +- Backup files are saved in `src-auth-perms-sync-runs//runs//`, unless the `--no-backup` arg is provided, so customers can review the changes made over time, and restore to a specific backup file, if needed @@ -45,7 +45,10 @@ Feel free to open issues or PRs, but responses are best effort. - One installation of this script can apply separate `maps.yaml` files on separate Sourcegraph instances - - Be sure to specify the path to the correct `maps.yaml` file for each run + - By default, each Sourcegraph instance gets its own generated `maps.yaml` + under `src-auth-perms-sync-runs//` + - If you pass `--maps-path`, relative paths are resolved from your current + working directory - Set the `SRC_ENDPOINT` and `SRC_ACCESS_TOKEN` environment variables correctly for each run ## Prerequisites @@ -89,102 +92,126 @@ Feel free to open issues or PRs, but responses are best effort. ## Install -- Requires Python 3.11 +- Requires Python >= 3.11 - Recommended: Use a Python virtual environment -### Install from a GitHub Release - -Use this when the VM can reach GitHub and PyPI during install: +### Install from PyPI ```bash -python3.11 -m venv .venv -. .venv/bin/activate -pip install \ - "https://github.com/sourcegraph/src-auth-perms-sync/releases/download/v0.1.0/src_auth_perms_sync-0.1.0-py3-none-any.whl" -``` +# Set up virtual environment +python3 -m venv .venv +source .venv/bin/activate +python -m pip install --upgrade pip + +# Install package from PyPI +python -m pip install src-auth-perms-sync -### Restricted/offline install from a GitHub Release +# Run the CLI +src-auth-perms-sync --help +``` -Use this when the VM cannot reach package indexes during install +### Restricted / offline install from a GitHub Release -Download the .tar.gz file from the GitHub release +Download the .tar.gz file from [a GitHub release](https://github.com/sourcegraph/src-auth-perms-sync/releases) ```bash tar -xzf src-auth-perms-sync-linux-x64.tar.gz -python3.11 -m venv .venv -. .venv/bin/activate + pip install --no-index --find-links ./wheelhouse src-auth-perms-sync + +# Run the CLI +src-auth-perms-sync --help ``` -After either install method, run the CLI from the activated virtual environment: +### Import into your own Python script -```bash -src-auth-perms-sync --help +```python +from pathlib import Path + +import src_auth_perms_sync as src + +config = src.Config( + src_endpoint="https://sourcegraph.example.com", + src_access_token="sgp_...", + maps_path=Path("/absolute/path/to/maps.yaml"), + apply=False, # Dry run (default), set to True to make changes +) + +succeeded = src.Set(config) + +# Other command wrappers: +# succeeded = src.Get(config) +# succeeded = src.Restore(config) +# succeeded = src.SyncSamlOrgs(config) ``` ## Inputs -- Environment variables +- Environment variables (CLI), or src.Config args (Python import) - `SRC_ENDPOINT` - `SRC_ACCESS_TOKEN` from a user with site-admin perms - - Supplied via the environment or a `.env` file - See [.env.example](./.env.example) -- YAML maps file `src-auth-perms-sync-runs//maps.yaml` +- YAML maps file + - By default: `src-auth-perms-sync-runs//maps.yaml` + - Or pass `--maps-path ./path/to/maps.yaml` - A list of mapping rules - Each mapping rule takes - - A list of filters for users - - A list of filters for repos + - A map of filters for users + - A map of filters for repos - See [maps-example.yaml](./maps-example.yaml) - - An empty maps.yaml file is created for you on the first --get run + - An empty maps.yaml file is created for you on the first `get` run ## Usage: Permission sync 1. **Get auth providers and code hosts** ```bash - uv run src-auth-perms-sync [--get] + src-auth-perms-sync get ``` - Queries the Sourcegraph instance for auth providers and code host connections - Writes generated reference files `auth-providers.yaml` and `code-hosts.yaml` under - `src-auth-perms-sync-runs//` + `src-auth-perms-sync-runs//` - Creates an empty `maps.yaml` if it doesn't exist - - Runs by default when no command is selected 2. **Configure mapping rules** - - Edit `maps.yaml` + - Edit `src-auth-perms-sync-runs//maps.yaml` - Add mapping rules under the `maps:` top level key - See [maps-example.yaml](./maps-example.yaml) 3. **Set: Dry run** ```bash - uv run src-auth-perms-sync --set maps.yaml --full + src-auth-perms-sync set --full ``` 4. **Set: Apply** ```bash - uv run src-auth-perms-sync --set maps.yaml --full --apply + src-auth-perms-sync set --full --apply ``` + - To use a maps file outside the generated endpoint directory, pass an + explicit path, for example `--maps-path ./maps.yaml` + 5. **Restore: Dry run** ```bash - uv run src-auth-perms-sync \ - --restore backups/maps.yaml/2026-04-27-08-24-25-set-apply/before.json + src-auth-perms-sync restore \ + --restore-path src-auth-perms-sync-runs//runs//before.json ``` - Roll back the explicit-permissions state on the instance to match a previously captured snapshot + - Relative `--restore-path` values are resolved from your current working directory 6. **Restore: Apply** ```bash - uv run src-auth-perms-sync \ - --restore backups/maps.yaml/2026-04-27-08-24-25-set-apply/before.json \ + src-auth-perms-sync restore \ + --restore-path src-auth-perms-sync-runs//runs//before.json \ --apply ``` @@ -193,7 +220,7 @@ src-auth-perms-sync --help 1. **Get user and org metadata** ```bash - uv run src-auth-perms-sync --sync-saml-orgs + src-auth-perms-sync sync-saml-orgs ``` - Queries the Sourcegraph instance for auth providers, users, users' SAML groups, and orgs @@ -202,45 +229,45 @@ src-auth-perms-sync --help 2. **Apply org sync** ```bash - uv run src-auth-perms-sync --sync-saml-orgs --apply + src-auth-perms-sync sync-saml-orgs --apply ``` - Creates the orgs if they don't exist, and sync the members from the SAML groups to the orgs - - `--sync-saml-orgs` can also be added to a `--set` run, to run both at the same time + - `--sync-saml-orgs` can also be added to a `set` run, to run both at the same time ## Options -Run `uv run src-auth-perms-sync --help` for options +Run `src-auth-perms-sync --help` for options ## File tree ```text -src-auth-perms-sync-runs/endpoint/ +src-auth-perms-sync-runs// ├── auth-providers.yaml ├── code-hosts.yaml ├── maps.yaml └── runs └── timestamp-command - ├── after.json ├── before.json + ├── after.json ├── diff.json ├── log.json └── maps.yaml ``` - The `src-auth-perms-sync-runs` dir is created under your current working directory -- The `endpoint` dir is created with the hostname from `SRC_ENDPOINT` +- The `` dir is created with the hostname from `SRC_ENDPOINT` - If `maps.yaml` doesn't exist already, it'll be created for you -- `auth-providers.yaml` and `code-hosts.yaml` are created / replaced by the `--get` command, +- `auth-providers.yaml` and `code-hosts.yaml` are created / replaced by the `get` command, for you to copy values from, to use in your `maps.yaml` -- Only one `maps.yaml` file can be used at a time per Sourcegraph instance, as each `--set --apply` +- Only one `maps.yaml` file can be used at a time per Sourcegraph instance, as each `set --apply` command resets the state on the Sourcegraph instance to the `maps.yaml` file which was used - Each run of the script creates a new `timestamp-command` dir under the `runs` dir, with: + - A `before.json` file, capturing the before state, which can be used in a restore run - A log file - A backup copy of the `maps.yaml` file which was used in that run - - A `before.json` file, capturing the before state, which can be restored from - Runs using `--apply` also create - An `after.json` file, capturing the new state - A `diff.json` file, a shorter, reviewable file containing the diffs between before and after - + diff --git a/dev/TODO.md b/dev/TODO.md index 7fe3c6a..53c7b2e 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -1,9 +1,42 @@ # TODO -## High priority: Bump src-py-lib after Node ID helper release +## High priority: Instrument with OpenTelemetry — in progress -- After releasing `src-py-lib` with Sourcegraph Node ID helpers, update - `pyproject.toml` and `uv.lock` to depend on that new version. +- [ ] Add OTel-native traces, metrics, and wide log events in `src-py-lib`. +- [ ] Add shared OTel bootstrap config/helpers with `--otel` and standard + `OTEL_*` env-var-backed CLI args. +- [ ] Replace custom trace-context propagation with OTel W3C propagation. +- [ ] Instrument shared HTTP and GraphQL clients manually, preserving safe + sanitized attributes and Sourcegraph-specific metadata. +- [ ] Rename Sourcegraph debug tracing from `--trace` to `--fetch-sg-traces`. +- [ ] Wire `src-auth-perms-sync` to the shared OTel bootstrap without doing + import-time logger/provider setup. +- [ ] Verify pyright, tests, and CLI help in both repos. + +## High priority: End to End test cases + +- Create test cases. Each test case should contain: + - Before state + - maps.yaml file + - Expected after state +- Script to run the script, and verify the after state matches the expected after state + +## High priority: Verify perms are updated when a user's SAML groups change + +- If a user gets added to a new SAML group, which hits a mapping, ensure they + get the new perms + +## High priority: Reduce worst-case full-permission sync load + +- Use the stress-run evidence in + [memory-efficiency.md](./memory-efficiency.md) + to request Sourcegraph bulk explicit-permission read and write APIs. +- Add an explicit destructive/performance-test mode to the e2e runner so giant + stress runs can skip or defer full restore cleanup when the goal is finding + the server-side breaking point. +- Revisit full snapshot capture once Sourcegraph exposes a bulk read path; + replace aliased `User.permissionsInfo.repositories(source: API)` calls before + raising concurrency further. ## Medium priority: Lightweight incremental updates @@ -25,11 +58,6 @@ - How do we avoid stampedes (e.g., bulk repo sync triggering thousands of re-runs)? -## Medium priority: Verify perms are updated when a user's SAML groups change - -- If a user gets added to a new SAML group, which hits a mapping, ensure they - get the new perms - ## Low priority: Repo-centric path, when users > repos, or for cross-checking We previously had a repo-centric capture path @@ -69,7 +97,14 @@ If/when we revisit: 3. Add a CLI flag (e.g. `--cross-check-capture`) gated behind a clear "this doubles capture cost" warning. -## Low priority: Expand group-membership filters beyond SAML +## Low priority: Grouped full-set plan if memory is still too high + +Phase 1 now avoids per-repo username sets for non-overlapping full-set maps. +If memory remains too high after re-measuring, implement the Phase 2 grouped +plan in [mapping-efficiency.md](./mapping-efficiency.md): combine map-entry +overlays into final groups of repos that share the same desired username tuple. + +## Low priority: Expand group-membership filters beyond SAML `allowGroups`-style enforcement exists on more than just SAML, but only SAML actually persists the group list. Recovery options for each: diff --git a/dev/dead-code-audit.md b/dev/audit-dead-code.md similarity index 100% rename from dev/dead-code-audit.md rename to dev/audit-dead-code.md diff --git a/dev/git-worktrees.md b/dev/git-worktrees.md deleted file mode 100644 index 6f83ed5..0000000 --- a/dev/git-worktrees.md +++ /dev/null @@ -1,89 +0,0 @@ -# Git worktrees - -Git worktrees let one repository have multiple checkout directories that share -the same object database. Each worktree can use a different branch, index, and -working tree. - -Use them when the human and one or more agents need to work in parallel without -mixing unrelated uncommitted changes in one checkout. - -## Why use worktrees for agent work - -Benefits: - -- Each task gets a clean branch and clean index. -- Agents cannot accidentally edit or stage the human's current local changes. -- VS Code review is clearer because Source Control shows one task's changes. -- Branches can be merged or rebased one at a time. -- Git prevents the same branch from being checked out in two worktrees. - -Worktrees do not remove real conflicts. If two branches edit the same lines or -the same behavior, the conflict still has to be resolved when those branches are -merged or rebased. Worktrees make that conflict explicit instead of silently -mixing edits in one dirty working tree. - -## Create a task worktree - -From the main checkout: - -```sh -git worktree add ../src-auth-perms-sync-backup-diffs \ - -b amp/backup-diff-files \ - HEAD -``` - -Then work from the new directory: - -```sh -cd ../src-auth-perms-sync-backup-diffs -``` - -If the task should start from a remote branch instead of the current commit, -replace `HEAD` with that branch, for example `origin/split-main-into-modules`. - -## Review in VS Code - -Open the task worktree directly: - -```sh -code ../src-auth-perms-sync-backup-diffs -``` - -VS Code treats it like a normal repo checkout. The Source Control view shows the -changes for that worktree only, without unrelated edits from other worktrees. - -## Merge conflict expectations - -Worktrees reduce accidental interference, not semantic overlap. - -Good parallelism: - -- One agent updates SAML org sync. -- Another agent updates packaging docs. -- The human edits a config example. - -Risky parallelism: - -- Two agents refactor `src/src_auth_perms_sync/cli.py` at the same time. -- One branch renames functions while another branch edits their call sites. -- Two branches change the same GraphQL mutation flow. - -For long-running tasks, rebase the task branch regularly: - -```sh -git fetch origin -git rebase origin/split-main-into-modules -``` - -Resolve any conflicts in the task worktree, run validation, then continue. - -## Clean up a finished worktree - -After the branch is merged or no longer needed: - -```sh -git worktree remove ../src-auth-perms-sync-backup-diffs -git branch -d amp/backup-diff-files -``` - -Use `git branch -D` only for an unmerged branch that is intentionally discarded. diff --git a/dev/mapping-efficiency.md b/dev/mapping-efficiency.md new file mode 100644 index 0000000..7e2d2a5 --- /dev/null +++ b/dev/mapping-efficiency.md @@ -0,0 +1,175 @@ +# Mapping efficiency + +## Rectangular maps example + +Input maps + +```yaml +maps: + - name: engineers get generated repos + users: + usernames: + - alice + - bob + - carol + repos: + names: + - repo-1 + - repo-2 + - repo-3 +``` + +### Original + +- Repo-centric plan, but every repo gets a full copy of the list of users, + so the memory storage size is truly users x repos +- If your list of users is 1,000 users, and 10 MB RAM, and you have 1,000 repos, + then this is 1,000,000 repo+user pairs, which is 1,000 x 10 MB RAM = 10 GB RAM +- This is a "full square" + +repo-1 -> (alice, bob, carol) +repo-2 -> (alice, bob, carol) +repo-3 -> (alice, bob, dan) +repo-4 -> (alice, bob, dan) + +### Current: Groups of users + +- We anticipate that many users will be grouped up into a small number of sets, + and that most repos' perms will be one of the sets +- This example cuts in ~half the amount memory consumed by lists of users as the Current example + +user-group-a = (alice, bob, carol) +user-group-b = (alice, bob, dan) + +repo-1 -> user-group-a +repo-2 -> user-group-a +repo-3 -> user-group-b +repo-4 -> user-group-b + +### Phase 2: Groups of users x Groups of repos + +- Realistically, we anticipate that many repos will also be grouped up into a small number of sets + +user-group-a = (alice, bob, carol) +user-group-b = (alice, bob, dan) + +repo-group-1 = (repo-1, repo-2) +repo-group-2 = (repo-3, repo-4) + +user-group-a -> repo-group-1 +user-group-b -> repo-group-2 + +## Current semantics + +Each `maps:` entry is naturally a grouped rule: + +```text +selected users × selected repos +``` + +The maps schema keeps this restrictive: `users:` and `repos:` are selector +maps, top-level selectors inside each map are ANDed together, and values inside +one selector list are ORed together. To OR across selectors, write more +top-level `maps:` entries. + +The full-set command must combine all entries before mutating Sourcegraph, +because `setRepositoryPermissionsForUsers` overwrites a repo's whole explicit +permission list. The required final state is: + +```text +desired_users(repo) = union(users_i for each map_i where repo is in repos_i) +``` + +Only after this union is known can the command safely apply per-repo overwrite +mutations. + +## Phase 1: lazy per-repo union sets + +The old full-set planner immediately expanded every map entry into: + +```text +repo_id -> set(username) +``` + +That is expensive for rectangular maps such as `10000 users × 1000 repos`: +the username strings are shared, but each repo owns a large Python set with one +hash-table entry per planned grant. + +Phase 1 keeps the existing downstream plan shape: + +```text +repo_id -> tuple(username) +``` + +but builds it more carefully: + +1. For a non-overlapping map entry, create one sorted username tuple and reuse + that same tuple for every matched repo. +2. If a later map entry touches a repo that already has users, promote only + that repo to a temporary set and union the usernames. +3. Convert only promoted repos back to sorted tuples after all map entries are + processed. + +This preserves the hard invariant while avoiding the large per-repo sets in +the common non-overlapping rectangular case. + +Measured on the sgdev test instance, the dry-run `10000x1000` case planned 10M +grants. Before Phase 1 it peaked at about 651 MiB RSS; after Phase 1 it peaked +at about 68 MiB RSS. + +## Phase 2: final grouped plan, if needed + +If Phase 1 is not enough, store the combined final plan as groups of repos that +share the same final user set: + +```text +tuple(username) -> tuple(repo_id) +``` + +This is not just one group per `maps:` entry. Map entries are input overlays; +final groups are the compressed result after every map entry has been unioned +onto the repo space. + +Example: + +```text +map A: alice,bob -> repo-1,repo-2 +map B: bob,chris -> repo-2,repo-3 + +final: +alice,bob -> repo-1 +alice,bob,chris -> repo-2 +bob,chris -> repo-3 +``` + +One practical data model would be: + +```python +@dataclass(frozen=True) +class RepositoryPermissionGroup: + usernames: tuple[str, ...] + repository_ids: tuple[str, ...] + + +@dataclass(frozen=True) +class FullSetPlan: + groups: tuple[RepositoryPermissionGroup, ...] + repo_names: dict[str, str] + repo_to_group_index: dict[str, int] + + def usernames_for_repo(self, repo_id: str) -> tuple[str, ...]: + return self.groups[self.repo_to_group_index[repo_id]].usernames +``` + +Apply still happens per repo: + +```text +for group in groups: + for repo_id in group.repository_ids: + setRepositoryPermissionsForUsers(repo_id, group.usernames) +``` + +Phase 2 touches more code than Phase 1: projected snapshots, diffs, +short-circuit filtering, apply iteration, and validation all currently expect +direct `repo_id -> usernames` lookups. Do it only if Phase 1 measurements still +show unacceptable memory use. diff --git a/dev/memory-efficiency-analyze.py b/dev/memory-efficiency-analyze.py new file mode 100755 index 0000000..f9c024a --- /dev/null +++ b/dev/memory-efficiency-analyze.py @@ -0,0 +1,618 @@ +#!/usr/bin/env python3 +"""Fit a Sourcegraph permissions memory model from e2e memory result JSON. + +The model is intentionally small and dependency-free: + + peak RSS MiB = intercept + users*b1 + repos*b2 + grants*b3 + +Use one command mode per fit. Mixing backup, no-backup, get, set, and restore +runs makes the per-grant coefficient much less useful. +""" + +from __future__ import annotations + +import argparse +import json +import math +import re +import statistics +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +FEATURE_NAMES = ("users", "repos", "grants") +COEFFICIENT_SCALE = { + "users": "bytes/user", + "repos": "bytes/repo", + "grants": "bytes/grant", +} + + +@dataclass(frozen=True) +class WorkloadDimensions: + """Canonical workload dimensions used by the memory model.""" + + users: float | None + repos: float | None + grants: float | None + + +@dataclass(frozen=True) +class MemoryObservation: + """One e2e command result with peak memory and workload dimensions.""" + + source_path: str + variant: str + case_name: str + command: str + iteration: int + peak_resident_megabytes: float + dimensions: WorkloadDimensions + + +@dataclass(frozen=True) +class MemoryModel: + """Fitted linear memory model.""" + + feature_names: tuple[str, ...] + coefficients_megabytes: dict[str, float] + observation_count: int + r_squared: float | None + mean_absolute_error_megabytes: float + p95_absolute_error_megabytes: float + max_absolute_error_megabytes: float + + +@dataclass(frozen=True) +class MemoryEstimate: + """Predicted memory for a proposed users x repos workload.""" + + dimensions: WorkloadDimensions + peak_resident_megabytes: float + peak_resident_megabytes_with_headroom: float + headroom_percent: float + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Fit a fixed + users + repos + grants memory model from e2e JSON.", + ) + parser.add_argument( + "results_json", + nargs="+", + type=Path, + help="One or more JSON files written by dev/test-end-to-end.py --results-json.", + ) + parser.add_argument( + "--variant", + help="Only include one variant, e.g. candidate or baseline.", + ) + parser.add_argument( + "--command", + help="Only include one structured command, e.g. set_full or get.", + ) + parser.add_argument( + "--case-regex", + help="Only include cases whose e2e case name matches this regular expression.", + ) + parser.add_argument( + "--features", + default="users,repos,grants", + help="Comma-separated model features from users,repos,grants (default: all).", + ) + parser.add_argument( + "--min-grants", + type=float, + default=1.0, + help="Drop observations below this grant count (default: 1).", + ) + parser.add_argument( + "--estimate-users", + type=float, + help="Estimate memory for this many users.", + ) + parser.add_argument( + "--estimate-repos", + type=float, + help="Estimate memory for this many repos.", + ) + parser.add_argument( + "--estimate-grants", + type=float, + help="Estimate memory for this many grants; defaults to users * repos.", + ) + parser.add_argument( + "--headroom-percent", + type=float, + default=30.0, + help="Headroom to add to estimates (default: 30).", + ) + parser.add_argument( + "--json", + action="store_true", + help="Write machine-readable JSON instead of a text report.", + ) + arguments = parser.parse_args() + + feature_names = parse_feature_names(arguments.features) + observations = load_observations(arguments.results_json) + filtered_observations = filter_observations( + observations, + variant=arguments.variant, + command=arguments.command, + case_regex=arguments.case_regex, + min_grants=arguments.min_grants, + ) + model_observations = observations_with_features(filtered_observations, feature_names) + minimum_observations = len(feature_names) + 1 + if len(model_observations) < minimum_observations: + print( + "Need at least " + f"{minimum_observations} observations with {', '.join(feature_names)} " + f"to fit this model; found {len(model_observations)}.", + file=sys.stderr, + ) + return 2 + + try: + model = fit_memory_model(model_observations, feature_names) + except ValueError as error: + print(f"Could not fit memory model: {error}", file=sys.stderr) + print( + "Try filtering to one command mode, adding varied users x repos shapes, " + "or using fewer --features.", + file=sys.stderr, + ) + return 2 + + estimate = build_estimate( + model, + feature_names, + estimate_users=arguments.estimate_users, + estimate_repos=arguments.estimate_repos, + estimate_grants=arguments.estimate_grants, + headroom_percent=arguments.headroom_percent, + ) + if arguments.json: + write_json_report(model, model_observations, estimate) + else: + write_text_report(model, model_observations, estimate) + return 0 + + +def parse_feature_names(raw_features: str) -> tuple[str, ...]: + names = tuple(name.strip() for name in raw_features.split(",") if name.strip()) + invalid = sorted(set(names) - set(FEATURE_NAMES)) + if invalid: + raise SystemExit(f"Unknown feature(s): {', '.join(invalid)}") + duplicates = sorted({name for name in names if names.count(name) > 1}) + if duplicates: + raise SystemExit(f"Duplicate feature(s): {', '.join(duplicates)}") + if not names: + raise SystemExit("At least one feature is required.") + return names + + +def load_observations(paths: list[Path]) -> list[MemoryObservation]: + observations: list[MemoryObservation] = [] + for path in paths: + with path.open(encoding="utf-8") as input_file: + payload: object = json.load(input_file) + for result in result_mappings(payload): + observation = observation_from_result(path, result) + if observation is not None: + observations.append(observation) + return observations + + +def result_mappings(payload: object) -> list[dict[str, Any]]: + if isinstance(payload, dict): + mapping = cast(dict[str, Any], payload) + results = mapping.get("results") + if isinstance(results, list): + return mapping_items(cast(list[object], results)) + if "memory" in mapping and "workload" in mapping: + return [mapping] + if isinstance(payload, list): + return mapping_items(cast(list[object], payload)) + return [] + + +def mapping_items(values: list[object]) -> list[dict[str, Any]]: + """Return only dict-like JSON objects from a JSON list.""" + return [cast(dict[str, Any], value) for value in values if isinstance(value, dict)] + + +def observation_from_result(path: Path, result: dict[str, Any]) -> MemoryObservation | None: + memory = object_mapping(result.get("memory")) + workload = object_mapping(result.get("workload")) + if memory is None or workload is None: + return None + peak_resident_megabytes = first_number(memory, ("peak_rss_mb", "external_peak_rss_mb")) + if peak_resident_megabytes is None: + return None + return MemoryObservation( + source_path=str(path), + variant=string_value(result.get("variant")), + case_name=string_value(result.get("case")), + command=string_value(result.get("command")), + iteration=integer_value(result.get("iteration")), + peak_resident_megabytes=peak_resident_megabytes, + dimensions=WorkloadDimensions( + users=first_number( + workload, + ( + "memory_model_user_count", + "selected_user_count", + "captured_user_count", + "snapshot_user_count_max", + "user_count", + "total_users_scanned", + "sourcegraph_user_count", + "total_users", + ), + ), + repos=first_number( + workload, + ( + "memory_model_repo_count", + "planned_repo_count", + "restore_snapshot_repo_count", + "snapshot_repos_with_explicit_grants_max", + "repos_with_explicit_grants", + "loaded_repo_count", + "repo_count", + ), + ), + grants=first_number( + workload, + ( + "memory_model_grant_count", + "planned_total_grants", + "restore_snapshot_total_grants", + "selected_total_grants", + "snapshot_total_grants_max", + "total_grants", + "apply_payload_grant_count", + ), + ), + ), + ) + + +def filter_observations( + observations: list[MemoryObservation], + *, + variant: str | None, + command: str | None, + case_regex: str | None, + min_grants: float, +) -> list[MemoryObservation]: + pattern = re.compile(case_regex) if case_regex else None + filtered: list[MemoryObservation] = [] + for observation in observations: + if variant is not None and observation.variant != variant: + continue + if command is not None and observation.command != command: + continue + if pattern is not None and pattern.search(observation.case_name) is None: + continue + if observation.dimensions.grants is None or observation.dimensions.grants < min_grants: + continue + filtered.append(observation) + return filtered + + +def observations_with_features( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> list[MemoryObservation]: + return [ + observation + for observation in observations + if all(feature_value(observation.dimensions, name) is not None for name in feature_names) + ] + + +def fit_memory_model( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> MemoryModel: + feature_scales = feature_scale_by_name(observations, feature_names) + matrix = [ + [1.0] + + [ + required_feature_value(observation.dimensions, feature_name) + / feature_scales[feature_name] + for feature_name in feature_names + ] + for observation in observations + ] + targets = [observation.peak_resident_megabytes for observation in observations] + scaled_coefficients = solve_normal_equations(matrix, targets) + coefficients = {"intercept": scaled_coefficients[0]} + for feature_index, feature_name in enumerate(feature_names, start=1): + coefficients[feature_name] = ( + scaled_coefficients[feature_index] / feature_scales[feature_name] + ) + predictions = [ + predict_megabytes(coefficients, observation.dimensions) for observation in observations + ] + residuals = [ + target - prediction for target, prediction in zip(targets, predictions, strict=True) + ] + absolute_residuals = [abs(residual) for residual in residuals] + target_mean = statistics.fmean(targets) + residual_sum_squares = sum(residual * residual for residual in residuals) + total_sum_squares = sum((target - target_mean) ** 2 for target in targets) + return MemoryModel( + feature_names=feature_names, + coefficients_megabytes=coefficients, + observation_count=len(observations), + r_squared=( + None if total_sum_squares == 0 else 1.0 - residual_sum_squares / total_sum_squares + ), + mean_absolute_error_megabytes=statistics.fmean(absolute_residuals), + p95_absolute_error_megabytes=percentile(absolute_residuals, 95.0), + max_absolute_error_megabytes=max(absolute_residuals), + ) + + +def feature_scale_by_name( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> dict[str, float]: + scales: dict[str, float] = {} + for feature_name in feature_names: + maximum = max( + abs(required_feature_value(observation.dimensions, feature_name)) + for observation in observations + ) + scales[feature_name] = maximum if maximum > 0 else 1.0 + return scales + + +def solve_normal_equations(matrix: list[list[float]], targets: list[float]) -> list[float]: + column_count = len(matrix[0]) + normal_matrix = [[0.0 for _ in range(column_count)] for _ in range(column_count)] + normal_targets = [0.0 for _ in range(column_count)] + for row, target in zip(matrix, targets, strict=True): + for row_index in range(column_count): + normal_targets[row_index] += row[row_index] * target + for column_index in range(column_count): + normal_matrix[row_index][column_index] += row[row_index] * row[column_index] + return solve_linear_system(normal_matrix, normal_targets) + + +def solve_linear_system(matrix: list[list[float]], values: list[float]) -> list[float]: + size = len(values) + augmented = [matrix[row_index][:] + [values[row_index]] for row_index in range(size)] + for pivot_index in range(size): + pivot_row = max( + range(pivot_index, size), + key=lambda row_index: abs(augmented[row_index][pivot_index]), + ) + pivot_value = augmented[pivot_row][pivot_index] + if abs(pivot_value) < 1e-12: + raise ValueError("features are collinear or the sample is too small") + augmented[pivot_index], augmented[pivot_row] = augmented[pivot_row], augmented[pivot_index] + for column_index in range(pivot_index, size + 1): + augmented[pivot_index][column_index] /= pivot_value + for row_index in range(size): + if row_index == pivot_index: + continue + factor = augmented[row_index][pivot_index] + for column_index in range(pivot_index, size + 1): + augmented[row_index][column_index] -= factor * augmented[pivot_index][column_index] + return [augmented[row_index][size] for row_index in range(size)] + + +def build_estimate( + model: MemoryModel, + feature_names: tuple[str, ...], + *, + estimate_users: float | None, + estimate_repos: float | None, + estimate_grants: float | None, + headroom_percent: float, +) -> MemoryEstimate | None: + if estimate_users is None and estimate_repos is None and estimate_grants is None: + return None + if "users" in feature_names and estimate_users is None: + raise SystemExit("--estimate-users is required because users is in --features.") + if "repos" in feature_names and estimate_repos is None: + raise SystemExit("--estimate-repos is required because repos is in --features.") + if "grants" in feature_names and estimate_grants is None: + if estimate_users is None or estimate_repos is None: + raise SystemExit( + "--estimate-grants is required unless --estimate-users and --estimate-repos " + "are both set." + ) + estimate_grants = estimate_users * estimate_repos + dimensions = WorkloadDimensions( + users=estimate_users, + repos=estimate_repos, + grants=estimate_grants, + ) + peak_resident_megabytes = predict_megabytes(model.coefficients_megabytes, dimensions) + return MemoryEstimate( + dimensions=dimensions, + peak_resident_megabytes=peak_resident_megabytes, + peak_resident_megabytes_with_headroom=peak_resident_megabytes + * (1.0 + headroom_percent / 100.0), + headroom_percent=headroom_percent, + ) + + +def predict_megabytes( + coefficients_megabytes: dict[str, float], dimensions: WorkloadDimensions +) -> float: + prediction = coefficients_megabytes["intercept"] + for feature_name in FEATURE_NAMES: + coefficient = coefficients_megabytes.get(feature_name) + value = feature_value(dimensions, feature_name) + if coefficient is not None and value is not None: + prediction += coefficient * value + return prediction + + +def write_text_report( + model: MemoryModel, observations: list[MemoryObservation], estimate: MemoryEstimate | None +) -> None: + print(f"Observations used: {model.observation_count}") + print(f"Features: {', '.join(model.feature_names)}") + print("\nCoefficients:") + print(f" intercept: {model.coefficients_megabytes['intercept']:.3f} MiB") + for feature_name in model.feature_names: + coefficient_megabytes = model.coefficients_megabytes[feature_name] + coefficient_bytes = coefficient_megabytes * 1024.0 * 1024.0 + print( + f" {feature_name}: {coefficient_megabytes:.9f} MiB/unit " + f"({coefficient_bytes:.1f} {COEFFICIENT_SCALE[feature_name]})" + ) + r_squared = "n/a" if model.r_squared is None else f"{model.r_squared:.4f}" + print("\nFit quality:") + print(f" R²: {r_squared}") + print(f" mean absolute error: {model.mean_absolute_error_megabytes:.2f} MiB") + print(f" p95 absolute error: {model.p95_absolute_error_megabytes:.2f} MiB") + print(f" max absolute error: {model.max_absolute_error_megabytes:.2f} MiB") + print("\nObserved range:") + print_dimension_range(observations, "users") + print_dimension_range(observations, "repos") + print_dimension_range(observations, "grants") + if estimate is not None: + print("\nEstimate:") + print(f" users: {format_optional_number(estimate.dimensions.users)}") + print(f" repos: {format_optional_number(estimate.dimensions.repos)}") + print(f" grants: {format_optional_number(estimate.dimensions.grants)}") + print(f" peak RSS: {estimate.peak_resident_megabytes:.1f} MiB") + print( + f" with {estimate.headroom_percent:g}% headroom: " + f"{estimate.peak_resident_megabytes_with_headroom:.1f} MiB" + ) + + +def write_json_report( + model: MemoryModel, observations: list[MemoryObservation], estimate: MemoryEstimate | None +) -> None: + report: dict[str, Any] = { + "observation_count": model.observation_count, + "features": list(model.feature_names), + "coefficients_mib": model.coefficients_megabytes, + "coefficients_bytes": { + feature_name: model.coefficients_megabytes[feature_name] * 1024.0 * 1024.0 + for feature_name in model.feature_names + }, + "fit": { + "r_squared": model.r_squared, + "mean_absolute_error_mib": model.mean_absolute_error_megabytes, + "p95_absolute_error_mib": model.p95_absolute_error_megabytes, + "max_absolute_error_mib": model.max_absolute_error_megabytes, + }, + "observed_range": observed_range_to_json(observations), + "estimate": estimate_to_json(estimate), + } + json.dump(report, sys.stdout, indent=2, sort_keys=True) + sys.stdout.write("\n") + + +def print_dimension_range(observations: list[MemoryObservation], feature_name: str) -> None: + values = [ + value + for observation in observations + if (value := feature_value(observation.dimensions, feature_name)) is not None + ] + if not values: + print(f" {feature_name}: n/a") + return + print(f" {feature_name}: {format_number(min(values))} .. {format_number(max(values))}") + + +def observed_range_to_json(observations: list[MemoryObservation]) -> dict[str, dict[str, float]]: + ranges: dict[str, dict[str, float]] = {} + for feature_name in FEATURE_NAMES: + values = [ + value + for observation in observations + if (value := feature_value(observation.dimensions, feature_name)) is not None + ] + if values: + ranges[feature_name] = {"min": min(values), "max": max(values)} + return ranges + + +def estimate_to_json(estimate: MemoryEstimate | None) -> dict[str, Any] | None: + if estimate is None: + return None + return { + "users": estimate.dimensions.users, + "repos": estimate.dimensions.repos, + "grants": estimate.dimensions.grants, + "peak_rss_mib": estimate.peak_resident_megabytes, + "headroom_percent": estimate.headroom_percent, + "peak_rss_mib_with_headroom": estimate.peak_resident_megabytes_with_headroom, + } + + +def object_mapping(value: object) -> dict[str, Any] | None: + return cast(dict[str, Any], value) if isinstance(value, dict) else None + + +def first_number(mapping: dict[str, Any], names: tuple[str, ...]) -> float | None: + for name in names: + value = mapping.get(name) + if isinstance(value, bool): + continue + if isinstance(value, int | float): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + continue + return None + + +def string_value(value: object) -> str: + return value if isinstance(value, str) else "" + + +def integer_value(value: object) -> int: + if isinstance(value, bool): + return 0 + return value if isinstance(value, int) else 0 + + +def feature_value(dimensions: WorkloadDimensions, feature_name: str) -> float | None: + if feature_name == "users": + return dimensions.users + if feature_name == "repos": + return dimensions.repos + if feature_name == "grants": + return dimensions.grants + raise ValueError(f"Unknown feature: {feature_name}") + + +def required_feature_value(dimensions: WorkloadDimensions, feature_name: str) -> float: + value = feature_value(dimensions, feature_name) + if value is None: + raise ValueError(f"Observation is missing feature: {feature_name}") + return value + + +def percentile(values: list[float], percentile_value: float) -> float: + if not values: + return math.nan + sorted_values = sorted(values) + index = math.ceil((percentile_value / 100.0) * len(sorted_values)) - 1 + return sorted_values[min(max(index, 0), len(sorted_values) - 1)] + + +def format_optional_number(value: float | None) -> str: + return "n/a" if value is None else format_number(value) + + +def format_number(value: float) -> str: + return f"{value:.0f}" if value.is_integer() else f"{value:.3f}" + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dev/memory-efficiency-generate.py b/dev/memory-efficiency-generate.py new file mode 100755 index 0000000..e2b6f3d --- /dev/null +++ b/dev/memory-efficiency-generate.py @@ -0,0 +1,1142 @@ +#!/usr/bin/env python3 +"""Generate and optionally run maps.yaml files for memory-model sweeps. + +The generated maps use exact `users.usernames` and `repos.names` filters. +Different workload shapes stress different parts of mapping resolution and +full-set planning, while preserving known selected user/repo/grant counts. + +By default this script only generates the maps. Pass `--run` to execute the +CLI in dry-run mode. Pass `--mode apply-with-backup --allow-apply` or +`--mode apply-no-backup --allow-apply` only on a scratch instance; those +modes mutate explicit permissions. +""" + +from __future__ import annotations + +import argparse +import csv +import datetime +import json +import os +import re +import shlex +import subprocess +import sys +import time +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, TypeAlias, cast +from urllib.parse import urlsplit + +import src_py_lib as src +import yaml +from src_py_lib.utils.config import load_config + +QUERY_EXTERNAL_SERVICES = """ +query MemoryModelExternalServices($first: Int!, $after: String) { + externalServices(first: $first, after: $after) { + nodes { + id + kind + displayName + repoCount + url + } + pageInfo { hasNextPage endCursor } + } +} +""" + +QUERY_USERNAMES = """ +query MemoryModelUsers($first: Int!, $after: String) { + users(first: $first, after: $after) { + nodes { username } + pageInfo { hasNextPage endCursor } + } +} +""" + +QUERY_USER_COUNT = """ +query MemoryModelUserCount { + users(first: 1) { totalCount } +} +""" + +QUERY_REPOS_BY_EXTERNAL_SERVICE = """ +query MemoryModelRepos($externalService: ID!, $first: Int!, $after: String) { + repositories( + first: $first + after: $after + externalService: $externalService + cloned: true + notCloned: true + ) { + nodes { name } + pageInfo { hasNextPage endCursor } + } +} +""" + +DEFAULT_CASES = "auto" +DEFAULT_USER_POINTS = (1, 10, 100, 1000, 10000) +DEFAULT_REPO_POINTS = (1, 10, 100, 1000) +DEFAULT_COMMAND = "uv run src-auth-perms-sync" +LOG_PATH_PATTERN = re.compile(r"Writing log events to (.+?/log\.json)\.") +RunMode = Literal["dry-run", "apply-with-backup", "apply-no-backup"] +SweepSuite = Literal["gentle", "breaking"] +StressShape: TypeAlias = Literal[ + "rectangle", + "user-shards", + "repo-shards", + "diagonal-shards", + "duplicate-rules", +] +STRESS_SHAPES: tuple[StressShape, ...] = ( + "rectangle", + "user-shards", + "repo-shards", + "diagonal-shards", + "duplicate-rules", +) +BREAKING_SHAPES: tuple[StressShape, ...] = ( + "rectangle", + "user-shards", + "repo-shards", + "duplicate-rules", +) + + +class SweepSourcegraphConfig(src.SourcegraphClientConfig): + """Sourcegraph connection config for discovery queries.""" + + +@dataclass(frozen=True) +class SweepCase: + """One generated workload case.""" + + users: int + repos: int + shape: StressShape = "rectangle" + rule_count: int = 1 + + @property + def grants(self) -> int: + """Final unique planned grants after map-entry unioning.""" + return unique_grant_count(self) + + @property + def raw_rule_grants(self) -> int: + """Total per-rule rectangle grants before cross-rule unioning.""" + return raw_rule_grant_count(self) + + @property + def map_rule_count(self) -> int: + return map_rule_count(self) + + @property + def name(self) -> str: + return ( + f"{self.shape}-m{self.map_rule_count:03d}-" + f"u{self.users:05d}-r{self.repos:05d}-g{self.grants:012d}" + ) + + +@dataclass(frozen=True) +class ExternalServiceChoice: + """Code host connection selected for repo sampling.""" + + graphql_id: str + database_id: int + display_name: str + kind: str + url: str + repo_count: int + + +@dataclass(frozen=True) +class GeneratedMap: + """One generated maps.yaml file and its workload dimensions.""" + + case: SweepCase + path: Path + + +@dataclass(frozen=True) +class CommandRunResult: + """One CLI execution result written in memory-efficiency-analyze.py-compatible shape.""" + + generated_map: GeneratedMap + mode: RunMode + return_code: int + elapsed_seconds: float + output_path: Path + log_path: Path | None + run_record: dict[str, Any] | None + + +def main() -> int: + parser = build_parser() + arguments = parser.parse_args() + mode = cast(RunMode, arguments.mode) + suite = cast(SweepSuite, arguments.suite) + if mode != "dry-run" and not arguments.allow_apply: + parser.error(f"--mode {mode} requires --allow-apply") + if arguments.rule_count < 1: + parser.error("--rule-count must be >= 1") + + config = sourcegraph_config(arguments) + output_dir = arguments.output_dir or default_output_dir(config.src_endpoint) + maps_dir = output_dir / "maps" + output_dir.mkdir(parents=True, exist_ok=True) + maps_dir.mkdir(parents=True, exist_ok=True) + + requested_cases = parse_cases(arguments.cases) + + client = src.SourcegraphClient( + endpoint=config.src_endpoint, + token=config.src_access_token, + http=src.HTTPClient( + timeout=arguments.http_timeout_seconds, + max_connections=max(4, arguments.parallelism), + ), + ) + try: + external_services = list_external_services(client) + inventory_repo_count = sum(service.repo_count for service in external_services) + service = choose_external_service(external_services, arguments.external_service_id) + total_user_count = count_users(client) + base_cases = requested_cases or default_cases_for_inventory( + total_user_count, + service.repo_count, + suite=suite, + ) + shapes = parse_shapes(arguments.shapes, suite) + cases = expand_cases(base_cases, shapes, arguments.rule_count) + max_users = max(sweep_case.users for sweep_case in cases) + max_repos = max(sweep_case.repos for sweep_case in cases) + usernames = list_usernames(client, max_users, arguments.page_size) + repo_names = list_repo_names(client, service, max_repos, arguments.page_size) + finally: + client.http.close() + + generated_maps = write_maps(maps_dir, cases, usernames, repo_names, service) + write_manifest( + output_dir, + generated_maps, + service, + config.src_endpoint, + inventory_repo_count, + total_user_count, + ) + print(f"Generated {len(generated_maps)} maps.yaml file(s) under {maps_dir}") + print( + f"Selected code host: {service.display_name} id={service.database_id} " + f"repos={service.repo_count}; instance repoCount sum={inventory_repo_count}" + ) + + if not arguments.run: + print("Generation only. Re-run with --run to execute the sweep.") + return 0 + + run_results = run_sweep( + generated_maps, + endpoint=config.src_endpoint, + access_token=config.src_access_token, + output_dir=output_dir, + command=arguments.command, + mode=mode, + parallelism=arguments.parallelism, + explicit_permissions_batch_size=arguments.explicit_permissions_batch_size, + http_timeout_seconds=arguments.http_timeout_seconds, + sample_interval=arguments.sample_interval, + trace=arguments.trace, + sourcegraph_user_count=total_user_count, + sourcegraph_inventory_repo_count=inventory_repo_count, + ) + write_results(output_dir, run_results, inventory_repo_count, total_user_count) + return 0 if all(result.return_code == 0 for result in run_results) else 1 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Generate and optionally run maps.yaml memory-model sweep cases.", + ) + parser.add_argument( + "--env-file", + type=Path, + default=Path(".env"), + help="Environment file with SRC_ENDPOINT and SRC_ACCESS_TOKEN (default: .env).", + ) + parser.add_argument("--src-endpoint", help="Override SRC_ENDPOINT for discovery and runs.") + parser.add_argument("--src-access-token", help="Override SRC_ACCESS_TOKEN.") + parser.add_argument( + "--output-dir", + type=Path, + help=( + "Directory for generated maps and result files. " + "Defaults under src-auth-perms-sync-runs/." + ), + ) + parser.add_argument( + "--suite", + choices=("gentle", "breaking"), + default="gentle", + help=( + "Case-size preset. gentle keeps the previous low-risk auto sweep; " + "breaking adds larger dimensions intended to find the failure point." + ), + ) + parser.add_argument( + "--cases", + default=DEFAULT_CASES, + help=( + "Comma-separated users x repos cases, e.g. '100x10,1000x25', " + "or 'auto' for a gentle inventory-aware sweep. Default: auto." + ), + ) + parser.add_argument( + "--shapes", + default="auto", + help=( + "Comma-separated workload shapes: " + f"{', '.join(STRESS_SHAPES)}. " + "Default auto means rectangle for --suite gentle and a mixed set " + "for --suite breaking." + ), + ) + parser.add_argument( + "--rule-count", + type=int, + default=10, + help="Target map-rule/selector count for multi-rule shapes (default: 10).", + ) + parser.add_argument( + "--external-service-id", + type=int, + help="Decoded external service DB id to sample repos from. Defaults to largest repoCount.", + ) + parser.add_argument( + "--page-size", + type=int, + default=1000, + help="GraphQL page size for discovery queries (default: 1000).", + ) + parser.add_argument( + "--run", + action="store_true", + help="Run src-auth-perms-sync for each generated maps.yaml file.", + ) + parser.add_argument( + "--mode", + choices=("dry-run", "apply-with-backup", "apply-no-backup"), + default="dry-run", + help="Run mode when --run is set. Default is dry-run.", + ) + parser.add_argument( + "--allow-apply", + action="store_true", + help="Required safety acknowledgement for mutating --mode values.", + ) + parser.add_argument( + "--command", + default=DEFAULT_COMMAND, + help=( + "Base command used to invoke the CLI; the script appends " + f"'set --maps-path ...' (default: {DEFAULT_COMMAND!r})." + ), + ) + parser.add_argument( + "--parallelism", + type=int, + default=1, + help="CLI --parallelism for sweep runs. Default 1 is gentle on pgsql.", + ) + parser.add_argument( + "--explicit-permissions-batch-size", + type=int, + default=25, + help="CLI --explicit-permissions-batch-size for sweep runs (default: 25).", + ) + parser.add_argument( + "--http-timeout-seconds", + type=float, + default=120.0, + help="HTTP timeout for discovery and CLI runs (default: 120).", + ) + parser.add_argument( + "--sample-interval", + type=float, + default=1.0, + help="CLI --sample-interval for resource samples (default: 1).", + ) + parser.add_argument( + "--trace", + action="store_true", + help="Pass --trace to src-auth-perms-sync sweep runs.", + ) + return parser + + +def sourcegraph_config(arguments: argparse.Namespace) -> SweepSourcegraphConfig: + overrides: dict[str, object] = {} + if arguments.src_endpoint: + overrides["src_endpoint"] = arguments.src_endpoint + if arguments.src_access_token: + overrides["src_access_token"] = arguments.src_access_token + return load_config( + SweepSourcegraphConfig, + env_file=arguments.env_file, + cli_overrides=overrides, + base_dir=Path.cwd(), + resolve_op_refs=True, + require=("src_access_token",), + ) + + +def parse_cases(raw_cases: str) -> list[SweepCase] | None: + if raw_cases.strip().lower() == "auto": + return None + cases: list[SweepCase] = [] + for raw_case in raw_cases.split(","): + case = raw_case.strip().lower() + if not case: + continue + users_text, separator, repos_text = case.partition("x") + if not separator: + raise SystemExit(f"Invalid case {raw_case!r}; expected USERSxREPOS") + try: + users = int(users_text) + repos = int(repos_text) + except ValueError as error: + raise SystemExit(f"Invalid case {raw_case!r}; counts must be integers") from error + if users < 1 or repos < 1: + raise SystemExit(f"Invalid case {raw_case!r}; counts must be >= 1") + cases.append(SweepCase(users=users, repos=repos)) + if not cases: + raise SystemExit("At least one --cases entry is required") + return cases + + +def parse_shapes(raw_shapes: str, suite: SweepSuite) -> tuple[StressShape, ...]: + """Return workload shapes requested by the operator.""" + if raw_shapes.strip().lower() == "auto": + return BREAKING_SHAPES if suite == "breaking" else ("rectangle",) + + valid_shapes = set(STRESS_SHAPES) + shapes: list[StressShape] = [] + for raw_shape in raw_shapes.split(","): + shape = raw_shape.strip().lower() + if not shape: + continue + if shape not in valid_shapes: + raise SystemExit( + f"Invalid shape {raw_shape!r}; expected one of {', '.join(STRESS_SHAPES)}" + ) + shapes.append(shape) + if not shapes: + raise SystemExit("At least one --shapes entry is required") + return tuple(dict.fromkeys(shapes)) + + +def default_cases_for_inventory( + user_count: int, repo_count: int, *, suite: SweepSuite +) -> list[SweepCase]: + """Return an inventory-aware sweep that covers user, repo, and grant axes.""" + if user_count < 1: + raise SystemExit("Need at least one Sourcegraph user for an auto sweep") + if repo_count < 1: + raise SystemExit("Need at least one Sourcegraph repo for an auto sweep") + if suite == "breaking": + return breaking_cases_for_inventory(user_count, repo_count) + + user_points = bounded_points(user_count, DEFAULT_USER_POINTS) + repo_points = bounded_points(repo_count, DEFAULT_REPO_POINTS) + cases: list[SweepCase] = [SweepCase(users=users, repos=1) for users in user_points] + cases.extend(SweepCase(users=1, repos=repos) for repos in repo_points if repos != 1) + + for users, repos in ( + (1000, 10), + (10000, 10), + (1000, 100), + (100, 1000), + ): + if users <= user_count and repos <= repo_count: + cases.append(SweepCase(users=users, repos=repos)) + + return unique_cases(cases) + + +def breaking_cases_for_inventory(user_count: int, repo_count: int) -> list[SweepCase]: + """Return larger cases ordered from likely-safe to likely-breaking.""" + capped_users = min(user_count, 10000) + capped_repos = min(repo_count, 50000) + candidate_dimensions = ( + (1, capped_repos), + (100, capped_repos), + (1000, min(capped_repos, 1000)), + (capped_users, 1), + (capped_users, 100), + (capped_users, 1000), + (capped_users, 5000), + (capped_users, 10000), + (capped_users, 25000), + (capped_users, capped_repos), + ) + cases = [ + SweepCase(users=users, repos=repos) + for users, repos in candidate_dimensions + if users <= user_count and repos <= repo_count + ] + return unique_cases(cases) + + +def expand_cases( + base_cases: Sequence[SweepCase], shapes: Sequence[StressShape], rule_count: int +) -> list[SweepCase]: + """Expand dimensions into the requested workload shapes.""" + cases: list[SweepCase] = [] + for base_case in base_cases: + for shape in shapes: + cases.append( + SweepCase( + users=base_case.users, + repos=base_case.repos, + shape=shape, + rule_count=rule_count if shape != "rectangle" else 1, + ) + ) + return unique_cases(cases) + + +def bounded_points(available_count: int, candidate_points: Sequence[int]) -> list[int]: + """Return candidate points that fit, plus the exact inventory cap if useful.""" + points = [point for point in candidate_points if point <= available_count] + if available_count not in points and available_count < candidate_points[-1]: + points.append(available_count) + return sorted(set(points)) + + +def unique_cases(cases: Sequence[SweepCase]) -> list[SweepCase]: + """Preserve case order while removing duplicates.""" + seen: set[tuple[int, int, StressShape, int]] = set() + unique: list[SweepCase] = [] + for sweep_case in cases: + key = (sweep_case.users, sweep_case.repos, sweep_case.shape, sweep_case.rule_count) + if key in seen: + continue + seen.add(key) + unique.append(sweep_case) + return unique + + +def list_external_services(client: src.SourcegraphClient) -> list[ExternalServiceChoice]: + services: list[ExternalServiceChoice] = [] + for node in client.stream_connection_nodes( + QUERY_EXTERNAL_SERVICES, + variables={"first": 100, "after": None}, + connection_path=("externalServices",), + page_size=100, + ): + service = cast(dict[str, Any], node) + graphql_id = str(service["id"]) + services.append( + ExternalServiceChoice( + graphql_id=graphql_id, + database_id=src.decode_external_service_id(graphql_id), + display_name=str(service.get("displayName") or ""), + kind=str(service.get("kind") or ""), + url=str(service.get("url") or ""), + repo_count=int(service.get("repoCount") or 0), + ) + ) + if not services: + raise SystemExit("No external services found on the Sourcegraph instance") + return services + + +def choose_external_service( + services: list[ExternalServiceChoice], requested_id: int | None +) -> ExternalServiceChoice: + if requested_id is not None: + for service in services: + if service.database_id == requested_id: + return service + raise SystemExit(f"External service id {requested_id} was not found") + return max(services, key=lambda service: service.repo_count) + + +def list_usernames(client: src.SourcegraphClient, count: int, page_size: int) -> list[str]: + usernames: list[str] = [] + for node in client.stream_connection_nodes( + QUERY_USERNAMES, + connection_path=("users",), + page_size=page_size, + ): + username = node.get("username") + if isinstance(username, str) and username: + usernames.append(username) + if len(usernames) >= count: + break + if len(usernames) < count: + raise SystemExit(f"Need {count} users but discovered only {len(usernames)}") + return usernames + + +def count_users(client: src.SourcegraphClient) -> int: + """Return total users on the Sourcegraph instance.""" + data = client.graphql(QUERY_USER_COUNT) + users = cast(dict[str, Any], data.get("users") or {}) + total_count = users.get("totalCount") + if not isinstance(total_count, int): + raise SystemExit("CountUsers response did not include users.totalCount") + return total_count + + +def list_repo_names( + client: src.SourcegraphClient, + service: ExternalServiceChoice, + count: int, + page_size: int, +) -> list[str]: + repo_names: list[str] = [] + for node in client.stream_connection_nodes( + QUERY_REPOS_BY_EXTERNAL_SERVICE, + variables={"externalService": service.graphql_id}, + connection_path=("repositories",), + page_size=page_size, + ): + name = node.get("name") + if isinstance(name, str) and name: + repo_names.append(name) + if len(repo_names) >= count: + break + if len(repo_names) < count: + raise SystemExit( + f"Need {count} repos from external service id={service.database_id} " + f"but discovered only {len(repo_names)}" + ) + return repo_names + + +def map_rule_count(sweep_case: SweepCase) -> int: + """Return the actual map-rule count for this case.""" + if sweep_case.shape == "rectangle": + return 1 + if sweep_case.shape == "duplicate-rules": + return sweep_case.rule_count + if sweep_case.shape == "user-shards": + return min(sweep_case.users, sweep_case.rule_count) + if sweep_case.shape == "repo-shards": + return min(sweep_case.repos, sweep_case.rule_count) + return min(sweep_case.users, sweep_case.repos, sweep_case.rule_count) + + +def user_selector_count(sweep_case: SweepCase) -> int: + """Return the number of user selectors emitted across all map rules.""" + return map_rule_count(sweep_case) + + +def repository_selector_count(sweep_case: SweepCase) -> int: + """Return the number of repository selectors emitted across all map rules.""" + return map_rule_count(sweep_case) + + +def unique_grant_count(sweep_case: SweepCase) -> int: + """Return final unique grants after unioning all map entries.""" + if sweep_case.shape == "diagonal-shards": + return sum( + user_count * repo_count + for user_count, repo_count in zip( + chunk_lengths(sweep_case.users, map_rule_count(sweep_case)), + chunk_lengths(sweep_case.repos, map_rule_count(sweep_case)), + strict=True, + ) + ) + return sweep_case.users * sweep_case.repos + + +def raw_rule_grant_count(sweep_case: SweepCase) -> int: + """Return total per-rule grants before cross-rule unioning.""" + if sweep_case.shape == "duplicate-rules": + return sweep_case.users * sweep_case.repos * sweep_case.rule_count + return unique_grant_count(sweep_case) + + +def chunk_lengths(total: int, chunk_count: int) -> list[int]: + """Return near-even chunk lengths for `total` items.""" + if chunk_count < 1: + raise ValueError("chunk_count must be >= 1") + base, extra = divmod(total, chunk_count) + return [base + (1 if index < extra else 0) for index in range(chunk_count)] + + +def chunked_values(values: Sequence[str], chunk_count: int) -> list[list[str]]: + """Split values into near-even non-empty chunks.""" + lengths = chunk_lengths(len(values), chunk_count) + chunks: list[list[str]] = [] + offset = 0 + for length in lengths: + if length < 1: + continue + chunks.append(list(values[offset : offset + length])) + offset += length + return chunks + + +def write_maps( + maps_dir: Path, + cases: Sequence[SweepCase], + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> list[GeneratedMap]: + generated: list[GeneratedMap] = [] + for sweep_case in cases: + map_path = maps_dir / f"maps-{sweep_case.name}.yaml" + rules = map_rules_for_case( + sweep_case, + usernames[: sweep_case.users], + repo_names[: sweep_case.repos], + service, + ) + payload = { + "maps": rules, + } + with map_path.open("w", encoding="utf-8") as output_file: + output_file.write( + "# Generated by dev/run-memory-model-sweep.py; safe to delete/regenerate.\n" + ) + output_file.write( + f"# users={sweep_case.users} repos={sweep_case.repos} " + f"planned_grants={sweep_case.grants} " + f"raw_rule_grants={sweep_case.raw_rule_grants} " + f"shape={sweep_case.shape} map_rules={sweep_case.map_rule_count}\n" + ) + yaml.safe_dump(payload, output_file, sort_keys=False, allow_unicode=True) + generated.append(GeneratedMap(case=sweep_case, path=map_path)) + return generated + + +def map_rules_for_case( + sweep_case: SweepCase, + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> list[dict[str, object]]: + """Build map rules for one workload shape.""" + if sweep_case.shape == "rectangle": + return [map_rule(sweep_case, 1, usernames, repo_names, service)] + if sweep_case.shape == "user-shards": + return [ + map_rule(sweep_case, index, user_chunk, repo_names, service) + for index, user_chunk in enumerate( + chunked_values(usernames, sweep_case.map_rule_count), start=1 + ) + ] + if sweep_case.shape == "repo-shards": + return [ + map_rule(sweep_case, index, usernames, repo_chunk, service) + for index, repo_chunk in enumerate( + chunked_values(repo_names, sweep_case.map_rule_count), start=1 + ) + ] + if sweep_case.shape == "diagonal-shards": + return [ + map_rule(sweep_case, index, user_chunk, repo_chunk, service) + for index, (user_chunk, repo_chunk) in enumerate( + zip( + chunked_values(usernames, sweep_case.map_rule_count), + chunked_values(repo_names, sweep_case.map_rule_count), + strict=True, + ), + start=1, + ) + ] + if sweep_case.shape == "duplicate-rules": + return [ + map_rule(sweep_case, index, usernames, repo_names, service) + for index in range(1, sweep_case.map_rule_count + 1) + ] + raise AssertionError(f"Unhandled shape {sweep_case.shape!r}") + + +def map_rule( + sweep_case: SweepCase, + index: int, + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> dict[str, object]: + """Build one rectangular map rule.""" + return { + "name": f"memory model {sweep_case.shape} rule {index}/{sweep_case.map_rule_count}", + "users": username_selector(usernames), + "repos": repository_selector(repo_names, service), + } + + +def username_selector(usernames: Sequence[str]) -> dict[str, object]: + return {"usernames": list(usernames)} + + +def repository_selector( + repo_names: Sequence[str], service: ExternalServiceChoice +) -> dict[str, object]: + return { + "codeHostConnection": { + "kind": service.kind, + "displayName": service.display_name, + "url": service.url, + }, + "names": list(repo_names), + } + + +def write_manifest( + output_dir: Path, + generated_maps: Sequence[GeneratedMap], + service: ExternalServiceChoice, + endpoint: str, + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + manifest = { + "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "endpoint": endpoint, + "external_service": service_to_json(service), + "sourcegraph_user_count": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + "maps": [ + { + "case": generated_map.case.name, + "shape": generated_map.case.shape, + "map_rule_count": generated_map.case.map_rule_count, + "user_selector_count": user_selector_count(generated_map.case), + "repository_selector_count": repository_selector_count(generated_map.case), + "selected_user_count": generated_map.case.users, + "selected_repo_count": generated_map.case.repos, + "selected_total_grants": generated_map.case.grants, + "raw_rule_grant_count": generated_map.case.raw_rule_grants, + "users": generated_map.case.users, + "repos": generated_map.case.repos, + "grants": generated_map.case.grants, + "path": str(generated_map.path), + } + for generated_map in generated_maps + ], + } + write_json(output_dir / "manifest.json", manifest) + + +def run_sweep( + generated_maps: Sequence[GeneratedMap], + *, + endpoint: str, + access_token: str, + output_dir: Path, + command: str, + mode: RunMode, + parallelism: int, + explicit_permissions_batch_size: int, + http_timeout_seconds: float, + sample_interval: float, + trace: bool, + sourcegraph_user_count: int, + sourcegraph_inventory_repo_count: int, +) -> list[CommandRunResult]: + results: list[CommandRunResult] = [] + for generated_map in generated_maps: + print(f"Running {generated_map.case.name} ...", flush=True) + started = time.monotonic() + process_output_path = output_dir / f"{generated_map.case.name}.out" + arguments = command_arguments( + command, + generated_map.path, + mode=mode, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + http_timeout_seconds=http_timeout_seconds, + sample_interval=sample_interval, + trace=trace, + ) + environment = command_environment(endpoint, access_token) + process = subprocess.run( + arguments, + cwd=Path.cwd(), + env=environment, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=False, + ) + elapsed_seconds = time.monotonic() - started + process_output_path.write_text(process.stdout, encoding="utf-8") + log_path = log_path_from_output(process.stdout) + run_record = read_run_record(log_path) + result = CommandRunResult( + generated_map=generated_map, + mode=mode, + return_code=process.returncode, + elapsed_seconds=elapsed_seconds, + output_path=process_output_path, + log_path=log_path, + run_record=run_record, + ) + results.append(result) + write_results( + output_dir, + results, + inventory_repo_count=sourcegraph_inventory_repo_count, + sourcegraph_user_count=sourcegraph_user_count, + ) + print( + f" return_code={process.returncode} " + f"peak_rss_mb={memory_peak(result.run_record)} " + f"output={process_output_path}", + flush=True, + ) + if process.returncode != 0: + print("Stopping after first failed case.", file=sys.stderr) + break + return results + + +def command_arguments( + command: str, + map_path: Path, + *, + mode: RunMode, + parallelism: int, + explicit_permissions_batch_size: int, + http_timeout_seconds: float, + sample_interval: float, + trace: bool, +) -> list[str]: + arguments = [ + *shlex.split(command), + "set", + "--maps-path", + str(map_path.resolve()), + "--full", + "--parallelism", + str(parallelism), + "--explicit-permissions-batch-size", + str(explicit_permissions_batch_size), + "--http-timeout-seconds", + f"{http_timeout_seconds:g}", + "--sample-interval", + f"{sample_interval:g}", + ] + if mode != "dry-run": + arguments.append("--apply") + if mode == "apply-no-backup": + arguments.append("--no-backup") + if trace: + arguments.append("--trace") + return arguments + + +def result_arguments(map_path: Path, mode: RunMode) -> list[str]: + """Return the CLI argument shape captured in results.json.""" + arguments = ["set", "--maps-path", str(map_path), "--full"] + if mode != "dry-run": + arguments.append("--apply") + if mode == "apply-no-backup": + arguments.append("--no-backup") + return arguments + + +def command_environment(endpoint: str, access_token: str) -> dict[str, str]: + environment = dict(os.environ) + environment["SRC_ENDPOINT"] = endpoint + environment["SRC_ACCESS_TOKEN"] = access_token + return environment + + +def log_path_from_output(output: str) -> Path | None: + match = LOG_PATH_PATTERN.search(output) + return Path(match.group(1)) if match else None + + +def read_run_record(log_path: Path | None) -> dict[str, Any] | None: + if log_path is None or not log_path.exists(): + return None + run_record: dict[str, Any] | None = None + with log_path.open(encoding="utf-8") as input_file: + for line in input_file: + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + if not isinstance(record, dict): + continue + record_mapping = cast(dict[str, object], record) + if record_mapping.get("event") == "run" and record_mapping.get("phase") == "end": + run_record = cast(dict[str, Any], record_mapping) + return run_record + + +def write_results( + output_dir: Path, + results: Sequence[CommandRunResult], + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + result_payload = { + "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "results": [ + result_to_json(result, inventory_repo_count, sourcegraph_user_count) + for result in results + ], + "comparisons": [], + } + write_json(output_dir / "results.json", result_payload) + write_results_csv( + output_dir / "results.csv", + results, + inventory_repo_count, + sourcegraph_user_count, + ) + + +def result_to_json( + result: CommandRunResult, inventory_repo_count: int, sourcegraph_user_count: int +) -> dict[str, Any]: + run_record = result.run_record or {} + peak_rss_mb = memory_peak(result.run_record) + case = result.generated_map.case + return { + "variant": "candidate", + "iteration": 1, + "case": case.name, + "shape": case.shape, + "run_mode": result.mode, + "arguments": result_arguments(result.generated_map.path, result.mode), + "return_code": result.return_code, + "elapsed_seconds": round(result.elapsed_seconds, 3), + "log_path": str(result.log_path) if result.log_path else None, + "run_directory": str(result.log_path.parent) if result.log_path else None, + "command": run_record.get("command") or "set_full", + "status": run_record.get("status"), + "jaeger_traces": [], + "memory": { + "peak_rss_mb": peak_rss_mb, + "sampled_peak_rss_mb": None, + "external_peak_rss_mb": None, + "resource_sample_count": 0, + "external_sample_count": 0, + "max_num_fds": run_record.get("num_fds"), + "max_num_threads": run_record.get("num_threads"), + "max_process_cpu_percent": None, + }, + "phase_memory": [], + "artifact_sizes": {}, + "workload": workload_json(case, inventory_repo_count, sourcegraph_user_count), + } + + +def workload_json( + sweep_case: SweepCase, inventory_repo_count: int, sourcegraph_user_count: int +) -> dict[str, int]: + return { + "selected_user_count": sweep_case.users, + "selected_repo_count": sweep_case.repos, + "selected_total_grants": sweep_case.grants, + "raw_rule_grant_count": sweep_case.raw_rule_grants, + "map_rule_count": sweep_case.map_rule_count, + "user_selector_count": user_selector_count(sweep_case), + "repository_selector_count": repository_selector_count(sweep_case), + "memory_model_user_count": sweep_case.users, + "memory_model_repo_count": sweep_case.repos, + "memory_model_grant_count": sweep_case.grants, + "sourcegraph_user_count": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + } + + +def write_results_csv( + path: Path, + results: Sequence[CommandRunResult], + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + fieldnames = [ + "case", + "shape", + "run_mode", + "map_rule_count", + "raw_rule_grants", + "users", + "repos", + "grants", + "sourcegraph_users_discovered", + "sourcegraph_inventory_repo_count", + "return_code", + "elapsed_seconds", + "peak_rss_mb", + "log_path", + "map_path", + "output_path", + ] + with path.open("w", encoding="utf-8", newline="") as output_file: + writer = csv.DictWriter(output_file, fieldnames=fieldnames) + writer.writeheader() + for result in results: + case = result.generated_map.case + writer.writerow( + { + "case": case.name, + "shape": case.shape, + "run_mode": result.mode, + "map_rule_count": case.map_rule_count, + "raw_rule_grants": case.raw_rule_grants, + "users": case.users, + "repos": case.repos, + "grants": case.grants, + "sourcegraph_users_discovered": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + "return_code": result.return_code, + "elapsed_seconds": f"{result.elapsed_seconds:.3f}", + "peak_rss_mb": memory_peak(result.run_record) or "", + "log_path": str(result.log_path) if result.log_path else "", + "map_path": str(result.generated_map.path), + "output_path": str(result.output_path), + } + ) + + +def memory_peak(run_record: Mapping[str, Any] | None) -> float | None: + if run_record is None: + return None + value = run_record.get("peak_rss_mb") + return float(value) if isinstance(value, int | float) else None + + +def write_json(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as output_file: + json.dump(payload, output_file, indent=2, sort_keys=True) + output_file.write("\n") + + +def service_to_json(service: ExternalServiceChoice) -> dict[str, object]: + return { + "graphql_id": service.graphql_id, + "database_id": service.database_id, + "display_name": service.display_name, + "kind": service.kind, + "url": service.url, + "repo_count": service.repo_count, + } + + +def default_output_dir(endpoint: str) -> Path: + host = urlsplit(endpoint).hostname or "sourcegraph" + safe_host = re.sub(r"[^A-Za-z0-9_.-]+", "-", host) + timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d-%H-%M-%S") + return Path("src-auth-perms-sync-runs") / safe_host / "memory-model-sweep" / timestamp + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dev/memory-efficiency-monitor-sourcegraph.sh b/dev/memory-efficiency-monitor-sourcegraph.sh new file mode 100755 index 0000000..07fab15 --- /dev/null +++ b/dev/memory-efficiency-monitor-sourcegraph.sh @@ -0,0 +1,348 @@ +#!/usr/bin/env bash +set -euo pipefail + +namespace="${SRC_AUTH_PERMS_SYNC_MONITOR_NAMESPACE:-m}" +interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_INTERVAL_SECONDS:-5}" +postgres_interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_POSTGRES_INTERVAL_SECONDS:-10}" +statements_interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_STATEMENTS_INTERVAL_SECONDS:-30}" +duration_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_DURATION_SECONDS:-}" +output_dir="${SRC_AUTH_PERMS_SYNC_MONITOR_OUTPUT_DIR:-}" +frontend_target="${SRC_AUTH_PERMS_SYNC_MONITOR_FRONTEND_TARGET:-deployment/sourcegraph-frontend}" +postgres_target="${SRC_AUTH_PERMS_SYNC_MONITOR_POSTGRES_TARGET:-pod/pgsql-0}" +kubectl_bin="${KUBECTL:-kubectl}" +psql_command="${SRC_AUTH_PERMS_SYNC_MONITOR_PSQL_COMMAND:-psql -X -U sg -d sg}" +stream_logs=true + +usage() { + cat <<'EOF' +Usage: dev/memory-efficiency-monitor-sourcegraph.sh [options] + +Collect timestamped Sourcegraph pod load evidence while the e2e script runs. +Press Ctrl-C to stop, or pass --duration-seconds. + +Options: + --namespace NAME Kubernetes namespace (default: m) + --interval-seconds N Pod/process/cgroup sample interval (default: 5) + --postgres-interval-seconds N pg_stat_activity sample interval (default: 10) + --statements-interval-seconds N pg_stat_statements sample interval (default: 30) + --duration-seconds N Stop automatically after N seconds + --output-dir PATH Output directory (default: /tmp/src-auth-perms-sync-sourcegraph-load-) + --frontend-target TARGET kubectl target for frontend (default: deployment/sourcegraph-frontend) + --postgres-target TARGET kubectl target for Postgres (default: pod/pgsql-0) + --psql-command COMMAND Command to run inside Postgres pod (default: psql -X -U sg -d sg) + --no-logs Do not stream frontend logs + -h, --help Show this help + +Examples: + dev/memory-efficiency-monitor-sourcegraph.sh + + dev/memory-efficiency-monitor-sourcegraph.sh \ + --duration-seconds 1800 \ + --output-dir /tmp/src-auth-perms-sync-load-$(date -u +%Y%m%d-%H%M%S) + +In another terminal, run: + uv run python dev/test-end-to-end.py --trace --sample-interval 0 --external-sample-interval 0 +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --namespace) + namespace="$2" + shift 2 + ;; + --interval-seconds) + interval_seconds="$2" + shift 2 + ;; + --postgres-interval-seconds) + postgres_interval_seconds="$2" + shift 2 + ;; + --statements-interval-seconds) + statements_interval_seconds="$2" + shift 2 + ;; + --duration-seconds) + duration_seconds="$2" + shift 2 + ;; + --output-dir) + output_dir="$2" + shift 2 + ;; + --frontend-target) + frontend_target="$2" + shift 2 + ;; + --postgres-target) + postgres_target="$2" + shift 2 + ;; + --psql-command) + psql_command="$2" + shift 2 + ;; + --no-logs) + stream_logs=false + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 2 + ;; + esac +done + +if [[ -z "${output_dir}" ]]; then + output_dir="/tmp/src-auth-perms-sync-sourcegraph-load-$(date -u +%Y%m%d-%H%M%S)" +fi +mkdir -p "${output_dir}" + +end_epoch="" +if [[ -n "${duration_seconds}" ]]; then + end_epoch="$(( $(date +%s) + duration_seconds ))" +fi + +pids=() + +timestamp() { + date -u +%Y-%m-%dT%H:%M:%SZ +} + +should_continue() { + [[ -z "${end_epoch}" || "$(date +%s)" -lt "${end_epoch}" ]] +} + +append_header() { + local file="$1" + local title="$2" + { + printf '\n===== %s %s =====\n' "$(timestamp)" "${title}" + } >>"${file}" +} + +run_sample_loop() { + local name="$1" + local sleep_seconds="$2" + local pid + shift 2 + ( + while should_continue; do + "$@" || true + sleep "${sleep_seconds}" + done + ) & + pid="$!" + pids+=("${pid}") + echo "Started ${name} sampler: pid=${pid} interval=${sleep_seconds}s" +} + +run_stream() { + local name="$1" + local pid + shift + ( + "$@" || true + ) & + pid="$!" + pids+=("${pid}") + echo "Started ${name} stream: pid=${pid}" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + if [[ ${#pids[@]} -gt 0 ]]; then + kill "${pids[@]}" 2>/dev/null || true + wait "${pids[@]}" 2>/dev/null || true + fi + echo "Stopped Sourcegraph load monitor. Output: ${output_dir}" + exit "${status}" +} + +trap cleanup EXIT INT TERM + +kubectl_exec() { + local target="$1" + shift + "${kubectl_bin}" exec -n "${namespace}" "${target}" -- "$@" +} + +kubectl_exec_stdin() { + local target="$1" + shift + "${kubectl_bin}" exec -i -n "${namespace}" "${target}" -- "$@" +} + +prepare_pg_stat_statements() { + local file="${output_dir}/postgres-statements-setup.log" + append_header "${file}" "create pg_stat_statements extension and reset stats" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select current_database(), current_user; +show shared_preload_libraries; +show track_io_timing; +create extension if not exists pg_stat_statements; +select pg_stat_statements_reset(); +SQL +} + +sample_kubectl_top() { + local file="${output_dir}/kubectl-top-pods-containers.log" + append_header "${file}" "kubectl top pods --containers" + "${kubectl_bin}" top pods -n "${namespace}" --containers >>"${file}" 2>&1 || true +} + +sample_frontend_processes() { + local file="${output_dir}/frontend-processes.log" + append_header "${file}" "${frontend_target} process CPU/RSS" + kubectl_exec "${frontend_target}" sh -lc ' + echo "--- top CPU ---" + ps auxww | sort -nrk3 | head -30 + echo "--- top RSS ---" + ps auxww | sort -nrk4 | head -30 + ' >>"${file}" 2>&1 || true +} + +sample_postgres_processes() { + local file="${output_dir}/postgres-processes.log" + append_header "${file}" "${postgres_target} process CPU/RSS" + kubectl_exec "${postgres_target}" sh -lc ' + echo "--- top CPU ---" + ps auxww | sort -nrk3 | head -30 + echo "--- top RSS ---" + ps auxww | sort -nrk4 | head -30 + ' >>"${file}" 2>&1 || true +} + +sample_cgroups() { + local file="${output_dir}/cgroups.log" + append_header "${file}" "cgroup CPU/memory" + for target in "${frontend_target}" "${postgres_target}"; do + { + echo "--- ${target} ---" + kubectl_exec "${target}" sh -lc ' + echo "cpu.stat" + cat /sys/fs/cgroup/cpu.stat 2>/dev/null || cat /sys/fs/cgroup/cpu/cpu.stat 2>/dev/null || true + echo "memory.current" + cat /sys/fs/cgroup/memory.current 2>/dev/null || cat /sys/fs/cgroup/memory/memory.usage_in_bytes 2>/dev/null || true + echo "memory.events" + cat /sys/fs/cgroup/memory.events 2>/dev/null || true + echo "memory.max" + cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null || true + ' + } >>"${file}" 2>&1 || true + done +} + +sample_postgres_activity() { + local file="${output_dir}/postgres-activity.log" + append_header "${file}" "pg_stat_activity, waits, locks" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select + pid, + now() - query_start as age, + state, + wait_event_type, + wait_event, + left(query, 220) as query +from pg_stat_activity +where state <> 'idle' +order by age desc +limit 30; + +select + wait_event_type, + wait_event, + state, + count(*) +from pg_stat_activity +group by 1,2,3 +order by count(*) desc; + +select + locktype, + mode, + granted, + count(*) +from pg_locks +group by 1,2,3 +order by count(*) desc; +SQL +} + +sample_pg_stat_statements() { + local file="${output_dir}/postgres-statements.log" + append_header "${file}" "pg_stat_statements top total_exec_time" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select + calls, + round(total_exec_time::numeric, 1) as total_ms, + round(mean_exec_time::numeric, 1) as mean_ms, + rows, + left(query, 260) as query +from pg_stat_statements +order by total_exec_time desc +limit 25; +SQL +} + +snapshot_pod_descriptions() { + local file="${output_dir}/pod-descriptions.log" + append_header "${file}" "kubectl describe selected targets" + "${kubectl_bin}" describe -n "${namespace}" "${frontend_target}" >>"${file}" 2>&1 || true + "${kubectl_bin}" describe -n "${namespace}" "${postgres_target}" >>"${file}" 2>&1 || true +} + +stream_frontend_logs() { + "${kubectl_bin}" logs -n "${namespace}" "${frontend_target}" --since=1m --timestamps -f \ + >"${output_dir}/frontend.log" 2>"${output_dir}/frontend-log-errors.log" +} + +stream_frontend_error_logs() { + "${kubectl_bin}" logs -n "${namespace}" "${frontend_target}" --since=1m --timestamps -f 2>/dev/null \ + | grep -Ei 'timeout|deadline|database|postgres|graphql|error|slow|cancel' \ + >"${output_dir}/frontend-errors-filtered.log" || true +} + +cat >"${output_dir}/metadata.txt" </runs//log.json @@ -28,13 +29,15 @@ jq -r ' select(.url | endswith("/.api/graphql")) | [ .duration_ms, - .request_headers.traceparent, - (.response_headers["x-trace-url"] // "") + (.response_headers["x-trace"] // ""), + (.response_headers["x-trace-url"] // ""), + (.request_headers.traceparent // "") ] | @tsv ' "$LOG" | sort -nr | head -20 ``` -Extract the W3C trace ID from a `traceparent` value: +Prefer the `x-trace` value when present. If Sourcegraph did not return one, +extract the trace ID from `traceparent`: ```bash TRACEPARENT=00---01 @@ -50,15 +53,9 @@ curl -sS \ > /tmp/sourcegraph-trace.json ``` -Jaeger ingestion can lag by a few seconds. If the API returns `trace not -found`, wait briefly and retry the same URL. - -For long runs such as `dev/test-end-to-end.py --trace`, fetch the -slow traces as soon as the relevant command finishes, or rerun a focused case -and fetch those traces immediately. On the sgdev test instance, a fully traced -end-to-end run can emit thousands of sampled traces; the in-memory Jaeger data -may evict or restart before the whole matrix finishes, returning `trace not -found` or temporary 502s for earlier trace IDs. +Jaeger ingestion can lag. If the API returns `trace not found`, wait briefly +and retry. For long runs, fetch traces as soon as the relevant command +finishes; older trace IDs can disappear before a full matrix ends. Summarize the hottest spans: @@ -85,18 +82,14 @@ for operation, durations in sorted( PY ``` -Do not commit tokens, customer URLs, or raw trace files. Keep trace JSON and -benchmark CSVs in `/tmp` unless a human asks to preserve them. +Do not commit tokens, customer URLs, raw trace JSON, benchmark CSVs, or monitor +artifacts. Keep them in `/tmp` unless a human asks to preserve them. -## Evidence to collect +## Trace the end-to-end matrix -To trace the full integration matrix, run the end-to-end script with its own -`--trace` flag. The runner forwards it to every child CLI invocation, then -tails each child run log and fetches all traced GraphQL Jaeger traces in the -background while that child command is still running. The runner uses -`src-py-lib` Config parsing, logging, Sourcegraph endpoint normalization, -`SourcegraphClient.fetch_jaeger_trace_summary()`, and a shared HTTP pool, so -trace summary and retry behavior match the CLI's Sourcegraph client: +Prefer the end-to-end runner as the single orchestrator. With `--trace`, it +passes tracing to every child CLI command, tails child JSON logs, and fetches +Jaeger traces in the background while each child command is still running. ```bash uv run python dev/test-end-to-end.py \ @@ -107,22 +100,21 @@ uv run python dev/test-end-to-end.py \ --results-csv /tmp/src-auth-perms-sync-end-to-end-trace.csv ``` -Use `--jaeger-trace-limit N` to fetch only the `N` slowest GraphQL traces per -case, or `--jaeger-trace-limit 0` to disable in-run Jaeger fetching while still -sending sampled trace headers. The default is to fetch every traced GraphQL -request. +Useful trace options: + +- `--jaeger-trace-limit N`: fetch only the `N` slowest GraphQL traces per case. +- `--jaeger-trace-limit 0`: send trace headers but skip Jaeger fetching. +- `--jaeger-trace-parallelism N`: tune concurrent Jaeger fetches. +- `--jaeger-trace-jsonl PATH`: stream compact trace summaries as JSON Lines. +- `--jaeger-trace-dir PATH`: store complete raw Jaeger payloads. -The runner writes trace summaries incrementally as JSON Lines. By default, it -uses a sibling of `--results-json` or `--results-csv`, named -`*-jaeger-traces.jsonl`. Override this with `--jaeger-trace-jsonl PATH`. +Raw trace files include: -The shared `src-py-lib` `stream_jaeger_trace_summaries()` helper now fetches in -parallel for in-process Sourcegraph clients. The end-to-end script still uses a -bounded global worker pool because the traced requests happen in child -processes and are discovered by tailing their JSON logs. Tune this with -`--jaeger-trace-parallelism N` (default 16). The runner drains outstanding -background collectors once at the end, before it writes JSON/CSV results, so -Jaeger collection does not add a blocking phase between child cases. +- `trace_request`: CLI-side HTTP and `graphql_query` correlation metadata, + including query name, page number, page size, cursor presence, query byte + count, variable names, response fields, status, and timing. +- `jaeger_summary`: compact hot-operation and GraphQL-operation summary. +- `jaeger_trace`: the complete Jaeger trace JSON returned by Sourcegraph. All runner flags are Config-backed. You can set them in the shell or `.env` with `SRC_AUTH_PERMS_SYNC_E2E_*` names, plus `SRC_ENDPOINT`, @@ -131,33 +123,59 @@ with `SRC_AUTH_PERMS_SYNC_E2E_*` names, plus `SRC_ENDPOINT`, For each tested batch size and parallelism, record: - CLI `capture_explicit_grants` duration from the structured log -- slowest `http_request` duration and its `x-trace` / `traceparent` metadata +- slowest GraphQL `http_request` duration and its trace metadata - Jaeger counts and summed duration for `GraphQL Request`, `repos.Get`, `sql.conn.query`, and `database.PermsStore.LoadUserPermissions` -- retries/timeouts from the CLI log - -In a traced sgdev end-to-end run after the matrix was trimmed to avoid -overlapping code paths, all 36 cases passed. Child command time summed to about -1,126 seconds. The JSONL trace summary file contained 3,256 GraphQL trace -lookups, but Jaeger returned only 26 summaries; most lookups returned `trace -not found`. The expensive cases were still dominated by full snapshot capture -and full apply / restore paths: - -| Case | Elapsed | GraphQL requests | Slowest GraphQL request | Dominant phase | -| --- | ---: | ---: | ---: | --- | -| `restore-full-apply-cleanup` | 234s | 913 | 3.2s | `capture_explicit_grants` / restore | -| `set-full-apply` | 214s | 917 | 3.2s | `capture_explicit_grants` / apply | -| `restore-full-no-backup-cleanup` | 135s | 510 | 3.2s | `capture_explicit_grants` / restore | -| `set-full-no-backup-apply` | 129s | 129 | 1.2s | apply mutations | -| `get-sync-saml-orgs-dry-run` | 116s | 510 | 3.2s | `capture_explicit_grants` | - -Fetch Jaeger traces immediately for long runs. In that same full matrix, older -trace IDs were no longer available by the time the run finished. Focused reruns -with immediate fetches gave stable Jaeger data. - -For current `src-auth-perms-sync`, `UserExplicitReposBatch` requests only repo -IDs from `User.permissionsInfo.repositories(source: API)`. A focused traced -batch for one user with 19 explicit repos showed per-user fanout: +- run-end `http_retry_count`, `http_request_attempt_count`, and timeout/error + counts + +## Monitor Sourcegraph load during e2e runs + +The runner can start the Sourcegraph pod/Postgres monitor and write monitor +artifact paths into the result JSON: + +```bash +uv run python dev/test-end-to-end.py \ + --trace \ + --monitor-sourcegraph-load \ + --sample-interval 0 \ + --external-sample-interval 0 \ + --results-json /tmp/src-auth-perms-sync-end-to-end-trace.json \ + --results-csv /tmp/src-auth-perms-sync-end-to-end-trace.csv +``` + +By default, monitor output is written beside `--results-json` or +`--results-csv` as `*-sourcegraph-load`, and the monitor's stdout/stderr goes +to `*-sourcegraph-load.log`. Override the location with +`--monitor-output-dir PATH`. Tune Kubernetes targets and sample intervals with +the `--monitor-*` flags if the test namespace or pod names differ. + +The lower-level helper remains available for focused profiling outside a full +e2e run: + +```bash +dev/memory-efficiency-monitor-sourcegraph.sh \ + --namespace m \ + --output-dir /tmp/src-auth-perms-sync-sourcegraph-load-$(date -u +%Y%m%d-%H%M%S) +``` + +Stop the helper with Ctrl-C, or add `--duration-seconds N`. It samples +Kubernetes CPU/memory, frontend and Postgres processes, cgroup CPU/memory +pressure, Postgres active queries/waits/locks, `pg_stat_statements` when +enabled, and frontend logs. On startup it runs `CREATE EXTENSION IF NOT EXISTS +pg_stat_statements` and `pg_stat_statements_reset()` through `kubectl exec` +against `pod/pgsql-0`, so statement summaries start clean for the monitored +run. + +## Current trace findings + +Current `src-auth-perms-sync` snapshots explicit API grants by calling +`User.permissionsInfo.repositories(source: API)` through aliased +`UserExplicitReposBatch` queries. It requests only permission repo IDs, then +hydrates names separately with `RepositoryNamesByID`. + +A focused traced batch for one user with 19 explicit repos showed per-user +fanout even when only IDs were requested: | User aliases | CLI request | Jaeger spans | `LoadUserPermissions` | `sql.conn.query` | | ---: | ---: | ---: | ---: | ---: | @@ -165,9 +183,9 @@ batch for one user with 19 explicit repos showed per-user fanout: | 25 | 508ms | 157 | 25 | 127 | | 100 | 1,185ms | 607 | 100 | 502 | -The remaining repository-name hydration is a second fanout. A traced -`RepositoryNamesByID` query for 19 repos produced 46 spans, including 19 -`repos.Get` spans and 22 `sql.conn.query` spans. +The second hydration query also fans out. A traced `RepositoryNamesByID` query +for 19 repos produced 46 spans, including 19 `repos.Get` spans and 22 +`sql.conn.query` spans. An older trace shape that resolved repository objects directly inside `permissionsInfo.repositories` showed the per-repo resolver fanout more @@ -175,12 +193,36 @@ dramatically: | Request shape | Root GraphQL span | Jaeger fanout | | --- | ---: | --- | -| 25 user aliases, 19 explicit repos each | ~770 ms | 475 `repos.Get`, 603 `sql.conn.query` | -| 100 user aliases, 19 explicit repos each | ~3,769 ms | 1,900 `repos.Get`, 2,403 `sql.conn.query` | +| 25 user aliases, 19 explicit repos each | ~770ms | 475 `repos.Get`, 603 `sql.conn.query` | +| 100 user aliases, 19 explicit repos each | ~3,769ms | 1,900 `repos.Get`, 2,403 `sql.conn.query` | Together these point to Sourcegraph server-side GraphQL / DB resolver fanout, -not local Python CPU. Larger batches reduce request count but increase per -request resolver and SQL work enough to create timeouts on this instance. +not local Python CPU. Larger batches reduce request count but can increase +per-request resolver and SQL work enough to cause timeouts on the test +instance. + +One live-instance behavior is expected: if Sourcegraph returns a GraphQL +application error showing that a repo/user disappeared between planning and the +mutation, `src-auth-perms-sync` logs a skipped mutation and continues. The next +scheduled run will re-plan against the then-current users/repos. Other GraphQL +application errors still fail normally. + +## Stress-run evidence + +A prior hard stress map used about 10,001 users and about 1,000 repos, planning +roughly 10 million explicit grants. That run showed Sourcegraph-side read and +write costs were the bottleneck. `pg_stat_statements` attributed most database +time to explicit-permissions helpers: + +| Sourcegraph operation | Calls | Total time | Mean time | +| --- | ---: | ---: | ---: | +| `permsStore.ListUserPermissions` | 19,974 | 30,862.6s | 1,545ms | +| `permsStore.upsertUserRepoPermissions-range1` | 472 | 1,178.8s | 2,497ms | + +Compared with focused traces at normal scale, `ListUserPermissions` became much +slower under the large explicit-perms state. This reinforces that the CLI needs +better Sourcegraph bulk read and write APIs for very large explicit permission +sets. ## Sourcegraph engineering request @@ -204,7 +246,7 @@ from `github.com/sourcegraph/sourcegraph`: creating an N+1 query pattern for repository hydration. - Even when the client asks only for permission repo IDs, each aliased user still runs `LoadUserPermissions` and several SQL queries. Current - `src-auth-perms-sync` then has to hydrate repository names separately through + `src-auth-perms-sync` then hydrates repository names separately through `node(id)`, which also resolves as one `repos.Get` per repository ID. - `internal/database/perms_store.go` has bulk write helpers for setting repo permissions, but the read path uses per-user connection queries and repo @@ -254,6 +296,11 @@ Important requirements: Expected benefit: replace hundreds or thousands of per-repo resolver SQL spans per request with one indexed `user_repo_permissions` join per user batch. +The stress profile also needs attention on the write path. A purpose-built +bulk overwrite API that accepts many repo/user edges at once, streams or stages +the input server-side, and avoids repeated per-repo permission reconciliation +would make worst-case full syncs much safer. + ## Copy/paste request Title: Add a bulk GraphQL read path for explicit repository permissions @@ -287,3 +334,6 @@ Acceptance criteria: latency visible. - `src-auth-perms-sync` can replace its aliased `User.permissionsInfo.repositories(source: API)` calls with this API. +- Follow-up: evaluate a bulk overwrite API for large full-set applies. The + stress run planned roughly 10 million grants and observed + `permsStore.upsertUserRepoPermissions-range1` averaging about 2.5s per call. diff --git a/dev/test-end-to-end.py b/dev/test-end-to-end.py index 6022f00..7f5edd1 100755 --- a/dev/test-end-to-end.py +++ b/dev/test-end-to-end.py @@ -13,19 +13,22 @@ from __future__ import annotations +import contextlib import csv import datetime +import heapq import json import os import re import shlex +import signal import statistics import subprocess import sys import threading import time from collections.abc import Iterable, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future from concurrent.futures import wait as wait_for_futures from dataclasses import dataclass from pathlib import Path @@ -33,38 +36,42 @@ from urllib.parse import urlsplit import src_py_lib as src -from src_py_lib.clients.sourcegraph import ( - JAEGER_TRACE_RETRY_DELAYS_SECONDS, - sourcegraph_trace_from_headers, -) +from src_py_lib.clients.sourcegraph import sourcegraph_trace_from_headers, summarize_jaeger_trace LOG_PATH_PATTERN = re.compile(r"Writing log events to (.+?/log\.json)\.") +SAFE_PATH_PART_PATTERN = re.compile(r"[^A-Za-z0-9_.-]+") DEFAULT_FUTURE_DATE = "2099-01-01" REMOVED_SRC_AUTH_PERMS_SYNC_ENVIRONMENT_PREFIX = "SRC_AUTH_PERMS_SYNC_" DEFAULT_SAMPLE_INTERVAL_SECONDS = 1.0 DEFAULT_REPEAT_COUNT = 1 DEFAULT_JAEGER_TRACE_LIMIT: int | None = None -DEFAULT_JAEGER_TRACE_PARALLELISM = 16 +DEFAULT_JAEGER_TRACE_PARALLELISM = 8 +DEFAULT_JAEGER_INITIAL_DELAY_SECONDS = 35.0 +DEFAULT_JAEGER_RETRY_DELAYS_SECONDS = ( + 2.0, + 5.0, + 10.0, + 20.0, + 30.0, + 60.0, + 60.0, + 60.0, + 60.0, + 60.0, + 60.0, +) DEFAULT_PARALLELISM = 4 DEFAULT_FULL_RESTORE_PARALLELISM = 1 +DEFAULT_INCLUDE_REDUNDANT_SCALE_CASES = False DEFAULT_MEMORY_SUMMARY_LIMIT = 20 DEFAULT_SRC_AUTH_PERMS_SYNC_COMMAND = "uv run src-auth-perms-sync" -WORKLOAD_FIELDS = ( - "user_count", - "total_users", - "total_users_scanned", - "repo_count", - "repos_with_explicit_grants", - "total_grants", - "mapping_count", - "plan_size", - "payload_count", - "target_organizations", - "desired_memberships", - "mutations_succeeded", - "mutations_failed", - "mutations_canceled", -) +DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE = "m" +DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS = 5 +DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS = 10 +DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS = 30 +DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET = "deployment/sourcegraph-frontend" +DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET = "pod/pgsql-0" +DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND = "psql -X -U sg -d sg" def format_jaeger_retry_delays(delays: Sequence[float]) -> str: @@ -160,6 +167,16 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): f"(default: {DEFAULT_FULL_RESTORE_PARALLELISM})" ), ) + include_redundant_scale_cases: bool = src.config_field( + default=DEFAULT_INCLUDE_REDUNDANT_SCALE_CASES, + env_var="SRC_AUTH_PERMS_SYNC_E2E_INCLUDE_REDUNDANT_SCALE_CASES", + cli_flag="--include-redundant-scale-cases", + cli_action="store_true", + help=( + "Also run older overlapping full-scale cases. Default keeps one heavy full " + "snapshot path and uses smaller cases for overlapping coverage." + ), + ) allow_non_test_endpoint: bool = src.config_field( default=False, env_var="SRC_AUTH_PERMS_SYNC_E2E_ALLOW_NON_TEST_ENDPOINT", @@ -203,6 +220,17 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): f"(default: {DEFAULT_JAEGER_TRACE_PARALLELISM})" ), ) + jaeger_initial_delay_seconds: float = src.config_field( + default=DEFAULT_JAEGER_INITIAL_DELAY_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_INITIAL_DELAY_SECONDS", + cli_flag="--jaeger-initial-delay-seconds", + metavar="SECONDS", + ge=0, + help=( + "Seconds to wait before first fetching each Jaeger trace, to allow OTel tail " + f"sampling to decide (default: {DEFAULT_JAEGER_INITIAL_DELAY_SECONDS:g})" + ), + ) jaeger_trace_jsonl: Path | None = src.config_field( default=None, env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_TRACE_JSONL", @@ -213,14 +241,26 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): "of --results-json or --results-csv when --trace is set." ), ) + jaeger_trace_directory: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_TRACE_DIR", + cli_flag="--jaeger-trace-dir", + metavar="PATH", + help=( + "Directory where complete raw Jaeger trace JSON files are written. Defaults " + "to a sibling directory of --results-json or --results-csv when --trace is set." + ), + ) jaeger_retry_delays: tuple[float, ...] = src.config_field( - default=JAEGER_TRACE_RETRY_DELAYS_SECONDS, + default=DEFAULT_JAEGER_RETRY_DELAYS_SECONDS, env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_RETRY_DELAYS", cli_flag="--jaeger-retry-delays", metavar="SECONDS[,SECONDS...]", help=( - "Comma-separated retry delays for Jaeger trace lookup lag " - f"(default: {format_jaeger_retry_delays(JAEGER_TRACE_RETRY_DELAYS_SECONDS)})" + "Comma-separated delays between queued Jaeger trace fetch retries. " + "Each value schedules one retry after the initial fetch; add more values " + "to try for longer " + f"(default: {format_jaeger_retry_delays(DEFAULT_JAEGER_RETRY_DELAYS_SECONDS)})" ), ) sample_interval: float = src.config_field( @@ -271,6 +311,103 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): "beside it as *-phases.csv" ), ) + monitor_sourcegraph_load: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_SOURCEGRAPH_LOAD", + cli_flag="--monitor-sourcegraph-load", + cli_action="store_true", + help=( + "Start the Sourcegraph pod/Postgres load monitor for this e2e run and write " + "its output beside the result artifacts." + ), + ) + sourcegraph_monitor_namespace: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_NAMESPACE", + cli_flag="--monitor-namespace", + metavar="NAME", + help=( + "Kubernetes namespace for Sourcegraph load monitoring " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE})" + ), + ) + sourcegraph_monitor_output_dir: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_OUTPUT_DIR", + cli_flag="--monitor-output-dir", + metavar="PATH", + help="Directory for Sourcegraph load monitor output; defaults beside result artifacts.", + ) + sourcegraph_monitor_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_INTERVAL_SECONDS", + cli_flag="--monitor-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "Pod/process/cgroup monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_postgres_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_POSTGRES_INTERVAL_SECONDS", + cli_flag="--monitor-postgres-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "Postgres activity monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_statements_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_STATEMENTS_INTERVAL_SECONDS", + cli_flag="--monitor-statements-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "pg_stat_statements monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_frontend_target: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_FRONTEND_TARGET", + cli_flag="--monitor-frontend-target", + metavar="TARGET", + help=( + "kubectl target for Sourcegraph frontend " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET})" + ), + ) + sourcegraph_monitor_postgres_target: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_POSTGRES_TARGET", + cli_flag="--monitor-postgres-target", + metavar="TARGET", + help=( + "kubectl target for Sourcegraph Postgres " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET})" + ), + ) + sourcegraph_monitor_psql_command: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_PSQL_COMMAND", + cli_flag="--monitor-psql-command", + metavar="COMMAND", + help=( + "psql command to run inside the Postgres pod " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND})" + ), + ) + sourcegraph_monitor_no_logs: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_NO_LOGS", + cli_flag="--monitor-no-logs", + cli_action="store_true", + help="Do not stream frontend logs while Sourcegraph load monitoring is enabled.", + ) fail_on_memory_regression_percent: float | None = src.config_field( default=None, env_var="SRC_AUTH_PERMS_SYNC_E2E_FAIL_ON_MEMORY_REGRESSION_PERCENT", @@ -431,22 +568,136 @@ def sample_once(self) -> None: self.peak_rss_mb = max_optional_float(self.peak_rss_mb, rss_mb) +class SourcegraphLoadMonitor: + """Run the Sourcegraph pod/Postgres monitor for the duration of the e2e suite.""" + + def __init__(self, config: EndToEndConfig, output_dir: Path) -> None: + self.config = config + self.output_dir = output_dir + self.log_path = output_dir.with_name(f"{output_dir.name}.log") + self._log_file: TextIO | None = None + self._process: subprocess.Popen[str] | None = None + + def start(self) -> None: + script_path = sourcegraph_monitor_script_path() + if not script_path.exists(): + raise RuntimeError(f"Sourcegraph load monitor script not found: {script_path}") + self.output_dir.parent.mkdir(parents=True, exist_ok=True) + self.log_path.parent.mkdir(parents=True, exist_ok=True) + command = [ + str(script_path), + "--namespace", + self.config.sourcegraph_monitor_namespace, + "--output-dir", + str(self.output_dir), + "--interval-seconds", + str(self.config.sourcegraph_monitor_interval_seconds), + "--postgres-interval-seconds", + str(self.config.sourcegraph_monitor_postgres_interval_seconds), + "--statements-interval-seconds", + str(self.config.sourcegraph_monitor_statements_interval_seconds), + "--frontend-target", + self.config.sourcegraph_monitor_frontend_target, + "--postgres-target", + self.config.sourcegraph_monitor_postgres_target, + "--psql-command", + self.config.sourcegraph_monitor_psql_command, + ] + if self.config.sourcegraph_monitor_no_logs: + command.append("--no-logs") + print(f"Starting Sourcegraph load monitor: {self.output_dir}") + self._log_file = self.log_path.open("w", encoding="utf-8") + self._process = subprocess.Popen( # noqa: S603 - command is trusted test config. + command, + cwd=Path.cwd(), + stdout=self._log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + ) + self._wait_until_started() + + def stop(self) -> None: + process = self._process + if process is None: + self._close_log_file() + return + if process.poll() is None: + with contextlib.suppress(ProcessLookupError): + os.killpg(process.pid, signal.SIGTERM) + try: + process.wait(timeout=15) + except subprocess.TimeoutExpired: + with contextlib.suppress(ProcessLookupError): + os.killpg(process.pid, signal.SIGKILL) + process.wait(timeout=15) + return_code = process.returncode + self._close_log_file() + if return_code not in {0, -15, 143}: + print( + f"Sourcegraph load monitor exited with status {return_code}; see {self.log_path}", + file=sys.stderr, + ) + else: + print(f"Stopped Sourcegraph load monitor. Output: {self.output_dir}") + + def _wait_until_started(self) -> None: + process = self._process + if process is None: + return + deadline = time.monotonic() + 60 + while time.monotonic() < deadline: + if process.poll() is not None: + raise RuntimeError( + f"Sourcegraph load monitor exited before startup completed; see {self.log_path}" + ) + if self.log_path.exists() and "Started kubectl-top" in self.log_path.read_text( + encoding="utf-8", errors="ignore" + ): + return + time.sleep(0.2) + raise RuntimeError( + f"Timed out waiting for Sourcegraph load monitor startup; see {self.log_path}" + ) + + def _close_log_file(self) -> None: + if self._log_file is not None: + self._log_file.close() + self._log_file = None + + +@dataclass +class JaegerTraceFetchTask: + """One trace fetch request that can be retried across the whole e2e run.""" + + trace_request: dict[str, Any] + future: Future[dict[str, Any]] + fetch_attempts: int = 0 + first_fetch_at: str | None = None + last_fetch_at: str | None = None + + class JaegerTraceFetchPool: - """Fetch Sourcegraph Jaeger trace summaries through one bounded HTTP pool.""" + """Fetch Sourcegraph Jaeger traces through one bounded retry queue.""" def __init__( self, config: EndToEndConfig, *, parallelism: int, + initial_delay_seconds: float, retry_delays_seconds: Sequence[float], jsonl_path: Path | None, + trace_directory: Path | None, ) -> None: + self.initial_delay_seconds = initial_delay_seconds self.retry_delays_seconds = tuple(retry_delays_seconds) - self._executor = ThreadPoolExecutor( - max_workers=parallelism, - thread_name_prefix="JaegerTraceFetch", - ) + self.max_fetch_attempts = len(self.retry_delays_seconds) + 1 + self._trace_directory = trace_directory + self._tasks: list[tuple[float, int, JaegerTraceFetchTask]] = [] + self._condition = threading.Condition() + self._sequence = 0 + self._closed = False self._jsonl_file: TextIO | None = None self._lock = threading.Lock() http = src.HTTPClient( @@ -459,36 +710,166 @@ def __init__( jsonl_path.parent.mkdir(parents=True, exist_ok=True) self._jsonl_file = jsonl_path.open("w", encoding="utf-8") print(f"Writing Jaeger trace summaries incrementally to {jsonl_path}") + if self._trace_directory is not None: + self._trace_directory.mkdir(parents=True, exist_ok=True) + print(f"Writing complete Jaeger traces to {self._trace_directory}") + self._workers = [ + threading.Thread( + target=self._worker, + name=f"JaegerTraceFetch-{worker_number}", + daemon=True, + ) + for worker_number in range(1, parallelism + 1) + ] + for worker in self._workers: + worker.start() def submit( self, trace_request: dict[str, Any], collector: JaegerTraceCollector, ) -> Future[dict[str, Any]]: - future = src.submit_with_log_context(self._executor, self._fetch_summary, trace_request) + future: Future[dict[str, Any]] = Future() future.add_done_callback(lambda completed: self._record_summary(collector, completed)) + task = JaegerTraceFetchTask( + trace_request=trace_request, + future=future, + ) + self._schedule(task, self.initial_delay_seconds) return future def close(self) -> None: - self._executor.shutdown(wait=True) + with self._condition: + self._closed = True + self._condition.notify_all() + for worker in self._workers: + worker.join() self._client.http.close() if self._jsonl_file is not None: self._jsonl_file.close() - def _fetch_summary(self, trace_request: dict[str, Any]) -> dict[str, Any]: + def _schedule(self, task: JaegerTraceFetchTask, delay_seconds: float) -> None: + with self._condition: + self._sequence += 1 + heapq.heappush( + self._tasks, + (time.monotonic() + delay_seconds, self._sequence, task), + ) + self._condition.notify() + + def _worker(self) -> None: + while True: + task = self._next_ready_task() + if task is None: + return + self._process(task) + + def _next_ready_task(self) -> JaegerTraceFetchTask | None: + with self._condition: + while True: + if self._closed and not self._tasks: + return None + if not self._tasks: + self._condition.wait() + continue + ready_at, _sequence, task = self._tasks[0] + delay_seconds = ready_at - time.monotonic() + if delay_seconds > 0: + self._condition.wait(delay_seconds) + continue + heapq.heappop(self._tasks) + return task + + def _process(self, task: JaegerTraceFetchTask) -> None: + if task.future.done(): + return + summary = self._fetch_summary(task) + if summary.get("jaeger_found") is True or not self._should_retry(task, summary): + task.future.set_result(summary) + return + self._schedule(task, self._retry_delay_seconds(task.fetch_attempts)) + + def _fetch_summary(self, task: JaegerTraceFetchTask) -> dict[str, Any]: + task.fetch_attempts += 1 + now = datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds") + if task.first_fetch_at is None: + task.first_fetch_at = now + task.last_fetch_at = now try: - trace = sourcegraph_trace_from_request(trace_request) - summary = self._client.fetch_jaeger_trace_summary( - trace, - retry_delays_seconds=self.retry_delays_seconds, - ).to_json() - return {**trace_request, **summary} + trace = sourcegraph_trace_from_request(task.trace_request) + jaeger_trace = self._client.fetch_jaeger_trace( + trace.trace_id, + retry_delays_seconds=(0.0,), + ) + summary = summarize_jaeger_trace(trace, jaeger_trace).to_json() + try: + trace_path = self._write_complete_trace(task, jaeger_trace, summary) + if trace_path is not None: + summary["jaeger_trace_path"] = str(trace_path) + except OSError as write_error: + summary["jaeger_trace_write_error"] = f"{type(write_error).__name__}: {write_error}" + return self._with_fetch_fields(task, summary) except Exception as exception: # noqa: BLE001 - keep long-running evidence collection alive. - return { - **trace_request, - "jaeger_found": False, - "error": f"{type(exception).__name__}: {exception}", - } + return self._with_fetch_fields( + task, + { + **task.trace_request, + "jaeger_found": False, + "error": f"{type(exception).__name__}: {exception}", + }, + ) + + def _with_fetch_fields( + self, task: JaegerTraceFetchTask, summary: dict[str, Any] + ) -> dict[str, Any]: + return { + **task.trace_request, + **summary, + "fetch_attempts": task.fetch_attempts, + "first_fetch_at": task.first_fetch_at, + "last_fetch_at": task.last_fetch_at, + "max_fetch_attempts": self.max_fetch_attempts, + } + + def _write_complete_trace( + self, + task: JaegerTraceFetchTask, + jaeger_trace: dict[str, Any], + summary: dict[str, Any], + ) -> Path | None: + if self._trace_directory is None: + return None + path = complete_jaeger_trace_path(self._trace_directory, task.trace_request) + payload = { + "collected_at": task.last_fetch_at, + "fetch_attempts": task.fetch_attempts, + "max_fetch_attempts": self.max_fetch_attempts, + "trace_request": task.trace_request, + "jaeger_summary": summary, + "jaeger_trace": jaeger_trace, + } + path.parent.mkdir(parents=True, exist_ok=True) + temporary_path = path.with_name( + f".{path.name}.tmp-{threading.get_ident()}-{time.monotonic_ns()}" + ) + temporary_path.write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + temporary_path.replace(path) + return path + + def _should_retry(self, task: JaegerTraceFetchTask, summary: dict[str, Any]) -> bool: + if self._closed or task.fetch_attempts >= self.max_fetch_attempts: + return False + error = str(summary.get("error") or "") + return error.startswith(("HTTP 404", "HTTP 502", "HTTP 503", "HTTP 504")) + + def _retry_delay_seconds(self, fetch_attempts: int) -> float: + if not self.retry_delays_seconds: + return 0.0 + delay_index = min(fetch_attempts - 1, len(self.retry_delays_seconds) - 1) + return self.retry_delays_seconds[delay_index] def _record_summary( self, @@ -527,6 +908,8 @@ def __init__( self.iteration = iteration self.case_name = case_name self.summaries: list[dict[str, Any]] = [] + self._graphql_queries_by_span: dict[tuple[str, str], dict[str, Any]] = {} + self._trace_requests_by_graphql_span: dict[tuple[str, str], dict[str, Any]] = {} self._requests_by_trace_id: dict[str, dict[str, Any]] = {} self._queued_trace_ids: set[str] = set() self._futures: list[Future[dict[str, Any]]] = [] @@ -599,15 +982,22 @@ def _record_line(self, line: str) -> None: return if not isinstance(record, dict): return + self._record_graphql_query_metadata(cast(dict[str, Any], record)) trace_request = graphql_trace_request_from_record(cast(dict[str, Any], record)) if trace_request is None: return trace_request.update( {"variant": self.variant, "iteration": self.iteration, "case": self.case_name} ) + graphql_span_key = self._graphql_span_key_for_http_record(cast(dict[str, Any], record)) trace_id = trace_request["trace_id"] submit_request: dict[str, Any] | None = None with self._lock: + if graphql_span_key is not None: + graphql_query = self._graphql_queries_by_span.get(graphql_span_key) + if graphql_query is not None: + trace_request["graphql_query"] = dict(graphql_query) + self._trace_requests_by_graphql_span[graphql_span_key] = trace_request existing_request = self._requests_by_trace_id.get(trace_id) if existing_request is None or trace_summary_duration_ms( trace_request @@ -621,6 +1011,29 @@ def _record_line(self, line: str) -> None: with self._lock: self._futures.append(future) + def _record_graphql_query_metadata(self, record: dict[str, Any]) -> None: + metadata = graphql_query_metadata_from_record(record) + if metadata is None: + return + span_key = graphql_query_span_key(record) + if span_key is None: + return + with self._lock: + existing_metadata = self._graphql_queries_by_span.get(span_key, {}) + merged_metadata = existing_metadata | metadata + self._graphql_queries_by_span[span_key] = merged_metadata + trace_request = self._trace_requests_by_graphql_span.get(span_key) + if trace_request is not None: + trace_request["graphql_query"] = dict(merged_metadata) + + @staticmethod + def _graphql_span_key_for_http_record(record: dict[str, Any]) -> tuple[str, str] | None: + trace_id = optional_string(record.get("trace")) + parent_span_id = optional_string(record.get("parent_span")) + if trace_id is None or parent_span_id is None: + return None + return trace_id, parent_span_id + def _submit_limited_requests(self) -> None: if self.limit is None: return @@ -849,7 +1262,16 @@ def _assert_result(self, result: CommandResult) -> None: def main() -> None: config = load_end_to_end_config() - with src.logging(config, command="test_end_to_end", git_cwd=Path.cwd()): + logging_settings = src.logging_settings_from_config( + config, + logs_dir=Path("logs-test-end-to-end"), + ) + with src.logging( + config, + command="test_end_to_end", + git_cwd=Path.cwd(), + logging_config=logging_settings, + ): run_end_to_end(config) @@ -879,14 +1301,22 @@ def run_end_to_end(config: EndToEndConfig) -> None: all_failures: list[str] = [] all_jaeger_collectors: list[JaegerTraceCollector] = [] jaeger_trace_fetch_pool = create_jaeger_trace_fetch_pool(config) + sourcegraph_load_monitor = create_sourcegraph_load_monitor(config) latest_baseline_repositories: set[str] = set() try: + if sourcegraph_load_monitor is not None: + sourcegraph_load_monitor.start() with src.event( "end_to_end_matrix", repeat=config.repeat, variant_count=len(variants), trace=config.trace, + sourcegraph_load_monitor=sourcegraph_load_monitor is not None, ) as matrix_summary: + if sourcegraph_load_monitor is not None: + matrix_summary["sourcegraph_load_monitor_dir"] = str( + sourcegraph_load_monitor.output_dir + ) for iteration in range(1, config.repeat + 1): for variant in variants: with src.stage("matrix_variant", variant=variant.name, iteration=iteration): @@ -915,6 +1345,8 @@ def run_end_to_end(config: EndToEndConfig) -> None: wait_for_jaeger_trace_collectors(all_jaeger_collectors) if jaeger_trace_fetch_pool is not None: jaeger_trace_fetch_pool.close() + if sourcegraph_load_monitor is not None: + sourcegraph_load_monitor.stop() if all_failures: print("\nFailures:", file=sys.stderr) for failure in all_failures: @@ -928,7 +1360,7 @@ def run_end_to_end(config: EndToEndConfig) -> None: print_phase_memory_summary(all_results, config.memory_summary_limit) comparisons = compare_variants(all_results) print_comparison_summary(comparisons) - write_results_files(all_results, comparisons, config) + write_results_files(all_results, comparisons, config, sourcegraph_load_monitor) raise_for_memory_regressions(comparisons, config) @@ -955,8 +1387,10 @@ def create_jaeger_trace_fetch_pool( return JaegerTraceFetchPool( config, parallelism=config.jaeger_trace_parallelism, + initial_delay_seconds=config.jaeger_initial_delay_seconds, retry_delays_seconds=config.jaeger_retry_delays, jsonl_path=jaeger_trace_jsonl_path(config), + trace_directory=jaeger_trace_directory(config), ) @@ -971,6 +1405,56 @@ def jaeger_trace_jsonl_path(config: EndToEndConfig) -> Path | None: return Path("/tmp") / f"src-auth-perms-sync-end-to-end-jaeger-traces-{stamp}.jsonl" +def jaeger_trace_directory(config: EndToEndConfig) -> Path: + """Return the directory where complete raw Jaeger traces should be stored.""" + if config.jaeger_trace_directory is not None: + return config.jaeger_trace_directory + anchor = config.results_json or config.results_csv + if anchor is not None: + return anchor.with_name(f"{anchor.stem}-jaeger-traces") + stamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") + return Path("/tmp") / f"src-auth-perms-sync-end-to-end-jaeger-traces-{stamp}" + + +def create_sourcegraph_load_monitor(config: EndToEndConfig) -> SourcegraphLoadMonitor | None: + """Return the Sourcegraph load monitor for this run, if enabled.""" + if not config.monitor_sourcegraph_load: + return None + return SourcegraphLoadMonitor(config, sourcegraph_monitor_output_dir(config)) + + +def sourcegraph_monitor_output_dir(config: EndToEndConfig) -> Path: + """Return where Sourcegraph pod/Postgres monitor artifacts should be stored.""" + if config.sourcegraph_monitor_output_dir is not None: + return config.sourcegraph_monitor_output_dir + anchor = config.results_json or config.results_csv + if anchor is not None: + return anchor.with_name(f"{anchor.stem}-sourcegraph-load") + stamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") + return Path("/tmp") / f"src-auth-perms-sync-end-to-end-sourcegraph-load-{stamp}" + + +def sourcegraph_monitor_script_path() -> Path: + """Return the lower-level monitor script used by the e2e orchestrator.""" + return Path(__file__).resolve().with_name("memory-efficiency-monitor-sourcegraph.sh") + + +def complete_jaeger_trace_path(trace_directory: Path, trace_request: dict[str, Any]) -> Path: + """Return the stable per-trace path for a complete Jaeger trace payload.""" + variant = safe_path_part(trace_request.get("variant"), default="variant") + iteration = int_field(trace_request, "iteration") or 0 + case_name = safe_path_part(trace_request.get("case"), default="case") + trace_id = safe_path_part(trace_request.get("trace_id"), default="trace") + return trace_directory / variant / f"iteration-{iteration:04d}" / case_name / f"{trace_id}.json" + + +def safe_path_part(value: object, *, default: str) -> str: + """Return a filesystem-safe path segment for generated trace artifacts.""" + text = str(value) if value is not None else "" + safe_text = SAFE_PATH_PART_PATTERN.sub("-", text).strip("-.") + return safe_text[:120] or default + + def command_environment(config: EndToEndConfig) -> dict[str, str]: """Return a deterministic child environment for CLI config parsing.""" environment = dict(os.environ) @@ -1012,7 +1496,7 @@ def run_matrix( baseline_result: CommandResult | None = None for case in read_only_cases(config): result = runner.run(case) - if case.name == "implicit-get-user": + if case.name == "get-users-baseline": baseline_result = result assert baseline_result is not None baseline_repositories = repositories_for_user(snapshot_path(baseline_result), config.user) @@ -1044,93 +1528,103 @@ def invalid_configuration_cases(config: EndToEndConfig) -> list[CommandCase]: restore_placeholder = "definitely-missing-before.json" missing_maps = "definitely-missing-command-permutation-maps.yaml" command_pairs: list[tuple[str, tuple[str, ...]]] = [ - ("get-set", ("--get", "--set", "maps.yaml")), - ("get-restore", ("--get", "--restore", restore_placeholder)), - ("set-restore", ("--set", "maps.yaml", "--restore", restore_placeholder)), + ("get-set", ("get", "set")), + ("get-restore", ("get", "restore", "--restore-path", restore_placeholder)), + ("set-restore", ("set", "--maps-path", "maps.yaml", "restore")), ] cases = [ CommandCase( name=f"invalid-multiple-commands-{name}", arguments=command_arguments, expected_exit_code=2, - must_contain=("choose only one",), + must_contain=("unrecognized arguments",), ) for name, command_arguments in command_pairs ] cases.append( CommandCase( name="invalid-restore-sync-saml-orgs", - arguments=("--restore", restore_placeholder, "--sync-saml-orgs"), + arguments=("restore", "--restore-path", restore_placeholder, "--sync-saml-orgs"), expected_exit_code=2, - must_contain=("with --get or --set",), + must_contain=("unrecognized arguments",), ) ) cases.extend( [ CommandCase( name="invalid-full-without-set", - arguments=("--full",), + arguments=("get", "--full"), expected_exit_code=2, - must_contain=("--full requires --set",), + must_contain=("unrecognized arguments",), ), CommandCase( name="invalid-set-full-and-user", - arguments=("--set", "maps.yaml", "--full", "--user", config.user), + arguments=("set", "--full", "--users", config.user), expected_exit_code=2, must_contain=("choose at most one",), ), CommandCase( name="invalid-set-full-and-users-without-explicit-perms", - arguments=("--set", "maps.yaml", "--full", "--users-without-explicit-perms"), + arguments=( + "set", + "--full", + "--users-without-explicit-perms", + ), expected_exit_code=2, must_contain=("choose at most one",), ), CommandCase( name="invalid-user-filter-conflict", - arguments=("--get", "--user", config.user, "--users-without-explicit-perms"), + arguments=("get", "--users", config.user, "--users-without-explicit-perms"), expected_exit_code=2, - must_contain=("choose only one of --user or --users-without-explicit-perms",), + must_contain=("choose only one of --users or --users-without-explicit-perms",), ), CommandCase( name="invalid-restore-user-filter", - arguments=("--restore", restore_placeholder, "--user", config.user), + arguments=( + "restore", + "--restore-path", + restore_placeholder, + "--users", + config.user, + ), expected_exit_code=2, - must_contain=("require --get or --set",), + must_contain=("unrecognized arguments",), ), CommandCase( name="invalid-sync-created-after-filter", - arguments=("--sync-saml-orgs", "--created-after", config.future_date), + arguments=("sync-saml-orgs", "--created-after", config.future_date), expected_exit_code=2, - must_contain=("require --get or --set",), + must_contain=("unrecognized arguments",), ), CommandCase( name="invalid-date-shape", - arguments=("--get", "--created-after", "2026-1-01"), + arguments=("get", "--created-after", "2026-1-01"), expected_exit_code=2, ), CommandCase( name="invalid-date-value", - arguments=("--get", "--created-after", "2026-02-31"), + arguments=("get", "--created-after", "2026-02-31"), expected_exit_code=1, must_contain=("--created-after must use YYYY-MM-DD",), ), CommandCase( name="invalid-missing-set-file", - arguments=("--set", missing_maps), + arguments=("set", "--maps-path", missing_maps), expected_exit_code=1, expected_log_command="set_full", expected_log_status="error", - must_contain=("--set input file does not exist",), + must_contain=("set input file does not exist",), ), CommandCase( name="invalid-removed-repositories-created-after-flag", - arguments=("--repositories-created-after", config.future_date), + arguments=("get", "--repositories-created-after", config.future_date), expected_exit_code=2, must_contain=("unrecognized arguments",), ), CommandCase( name="invalid-removed-get-schema-flag", - arguments=("--get-schema", "definitely-missing-schema.gql"), + arguments=("get", "--get-schema", "definitely-missing-schema.gql"), expected_exit_code=2, must_contain=("unrecognized arguments",), ), @@ -1144,30 +1638,24 @@ def read_only_cases(config: EndToEndConfig) -> list[CommandCase]: CommandCase( name="help", arguments=("--help",), - must_contain=("usage: src-auth-perms-sync", "--set [FILE]"), + must_contain=("usage: src-auth-perms-sync", "commands:"), must_not_contain=("--repositories-created-after", "--get-schema"), ), CommandCase( - name="implicit-get-user", - arguments=("--user", config.user), - expected_log_command="get", - must_contain=("Wrote before-snapshot",), - ), - CommandCase( - name="explicit-get-user", - arguments=("--get", "--user", config.user), + name="get-users-baseline", + arguments=("get", "--users", config.user), expected_log_command="get", must_contain=("Wrote before-snapshot",), ), CommandCase( name="get-created-after-future", - arguments=("--get", "--created-after", config.future_date), + arguments=("get", "--created-after", config.future_date), expected_log_command="get", must_contain=("Selected 0 user(s) for get output",), ), CommandCase( name="get-user-created-after-future", - arguments=("--get", "--user", config.user, "--created-after", config.future_date), + arguments=("get", "--users", config.user, "--created-after", config.future_date), expected_log_command="get", must_contain_one_of=( "Selected 0 user(s) for get output", @@ -1177,7 +1665,7 @@ def read_only_cases(config: EndToEndConfig) -> list[CommandCase]: CommandCase( name="get-users-without-explicit-perms-created-after-future", arguments=( - "--get", + "get", "--users-without-explicit-perms", "--created-after", config.future_date, @@ -1185,12 +1673,6 @@ def read_only_cases(config: EndToEndConfig) -> list[CommandCase]: expected_log_command="get", must_contain=("Selected 0 user(s) for get output",), ), - CommandCase( - name="get-sync-saml-orgs-dry-run", - arguments=("--get", "--sync-saml-orgs"), - expected_log_command="get_sync_saml_orgs", - must_contain=("Wrote before-snapshot", "Dry run complete"), - ), ] return cases @@ -1200,8 +1682,7 @@ def run_safe_set_cases(config: EndToEndConfig, runner: CommandPermutationRunner) CommandCase( name="set-explicit-full-no-op-apply", arguments=( - "--set", - "maps.yaml", + "set", "--full", "--created-after", config.future_date, @@ -1219,8 +1700,8 @@ def run_safe_set_cases(config: EndToEndConfig, runner: CommandPermutationRunner) def set_user_dry_run_case(config: EndToEndConfig) -> CommandCase: return CommandCase( name="set-user-dry-run", - arguments=("--set", "maps.yaml", "--user", config.user), - expected_log_command="set_user", + arguments=("set", "--users", config.user), + expected_log_command="set_users", must_contain=("Dry run complete",), ) @@ -1229,15 +1710,14 @@ def set_user_apply_case(config: EndToEndConfig) -> CommandCase: return CommandCase( name="set-user-apply", arguments=( - "--set", - "maps.yaml", - "--user", + "set", + "--users", config.user, "--apply", "--parallelism", str(config.parallelism), ), - expected_log_command="set_user", + expected_log_command="set_users", must_contain_one_of=( "VALIDATION OK: all", "All selected users already have the mapped explicit grants", @@ -1249,8 +1729,7 @@ def users_without_explicit_permissions_no_op_case(config: EndToEndConfig) -> Com return CommandCase( name="set-users-without-explicit-perms-no-op-apply", arguments=( - "--set", - "maps.yaml", + "set", "--users-without-explicit-perms", "--created-after", config.future_date, @@ -1268,7 +1747,8 @@ def restore_scoped_dry_run_case(snapshot: Path, config: EndToEndConfig) -> Comma return CommandCase( name="restore-scoped-dry-run", arguments=( - "--restore", + "restore", + "--restore-path", str(snapshot), "--parallelism", str(config.parallelism), @@ -1282,7 +1762,8 @@ def restore_scoped_apply_case(snapshot: Path, config: EndToEndConfig) -> Command return CommandCase( name="restore-scoped-apply-cleanup", arguments=( - "--restore", + "restore", + "--restore-path", str(snapshot), "--apply", "--parallelism", @@ -1299,7 +1780,7 @@ def restore_scoped_apply_case(snapshot: Path, config: EndToEndConfig) -> Command def sync_saml_apply_case() -> CommandCase: return CommandCase( name="sync-saml-orgs-apply", - arguments=("--sync-saml-orgs", "--apply"), + arguments=("sync-saml-orgs", "--apply"), expected_log_command="sync_saml_orgs", must_contain=("VALIDATION OK: all target org memberships match",), ) @@ -1308,7 +1789,7 @@ def sync_saml_apply_case() -> CommandCase: def final_get_user_case(config: EndToEndConfig) -> CommandCase: return CommandCase( name="final-get-user-baseline-check", - arguments=("--get", "--user", config.user), + arguments=("get", "--users", config.user), expected_log_command="get", must_contain=("Wrote before-snapshot",), ) @@ -1318,43 +1799,44 @@ def run_full_apply_cases(config: EndToEndConfig, runner: CommandPermutationRunne dry_run_result = runner.run( CommandCase( name="set-full-dry-run", - arguments=("--set",), + arguments=("set",), expected_log_command="set_full", must_contain=("Dry run complete",), ) ) baseline_snapshot = snapshot_path(dry_run_result) - try: - runner.run( - CommandCase( - name="set-full-apply", - arguments=( - "--set", - "--apply", - "--parallelism", - str(config.parallelism), - ), - expected_log_command="set_full", - must_contain=("VALIDATION OK",), + if config.include_redundant_scale_cases: + try: + runner.run( + CommandCase( + name="set-full-apply", + arguments=( + "set", + "--apply", + "--parallelism", + str(config.parallelism), + ), + expected_log_command="set_full", + must_contain=("VALIDATION OK",), + ) ) - ) - finally: - runner.run( - restore_full_apply_case( - "restore-full-apply-cleanup", - baseline_snapshot, - config, - no_backup=False, + finally: + runner.run( + restore_full_apply_case( + "restore-full-apply-cleanup", + baseline_snapshot, + config, + no_backup=False, + ) ) - ) try: runner.run( CommandCase( name="set-full-no-backup-apply", arguments=( - "--set", + "set", "--apply", "--no-backup", "--parallelism", @@ -1374,14 +1856,19 @@ def run_full_apply_cases(config: EndToEndConfig, runner: CommandPermutationRunne ) ) - # Covers the combined set+SAML dispatch and SAML dry-run path without - # repeating the full set apply and full restore cleanup paths, which are - # already covered above. + # Covers combined set+SAML dispatch and SAML dry-run with a user-scoped + # set path, so the default suite keeps only one expensive full-snapshot + # case. Pass --include-redundant-scale-cases to restore older overlap. runner.run( CommandCase( - name="set-full-sync-saml-orgs-dry-run", - arguments=("--set", "--sync-saml-orgs"), - expected_log_command="set_full_sync_saml_orgs", + name="set-user-sync-saml-orgs-dry-run", + arguments=( + "set", + "--users", + config.user, + "--sync-saml-orgs", + ), + expected_log_command="set_users_sync_saml_orgs", must_contain=("Dry run complete",), ) ) @@ -1395,7 +1882,8 @@ def restore_full_apply_case( no_backup: bool, ) -> CommandCase: restore_arguments = [ - "--restore", + "restore", + "--restore-path", str(snapshot), "--apply", "--parallelism", @@ -1602,20 +2090,178 @@ def parse_log_timestamp(value: object) -> datetime.datetime | None: def workload_from_records(records: list[dict[str, Any]]) -> dict[str, int | float | str]: - """Collect stable workload-size fields so memory can be normalized.""" + """Collect named workload dimensions from structured log records. + + Earlier e2e summaries used raw field names from unrelated events, which made + values like `total_users` and `repo_count` ambiguous. Keep this summary + event-aware so each key says what it counts. + """ workload: dict[str, int | float | str] = {} for record in records: - for field_name in WORKLOAD_FIELDS: - value = record.get(field_name) - if isinstance(value, int | float): - old_value = workload.get(field_name) - if not isinstance(old_value, int | float) or value > old_value: - workload[field_name] = value - elif isinstance(value, str) and field_name not in workload: - workload[field_name] = value + event_name = optional_string(record.get("event")) + phase = optional_string(record.get("phase")) + if event_name == "capture_explicit_grants": + record_workload_max(workload, "sourcegraph_user_count", record.get("total_users")) + if phase == "end": + record_workload_max(workload, "captured_user_count", record.get("user_count")) + elif event_name in {"build_snapshot", "build_user_scoped_snapshot"} and phase == "end": + record_workload_max(workload, "snapshot_user_count_max", record.get("user_count")) + record_workload_max( + workload, + "snapshot_repos_with_explicit_grants_max", + record.get("repos_with_explicit_grants"), + ) + record_workload_max(workload, "snapshot_total_grants_max", record.get("total_grants")) + record_workload_max(workload, "captured_user_count", record.get("user_count")) + elif event_name == "user_explicit_repos_batch_fetch" and phase == "end": + record_workload_max(workload, "batch_user_count_max", record.get("user_count")) + record_workload_max( + workload, + "batch_fetched_grant_count_max", + record.get("fetched_grant_count") + if "fetched_grant_count" in record + else record.get("repo_count"), + ) + elif event_name == "load_repos_by_external_service" and phase == "end": + record_workload_max(workload, "loaded_repo_count", record.get("repo_count")) + record_workload_max( + workload, + "expected_repo_count", + record.get("expected_repo_count"), + ) + elif event_name == "apply_username_overwrites": + record_workload_max(workload, "apply_payload_count", record.get("payload_count")) + record_workload_max( + workload, + "apply_payload_grant_count", + record.get("payload_grant_count") + if "payload_grant_count" in record + else record.get("total_users"), + ) + record_workload_max(workload, "parallelism", record.get("parallelism")) + if phase == "end": + record_workload_max( + workload, + "apply_mutations_succeeded", + record.get("succeeded"), + ) + record_workload_max(workload, "apply_mutations_failed", record.get("failed")) + record_workload_max(workload, "apply_mutations_canceled", record.get("canceled")) + elif ( + event_name + in { + "cmd_get", + "cmd_restore", + "cmd_restore_user_scoped", + "cmd_set", + "cmd_set_additive_user", + "cmd_set_additive_users_without_explicit_perms", + } + and phase == "end" + ): + record_command_workload(workload, record) + elif event_name in {"sync_saml_orgs", "cmd_sync_saml_orgs"} and phase == "end": + record_workload_max( + workload, + "target_organizations", + record.get("target_organizations"), + ) + record_workload_max(workload, "desired_memberships", record.get("desired_memberships")) + + record_workload_model_dimensions(workload) return workload +def record_command_workload(workload: dict[str, int | float | str], record: dict[str, Any]) -> None: + """Copy command-level counts using names that preserve their meaning.""" + event_name = optional_string(record.get("event")) + repo_count = record.get("repo_count") + total_grants = record.get("total_grants") + if event_name == "cmd_set": + record_workload_max(workload, "planned_repo_count", repo_count) + record_workload_max(workload, "planned_total_grants", total_grants) + elif event_name == "cmd_get": + record_workload_max(workload, "selected_user_count", record.get("user_count")) + record_workload_max(workload, "selected_total_grants", total_grants) + elif event_name == "cmd_restore": + record_workload_max(workload, "restore_snapshot_repo_count", record.get("snapshot_repos")) + record_workload_max( + workload, + "restore_snapshot_total_grants", + record.get("snapshot_grants"), + ) + elif event_name == "cmd_set_additive_user": + record_workload_max(workload, "selected_user_count", record.get("user_count")) + record_workload_max(workload, "planned_repo_count", repo_count) + record_workload_max(workload, "planned_total_grants", total_grants) + + record_workload_max(workload, "mapping_count", record.get("mapping_count")) + record_workload_max(workload, "mutations_succeeded", record.get("mutations_succeeded")) + record_workload_max(workload, "mutations_failed", record.get("mutations_failed")) + record_workload_max(workload, "mutations_canceled", record.get("mutations_canceled")) + + +def record_workload_model_dimensions(workload: dict[str, int | float | str]) -> None: + """Add the canonical dimensions used by memory modeling.""" + user_count = max_workload_number( + workload, + ( + "selected_user_count", + "captured_user_count", + "snapshot_user_count_max", + "sourcegraph_user_count", + ), + ) + repo_count = max_workload_number( + workload, + ( + "planned_repo_count", + "restore_snapshot_repo_count", + "snapshot_repos_with_explicit_grants_max", + "loaded_repo_count", + ), + ) + grant_count = max_workload_number( + workload, + ( + "planned_total_grants", + "restore_snapshot_total_grants", + "selected_total_grants", + "snapshot_total_grants_max", + "apply_payload_grant_count", + ), + ) + if user_count is not None: + workload["memory_model_user_count"] = user_count + if repo_count is not None: + workload["memory_model_repo_count"] = repo_count + if grant_count is not None: + workload["memory_model_grant_count"] = grant_count + + +def max_workload_number( + workload: dict[str, int | float | str], field_names: Sequence[str] +) -> int | float | None: + """Return the largest numeric value found for the supplied workload fields.""" + values = [ + value + for field_name in field_names + if isinstance((value := workload.get(field_name)), int | float) + ] + return max(values) if values else None + + +def record_workload_max( + workload: dict[str, int | float | str], field_name: str, value: object +) -> None: + """Record the maximum numeric value for a named workload dimension.""" + if isinstance(value, bool) or not isinstance(value, int | float): + return + old_value = workload.get(field_name) + if not isinstance(old_value, int | float) or value > old_value: + workload[field_name] = value + + def artifact_sizes_for_run(log_path: Path) -> dict[str, int]: """Return sizes of JSON artifacts in the same run directory as the log.""" run_directory = log_path.parent @@ -1636,6 +2282,54 @@ def wait_for_jaeger_trace_collectors(collectors: list[JaegerTraceCollector]) -> collector.wait() +def graphql_query_metadata_from_record(record: dict[str, Any]) -> dict[str, Any] | None: + """Return correlation metadata from a structured `graphql_query` log record.""" + if record.get("event") != "graphql_query": + return None + metadata: dict[str, Any] = { + "span_id": record.get("span"), + "parent_span_id": record.get("parent_span"), + "trace_id": record.get("trace"), + } + phase = record.get("phase") + if phase == "start": + metadata["started_at"] = record.get("ts") + elif phase == "end": + metadata["ended_at"] = record.get("ts") + for field_name in ( + "cursor_present", + "duration_ms", + "error_type", + "graphql_client", + "page_number", + "page_size", + "query_bytes", + "query_name", + "response_fields", + "status", + "url", + "variable_names", + # Current src-py-lib logs variable names only. Keep these optional fields + # so raw trace artifacts automatically include values if the GraphQL log + # event grows an opt-in sanitized-variable field later. + "input_variables", + "variable_values", + "variables", + ): + if field_name in record: + metadata[field_name] = record[field_name] + return {key: value for key, value in metadata.items() if value is not None} + + +def graphql_query_span_key(record: dict[str, Any]) -> tuple[str, str] | None: + """Return the `(trace_id, span_id)` key for a GraphQL query log span.""" + trace_id = optional_string(record.get("trace")) + span_id = optional_string(record.get("span")) + if trace_id is None or span_id is None: + return None + return trace_id, span_id + + def graphql_trace_request_from_record(record: dict[str, Any]) -> dict[str, Any] | None: if record.get("event") != "http_request" or record.get("phase") != "end": return None @@ -1997,9 +2691,10 @@ def write_results_files( results: list[CommandResult], comparisons: list[CaseComparison], config: EndToEndConfig, + sourcegraph_load_monitor: SourcegraphLoadMonitor | None, ) -> None: if config.results_json is not None: - write_results_json(config.results_json, results, comparisons) + write_results_json(config.results_json, results, comparisons, sourcegraph_load_monitor) if config.results_csv is not None: write_results_csv(config.results_csv, results) phase_csv = phase_results_csv_path(config.results_csv) @@ -2010,12 +2705,20 @@ def write_results_json( path: Path, results: list[CommandResult], comparisons: list[CaseComparison], + sourcegraph_load_monitor: SourcegraphLoadMonitor | None, ) -> None: path.parent.mkdir(parents=True, exist_ok=True) + sourcegraph_monitor: dict[str, Any] | None = None + if sourcegraph_load_monitor is not None: + sourcegraph_monitor = { + "output_dir": str(sourcegraph_load_monitor.output_dir), + "log_path": str(sourcegraph_load_monitor.log_path), + } with path.open("w", encoding="utf-8") as output_file: json.dump( { "generated_at": datetime.datetime.now(datetime.UTC).isoformat(), + "sourcegraph_load_monitor": sourcegraph_monitor, "results": [result_to_json(result) for result in results], "comparisons": [comparison_to_json(comparison) for comparison in comparisons], }, @@ -2240,7 +2943,11 @@ def normalized_memory(result: CommandResult) -> dict[str, float]: if peak_rss_mb is None: return {} normalized: dict[str, float] = {} - for field_name in ("user_count", "total_users", "repo_count", "total_grants"): + for field_name in ( + "memory_model_user_count", + "memory_model_repo_count", + "memory_model_grant_count", + ): value = result.workload.get(field_name) if isinstance(value, int | float) and value > 0: normalized[f"peak_rss_mb_per_{field_name}"] = peak_rss_mb / float(value) diff --git a/dev/test-plan.md b/dev/test-plan.md deleted file mode 100644 index 521d165..0000000 --- a/dev/test-plan.md +++ /dev/null @@ -1,207 +0,0 @@ -# Large-scale test plan for `src-auth-perms-sync` - -## Known constraints - -1. **`src-auth-perms-sync` snapshot/diff/expected-set state is fully in-memory.** - Snapshot cost scales with the **number of explicit grants**, not with the - number of synced repos. A `1M repos × 10K users = 10⁹ grants` literal - one-shot run will OOM on `build_snapshot` long before it stresses - anything server-side. Split scaling tests into three axes: - **repo count**, **users-per-mutation payload**, and **total grants**. - -2. **`setRepositoryPermissionsForUsers` is one short transaction with an - in-transaction `DELETE` of stale rows.** Concurrent mutations on - *different* repos do not block each other on rows but do contend on - B-tree pages of `user_repo_permissions_perms_unique_idx` and - `user_repo_permissions_repo_id_user_id_idx`. Practical ceiling on a - reasonable Postgres: **~200–500 mutations/s** at `--parallelism` 100–200. - Throughput plateaus or degrades above ~256 workers. - ---- - -## Scaling risks - -1. **Snapshot OOM** in scenarios d / g. The script holds - `repo_users`, `expected_users`, `user_repos`, and `Snapshot.repos` - simultaneously in RAM. Mitigation: bound scenario d at the smallest - cliff that reproduces the OOM; do not insist on full-corpus backup. -2. **Inode exhaustion** at corpus generation. Mitigation: `mkfs.ext4 -i 4096`. -3. **`/v1/list-repos` timeout** if a single shard accidentally gets - >50K repos. Mitigation: hard-cap shard size at 10K and assert - directory count after generation. -4. **`repoConcurrentExternalServiceSyncers=3` default** silently - serializing 100-shard sync to ~33× the expected wall clock. - Mitigation: assert site config value before triggering sync. -5. **GraphQL request body size** in scenario g. A 10K-user payload at - ~80 bytes/user is ~800KB, well under the typical 1MB body limit but - close. Watch for HTTP 413. -6. **FD exhaustion** at `--parallelism 256`. Mitigation: - `ulimit -n 8192` before each run; monitor `num_fds` in - `resource_sample`. -7. **`externalAccounts(first: 50)` truncation** if any user gets - over-seeded. Mitigation: SQL assertion - `SELECT user_id, COUNT(*) FROM user_external_accounts GROUP BY 1 HAVING COUNT(*) > 50;` - must return zero rows. - ---- - -## Measurement plan - -The script already emits the right primitives in -`src-auth-perms-sync-runs//runs//log.json`. -Use `jq` and Python for post-run analysis; do not modify the script. - -### Per-run assertions (correctness gates, fail the test on violation) - -```bash -F=src-auth-perms-sync-runs//runs//log.json -# Every event() emits paired phase=="start"/"end" records; aggregations -# below filter on phase=="end" so they only see completed operations -# (start records have no duration_ms / status / mutation counters). -jq -s '.[-1]' $F | jq '.event == "run" and .phase == "end" and .exit_code == 0' -jq 'select(.event=="apply_payloads" and .phase=="end") | .failed' $F | grep -v '^0$' && exit 1 -jq 'select(.event=="cmd_set" and .phase=="end") | .mutations_failed' $F | grep -v '^0$' && exit 1 -# For backup runs, restore residual diff must be empty -``` - -### Per-run KPIs (extract and store; plot across the sweep) - -For each event of interest (`set_repo_perms`, `graphql_request` filtered -by `query_name == "SetRepoPerms"`, `paginate_page`, `resource_sample`): - -- p50 / p95 / p99 / max `duration_ms` -- count, retry count (`retry_wait` events) -- `request_bytes_total`, `response_bytes_total` -- `peak_rss_mb`, `max_num_fds`, `max_num_threads`, `max_process_cpu_percent` -- mutation throughput = - `apply_payloads.succeeded / (apply_payloads.duration_ms / 1000)` - -### Pagination cliff plot - -```bash -jq -r 'select(.event=="paginate_page" and .phase=="end") | - [.query_name, .page_index, .duration_ms] | @tsv' $F \ - > paginate.tsv -# Plot duration_ms vs page_index, faceted by query_name. -# Expect: -# ListUsers: ~constant (~10 pages of 1000 users) -# ReposByExternalService: ~constant per ES (10 pages × 100 ES) -# UserExplicitRepos: grows linearly with explicit repo count -``` - -### Parallelism sweep - -```bash -for P in 1 16 64 128 256; do - ulimit -n 8192 - uv run src-auth-perms-sync --set full-00.yaml --apply --parallelism $P --no-backup - cp src-auth-perms-sync-runs//runs/*/log.json results/sweep-p$P.jsonl -done -``` - -Plot mutation throughput vs P. Expect the throughput curve to rise to -~P=64 then plateau or regress as Postgres lock contention dominates. -Watch retry rate (`retry_wait` count) — if it climbs steeply past P=128 -that's the server saying "back off." - -### Snapshot scaling sweep (scenario d) - -```text -100K grants → expect peak_rss ~ ?, snapshot wall ~ ? -500K → ? -1M → ? -2M → ? (likely first cliff) -5M → goal: confirm OOM threshold -``` - -These numbers don't have published baselines yet; this run *creates* -them. The deliverable is "we now know `--set --apply` with `--parallelism 16` -hits N MB RSS and W seconds of snapshot wall-clock at G grants." - ---- - -## Failure injection (scenario e) - -### Kill a shard mid-sync - -1. Start fresh ES creation in the SG admin UI / GraphQL. -2. After ~30s of sync, `docker kill sg-serve-042`. -3. Sync wait probe should observe `lastSyncError != null` and stay - below `repoCount == 10000` for that shard. -4. **Assert: `--apply` is not started yet** (the wait probe blocks). -5. `docker start sg-serve-042`. Re-probe. Sync resumes. Proceed. - -### Kill Sourcegraph mid-apply - -1. Start `--set full-00.yaml --apply --parallelism 64 --no-backup`. -2. After ~10–20% of `set_repo_perms` events appear in the JSONL, - `kubectl rollout restart deployment/frontend` (or equivalent). -3. The script will record retries, then GraphQL errors, and exit non-zero. -4. **Assert: `cmd_set.mutations_failed > 0` and `set_repo_perms.error_type` - includes `GraphQLError`.** -5. Re-run the same command. Because `setRepositoryPermissionsForUsers` is - an idempotent overwrite, the second run should converge to - `mutations_failed == 0` and post-apply validation should match the - expected per-repo user sets. - -### Concurrent race (scenario f) - -1. Run `maps-A.yaml` and `maps-B.yaml` (overlapping ES IDs, different - buckets) in two terminals with `--apply` simultaneously. -2. **Expected**: last-writer-wins per repo; the `validate_post_apply` - step in at least one of the two runs logs a per-repo expected-vs-actual - mismatch warning (drift detected). Both runs exit 0; the warning is - the signal. - ---- - -## Cleanup and iteration - -### Per-run reset (fast, ~seconds) - -```sql -DELETE FROM user_repo_permissions WHERE source = 'api'; -DELETE FROM user_pending_permissions; -``` - -### Topology reset (after a full scenario, ~minutes) - -- Delete all 100 shard external services via GraphQL (this cascades to - `external_service_repos`; `repo` rows are GC'd by the syncer). -- Recreate when the next scenario starts. - -### Best inner-loop primitive: DB snapshot - -After steps 1–4 of §3 are complete (users seeded, providers -configured, ES synced, `user_repo_permissions` empty), take a `pg_dump` -or take a logical Postgres snapshot. Restore between runs in seconds. -Keep the on-disk repo corpus immutable across runs — it is the slow -part to rebuild. - -### What to NEVER do between runs - -- Do not regenerate the 1M repo corpus. -- Do not re-create the 100 external services unless the shard topology - changes. -- Do not restart `src serve-git` containers between runs (they're - stateless; killing them only forces a re-walk on next sync). - ---- - -## Deliverables - -1. Generator scripts under `scripts/loadtest/` (corpus + docker compose - for 100 shards + SG site-config snippet + user/account seeding SQL). -2. 10 generated `full-NN.yaml` mapping configs + the smoke / medium / - giant-payload / failure / race configs. -3. A `runner.sh` that drives scenarios a → g in order, gates on the - per-run assertions in §5, and copies - `src-auth-perms-sync-runs//runs//log.json` plus - same-directory snapshots into a timestamped - `results//` dir. -4. `analyze.py` that consumes a `results//` dir and emits a - markdown report: KPIs, percentile tables, paginate-cost plot, - parallelism sweep curve, snapshot-cliff curve. -5. This document, updated with measured numbers once scenarios c, c′, - and d have actually run (the snapshot cliffs especially are unknown - until measured). diff --git a/dev/python-versions.md b/dev/update-python-versions.md similarity index 100% rename from dev/python-versions.md rename to dev/update-python-versions.md diff --git a/maps-example.yaml b/maps-example.yaml index c854791..49afbd6 100644 --- a/maps-example.yaml +++ b/maps-example.yaml @@ -1,39 +1,130 @@ -# Auth provider → code host connection mapping rules -# Maintain this file using auth-providers.yaml and code-hosts.yaml as references. -# Those files are generated under src-auth-perms-sync-runs//. +# User → Repo permission mapping rules + +# Maintain your maps.yaml file, using the values from auth-providers.yaml and code-hosts.yaml, +# which are created by the --get command, under `src-auth-perms-sync-runs//` + +# Schema details: +# maps: list[map] +# - name: string +# users: map +# authProvider: map +# type: string +# serviceID: string +# clientID: string +# displayName: string +# configID: string +# samlGroup: string +# emails: list[string] # exact verified email addresses +# emailRegexes: list[string] # Python regexes for verified email addresses +# usernames: list[string] # exact Sourcegraph usernames +# usernameRegexes: list[string] # Python regexes for Sourcegraph usernames +# repos: map +# codeHostConnection: map +# displayName: string +# kind: string +# url: string +# username: string +# names: list[string] # exact Sourcegraph repo names +# nameRegexes: list[string] # Python regexes for Sourcegraph repo names + +# Filter scopes: +# - Children of lists are ORed together (casting a wider net) +# - Children of maps are ANDed together (casting a narrower net) maps: -- name: All users from Line of Business 1 - User Group 1 get access to all repos synced from service account 1 +# Widest net +- name: All users get all repos + users: + usernameRegexes: + - '.*' + repos: + nameRegexes: + - '.*' + +# Wide net +- name: All Okta SAML users get access to all Bitbucket repos + users: + authProvider: + configID: okta + type: saml + repos: + codeHostConnection: + kind: BITBUCKETSERVER + +# Medium net +- name: | + Members of samlGroup LOB1-GROUP1, from any auth provider + get any repos cloned using username LOB1-SA1, from any code host users: authProvider: samlGroup: LOB1-GROUP1 repos: codeHostConnection: - config: - username: LOB1-SA1 + username: LOB1-SA1 -- name: All users from Line of Business 1 - User Group 2 get access to all repos synced from service account 2 +# Narrower net +- name: | + Members of samlGroup LOB1-GROUP1 from the okta saml provider + get repos cloned from a specific Bitbucket code host connection users: authProvider: - samlGroup: LOB1-GROUP2 + configID: okta + samlGroup: LOB1-GROUP1 + type: saml repos: codeHostConnection: - config: - username: LOB1-SA2 + displayName: 'BITBUCKETSERVER #1' + kind: BITBUCKETSERVER + url: https://bitbucket.example.com/ + username: LOB1-SA1 -- name: All Okta SAML users get access to all Bitbucket repos +# Even narrower net +- name: | + Alice and Bob get access to bitbucket.example.com/example/private-repo, + if they are members of LOB1-GROUP1 from okta saml users: authProvider: configID: okta + samlGroup: LOB1-GROUP1 type: saml + emails: + - alice@example.com + - bob@example.com repos: codeHostConnection: + displayName: Bitbucket kind: BITBUCKETSERVER + url: https://bitbucket.example.com/ + username: LOB1-SA1 + names: + - bitbucket.example.com/example/private-repo -- name: All builtin users get access to all repos under the github.com/example org, from any code host connection +# Narrowest net +- name: Alice gets private-repo repo, if all stars align users: authProvider: - type: builtin + clientID: https://sourcegraph.example.com/.auth/saml/metadata + configID: okta + displayName: Okta + samlGroup: LOB1-GROUP1 + serviceID: http://www.okta.com/example123 + type: saml + emails: + - alice@example.com + emailRegexes: + - '@example\.com$' + usernames: + - alice + usernameRegexes: + - '^alice$' repos: - regex: https://github.com/example/.* + codeHostConnection: + displayName: 'BITBUCKETSERVER #1' + kind: BITBUCKETSERVER + url: https://bitbucket.example.com/ + username: LOB1-SA1 + names: + - bitbucket.example.com/example/private-repo + nameRegexes: + - '^bitbucket\.example\.com/example/private-repo$' diff --git a/pyproject.toml b/pyproject.toml index 25ab347..de5696e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [dependency-groups] @@ -11,7 +11,7 @@ dev = [ [project] name = "src-auth-perms-sync" -version = "0.2.2" +dynamic = ["version"] description = "Set Sourcegraph permissions from authentication provider data" readme = "README.md" requires-python = ">=3.11" @@ -48,6 +48,9 @@ Issues = "https://github.com/sourcegraph/src-auth-perms-sync/issues" [tool.hatch.build.targets.wheel] packages = ["src/src_auth_perms_sync"] +[tool.hatch.version] +source = "vcs" + [tool.pyright] include = ["src", "tests", "dev"] exclude = ["src-auth-perms-sync-runs", ".venv", "build", "dist"] diff --git a/src/src_auth_perms_sync/__init__.py b/src/src_auth_perms_sync/__init__.py index cfcdeaf..8b16d8e 100644 --- a/src/src_auth_perms_sync/__init__.py +++ b/src/src_auth_perms_sync/__init__.py @@ -1 +1,11 @@ -"""Project package for src-auth-perms-sync.""" +"""Importable API for src-auth-perms-sync.""" + +from .cli import Config, Get, Restore, Set, SyncSamlOrgs + +__all__ = [ + "Config", + "Get", + "Restore", + "Set", + "SyncSamlOrgs", +] diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 8761abb..dfbe308 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -9,15 +9,18 @@ from __future__ import annotations +import argparse import logging import os import sys +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import Literal, NoReturn, TypeAlias +from typing import Literal, NoReturn, TypeAlias, cast import src_py_lib as src +from src_py_lib.utils import config as config_utils from .orgs import command as organizations_command from .permissions import command as permissions_command @@ -27,38 +30,79 @@ log = logging.getLogger(__name__) + CommandName: TypeAlias = Literal["get", "set", "restore", "sync_saml_orgs"] +DEFAULT_MAPS_FILE_NAME = "maps.yaml" +COMMON_CONFIG_FIELDS = src.config_field_names( + src.SourcegraphClientConfig, + src.LoggingConfig, + "parallelism", + "http_timeout_seconds", + "max_attempts", + "sample_interval", + "trace", +) +GET_CONFIG_FIELDS = src.config_field_names( + "users", + "users_without_explicit_perms", + "created_after", + "explicit_permissions_batch_size", + *COMMON_CONFIG_FIELDS, +) +SET_CONFIG_FIELDS = src.config_field_names( + "maps_path", + "full", + "users", + "users_without_explicit_perms", + "created_after", + "sync_saml_organizations", + "apply", + "no_backup", + "explicit_permissions_batch_size", + *COMMON_CONFIG_FIELDS, +) +RESTORE_CONFIG_FIELDS = src.config_field_names( + "restore_path", + "apply", + "no_backup", + "explicit_permissions_batch_size", + *COMMON_CONFIG_FIELDS, +) +SYNC_SAML_ORGS_CONFIG_FIELDS = src.config_field_names( + "apply", + "no_backup", + *COMMON_CONFIG_FIELDS, +) LogCommandName: TypeAlias = Literal[ "get", "set_full", - "set_user", + "set_users", "set_users_without_explicit_perms", "restore", "sync_saml_orgs", - "get_sync_saml_orgs", "set_full_sync_saml_orgs", - "set_user_sync_saml_orgs", + "set_users_sync_saml_orgs", "set_users_without_explicit_perms_sync_saml_orgs", ] SET_COMMAND_LOG_NAMES: dict[permission_types.SetCommandMode, LogCommandName] = { "full": "set_full", - "user": "set_user", + "users": "set_users", "users_without_explicit_perms": "set_users_without_explicit_perms", } SET_COMMAND_ARTIFACT_NAMES: dict[permission_types.SetCommandMode, str] = { "full": "set-{run_mode}", - "user": "set-add-user-{run_mode}", + "users": "set-add-users-{run_mode}", "users_without_explicit_perms": "set-add-users-without-explicit-perms-{run_mode}", } SYNC_SET_COMMAND_LOG_NAMES: dict[permission_types.SetCommandMode, LogCommandName] = { "full": "set_full_sync_saml_orgs", - "user": "set_user_sync_saml_orgs", + "users": "set_users_sync_saml_orgs", "users_without_explicit_perms": "set_users_without_explicit_perms_sync_saml_orgs", } SYNC_SET_COMMAND_ARTIFACT_NAMES: dict[permission_types.SetCommandMode, str] = { "full": "set-sync-saml-orgs-{run_mode}", - "user": "set-add-user-sync-saml-orgs-{run_mode}", + "users": "set-add-users-sync-saml-orgs-{run_mode}", "users_without_explicit_perms": ( "set-add-users-without-explicit-perms-sync-saml-orgs-{run_mode}" ), @@ -77,48 +121,72 @@ class ResolvedCommand: @property def set_mode(self) -> permission_types.SetCommandMode | None: - """Return the concrete `--set` mode when this is a set command.""" + """Return the concrete set mode when this is a set command.""" if self.set_options is None: return None return self.set_options.mode -class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfig): +@dataclass(frozen=True) +class CliInput: + """Parsed CLI command and runtime config.""" + + command_name: CommandName + config: Config + + +@dataclass(frozen=True) +class CliCommand: + """Argparse subcommand metadata.""" + + argument_name: str + command_name: CommandName + help: str + description: str + config_fields: tuple[str, ...] + + +class Config(src.SourcegraphClientConfig, src.LoggingConfig): """Config values loaded from defaults, .env, environment, and CLI flags.""" - get: bool = src.config_field( - default=False, - env_var="SRC_AUTH_PERMS_SYNC_GET", - cli_flag="--get", - cli_action="store_true", - help="Query the SG instance and write/refresh auth-providers.yaml and code-hosts.yaml", + maps_path: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_MAPS_PATH", + cli_flag="--maps-path", + metavar="FILE", + help=( + "Maps YAML file for the set command.\n" + "If omitted, set uses maps.yaml under src-auth-perms-sync-runs//.\n" + "Relative paths are resolved from the current working directory." + ), + help_group="Permission sync", ) - set_path: Path | None = src.config_field( + restore_path: Path | None = src.config_field( default=None, - env_var="SRC_AUTH_PERMS_SYNC_SET", - cli_flag="--set", - cli_nargs="?", - cli_const="maps.yaml", + env_var="SRC_AUTH_PERMS_SYNC_RESTORE_PATH", + cli_flag="--restore-path", metavar="FILE", help=( - "Read the YAML config file and execute the mapping rules.\n" - "Defaults to maps.yaml under src-auth-perms-sync-runs//.\n" - "Relative paths are resolved from that path." + "Snapshot JSON file for the restore command.\n" + "Relative paths are resolved from the current working directory." ), + help_group="Restore", ) full: bool = src.config_field( default=False, env_var="SRC_AUTH_PERMS_SYNC_FULL", cli_flag="--full", cli_action="store_true", - help="With --set: run the full overwrite reconciliation mode (default)", + help="With the set command: run the full overwrite reconciliation mode (default)", + help_group="Permission sync", ) - user: str | None = src.config_field( - default=None, - env_var="SRC_AUTH_PERMS_SYNC_USER", - cli_flag="--user", - metavar="USER", - help="Process a specific Sourcegraph user by username or email address", + users: tuple[str, ...] = src.config_field( + default=(), + env_var="SRC_AUTH_PERMS_SYNC_USERS", + cli_flag="--users", + metavar="USERS", + help="Process comma-delimited Sourcegraph usernames and/or email addresses", + help_group="User filters", ) users_without_explicit_perms: bool = src.config_field( default=False, @@ -126,6 +194,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi cli_flag="--users-without-explicit-perms", cli_action="store_true", help="Process Sourcegraph users without explicit permissions", + help_group="User filters", ) created_after: str | None = src.config_field( default=None, @@ -134,16 +203,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi metavar="YYYY-MM-DD", pattern=r"^\d{4}-\d{2}-\d{2}$", help="Process Sourcegraph users created on or after this date", - ) - restore_path: Path | None = src.config_field( - default=None, - env_var="SRC_AUTH_PERMS_SYNC_RESTORE", - cli_flag="--restore", - metavar="FILE", - help=( - "Restore explicit-permissions state to match the given snapshot JSON file.\n" - "Relative paths are resolved under 'src-auth-perms-sync-runs//.'" - ), + help_group="User filters", ) sync_saml_organizations: bool = src.config_field( default=False, @@ -151,6 +211,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi cli_flag="--sync-saml-orgs", cli_action="store_true", help="Create/update Sourcegraph organizations for each discovered SAML group", + help_group="Organization sync", ) apply: bool = src.config_field( default=False, @@ -158,6 +219,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi cli_flag="--apply", cli_action="store_true", help="With mutating commands: actually mutate state. Default is dry-run", + help_group="Mutation", ) no_backup: bool = src.config_field( default=False, @@ -165,6 +227,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi cli_flag="--no-backup", cli_action="store_true", help="With mutating commands: skip before/after snapshots and validation", + help_group="Mutation", ) parallelism: int = src.config_field( default=16, @@ -173,6 +236,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi metavar="N", ge=1, help="Concurrent Sourcegraph API worker threads (default: 16)", + help_group="Runtime", ) explicit_permissions_batch_size: int = src.config_field( default=25, @@ -183,6 +247,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi help=( "Users per GraphQL request when capturing explicit repository permissions (default: 25)" ), + help_group="Runtime", ) max_attempts: int = src.config_field( default=5, @@ -191,6 +256,16 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi metavar="N", ge=1, help="Max attempts per HTTP request before giving up (default: 5)", + help_group="Runtime", + ) + http_timeout_seconds: float = src.config_field( + default=60.0, + env_var="SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS", + cli_flag="--http-timeout-seconds", + metavar="SECONDS", + gt=0, + help="HTTP read timeout per request in seconds (default: 60)", + help_group="Runtime", ) sample_interval: float = src.config_field( default=10.0, @@ -199,6 +274,7 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi metavar="SECONDS", ge=0, help="Seconds between logging compute resource samples; set 0 to disable (default: 10)", + help_group="Runtime", ) trace: bool = src.config_field( default=False, @@ -206,68 +282,103 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi cli_flag="--trace", cli_action="store_true", help=("Ask Sourcegraph to retain traces for GraphQL requests and return trace metadata"), + help_group="Runtime", ) +CLI_COMMANDS: tuple[CliCommand, ...] = ( + CliCommand( + argument_name="get", + command_name="get", + help="Discover auth providers and code hosts", + description="Gather auth providers, code hosts, users, and permissions.", + config_fields=GET_CONFIG_FIELDS, + ), + CliCommand( + argument_name="set", + command_name="set", + help="Reconcile repo permissions from maps.yaml", + description="Reconcile Sourcegraph explicit repo permissions from maps.yaml.", + config_fields=SET_CONFIG_FIELDS, + ), + CliCommand( + argument_name="restore", + command_name="restore", + help="Restore repo permissions from a snapshot", + description="Restore Sourcegraph explicit repo permissions from a snapshot JSON file.", + config_fields=RESTORE_CONFIG_FIELDS, + ), + CliCommand( + argument_name="sync-saml-orgs", + command_name="sync_saml_orgs", + help="Sync orgs from SAML groups", + description="Create/update Sourcegraph organizations and memberships from SAML groups.", + config_fields=SYNC_SAML_ORGS_CONFIG_FIELDS, + ), +) + + def config_error(message: str) -> NoReturn: """Exit with a concise config/argument error.""" print(f"src-auth-perms-sync: error: {message}", file=sys.stderr) raise SystemExit(2) -def validate_config(config: SrcAuthPermissionsSyncConfig) -> None: +def validate_config(command_name: CommandName, config: Config) -> None: """Validate cross-field CLI/config constraints.""" - validate_command_selection(config) - validate_user_filter_selection(config) - validate_set_mode_selection(config) - - -def validate_command_selection(config: SrcAuthPermissionsSyncConfig) -> None: - """Validate compatible top-level command flags.""" - if sum((config.get, config.set_path is not None, config.restore_path is not None)) > 1: - config_error("choose only one of --get, --set, or --restore") - if config.restore_path is not None and config.sync_saml_organizations: - config_error("--sync-saml-orgs can run by itself or with --get or --set") - - -def validate_user_filter_selection(config: SrcAuthPermissionsSyncConfig) -> None: + validate_command_options(command_name, config) + validate_user_filter_selection(command_name, config) + validate_set_mode_selection(command_name, config) + + +def validate_command_options(command_name: CommandName, config: Config) -> None: + """Validate options that only make sense with specific commands.""" + if command_name == "get" and config.apply: + config_error("--apply cannot be used with the read-only get command") + if command_name == "get" and config.no_backup: + config_error("--no-backup cannot be used with the read-only get command") + if config.sync_saml_organizations and command_name != "set": + config_error("--sync-saml-orgs can only be combined with set") + if command_name == "restore" and config.restore_path is None: + config_error("restore requires --restore-path") + if config.restore_path is not None and command_name != "restore": + config_error("--restore-path requires the restore command") + + +def validate_user_filter_selection(command_name: CommandName, config: Config) -> None: """Validate user-scope filters and their compatible commands.""" - user_identifier_filters = sum((config.user is not None, config.users_without_explicit_perms)) - if user_identifier_filters > 1: - config_error("choose only one of --user or --users-without-explicit-perms") - - user_filter_selected = user_identifier_filters > 0 or config.created_after is not None - user_filter_allowed = ( - config.get - or config.set_path is not None - or (config.restore_path is None and not config.sync_saml_organizations) - ) + user_scope_filter_count = sum((bool(config.users), config.users_without_explicit_perms)) + if user_scope_filter_count > 1: + config_error("choose only one of --users or --users-without-explicit-perms") + + user_filter_selected = user_scope_filter_count > 0 or config.created_after is not None + user_filter_allowed = command_name in {"get", "set"} if user_filter_selected and not user_filter_allowed: config_error( - "--user, --users-without-explicit-perms, and --created-after require --get or --set" + "--users, --users-without-explicit-perms, and --created-after require get or set" ) -def validate_set_mode_selection(config: SrcAuthPermissionsSyncConfig) -> None: - """Validate `--set` mode flags.""" - if config.full and config.set_path is None: - config_error("--full requires --set") +def validate_set_mode_selection(command_name: CommandName, config: Config) -> None: + """Validate set command mode flags.""" + if config.full and command_name != "set": + config_error("--full requires the set command") - if config.set_path is None: + if command_name != "set": return - if sum((config.full, config.user is not None, config.users_without_explicit_perms)) > 1: + if sum((config.full, bool(config.users), config.users_without_explicit_perms)) > 1: config_error( - "with --set, choose at most one of --full, --user, or --users-without-explicit-perms" + "with set, choose at most one of --full, --users, or --users-without-explicit-perms" ) -def set_command_options(config: SrcAuthPermissionsSyncConfig) -> permission_types.SetCommandOptions: - """Return the validated `--set` mode options.""" - if config.user is not None: +def set_command_options(config: Config) -> permission_types.SetCommandOptions: + """Return the validated set mode options.""" + if config.users: return permission_types.SetCommandOptions( - mode="user", - user_identifier=config.user, + mode="users", + user_identifiers=config.users, user_created_after=config.created_after, ) if config.users_without_explicit_perms: @@ -281,38 +392,29 @@ def set_command_options(config: SrcAuthPermissionsSyncConfig) -> permission_type ) -def resolve_command(config: SrcAuthPermissionsSyncConfig) -> ResolvedCommand: +def resolve_command(command_name: CommandName, config: Config) -> ResolvedCommand: """Return the command execution plan derived from config.""" run_mode = "apply" if config.apply else "dry-run" - if config.set_path is not None: + if command_name == "set": return resolve_set_command(config, run_mode) - if config.restore_path is not None: + if command_name == "restore": return ResolvedCommand( name="restore", log_name="restore", artifact_name=f"restore-{run_mode}", ) - if config.get and config.sync_saml_organizations: - return ResolvedCommand( - name="get", - log_name="get_sync_saml_orgs", - artifact_name=f"get-sync-saml-orgs-{run_mode}", - sync_saml_organizations=True, - ) - if config.get: + if command_name == "get": return ResolvedCommand(name="get", log_name="get", artifact_name="get") - if config.sync_saml_organizations: - return ResolvedCommand( - name="sync_saml_orgs", - log_name="sync_saml_orgs", - artifact_name=f"sync-saml-orgs-{run_mode}", - sync_saml_organizations=True, - ) - return ResolvedCommand(name="get", log_name="get", artifact_name="get") + return ResolvedCommand( + name="sync_saml_orgs", + log_name="sync_saml_orgs", + artifact_name=f"sync-saml-orgs-{run_mode}", + sync_saml_organizations=True, + ) -def resolve_set_command(config: SrcAuthPermissionsSyncConfig, run_mode: str) -> ResolvedCommand: - """Return resolved metadata for the selected `--set` command mode.""" +def resolve_set_command(config: Config, run_mode: str) -> ResolvedCommand: + """Return resolved metadata for the selected set command mode.""" set_options = set_command_options(config) log_names = ( SYNC_SET_COMMAND_LOG_NAMES if config.sync_saml_organizations else SET_COMMAND_LOG_NAMES @@ -331,50 +433,78 @@ def resolve_set_command(config: SrcAuthPermissionsSyncConfig, run_mode: str) -> ) -def load_config() -> SrcAuthPermissionsSyncConfig: - """Parse and validate CLI/environment config.""" - config = src.parse_args( - SrcAuthPermissionsSyncConfig, - description=__doc__, - base_dir=Path("."), +def load_cli(argv: Sequence[str] | None = None) -> CliInput: + """Parse and validate the CLI command plus environment/config options.""" + parser = argparse.ArgumentParser( + description=__doc__.strip() if __doc__ is not None else None, + formatter_class=argparse.RawDescriptionHelpFormatter, + allow_abbrev=False, ) - validate_config(config) - return config - - -def endpoint_scoped_config( - config: SrcAuthPermissionsSyncConfig, endpoint: str -) -> SrcAuthPermissionsSyncConfig: - """Return config with relative operator artifact paths scoped to this endpoint.""" - updates: dict[str, object] = {} - if config.set_path is not None: - updates["set_path"] = backups.endpoint_artifact_path(endpoint, config.set_path) - if config.restore_path is not None: - updates["restore_path"] = backups.endpoint_artifact_path(endpoint, config.restore_path) - if not updates: + subparsers = parser.add_subparsers( + title="commands", + metavar="COMMAND", + dest="command_argument", + required=True, + ) + for command in CLI_COMMANDS: + command_parser = subparsers.add_parser( + command.argument_name, + help=command.help, + description=command.description, + formatter_class=config_utils.config_help_formatter( + Config, + include_fields=command.config_fields, + ), + allow_abbrev=False, + ) + command_parser.set_defaults(command_name=command.command_name) + config_utils.add_config_arguments( + command_parser, + Config, + include_fields=command.config_fields, + ) + arguments = parser.parse_args(argv) + try: + config = config_utils.load_config_from_args( + Config, + arguments, + base_dir=Path("."), + resolve_op_refs=True, + ) + except src.ConfigError as exception: + parser.error(str(exception)) + command_name = cast(CommandName, arguments.command_name) + validate_config(command_name, config) + return CliInput(command_name=command_name, config=config) + + +def default_maps_path(endpoint: str) -> Path: + """Return the generated maps path for a Sourcegraph endpoint.""" + return backups.endpoint_artifacts_directory(endpoint) / DEFAULT_MAPS_FILE_NAME + + +def config_with_default_paths(command_name: CommandName, config: Config, endpoint: str) -> Config: + """Return config with omitted file paths filled from generated defaults.""" + if command_name != "set" or config.maps_path is not None: return config - return config.model_copy(update=updates) + return config.model_copy(update={"maps_path": default_maps_path(endpoint)}) -def require_set_input_file(config: SrcAuthPermissionsSyncConfig) -> None: +def require_set_input_file(maps_path: Path) -> None: """Exit with a clear error if the selected maps file is missing.""" - if config.set_path is None: - return - if config.set_path.is_file(): + if maps_path.is_file(): return - if config.set_path.exists(): - raise SystemExit(f"--set input path is not a file: {config.set_path}") + if maps_path.exists(): + raise SystemExit(f"set input path is not a file: {maps_path}") raise SystemExit( - "--set input file does not exist: " - f"{config.set_path}\n" - "Run `uv run src-auth-perms-sync --get` to create the default maps.yaml, " + "set input file does not exist: " + f"{maps_path}\n" + "Run `uv run src-auth-perms-sync get` to create the default maps.yaml, " "or pass a path to an existing maps file." ) -def run_fields( - config: SrcAuthPermissionsSyncConfig, command: ResolvedCommand, endpoint: str -) -> dict[str, object]: +def run_fields(config: Config, command: ResolvedCommand, endpoint: str) -> dict[str, object]: """Return run-level fields for structured logging.""" return { "cli_cmd": command.log_name, @@ -387,6 +517,7 @@ def run_fields( "explicit_permissions_batch_size": config.explicit_permissions_batch_size, "trace": config.trace, "max_attempts": config.max_attempts, + "http_timeout_seconds": config.http_timeout_seconds, "no_backup": config.no_backup, "sample_interval": config.sample_interval, "user_created_after": config.created_after, @@ -397,13 +528,14 @@ def run_fields( def run_with_client( - config: SrcAuthPermissionsSyncConfig, + config: Config, command: ResolvedCommand, endpoint: str, worker_pool: ThreadPoolExecutor, ) -> None: """Create a client, run the selected command, and always close HTTP resources.""" http = src.HTTPClient( + timeout=config.http_timeout_seconds, user_agent="src-auth-perms-sync/0.1 (+python)", max_attempts=config.max_attempts, max_connections=config.parallelism, @@ -421,7 +553,7 @@ def run_with_client( def run_command( - config: SrcAuthPermissionsSyncConfig, + config: Config, command: ResolvedCommand, client: src.SourcegraphClient, worker_pool: ThreadPoolExecutor, @@ -456,19 +588,21 @@ def run_command( def run_set( - config: SrcAuthPermissionsSyncConfig, + config: Config, command: ResolvedCommand, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, worker_pool: ThreadPoolExecutor, ) -> run_context.CommandData: """Run the selected repo-permission sync command.""" - assert config.set_path is not None assert command.set_options is not None - require_set_input_file(config) + maps_path = config.maps_path + if maps_path is None: + raise SystemExit("set requires a maps file path") + require_set_input_file(maps_path) return permissions_command.cmd_set( client, - config.set_path, + maps_path, command.set_options, dry_run=not config.apply, parallelism=config.parallelism, @@ -484,7 +618,7 @@ def run_set( def run_restore( - config: SrcAuthPermissionsSyncConfig, + config: Config, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, worker_pool: ThreadPoolExecutor, @@ -504,7 +638,7 @@ def run_restore( def run_sync_saml_organizations( - config: SrcAuthPermissionsSyncConfig, + config: Config, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, command_data: run_context.CommandData, @@ -525,14 +659,14 @@ def run_sync_saml_organizations( def run_get( - config: SrcAuthPermissionsSyncConfig, + config: Config, client: src.SourcegraphClient, sourcegraph_site_config: site_config.SiteConfig, worker_pool: ThreadPoolExecutor, ) -> run_context.CommandData: """Run the default read-only discovery command.""" artifacts_directory = backups.endpoint_artifacts_directory(client.endpoint) - maps_path = artifacts_directory / "maps.yaml" + maps_path = default_maps_path(client.endpoint) maps_created = permissions_maps.create_maps_yaml_if_missing(maps_path) if maps_created: log.info("maps.yaml missing, created %s with an empty maps list.", maps_path) @@ -544,7 +678,7 @@ def run_get( artifacts_directory / "code-hosts.yaml", artifacts_directory / "auth-providers.yaml", maps_path, - user_identifier=config.user, + user_identifiers=config.users, users_without_explicit_perms=config.users_without_explicit_perms, user_created_after=config.created_after, parallelism=config.parallelism, @@ -554,7 +688,7 @@ def run_get( sourcegraph_site_config.saml_groups_attribute_name_by_config_id ), auth_providers_by_config_id=sourcegraph_site_config.auth_providers_by_config_id, - retain_saml_group_users=config.sync_saml_organizations, + retain_saml_group_users=False, worker_pool=worker_pool, ) @@ -567,14 +701,47 @@ def reraise_system_exit_with_logged_error(exception: SystemExit) -> NoReturn: raise exception -def main() -> None: - config = load_config() - command = resolve_command(config) +def Get(config: Config) -> bool: + """Run repository permission discovery and return whether it succeeded.""" + return _run("get", config) + + +def Set(config: Config) -> bool: + """Run repository permission reconciliation and return whether it succeeded.""" + return _run("set", config) + + +def Restore(config: Config) -> bool: + """Run repository permission restore and return whether it succeeded.""" + return _run("restore", config) + + +def SyncSamlOrgs(config: Config) -> bool: + """Run SAML organization sync and return whether it succeeded.""" + return _run("sync_saml_orgs", config) + + +def _run(command_name: CommandName, config: Config) -> bool: + """Run a command and return whether it completed successfully.""" + try: + _run_or_raise(command_name, config) + except SystemExit as exception: + return exception.code in (None, 0) + except Exception: + log.exception("src-auth-perms-sync run failed.") + return False + return True + + +def _run_or_raise(command_name: CommandName, config: Config) -> None: + """Run src-auth-perms-sync, preserving CLI-style exceptions.""" + validate_config(command_name, config) + command = resolve_command(command_name, config) try: endpoint = src.normalize_sourcegraph_endpoint(config.src_endpoint) except ValueError as error: config_error(str(error)) - config = endpoint_scoped_config(config, endpoint) + config = config_with_default_paths(command_name, config, endpoint) run_timestamp = backups.backup_timestamp() run_directory = backups.artifact_run_directory( run_timestamp, @@ -604,3 +771,8 @@ def main() -> None: run_with_client(config, command, endpoint, worker_pool) except SystemExit as exception: reraise_system_exit_with_logged_error(exception) + + +def main() -> None: + cli_input = load_cli() + _run_or_raise(cli_input.command_name, cli_input.config) diff --git a/src/src_auth_perms_sync/permissions/apply.py b/src/src_auth_perms_sync/permissions/apply.py index 7849855..302cef6 100644 --- a/src/src_auth_perms_sync/permissions/apply.py +++ b/src/src_auth_perms_sync/permissions/apply.py @@ -26,6 +26,12 @@ log = logging.getLogger(__name__) +MISSING_MUTATION_RESOURCE_TERMS = ( + "repo", + "repository", + "user", +) + @dataclass class CircuitBreaker: @@ -179,6 +185,21 @@ def _mutate_repo_permission_for_user( ) +def is_missing_mutation_resource_error(exception: BaseException) -> bool: + """Return whether a mutation failed because its repo/user disappeared. + + Sourcegraph instances are live systems: users and repos can be deleted + between discovery/planning and the eventual mutation. Those races should + be logged and skipped, not treated as backend-health failures. + """ + if not isinstance(exception, src.GraphQLError): + return False + message = str(exception).lower() + if not any(term in message for term in MISSING_MUTATION_RESOURCE_TERMS): + return False + return "not found" in message or "could not resolve" in message + + def _apply_permission_changes( client: src.SourcegraphClient, changes: Sequence[PermissionChange], @@ -198,6 +219,7 @@ def _apply_permission_changes( succeeded = 0 failed = 0 canceled = 0 + skipped = 0 breaker = CircuitBreaker() with run_context.thread_pool(parallelism, worker_pool) as executor: futures = { @@ -228,6 +250,17 @@ def _apply_permission_changes( canceled += 1 continue except Exception as exception: + if is_missing_mutation_resource_error(exception): + skipped += 1 + log.warning( + " SKIP %s %s → %s (id=%d): repo/user no longer exists: %s", + action, + change.username, + change.repo_name, + src.decode_repository_id(change.repo_id), + exception, + ) + continue failed += 1 breaker.record(success=False) log.error( @@ -246,11 +279,13 @@ def _apply_permission_changes( batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled + batch_event["skipped"] = skipped batch_event["circuit_broken"] = breaker.is_open() return shared_types.MutationCounts( succeeded=succeeded, failed=failed, canceled=canceled, + skipped=skipped, ) @@ -301,17 +336,18 @@ def _apply_repo_overwrite_plans( ) -> shared_types.MutationCounts: """Dispatch per-repo overwrite mutations with bounded in-flight work.""" max_pending_futures = max(1, parallelism * 2) - total_users = sum(len(overwrite.usernames) for overwrite in overwrites) + payload_grant_count = sum(len(overwrite.usernames) for overwrite in overwrites) with src.event( "apply_username_overwrites", payload_count=len(overwrites), parallelism=parallelism, - total_users=total_users, + payload_grant_count=payload_grant_count, max_pending_futures=max_pending_futures, ) as batch_event: succeeded = 0 failed = 0 canceled = 0 + skipped = 0 submitted_count = 0 submissions_stopped = False breaker = CircuitBreaker() @@ -371,6 +407,15 @@ def _stop_submissions() -> None: canceled += 1 continue except Exception as exception: + if is_missing_mutation_resource_error(exception): + skipped += 1 + log.warning( + " SKIP %s (id=%d): repo/user no longer exists: %s", + overwrite.repository_name, + src.decode_repository_id(overwrite.repository_id), + exception, + ) + continue failed += 1 breaker.record(success=False) log.error( @@ -395,12 +440,14 @@ def _stop_submissions() -> None: batch_event["succeeded"] = succeeded batch_event["failed"] = failed batch_event["canceled"] = canceled + batch_event["skipped"] = skipped batch_event["circuit_broken"] = breaker.is_open() batch_event["submitted"] = submitted_count return shared_types.MutationCounts( succeeded=succeeded, failed=failed, canceled=canceled, + skipped=skipped, ) diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 95d4041..c9f4e5e 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import Any, cast +from typing import Any import src_py_lib as src @@ -43,7 +43,7 @@ class _ResolvedMapping: index: int name: str - users_section: dict[str, object] + user_selector: permission_types.UserSelector repos: list[permission_types.Repository] @@ -52,9 +52,9 @@ def resolve_additive_mappings(context: permission_types.MappingContext) -> list[ resolved: list[_ResolvedMapping] = [] for mapping_index, mapping in enumerate(context.mapping_rules, start=1): name = mapping.get("name", f"") - repos_section = cast(dict[str, object], mapping["repos"]) + repository_selector = mapping["repos"] matched_repos = permissions_mapping.resolve_repos( - repos_section, + repository_selector, context.services_by_id, context.repos_by_external_service_id, context.all_repos_by_id, @@ -72,7 +72,7 @@ def resolve_additive_mappings(context: permission_types.MappingContext) -> list[ _ResolvedMapping( index=mapping_index, name=name, - users_section=cast(dict[str, object], mapping["users"]), + user_selector=mapping["users"], repos=matched_repos, ) ) @@ -85,7 +85,7 @@ def cmd_get( auth_providers_path: Path, maps_path: Path, *, - user_identifier: str | None, + user_identifiers: tuple[str, ...], users_without_explicit_perms: bool, user_created_after: str | None, parallelism: int, @@ -120,7 +120,7 @@ def cmd_get( code_hosts_path=str(code_hosts_path), auth_providers_path=str(auth_providers_path), maps_path=str(maps_path), - user_identifier=user_identifier, + user_identifiers=user_identifiers, users_without_explicit_perms=users_without_explicit_perms, user_created_after=user_created_after, parallelism=parallelism, @@ -134,7 +134,7 @@ def cmd_get( users = _load_get_users( client, - user_identifier=user_identifier, + user_identifiers=user_identifiers, users_without_explicit_perms=users_without_explicit_perms, user_created_after=user_created_after, ) @@ -206,7 +206,7 @@ def cmd_get( raw_providers, attribute_names_by_provider, ) - if user_identifier is None + if not user_identifiers and not users_without_explicit_perms and user_created_after is None and retain_saml_group_users @@ -221,24 +221,27 @@ def cmd_get( def _load_get_users( client: src.SourcegraphClient, *, - user_identifier: str | None, + user_identifiers: tuple[str, ...], users_without_explicit_perms: bool, user_created_after: str | None, ) -> list[shared_types.User]: """Load the Sourcegraph users selected by get/set-compatible user filters.""" - if user_identifier is not None: - user = _resolve_user_identifier(client, user_identifier) + if user_identifiers: + users = _resolve_user_identifiers(client, user_identifiers) if user_created_after is None: - return [user] + return users candidate_user_ids = user_ids_created_on_or_after(client, user_created_after) - if user["id"] in candidate_user_ids: - return [user] - log.info( - "User %s was not created on or after %s — no user metadata selected.", - user["username"], - user_created_after, - ) - return [] + selected_users: list[shared_types.User] = [] + for user in users: + if user["id"] in candidate_user_ids: + selected_users.append(user) + continue + log.info( + "User %s was not created on or after %s — no user metadata selected.", + user["username"], + user_created_after, + ) + return selected_users if users_without_explicit_perms or user_created_after is not None: created_after_filter: str | None = None @@ -317,7 +320,7 @@ def cmd_set( retain_saml_group_users: bool = False, worker_pool: ThreadPoolExecutor | None = None, ) -> run_context.CommandData: - """Dispatch the selected `--set` mode.""" + """Dispatch the selected set mode.""" if options.mode == "full": return permissions_full_set.cmd_set_full( client, @@ -332,12 +335,12 @@ def cmd_set( retain_saml_group_users, worker_pool, ) - if options.mode == "user": - assert options.user_identifier is not None - return cmd_set_additive_user( + if options.mode == "users": + assert options.user_identifiers + return cmd_set_additive_users( client, input_path, - options.user_identifier, + options.user_identifiers, options.user_created_after, dry_run, parallelism, @@ -361,10 +364,10 @@ def cmd_set( return run_context.CommandData() -def cmd_set_additive_user( +def cmd_set_additive_users( client: src.SourcegraphClient, input_path: Path, - user_identifier: str, + user_identifiers: tuple[str, ...], user_created_after: str | None, dry_run: bool, parallelism: int, @@ -373,11 +376,11 @@ def cmd_set_additive_user( do_backup: bool, worker_pool: ThreadPoolExecutor | None = None, ) -> run_context.CommandData: - """Add missing mapped permissions for one resolved user.""" + """Add missing mapped permissions for resolved users.""" with src.event( - "cmd_set_additive_user", + "cmd_set_additive_users", input_path=str(input_path), - user_identifier=user_identifier, + user_identifiers=user_identifiers, user_created_after=user_created_after, dry_run=dry_run, parallelism=parallelism, @@ -386,33 +389,50 @@ def cmd_set_additive_user( context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) if context is None: return run_context.CommandData() - user = _resolve_user_identifier(client, user_identifier) + include_user_emails = permissions_mapping.mapping_rules_need_user_emails( + context.mapping_rules + ) + users = _resolve_user_identifiers( + client, + user_identifiers, + include_emails=include_user_emails, + ) if user_created_after is not None: candidate_user_ids = user_ids_created_on_or_after(client, user_created_after) - if user["id"] not in candidate_user_ids: + selected_users: list[shared_types.User] = [] + for user in users: + if user["id"] in candidate_user_ids: + selected_users.append(user) + continue log.info( "User %s was not created on or after %s — nothing to do.", user["username"], user_created_after, ) + users = selected_users + if not users: return run_context.CommandData(auth_providers=context.providers) resolved_mappings = resolve_additive_mappings(context) - additions = _plan_additions_for_user( - client, - context, - resolved_mappings, - user, - ) + additions: list[permissions_apply.PermissionAddition] = [] + for user in users: + additions.extend( + _plan_additions_for_user( + client, + context, + resolved_mappings, + user, + ) + ) _run_additive_apply( client, input_path, - [user], + users, additions, dry_run=dry_run, parallelism=parallelism, bind_id_mode=bind_id_mode, do_backup=do_backup, - command_name="set-add-user", + command_name="set-add-users", worker_pool=worker_pool, ) return run_context.CommandData(auth_providers=context.providers) @@ -446,6 +466,9 @@ def cmd_set_additive_users_without_explicit_perms( context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) if context is None: return run_context.CommandData() + include_user_emails = permissions_mapping.mapping_rules_need_user_emails( + context.mapping_rules + ) resolved_mappings = resolve_additive_mappings(context) candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) log.info("Received %d non-deleted user candidate(s).", len(candidates)) @@ -455,7 +478,11 @@ def cmd_set_additive_users_without_explicit_perms( for candidate in candidates: if permissions_sourcegraph.user_has_explicit_repos(client, candidate["id"]): continue - user = permissions_sourcegraph.get_user_by_id(client, candidate["id"]) + user = permissions_sourcegraph.get_user_by_id( + client, + candidate["id"], + include_emails=include_user_emails, + ) if user is None: log.warning( "Skipping user candidate %s: user no longer exists.", @@ -491,19 +518,56 @@ def cmd_set_additive_users_without_explicit_perms( return run_context.CommandData(auth_providers=context.providers) +def _resolve_user_identifiers( + client: src.SourcegraphClient, + user_identifiers: tuple[str, ...], + *, + include_emails: bool = False, +) -> list[shared_types.User]: + """Resolve username/email inputs to distinct Sourcegraph users in caller order.""" + users: list[shared_types.User] = [] + seen_user_ids: set[str] = set() + for user_identifier in user_identifiers: + user = _resolve_user_identifier( + client, + user_identifier, + include_emails=include_emails, + ) + if user["id"] in seen_user_ids: + continue + seen_user_ids.add(user["id"]) + users.append(user) + return users + + def _resolve_user_identifier( - client: src.SourcegraphClient, user_identifier: str + client: src.SourcegraphClient, + user_identifier: str, + *, + include_emails: bool = False, ) -> shared_types.User: """Resolve username/email input to one Sourcegraph user.""" user: shared_types.User | None if "@" in user_identifier: user = permissions_sourcegraph.get_user_by_email( - client, user_identifier - ) or permissions_sourcegraph.get_user_by_username(client, user_identifier) + client, + user_identifier, + include_emails=include_emails, + ) or permissions_sourcegraph.get_user_by_username( + client, + user_identifier, + include_emails=include_emails, + ) else: user = permissions_sourcegraph.get_user_by_username( - client, user_identifier - ) or permissions_sourcegraph.get_user_by_email(client, user_identifier) + client, + user_identifier, + include_emails=include_emails, + ) or permissions_sourcegraph.get_user_by_email( + client, + user_identifier, + include_emails=include_emails, + ) if user is None: raise SystemExit(f"No Sourcegraph user found for {user_identifier!r}.") if user["username"] != user_identifier: @@ -521,8 +585,8 @@ def _plan_additions_for_user( """Return missing additive permission edges for one user.""" desired_repos: dict[str, permission_types.Repository] = {} for resolved_mapping in resolved_mappings: - if not permissions_mapping.user_matches_users_section( - resolved_mapping.users_section, + if not permissions_mapping.user_matches_user_selector( + resolved_mapping.user_selector, user, context.providers, context.saml_groups_attribute_names, @@ -648,8 +712,9 @@ def _apply_additive_permissions( worker_pool=worker_pool, ) log.info( - "Additive apply done. %d succeeded, %d failed, %d canceled.", + "Additive apply done. %d succeeded, %d skipped, %d failed, %d canceled.", mutations.succeeded, + mutations.skipped, mutations.failed, mutations.canceled, ) diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 1ef3fda..a3534d8 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import Any, cast +from typing import Any import src_py_lib as src @@ -98,6 +98,7 @@ def _capture_full_set_snapshot_state( explicit_permissions_batch_size: int, bind_id_mode: str, worker_pool: ThreadPoolExecutor | None = None, + include_user_emails: bool = False, ) -> _FullSetUserState: """Load users while capturing the before-snapshot.""" total_users = shared_sourcegraph.count_users(client) @@ -110,7 +111,11 @@ def _capture_full_set_snapshot_state( before_timestamp = backups.backup_timestamp() before_snapshot = permission_snapshot.build_snapshot( client, - shared_sourcegraph.list_users_streaming(client, collect_into=users), + shared_sourcegraph.list_users_streaming( + client, + collect_into=users, + include_emails=include_user_emails, + ), parallelism, bind_id_mode, input_path, @@ -140,6 +145,7 @@ def _load_full_set_snapshot_state( bind_id_mode: str, capture_before: bool, worker_pool: ThreadPoolExecutor | None = None, + include_user_emails: bool = False, ) -> _FullSetUserState: """Load all users, optionally with a before-snapshot.""" if capture_before: @@ -150,10 +156,14 @@ def _load_full_set_snapshot_state( explicit_permissions_batch_size, bind_id_mode, worker_pool, + include_user_emails=include_user_emails, ) log.info("Loading users from %s ...", client.endpoint) - users = shared_sourcegraph.list_users_with_accounts(client) + users = shared_sourcegraph.list_users_with_accounts( + client, + include_emails=include_user_emails, + ) log.info("Received %d total users.", len(users)) return _FullSetUserState(users=users) @@ -256,23 +266,24 @@ def _write_noop_full_set_snapshots( return before_path, after_path, diff_path, maps_backup_path -def _plan_full_set_permissions( +def plan_full_set_permissions( context: permission_types.MappingContext, users: list[shared_types.User], ) -> _FullSetPlan: """Resolve mapping rules into one repo-to-users overwrite plan.""" - repo_usernames: dict[str, set[str]] = {} + expected_users: dict[str, tuple[str, ...]] = {} + union_usernames_by_repo_id: dict[str, set[str]] = {} repo_names: dict[str, str] = {} for mapping_index, mapping in enumerate(context.mapping_rules, start=1): name = mapping.get("name", f"") log.info("=== Mapping %d / %d: %s ===", mapping_index, len(context.mapping_rules), name) - users_section = cast(dict[str, object], mapping["users"]) - repos_section = cast(dict[str, object], mapping["repos"]) + user_selector = mapping["users"] + repository_selector = mapping["repos"] matched_users = permissions_mapping.resolve_users( - users_section, + user_selector, users, context.providers, context.saml_groups_attribute_names, @@ -283,7 +294,7 @@ def _plan_full_set_permissions( continue matched_repos = permissions_mapping.resolve_repos( - repos_section, + repository_selector, context.services_by_id, context.repos_by_external_service_id, context.all_repos_by_id, @@ -293,15 +304,28 @@ def _plan_full_set_permissions( log.warning(" No repos matched — skipping rule.") continue - matched_usernames = tuple(user["username"] for user in matched_users) + matched_usernames = tuple(sorted({user["username"] for user in matched_users})) for repo in matched_repos: - bucket = repo_usernames.setdefault(repo["id"], set()) - repo_names[repo["id"]] = repo["name"] - bucket.update(matched_usernames) + repo_id = repo["id"] + repo_names[repo_id] = repo["name"] + union_usernames = union_usernames_by_repo_id.get(repo_id) + if union_usernames is not None: + union_usernames.update(matched_usernames) + continue + + existing_usernames = expected_users.get(repo_id) + if existing_usernames is not None: + union_usernames = set(existing_usernames) + union_usernames.update(matched_usernames) + union_usernames_by_repo_id[repo_id] = union_usernames + del expected_users[repo_id] + continue + + expected_users[repo_id] = matched_usernames + + for repo_id, usernames in union_usernames_by_repo_id.items(): + expected_users[repo_id] = tuple(sorted(usernames)) - expected_users = { - repo_id: tuple(sorted(usernames)) for repo_id, usernames in repo_usernames.items() - } total_grants = sum(len(usernames) for usernames in expected_users.values()) if expected_users: log.info( @@ -480,8 +504,9 @@ def _apply_full_set_plans( worker_pool=worker_pool, ) log.info( - "Apply done. %d succeeded, %d failed, %d canceled.", + "Apply done. %d succeeded, %d skipped, %d failed, %d canceled.", mutations.succeeded, + mutations.skipped, mutations.failed, mutations.canceled, ) @@ -501,6 +526,7 @@ def _record_full_set_event_fields( command_event["repo_count"] = len(plan.expected_users) command_event["total_grants"] = plan.total_grants command_event["mutations_succeeded"] = apply_result.mutations.succeeded + command_event["mutations_skipped"] = apply_result.mutations.skipped command_event["mutations_failed"] = apply_result.mutations.failed command_event["mutations_canceled"] = apply_result.mutations.canceled command_event["full_short_circuit"] = apply_result.full_short_circuit @@ -562,7 +588,7 @@ def _finish_full_set_apply_with_backup( log.info( "To roll back the explicit-permissions state captured in " "the before-snapshot, run:\n" - " uv run src-auth-perms-sync --restore %s --apply", + " uv run src-auth-perms-sync restore --restore-path %s --apply", before_path, ) @@ -576,7 +602,7 @@ def _raise_for_failed_full_set_apply( log.error( "RUN FAILED: %d mutation(s) failed, %d canceled by circuit " "breaker (out of %d planned). Review the log file and the " - "before/after snapshots for details, then re-run --set --apply " + "before/after snapshots for details, then re-run set --apply " "(after addressing the underlying cause) to retry the " "remaining work.", apply_result.mutations.failed, @@ -656,6 +682,7 @@ def _load_full_set_plan( retain_saml_group_users: bool, worker_pool: ThreadPoolExecutor | None = None, ) -> _FullSetLoadedPlan: + include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) user_state = _load_full_set_snapshot_state( client, input_path, @@ -664,6 +691,7 @@ def _load_full_set_plan( bind_id_mode, capture_before=capture_before, worker_pool=worker_pool, + include_user_emails=include_user_emails, ) before_path: Path | None = None if capture_before: @@ -687,7 +715,7 @@ def _load_full_set_plan( user_state.users, user_created_after, ) - plan = _plan_full_set_permissions(context, users) + plan = plan_full_set_permissions(context, users) snapshot_state = _compact_full_set_snapshot_state(user_state, users) saml_group_users = ( saml_groups.compact_saml_group_users( diff --git a/src/src_auth_perms_sync/permissions/mapping.py b/src/src_auth_perms_sync/permissions/mapping.py index 904a63e..69d78f1 100644 --- a/src/src_auth_perms_sync/permissions/mapping.py +++ b/src/src_auth_perms_sync/permissions/mapping.py @@ -1,16 +1,15 @@ """Permission mapping resolution: validate rules and match users/repos. -Each mapping rule has a `users:` section and a `repos:` section, each -containing one or more matchers (today: `authProvider`, -`codeHostConnection`, and `regex`). Within a matcher, the supplied -keys AND together against the discovered auth-provider / external- -service entries. Across mapping rules, `cmd_set` unions the per-repo -user sets at apply time — see `src/src_auth_perms_sync/permissions/types.py` for the rationale. +Each mapping rule has a `users:` section and a `repos:` section. Top-level +selectors under each section AND together to keep each rule restrictive. +Values inside each supplied selector list OR together. Across mapping rules, +`cmd_set` unions the per-repo user sets at apply time — see +`src/src_auth_perms_sync/permissions/types.py` for the rationale. Adding a new matcher type: 1. Add the TypedDict in `src/src_auth_perms_sync/permissions/types.py`. - 2. Add it as a sibling key on `UsersFilter` or `ReposFilter`. + 2. Add it as a sibling key on `UserSelector` or `RepositorySelector`. 3. Add a branch in `resolve_users` / `resolve_repos` below. 4. Add structural validation in `validate_mapping_rules`. 5. Add an example rule using the new matcher to `maps-example.yaml`. @@ -20,7 +19,7 @@ import logging import re -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any, cast import json5 @@ -50,7 +49,7 @@ "configID", "samlGroup", } -CODE_HOST_MATCHER_FIELDS: set[str] = {"id", "kind", "displayName", "url", "config"} +CODE_HOST_MATCHER_FIELDS: set[str] = {"kind", "displayName", "url", "username"} AUTH_PROVIDER_VALUE_MATCHES: tuple[tuple[str, str], ...] = ( ("type", "serviceType"), ("serviceID", "serviceID"), @@ -58,7 +57,15 @@ ("displayName", "displayName"), ("configID", "configID"), ) -CODE_HOST_VALUE_MATCHES: tuple[str, ...] = ("kind", "displayName", "url") +CODE_HOST_DIRECT_VALUE_MATCHES: tuple[str, ...] = ("kind", "displayName", "url") +USER_SELECTOR_FIELDS: set[str] = { + "authProvider", + "emails", + "emailRegexes", + "usernames", + "usernameRegexes", +} +REPOSITORY_SELECTOR_FIELDS: set[str] = {"codeHostConnection", "names", "nameRegexes"} # --------------------------------------------------------------------------- @@ -66,7 +73,7 @@ # --------------------------------------------------------------------------- -def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: +def validate_mapping_rules(rules: Sequence[object]) -> None: """Fail fast on structural problems in the YAML before doing any work. Catches operator typos that would otherwise produce confusing partial @@ -81,20 +88,37 @@ def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: bugs. """ errors: list[str] = [] - for rule_index, rule in enumerate(rules, start=1): + for rule_index, rule_object in enumerate(rules, start=1): + if not isinstance(rule_object, dict): + errors.append( + f"mapping {rule_index}: each `maps:` entry must be a mapping " + f"(got {type(rule_object).__name__})" + ) + continue + + rule = cast(Mapping[str, object], rule_object) label = rule.get("name") or f"" prefix = f"mapping {rule_index} ({label!r})" - users_section = cast(dict[str, object], rule.get("users") or {}) - repos_section = cast(dict[str, object], rule.get("repos") or {}) - - if not users_section: - errors.append(f"{prefix}: `users:` section is empty (matches no users)") - if not repos_section: - errors.append(f"{prefix}: `repos:` section is empty (matches no repos)") - - errors.extend(_validate_users_section(users_section, prefix)) - errors.extend(_validate_repos_section(repos_section, prefix)) + errors.extend(_validate_mapping_name(rule.get("name"), prefix)) + errors.extend( + _validate_selector_section( + rule.get("users"), + prefix, + "users", + USER_SELECTOR_FIELDS, + _validate_user_selector, + ) + ) + errors.extend( + _validate_selector_section( + rule.get("repos"), + prefix, + "repos", + REPOSITORY_SELECTOR_FIELDS, + _validate_repository_selector, + ) + ) if errors: bullet = "\n - " @@ -103,30 +127,116 @@ def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: ) -_KNOWN_USER_MATCHERS: set[str] = {"authProvider"} +def mapping_rules_need_user_emails(mapping_rules: list[permission_types.MappingRule]) -> bool: + """Return whether any mapping rule filters users by verified email.""" + return any( + "emails" in mapping["users"] or "emailRegexes" in mapping["users"] + for mapping in mapping_rules + ) + + +def _validate_mapping_name(value: object, prefix: str) -> list[str]: + """Validate the required human-readable mapping name.""" + if value is None: + return [f"{prefix}: `name:` is missing"] + if not isinstance(value, str): + return [f"{prefix}: `name:` must be a string (got {type(value).__name__})"] + if not value: + return [f"{prefix}: `name:` is empty"] + return [] + + +def _validate_selector_section( + value: object, + prefix: str, + section_name: str, + known_fields: set[str], + validate_selector: Callable[[dict[str, object], str, str], list[str]], +) -> list[str]: + """Validate a top-level user or repo selector mapping.""" + if value is None: + return [f"{prefix}: `{section_name}:` section is missing"] + if not isinstance(value, dict): + return [ + f"{prefix}: `{section_name}:` must be a selector mapping (got {type(value).__name__})" + ] + + selector = cast(dict[str, object], value) + errors: list[str] = [] + if not selector: + errors.append(f"{prefix}: `{section_name}:` section is empty (matches nothing)") + return errors + + for field_name in sorted(set(selector) - known_fields): + errors.append(f"{prefix}: unknown {section_name} field {field_name!r}") + errors.extend(validate_selector(selector, prefix, section_name)) + return errors -def _validate_users_section(section: dict[str, object], prefix: str) -> list[str]: - """Reject unknown matcher keys and validate each matcher's shape.""" +def _validate_user_selector( + selector: dict[str, object], prefix: str, selector_path: str +) -> list[str]: + """Validate one user selector's ANDed matcher fields.""" errors: list[str] = [] - for key in section: - if key not in _KNOWN_USER_MATCHERS: - errors.append(f"{prefix}: unknown users matcher {key!r}") - auth_provider = cast(dict[str, object] | None, section.get("authProvider")) + auth_provider = selector.get("authProvider") if auth_provider is not None: - unknown = set(auth_provider) - AUTH_PROVIDER_MATCHER_FIELDS - for field_name in sorted(unknown): - errors.append(f"{prefix}: unknown authProvider field {field_name!r}") - if not auth_provider: - errors.append( - f"{prefix}: authProvider is empty (would match every provider on the instance)" + errors.extend(_validate_auth_provider_matcher(auth_provider, prefix, selector_path)) + if "emails" in selector: + errors.extend(_validate_string_list(selector["emails"], prefix, f"{selector_path}.emails")) + if "emailRegexes" in selector: + errors.extend( + _validate_regexes(selector["emailRegexes"], prefix, f"{selector_path}.emailRegexes") + ) + if "usernames" in selector: + errors.extend( + _validate_string_list(selector["usernames"], prefix, f"{selector_path}.usernames") + ) + if "usernameRegexes" in selector: + errors.extend( + _validate_regexes( + selector["usernameRegexes"], prefix, f"{selector_path}.usernameRegexes" ) - if "samlGroup" in auth_provider: - errors.extend(_validate_saml_group(auth_provider, prefix)) + ) + return errors + + +def _validate_repository_selector( + selector: dict[str, object], prefix: str, selector_path: str +) -> list[str]: + """Validate one repository selector's ANDed matcher fields.""" + errors: list[str] = [] + code_host_connection = selector.get("codeHostConnection") + if code_host_connection is not None: + errors.extend( + _validate_code_host_connection_matcher(code_host_connection, prefix, selector_path) + ) + if "names" in selector: + errors.extend(_validate_string_list(selector["names"], prefix, f"{selector_path}.names")) + if "nameRegexes" in selector: + errors.extend( + _validate_regexes(selector["nameRegexes"], prefix, f"{selector_path}.nameRegexes") + ) + return errors + + +def _validate_auth_provider_matcher(value: object, prefix: str, selector_path: str) -> list[str]: + """Validate an `authProvider:` matcher.""" + path = f"{selector_path}.authProvider" + if not isinstance(value, dict): + return [f"{prefix}: {path} must be a mapping (got {type(value).__name__})"] + + auth_provider = cast(dict[str, object], value) + errors: list[str] = [] + for field_name in sorted(set(auth_provider) - AUTH_PROVIDER_MATCHER_FIELDS): + errors.append(f"{prefix}: unknown {path} field {field_name!r}") + if not auth_provider: + errors.append(f"{prefix}: {path} is empty (would match every provider on the instance)") + if "samlGroup" in auth_provider: + errors.extend(_validate_saml_group(auth_provider, prefix, path)) return errors -def _validate_saml_group(auth_provider: dict[str, object], prefix: str) -> list[str]: +def _validate_saml_group(auth_provider: dict[str, object], prefix: str, path: str) -> list[str]: """`authProvider.samlGroup`, if present, must be a non-empty string and incompatible with a non-SAML `type:` (the rule could never match). """ @@ -134,12 +244,12 @@ def _validate_saml_group(auth_provider: dict[str, object], prefix: str) -> list[ value = auth_provider["samlGroup"] if not isinstance(value, str): errors.append( - f"{prefix}: authProvider.samlGroup must be a single group-name " + f"{prefix}: {path}.samlGroup must be a single group-name " f"string (got {type(value).__name__} {value!r}); to OR multiple " - f"groups, write multiple rules" + f"groups, add multiple top-level maps entries" ) elif not value: - errors.append(f"{prefix}: authProvider.samlGroup is an empty string") + errors.append(f"{prefix}: {path}.samlGroup is an empty string") declared_type = auth_provider.get("type") if ( isinstance(declared_type, str) @@ -147,60 +257,71 @@ def _validate_saml_group(auth_provider: dict[str, object], prefix: str) -> list[ and declared_type != saml_groups.SAML_SERVICE_TYPE ): errors.append( - f"{prefix}: authProvider.samlGroup is set but authProvider.type " + f"{prefix}: {path}.samlGroup is set but {path}.type " f"is {declared_type!r}; only SAML providers carry group claims" ) return errors -def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str]: - """Reject unknown matcher keys and validate `codeHostConnection:` shape.""" +def _validate_code_host_connection_matcher( + value: object, prefix: str, selector_path: str +) -> list[str]: + """Validate a `codeHostConnection:` matcher.""" + path = f"{selector_path}.codeHostConnection" + if not isinstance(value, dict): + return [f"{prefix}: {path} must be a mapping (got {type(value).__name__})"] + + code_host_section = cast(dict[str, object], value) errors: list[str] = [] - for key in section: - if key not in {"codeHostConnection", "regex"}: - errors.append(f"{prefix}: unknown repos matcher {key!r}") - code_host_section = cast(dict[str, object] | None, section.get("codeHostConnection")) - if code_host_section is not None: - unknown = set(code_host_section) - CODE_HOST_MATCHER_FIELDS - for field_name in sorted(unknown): - errors.append(f"{prefix}: unknown codeHostConnection field {field_name!r}") - if not (set(code_host_section) & CODE_HOST_MATCHER_FIELDS): + for field_name in sorted(set(code_host_section) - CODE_HOST_MATCHER_FIELDS): + errors.append(f"{prefix}: unknown {path} field {field_name!r}") + if not code_host_section: + errors.append( + f"{prefix}: {path} is empty (would match every external service on " + f"the instance); supply at least one of {sorted(CODE_HOST_MATCHER_FIELDS)}" + ) + for field_name in sorted(CODE_HOST_MATCHER_FIELDS & set(code_host_section)): + field_value = code_host_section[field_name] + if not isinstance(field_value, str): errors.append( - f"{prefix}: codeHostConnection is empty (would match every " - f"external service on the instance); supply at least one of " - f"{sorted(CODE_HOST_MATCHER_FIELDS)}" + f"{prefix}: {path}.{field_name} must be a string " + f"(got {type(field_value).__name__} {field_value!r})" ) - if "id" in code_host_section: - external_service_id = code_host_section["id"] - if external_service_id is None or external_service_id == "": - errors.append( - f"{prefix}: codeHostConnection.id, if supplied, must be " - f"a non-empty integer (e.g. `id: 5`)" - ) - elif not isinstance(external_service_id, int) or isinstance(external_service_id, bool): - errors.append( - f"{prefix}: codeHostConnection.id must be an integer " - f"(got {type(external_service_id).__name__} {external_service_id!r}); " - f"the YAML config holds the decoded DB primary key, not the " - f"opaque base64 GraphQL Node ID" - ) - if "config" in code_host_section and not isinstance(code_host_section["config"], dict): + elif not field_value: + errors.append(f"{prefix}: {path}.{field_name} is an empty string") + return errors + + +def _validate_regexes(value: object, prefix: str, path: str) -> list[str]: + """Validate list-based regex filters.""" + errors = _validate_string_list(value, prefix, path) + if errors: + return errors + + for index, pattern in enumerate(cast(list[str], value)): + try: + re.compile(pattern) + except re.error as exception: + errors.append(f"{prefix}: {path}[{index}] is not a valid Python regex: {exception}") + return errors + + +def _validate_string_list(value: object, prefix: str, path: str) -> list[str]: + """Validate list-based exact-match filters.""" + if not isinstance(value, list): + return [f"{prefix}: {path} must be a list of strings (got {type(value).__name__})"] + + items = cast(list[object], value) + errors: list[str] = [] + if not items: + errors.append(f"{prefix}: {path} is empty (matches nothing)") + for index, item in enumerate(items): + if not isinstance(item, str): errors.append( - f"{prefix}: codeHostConnection.config must be a mapping of " - f"key/value pairs to deep-subset-match against the service's " - f"parsed config (got {type(code_host_section['config']).__name__})" + f"{prefix}: {path}[{index}] must be a string (got {type(item).__name__} {item!r})" ) - regex = section.get("regex") - if regex is not None: - if not isinstance(regex, str): - errors.append(f"{prefix}: repos.regex must be a string (got {type(regex).__name__})") - elif not regex: - errors.append(f"{prefix}: repos.regex is an empty string") - else: - try: - re.compile(regex) - except re.error as exception: - errors.append(f"{prefix}: repos.regex is not a valid Python regex: {exception}") + elif not item: + errors.append(f"{prefix}: {path}[{index}] is an empty string") return errors @@ -210,12 +331,12 @@ def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str def resolve_users( - section: dict[str, object], + selector: permission_types.UserSelector, all_users: list[shared_types.User], all_providers: list[shared_types.AuthProvider], saml_groups_attribute_names: saml_groups.SamlGroupsAttributeNameByProvider | None = None, ) -> list[shared_types.User]: - """Return users matching ALL matchers under `users:` (intersection). + """Return users matching ALL top-level selectors under `users:`. `saml_groups_attribute_names` overrides the default `"groups"` SAML assertion attribute name per (serviceID, clientID) — see @@ -223,61 +344,178 @@ def resolve_users( `None`, every SAML provider falls back to the default. Only consulted by the `authProvider.samlGroup` sub-field. - Empty section returns an empty user set — `validate_mapping_rules` + Empty sections return an empty user set — `validate_mapping_rules` rejects this at config-load time, so this branch only fires for programmatic callers. """ - if not section: + if not selector: return [] - users_by_id: dict[str, shared_types.User] = {user["id"]: user for user in all_users} - matched_ids: set[str] | None = None - for key, matcher in section.items(): - if key == "authProvider": - current_ids = { + selector_matches: list[set[str]] = [] + auth_provider = selector.get("authProvider") + if auth_provider is not None: + selector_matches.append( + { user["id"] for user in _users_matching_auth_provider( - cast(permission_types.AuthProviderMatcher, matcher), + auth_provider, all_users, all_providers, saml_groups_attribute_names, ) } - else: - # validate_mapping_rules catches this earlier with a clearer - # message; this only fires for programmatic callers. - raise ValueError(f"unknown users matcher {key!r}") - matched_ids = current_ids if matched_ids is None else matched_ids & current_ids + ) + + emails = selector.get("emails") + if emails is not None: + selector_matches.append( + {user["id"] for user in _users_matching_email_values(emails, all_users)} + ) + + email_regexes = selector.get("emailRegexes") + if email_regexes is not None: + selector_matches.append( + {user["id"] for user in _users_matching_email_regexes(email_regexes, all_users)} + ) + + usernames = selector.get("usernames") + if usernames is not None: + selector_matches.append( + {user["id"] for user in _users_matching_username_values(usernames, all_users)} + ) + + username_regexes = selector.get("usernameRegexes") + if username_regexes is not None: + selector_matches.append( + {user["id"] for user in _users_matching_username_regexes(username_regexes, all_users)} + ) + + if not selector_matches: + return [] + + matched_ids = selector_matches[0] + for current_ids in selector_matches[1:]: + matched_ids &= current_ids if not matched_ids: return [] - assert matched_ids is not None - return [users_by_id[user_id] for user_id in matched_ids] + return [user for user in all_users if user["id"] in matched_ids] -def user_matches_users_section( - section: dict[str, object], +def user_matches_user_selector( + selector: permission_types.UserSelector, user: shared_types.User, all_providers: list[shared_types.AuthProvider], saml_groups_attribute_names: saml_groups.SamlGroupsAttributeNameByProvider | None = None, ) -> bool: - """Return whether one user matches ALL matchers under `users:`.""" - if not section: + """Return whether one user matches ALL top-level selectors under `users:`.""" + if not selector: return False - for key, matcher in section.items(): - if key == "authProvider": - if not _user_matches_auth_provider( - cast(permission_types.AuthProviderMatcher, matcher), - user, - all_providers, - saml_groups_attribute_names, - ): - return False - else: - # validate_mapping_rules catches this earlier with a clearer - # message; this only fires for programmatic callers. - raise ValueError(f"unknown users matcher {key!r}") - return True + auth_provider = selector.get("authProvider") + if auth_provider is not None and not _user_matches_auth_provider( + auth_provider, + user, + all_providers, + saml_groups_attribute_names, + ): + return False + + emails = selector.get("emails") + if emails is not None and not _user_matches_email(user, set(emails), []): + return False + + email_regexes = selector.get("emailRegexes") + if email_regexes is not None and not _user_matches_email( + user, set(), _compiled_regexes(email_regexes) + ): + return False + + usernames = selector.get("usernames") + if usernames is not None and not _text_matches(user["username"], set(usernames), []): + return False + + username_regexes = selector.get("usernameRegexes") + if username_regexes is None: + return True + return _text_matches(user["username"], set(), _compiled_regexes(username_regexes)) + + +def _users_matching_email_values( + emails: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users with at least one verified email equal to a listed email.""" + exact_values = set(emails) + matched = [user for user in all_users if _user_matches_email(user, exact_values, [])] + log.info( + " emails → %d user(s) matched %d email selector(s)", + len(matched), + len(exact_values), + ) + return matched + + +def _users_matching_email_regexes( + email_regexes: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users with at least one verified email matching a listed regex.""" + patterns = _compiled_regexes(email_regexes) + matched = [user for user in all_users if _user_matches_email(user, set(), patterns)] + log.info( + " emailRegexes → %d user(s) matched %d email regex selector(s)", + len(matched), + len(set(email_regexes)), + ) + return matched + + +def _user_matches_email( + user: shared_types.User, exact_values: set[str], patterns: list[re.Pattern[str]] +) -> bool: + """Match only verified emails, mirroring Sourcegraph's `user(email:)` lookup.""" + return any( + user_email["verified"] and _text_matches(user_email["email"], exact_values, patterns) + for user_email in user.get("emails", []) + ) + + +def _users_matching_username_values( + usernames: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users whose Sourcegraph username equals a listed username.""" + exact_values = set(usernames) + matched = [user for user in all_users if _text_matches(user["username"], exact_values, [])] + log.info( + " usernames → %d user(s) matched %d username selector(s)", + len(matched), + len(exact_values), + ) + return matched + + +def _users_matching_username_regexes( + username_regexes: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users whose Sourcegraph username matches a listed regex.""" + patterns = _compiled_regexes(username_regexes) + matched = [user for user in all_users if _text_matches(user["username"], set(), patterns)] + log.info( + " usernameRegexes → %d user(s) matched %d username regex selector(s)", + len(matched), + len(set(username_regexes)), + ) + return matched + + +def _compiled_regexes(regexes: list[str]) -> list[re.Pattern[str]]: + """Return compiled regexes.""" + return [re.compile(pattern) for pattern in regexes] + + +def _text_matches(value: str, exact_values: set[str], patterns: list[re.Pattern[str]]) -> bool: + """Return whether text matches exact values or any regex.""" + if value in exact_values: + return True + return any(pattern.search(value) for pattern in patterns) def _users_matching_auth_provider( @@ -434,49 +672,78 @@ def _user_has_saml_group_in_provider( def resolve_repos( - section: dict[str, object], + selector: permission_types.RepositorySelector, services_by_id: dict[int, permission_types.ExternalService], repos_by_external_service_id: dict[int, list[permission_types.Repository]], all_repos_by_id: dict[str, permission_types.Repository], ) -> list[permission_types.Repository]: - """Return repos matching ALL matchers under `repos:` (intersection). + """Return repos matching ALL top-level selectors under `repos:`. - Empty section returns an empty repo set; `validate_mapping_rules` + Empty sections return an empty repo set; `validate_mapping_rules` rejects this at config-load time. """ - if not section: + if not selector: return [] - matched_ids: set[str] | None = None - repo_index: dict[str, permission_types.Repository] = {} - ordered_keys = [key for key in ("codeHostConnection", "regex") if key in section] - for key in ordered_keys: - matcher = section[key] - if key == "codeHostConnection": - repos = _repos_matching_code_host_connection( - cast(permission_types.CodeHostConnectionMatcher, matcher), - services_by_id, - repos_by_external_service_id, - ) - elif key == "regex": - candidate_repos = ( - [repo_index[repo_id] for repo_id in matched_ids] - if matched_ids is not None - else list(all_repos_by_id.values()) - ) - repos = _repos_matching_regex(cast(str, matcher), candidate_repos) - else: - # validate_mapping_rules catches this earlier with a clearer - # message; this only fires for programmatic callers. - raise ValueError(f"unknown repos matcher {key!r}") - current_ids = {repo["id"] for repo in repos} - for repo in repos: - repo_index[repo["id"]] = repo - matched_ids = current_ids if matched_ids is None else matched_ids & current_ids + selector_matches: list[set[str]] = [] + repo_index = dict(all_repos_by_id) + candidate_repos = list(all_repos_by_id.values()) + code_host_connection = selector.get("codeHostConnection") + if code_host_connection is not None: + repos = _repos_matching_code_host_connection( + code_host_connection, + services_by_id, + repos_by_external_service_id, + ) + repo_index.update({repo["id"]: repo for repo in repos}) + candidate_repos = repos + selector_matches.append({repo["id"] for repo in repos}) + + names = selector.get("names") + if names is not None: + selector_matches.append(_repo_ids_matching_names(names, candidate_repos)) + + name_regexes = selector.get("nameRegexes") + if name_regexes is not None: + selector_matches.append(_repo_ids_matching_name_regexes(name_regexes, candidate_repos)) + + if not selector_matches: + return [] + + matched_ids = selector_matches[0] + for current_ids in selector_matches[1:]: + matched_ids &= current_ids if not matched_ids: return [] - assert matched_ids is not None - return [repo_index[repo_id] for repo_id in matched_ids] + return [repo for repo in repo_index.values() if repo["id"] in matched_ids] + + +def _repo_ids_matching_names( + names: list[str], repos: list[permission_types.Repository] +) -> set[str]: + """Return repo IDs whose Sourcegraph name equals a listed name.""" + exact_values = set(names) + matched = {repo["id"] for repo in repos if _repo_name_matches(repo["name"], exact_values, [])} + log.info( + " names → %d repo(s) matched %d name selector(s)", + len(matched), + len(exact_values), + ) + return matched + + +def _repo_ids_matching_name_regexes( + name_regexes: list[str], repos: list[permission_types.Repository] +) -> set[str]: + """Return repo IDs whose Sourcegraph name matches a listed regex.""" + patterns = _compiled_regexes(name_regexes) + matched = {repo["id"] for repo in repos if _repo_name_matches(repo["name"], set(), patterns)} + log.info( + " nameRegexes → %d repo(s) matched %d name regex selector(s)", + len(matched), + len(set(name_regexes)), + ) + return matched def _repos_matching_code_host_connection( @@ -505,70 +772,47 @@ def _repos_matching_code_host_connection( return list(matched_repos.values()) -def _repos_matching_regex( - pattern: str, repos: list[permission_types.Repository] -) -> list[permission_types.Repository]: - """Return repos whose name matches `pattern` using Python `re`. +def _repo_name_matches( + repository_name: str, exact_values: set[str], patterns: list[re.Pattern[str]] +) -> bool: + """Return whether a repo name matches exact values or regexes. Sourcegraph repo names usually omit the URL scheme (for example - `github.com/example/repo`). To keep URL-looking operator patterns - useful, also test `https://`. + `github.com/example/repo`). To keep URL-looking operator regexes useful, + also test `https://` for regex matches. Exact matches remain + exact Sourcegraph repo names. """ - compiled = re.compile(pattern) - matched = [ - repo - for repo in repos - if compiled.search(repo["name"]) or compiled.search(f"https://{repo['name']}") - ] - log.info(" regex → %d repo(s) matched %r", len(matched), pattern) - return matched + if repository_name in exact_values: + return True + return any( + pattern.search(repository_name) or pattern.search(f"https://{repository_name}") + for pattern in patterns + ) def _services_matching( services_by_id: dict[int, permission_types.ExternalService], matcher: permission_types.CodeHostConnectionMatcher, ) -> list[permission_types.ExternalService]: - """AND across the supplied matcher fields. If `id` is supplied we - short-circuit to a single candidate; remaining fields then act as a - defensive cross-check against an ES recreated/renamed under the - same id. Without `id`, every other supplied field is a primary - discriminator across the full service list. - """ - if "id" in matcher: - single_service = services_by_id.get(matcher["id"]) - if single_service is None: - return [] - candidates = [single_service] - else: - candidates = list(services_by_id.values()) - + """AND across the supplied human-readable code-host matcher fields.""" matched: list[permission_types.ExternalService] = [] matcher_values = cast(Mapping[str, object], matcher) - for service in candidates: + for service in services_by_id.values(): service_values = cast(Mapping[str, object], service) if not all( field_name not in matcher_values or matcher_values[field_name] == service_values[field_name] - for field_name in CODE_HOST_VALUE_MATCHES + for field_name in CODE_HOST_DIRECT_VALUE_MATCHES ): continue - if "config" in matcher and not _config_subset_matches( - matcher["config"], _parsed_service_config(service) - ): + if "username" in matcher and matcher["username"] != _service_username(service): continue matched.append(service) return matched def _parsed_service_config(service: permission_types.ExternalService) -> dict[str, Any]: - """Best-effort parse of `ExternalService.config` (JSONC string). - - Returns an empty dict if the config is missing or unparseable — - callers treat that as "no keys to match against", so a `config:` - matcher against such a service simply fails to match instead of - raising. Sourcegraph's resolver returns a JSON object string, so - parse failures here are anomalies worth not crashing on. - """ + """Best-effort parse of `ExternalService.config` (JSONC string).""" raw_config = service.get("config") if not raw_config: return {} @@ -581,46 +825,10 @@ def _parsed_service_config(service: permission_types.ExternalService) -> dict[st return cast(dict[str, Any], parsed) -def _config_subset_matches(matcher_config: dict[str, Any], service_config: dict[str, Any]) -> bool: - """True iff every key in `matcher_config` is present in `service_config` - with a matching value. Nested dicts are matched recursively - (subset semantics); lists and scalars are matched by equality. - - Sourcegraph's `REDACTED` sentinel is left as-is on the service side: - a matcher that names a redacted key (e.g. `token`) compares against - the literal `"REDACTED"` string and almost certainly fails to - match — exactly the semantics we want, since the operator can't - have known the real secret value. - """ - for key, expected in matcher_config.items(): - if key not in service_config: - return False - actual = service_config[key] - if isinstance(expected, dict) and isinstance(actual, dict): - if not _config_subset_matches( - cast(dict[str, Any], expected), cast(dict[str, Any], actual) - ): - return False - continue - if expected != actual: - return False - return True - - -def referenced_external_service_ids(rules: list[permission_types.MappingRule]) -> set[int]: - """Collect all external_service IDs referenced by the mapping rules. - - Returns integer DB primary keys (the YAML-facing form). Used by - `cmd_set` to pre-flight-warn about any IDs that the live instance - doesn't know about, before per-mapping resolution runs. - """ - referenced: set[int] = set() - for rule in rules: - repos_section = rule.get("repos") or {} - code_host_section = repos_section.get("codeHostConnection") - if code_host_section and "id" in code_host_section: - referenced.add(code_host_section["id"]) - return referenced +def _service_username(service: permission_types.ExternalService) -> str | None: + """Return the code-host username from `ExternalService.config`, if present.""" + username = _parsed_service_config(service).get("username") + return username if isinstance(username, str) else None def _format_matcher(matcher: dict[str, object]) -> str: diff --git a/src/src_auth_perms_sync/permissions/maps.py b/src/src_auth_perms_sync/permissions/maps.py index 98ff4c6..d6cb86d 100644 --- a/src/src_auth_perms_sync/permissions/maps.py +++ b/src/src_auth_perms_sync/permissions/maps.py @@ -133,26 +133,21 @@ def count_users_per_provider( def external_service_to_yaml(service: permission_types.ExternalService) -> dict[str, Any]: """Render an external service for the YAML config. - Keys mirror Sourcegraph GraphQL `ExternalService` field names directly - (camelCase). Every scalar field exposed by the GraphQL schema is - surfaced here, including the JSONC `config` blob (parsed and emitted - as a nested mapping). Sourcegraph's `config` resolver redacts secrets - by replacing their values with the literal string `"REDACTED"`; we - strip those keys recursively via `_strip_redacted` so the YAML - contains no useless redaction placeholders. Nested arrays - (e.g. `webhooks[]`, `exclude[]`) are walked too. - - `id` is the decoded integer DB primary key, NOT the opaque base64 - GraphQL Node ID — operators copy this into mapping rules' `repos. - codeHostConnection.id` field, where the integer form is much - friendlier than `RXh0ZXJuYWxTZXJ2aWNlOjU=`. + Keys mirror the human-readable Sourcegraph GraphQL `ExternalService` + fields that maps can match. The opaque GraphQL `id` is omitted; + maps should identify code host connections with `kind`, `displayName`, + `url`, and/or `username`. + + The JSONC `config` blob is parsed only to lift its top-level + `username` into the read-only discovery YAML. The rest of `config` + is intentionally omitted because maps no longer support matching + code-host connections by arbitrary config subtrees. Optional / nullable fields are omitted when null/empty so the YAML stays readable. Booleans are always emitted (true or false) so the discovered state is explicit. """ rendered: dict[str, Any] = { - "id": src.decode_external_service_id(service["id"]), "kind": service["kind"], "displayName": service["displayName"], "url": service["url"], @@ -181,21 +176,22 @@ def external_service_to_yaml(service: permission_types.ExternalService) -> dict[ raw_config = service.get("config") if raw_config: try: - parsed_config = cast(dict[str, Any], json5.loads(raw_config)) + parsed_config = cast(Any, json5.loads(raw_config)) except ValueError: - # Unparsable JSONC: surface the raw string verbatim so the - # operator can still see what's there. Stripping doesn't - # apply since we have no structure to walk. - rendered["config"] = raw_config + pass else: - rendered["config"] = _strip_redacted(parsed_config) + if isinstance(parsed_config, dict): + config_values = cast(dict[str, Any], parsed_config) + username = config_values.get("username") + if isinstance(username, str) and username: + rendered["username"] = username return rendered def dump_auth_providers_yaml(path: Path, providers: list[dict[str, Any]]) -> None: header = ( "# Sourcegraph auth provider configs.\n" - "# Generated/refreshed by: src-auth-perms-sync --get\n" + "# Generated/refreshed by: src-auth-perms-sync get\n" "# Use these values when writing maps.yaml rules under `users.authProvider`.\n" "# This file is read-only reference data; edit maps.yaml, not this file.\n" ) @@ -205,9 +201,9 @@ def dump_auth_providers_yaml(path: Path, providers: list[dict[str, Any]]) -> Non def dump_code_hosts_yaml(path: Path, code_hosts: list[dict[str, Any]]) -> None: header = ( "# Sourcegraph code host connection configs.\n" - "# Generated/refreshed by: src-auth-perms-sync --get\n" + "# Generated/refreshed by: src-auth-perms-sync get\n" "# Use these values when writing maps.yaml rules under `repos.codeHostConnection`.\n" - "# Secrets from ExternalService.config are stripped.\n" + "# ExternalService.config.username is surfaced as top-level `username` when present.\n" "# This file is read-only reference data; edit maps.yaml, not this file.\n" ) _dump_readonly_discovery_yaml(path, header, "codeHostConnections", code_hosts) diff --git a/src/src_auth_perms_sync/permissions/queries.py b/src/src_auth_perms_sync/permissions/queries.py index afa83b5..7a7e2ec 100644 --- a/src/src_auth_perms_sync/permissions/queries.py +++ b/src/src_auth_perms_sync/permissions/queries.py @@ -89,32 +89,57 @@ } """ -QUERY_USER_BY_USERNAME = f""" +USER_EMAIL_FIELDS = """ +emails { + email + verified +} +""" + + +def user_fields(*, include_emails: bool = False) -> str: + """Return user fields, adding emails only when downstream matching needs them.""" + if include_emails: + return f"{USER_FIELDS}\n{USER_EMAIL_FIELDS}" + return USER_FIELDS + + +def query_user_by_username(*, include_emails: bool = False) -> str: + return f""" query UserByUsername($username: String!) {{ user(username: $username) {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} """ -QUERY_USER_BY_EMAIL = f""" + +def query_user_by_email(*, include_emails: bool = False) -> str: + return f""" query UserByEmail($email: String!) {{ user(email: $email) {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} """ -QUERY_USER_BY_ID = f""" + +def query_user_by_id(*, include_emails: bool = False) -> str: + return f""" query UserByID($id: ID!) {{ node(id: $id) {{ ... on User {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} }} """ + +QUERY_USER_BY_USERNAME = query_user_by_username() +QUERY_USER_BY_EMAIL = query_user_by_email() +QUERY_USER_BY_ID = query_user_by_id() + QUERY_SITE_USERS = """ query SiteUsers($limit: Int!, $offset: Int!, $createdAt: SiteUsersDateRangeInput) { site { diff --git a/src/src_auth_perms_sync/permissions/restore.py b/src/src_auth_perms_sync/permissions/restore.py index da7ab1a..5ab1ece 100644 --- a/src/src_auth_perms_sync/permissions/restore.py +++ b/src/src_auth_perms_sync/permissions/restore.py @@ -489,6 +489,12 @@ def _log_user_scoped_restore_done(mutations: _UserScopedRestoreMutationResult) - mutations.additions.succeeded, mutations.removals.succeeded, ) + skipped = mutations.additions.skipped + mutations.removals.skipped + if skipped: + log.warning( + "Scoped restore skipped %d vanished repo/user mutation(s); the next run will re-plan.", + skipped, + ) def _restore_command_name(dry_run: bool) -> str: @@ -725,8 +731,9 @@ def _apply_restore_overwrites( worker_pool=worker_pool, ) log.info( - "Restore done. %d succeeded, %d failed, %d canceled.", + "Restore done. %d succeeded, %d skipped, %d failed, %d canceled.", mutations.succeeded, + mutations.skipped, mutations.failed, mutations.canceled, ) @@ -744,6 +751,7 @@ def _record_restore_event_fields( command_event["repos_short_circuited"] = plan.skipped_repo_count command_event["snapshot_grants"] = snapshot_state.target_snapshot["stats"]["total_grants"] command_event["mutations_succeeded"] = mutations.succeeded + command_event["mutations_skipped"] = mutations.skipped command_event["mutations_failed"] = mutations.failed command_event["mutations_canceled"] = mutations.canceled diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index ae3928c..3073205 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -230,7 +230,7 @@ def _fetch( user["username"]: repository_ids_by_user_id.get(user["id"], []) for user in batch_users } - fetch_event["repo_count"] = sum( + fetch_event["fetched_grant_count"] = sum( len(repository_ids) for repository_ids in repository_ids_by_username.values() ) fetch_event["per_user_failures"] = failures diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index c64c4b4..ed4e32b 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -41,29 +41,53 @@ def list_repos_for_external_service( ] -def get_user_by_username(client: src.SourcegraphClient, username: str) -> shared_types.User | None: +def get_user_by_username( + client: src.SourcegraphClient, + username: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Return the exact Sourcegraph user for `username`, if it exists.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_USERNAME, cast(src.JSONDict, {"username": username})), + client.graphql( + queries.query_user_by_username(include_emails=include_emails), + cast(src.JSONDict, {"username": username}), + ), ) return cast(shared_types.User | None, data.get("user")) -def get_user_by_email(client: src.SourcegraphClient, email: str) -> shared_types.User | None: +def get_user_by_email( + client: src.SourcegraphClient, + email: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Return the user owning the verified email address, if it exists.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_EMAIL, cast(src.JSONDict, {"email": email})), + client.graphql( + queries.query_user_by_email(include_emails=include_emails), + cast(src.JSONDict, {"email": email}), + ), ) return cast(shared_types.User | None, data.get("user")) -def get_user_by_id(client: src.SourcegraphClient, user_id: str) -> shared_types.User | None: +def get_user_by_id( + client: src.SourcegraphClient, + user_id: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Hydrate a User node by GraphQL ID.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_ID, cast(src.JSONDict, {"id": user_id})), + client.graphql( + queries.query_user_by_id(include_emails=include_emails), + cast(src.JSONDict, {"id": user_id}), + ), ) return cast(shared_types.User | None, data.get("node")) @@ -171,11 +195,10 @@ def list_users_explicit_repo_ids( repository_ids_by_user_id: dict[str, list[str]] = {user_id: [] for user_id in user_ids} pending_pages: list[tuple[str, str | None]] = [(user_id, None) for user_id in user_ids] - graphql_client = _graphql_client_without_auto_pagination(client) while pending_pages: batch = pending_pages[:batch_size] del pending_pages[:batch_size] - data = graphql_client.execute( + data = client.graphql( _user_explicit_repos_batch_query(len(batch)), _user_explicit_repos_batch_variables(batch), follow_pages=False, @@ -237,15 +260,6 @@ def list_repositories_by_ids( return repositories -def _graphql_client_without_auto_pagination(client: src.SourcegraphClient) -> src.GraphQLClient: - return src.GraphQLClient( - url=f"{client.endpoint}/.api/graphql", - headers={"Authorization": f"token {client.token}"}, - label="Sourcegraph", - http=client.http, - ) - - def _batches(values: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]: for start_index in range(0, len(values), batch_size): yield values[start_index : start_index + batch_size] diff --git a/src/src_auth_perms_sync/permissions/types.py b/src/src_auth_perms_sync/permissions/types.py index f57ffab..4f57eed 100644 --- a/src/src_auth_perms_sync/permissions/types.py +++ b/src/src_auth_perms_sync/permissions/types.py @@ -3,23 +3,23 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal, NotRequired, TypeAlias, TypedDict +from typing import Any, Literal, TypeAlias, TypedDict from ..shared import types as shared_types SetCommandMode: TypeAlias = Literal[ "full", - "user", + "users", "users_without_explicit_perms", ] @dataclass(frozen=True) class SetCommandOptions: - """Operator-selected mode for `--set`.""" + """Operator-selected mode for the set command.""" mode: SetCommandMode - user_identifier: str | None = None + user_identifiers: tuple[str, ...] = () user_created_after: str | None = None @@ -76,26 +76,34 @@ class AuthProviderMatcher(TypedDict, total=False): class CodeHostConnectionMatcher(TypedDict, total=False): """Match repos by Sourcegraph code-host connection discovery fields.""" - id: int kind: str displayName: str url: str - config: dict[str, Any] + username: str + +class UserSelector(TypedDict, total=False): + """User selectors. Fields AND together; values inside each field OR.""" -class UsersFilter(TypedDict, total=False): authProvider: AuthProviderMatcher + emails: list[str] + emailRegexes: list[str] + usernames: list[str] + usernameRegexes: list[str] -class ReposFilter(TypedDict, total=False): +class RepositorySelector(TypedDict, total=False): + """Repository selectors. Fields AND together; values inside each field OR.""" + codeHostConnection: CodeHostConnectionMatcher - regex: str + names: list[str] + nameRegexes: list[str] class MappingRule(TypedDict): - name: NotRequired[str] - users: UsersFilter - repos: ReposFilter + name: str + users: UserSelector + repos: RepositorySelector class ConfigFile(TypedDict, total=False): diff --git a/src/src_auth_perms_sync/permissions/workflow.py b/src/src_auth_perms_sync/permissions/workflow.py index d882d95..7f0ab66 100644 --- a/src/src_auth_perms_sync/permissions/workflow.py +++ b/src/src_auth_perms_sync/permissions/workflow.py @@ -30,9 +30,9 @@ def load_discovery( dict[tuple[str, str], str], ]: """Fetch auth providers + external services and resolve the SAML attribute - names map, with consistent logging. Shared by --get and --set; returns the - raw lists so each caller can transform them as needed (YAML form for --get, - keyed-by-id dict for --set). + names map, with consistent logging. Shared by get and set; returns the + raw lists so each caller can transform them as needed (YAML form for get, + keyed-by-id dict for set). Both commands need exactly the same instance state to do their work, so centralizing this avoids drift in which providers/services are considered @@ -143,7 +143,6 @@ def load_mapping_context_for_rules( len(all_repos_by_id), len(services_by_id), ) - warn_unknown_external_services(mapping_rules, services_by_id) return permission_types.MappingContext( mapping_rules=mapping_rules, providers=providers, @@ -154,23 +153,6 @@ def load_mapping_context_for_rules( ) -def warn_unknown_external_services( - mapping_rules: list[permission_types.MappingRule], - services_by_id: dict[int, permission_types.ExternalService], -) -> None: - """Warn when maps reference code-host connection IDs absent on the instance.""" - for external_service_id in sorted( - permissions_mapping.referenced_external_service_ids(mapping_rules) - ): - if external_service_id not in services_by_id: - log.warning( - "External service id %s is referenced by the maps but " - "is not present on the instance — rules using it will " - "resolve to zero repos.", - external_service_id, - ) - - def snapshot_path( input_path: Path, timestamp: str, diff --git a/src/src_auth_perms_sync/shared/backups.py b/src/src_auth_perms_sync/shared/backups.py index 5d276ab..e4a2ca7 100644 --- a/src/src_auth_perms_sync/shared/backups.py +++ b/src/src_auth_perms_sync/shared/backups.py @@ -97,13 +97,6 @@ def endpoint_directory_name(endpoint: str) -> str: return safe_filename_part(directory_name) -def endpoint_artifact_path(endpoint: str, path: Path) -> Path: - """Resolve a user-facing artifact path within the endpoint directory by default.""" - if path.is_absolute(): - return path - return endpoint_artifacts_directory(endpoint) / path - - def _fallback_endpoint_port(hostname_and_port: str) -> int | None: """Parse a port from an endpoint netloc that urlsplit could not fully parse.""" if ":" not in hostname_and_port: diff --git a/src/src_auth_perms_sync/shared/queries.py b/src/src_auth_perms_sync/shared/queries.py index 4ffb1e4..c833e42 100644 --- a/src/src_auth_perms_sync/shared/queries.py +++ b/src/src_auth_perms_sync/shared/queries.py @@ -38,15 +38,25 @@ } """ -QUERY_USERS = """ -query ListUsers($first: Int!, $after: String) { - users(first: $first, after: $after) { - nodes { +USER_EMAIL_FIELDS = """ emails { + email + verified + } +""" + + +def query_users(*, include_emails: bool = False) -> str: + """Return the users page query, adding email fields only when requested.""" + email_fields = USER_EMAIL_FIELDS if include_emails else "" + return f""" +query ListUsers($first: Int!, $after: String) {{ + users(first: $first, after: $after) {{ + nodes {{ id username builtinAuth - externalAccounts(first: 50) { - nodes { +{email_fields} externalAccounts(first: 50) {{ + nodes {{ serviceType serviceID clientID @@ -56,10 +66,13 @@ # Admin. Returns null for serviceType where the resolver does # not expose data (e.g. plain GitHub OAuth without SSO). accountData - } - } - } - pageInfo { hasNextPage endCursor } - } -} + }} + }} + }} + pageInfo {{ hasNextPage endCursor }} + }} +}} """ + + +QUERY_USERS = query_users() diff --git a/src/src_auth_perms_sync/shared/sourcegraph.py b/src/src_auth_perms_sync/shared/sourcegraph.py index 3b39d94..f138e2d 100644 --- a/src/src_auth_perms_sync/shared/sourcegraph.py +++ b/src/src_auth_perms_sync/shared/sourcegraph.py @@ -32,11 +32,15 @@ def count_users(client: src.SourcegraphClient) -> int: return cast(int, data["users"]["totalCount"]) -def list_users_with_accounts(client: src.SourcegraphClient) -> list[shared_types.User]: +def list_users_with_accounts( + client: src.SourcegraphClient, + *, + include_emails: bool = False, +) -> list[shared_types.User]: return [ cast(shared_types.User, node) for node in client.stream_connection_nodes( - queries.QUERY_USERS, + queries.query_users(include_emails=include_emails), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ) @@ -46,6 +50,8 @@ def list_users_with_accounts(client: src.SourcegraphClient) -> list[shared_types def list_users_streaming( client: src.SourcegraphClient, collect_into: list[shared_types.User] | None = None, + *, + include_emails: bool = False, ) -> Iterator[shared_types.User]: """Stream ListUsers pages one at a time, yielding each User as it arrives. @@ -59,7 +65,7 @@ def list_users_streaming( streaming benefit in one pass — no double-pagination. """ for node in client.stream_connection_nodes( - queries.QUERY_USERS, + queries.query_users(include_emails=include_emails), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ): diff --git a/src/src_auth_perms_sync/shared/types.py b/src/src_auth_perms_sync/shared/types.py index 9429a41..46f2032 100644 --- a/src/src_auth_perms_sync/shared/types.py +++ b/src/src_auth_perms_sync/shared/types.py @@ -30,11 +30,17 @@ class ExternalAccountConnection(TypedDict): nodes: list[ExternalAccount] +class UserEmail(TypedDict): + email: str + verified: bool + + class User(TypedDict): id: str username: str builtinAuth: bool externalAccounts: ExternalAccountConnection + emails: NotRequired[list[UserEmail]] @dataclass(frozen=True, slots=True) @@ -48,6 +54,7 @@ class MutationCounts: succeeded: int = 0 failed: int = 0 canceled: int = 0 + skipped: int = 0 @dataclass(frozen=True, slots=True) diff --git a/tests/integration/test_cli_entrypoint.py b/tests/integration/test_cli_entrypoint.py index a936982..efc968a 100644 --- a/tests/integration/test_cli_entrypoint.py +++ b/tests/integration/test_cli_entrypoint.py @@ -15,6 +15,62 @@ def test_module_help_prints_usage(self) -> None: ) self.assertIn("src-auth-perms-sync", completed_process.stdout) - self.assertIn("--set", completed_process.stdout) - self.assertIn("--sync-saml-orgs", completed_process.stdout) + self.assertIn("set:\n- Explicit repo permissions", completed_process.stdout) + self.assertIn("Organizations and memberships\n\nSee", completed_process.stdout) + self.assertIn("commands:", completed_process.stdout) + self.assertIn("COMMAND", completed_process.stdout) + self.assertIn("get", completed_process.stdout) + self.assertIn("set", completed_process.stdout) + self.assertIn("sync-saml-orgs", completed_process.stdout) + self.assertIn("Sync orgs from SAML groups", completed_process.stdout) + self.assertNotIn("--maps-path", completed_process.stdout) self.assertEqual("", completed_process.stderr) + + def test_command_help_prints_command_specific_options(self) -> None: + get_help = subprocess.run( + [sys.executable, "-m", "src_auth_perms_sync", "get", "--help"], + check=True, + capture_output=True, + text=True, + ) + set_help = subprocess.run( + [sys.executable, "-m", "src_auth_perms_sync", "set", "--help"], + check=True, + capture_output=True, + text=True, + ) + restore_help = subprocess.run( + [sys.executable, "-m", "src_auth_perms_sync", "restore", "--help"], + check=True, + capture_output=True, + text=True, + ) + sync_saml_orgs_help = subprocess.run( + [sys.executable, "-m", "src_auth_perms_sync", "sync-saml-orgs", "--help"], + check=True, + capture_output=True, + text=True, + ) + + self.assertNotIn("--apply", get_help.stdout) + self.assertNotIn("--no-backup", get_help.stdout) + self.assertNotIn("--sync-saml-orgs", get_help.stdout) + self.assertIn("--users USERS", get_help.stdout) + self.assertNotIn("--user USER", get_help.stdout) + self.assertIn("--maps-path FILE", set_help.stdout) + self.assertIn("--users USERS", set_help.stdout) + self.assertIn("--sync-saml-orgs", set_help.stdout) + self.assertNotIn("--restore-path", set_help.stdout) + self.assertIn("Permission sync:", set_help.stdout) + self.assertIn("Organization sync:", set_help.stdout) + self.assertIn("Sourcegraph:", set_help.stdout) + self.assertIn("Logging:", set_help.stdout) + self.assertLess(set_help.stdout.index("\nLogging:"), set_help.stdout.index("\nConfig:")) + self.assertIn("--restore-path FILE", restore_help.stdout) + self.assertNotIn("--maps-path", restore_help.stdout) + self.assertIn("--apply", sync_saml_orgs_help.stdout) + self.assertNotIn("--sync-saml-orgs", sync_saml_orgs_help.stdout) + self.assertEqual("", get_help.stderr) + self.assertEqual("", set_help.stderr) + self.assertEqual("", restore_help.stderr) + self.assertEqual("", sync_saml_orgs_help.stderr) diff --git a/tests/unit/test_apply.py b/tests/unit/test_apply.py new file mode 100644 index 0000000..59326a0 --- /dev/null +++ b/tests/unit/test_apply.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import unittest +from typing import Any, cast + +import src_py_lib as src + +from src_auth_perms_sync.permissions import apply +from src_auth_perms_sync.permissions import types as permission_types + + +class _FakeSourcegraphClient: + def __init__(self, exception: BaseException | None = None) -> None: + self.exception = exception + self.calls: list[tuple[str, dict[str, Any]]] = [] + + def graphql(self, query: str, variables: src.JSONDict) -> dict[str, Any]: + self.calls.append((query, dict(variables))) + if self.exception is not None: + raise self.exception + return {} + + +class ApplyTests(unittest.TestCase): + def test_repo_not_found_overwrite_is_skipped_not_failed(self) -> None: + client = _FakeSourcegraphClient( + src.GraphQLError( + "Sourcegraph GraphQL errors: [{'message': 'repo not found: id=264'}]", + is_application_error=True, + ) + ) + counts = apply.apply_username_overwrites( + cast(src.SourcegraphClient, client), + [ + permission_types.RepositoryUsernameOverwrite( + repository_id=src.encode_repository_id(264), + repository_name="test-repo-0241", + usernames=("alice",), + ) + ], + parallelism=1, + ) + + self.assertEqual(0, counts.succeeded) + self.assertEqual(1, counts.skipped) + self.assertEqual(0, counts.failed) + self.assertEqual(0, counts.canceled) + self.assertEqual(1, len(client.calls)) + + def test_user_not_found_addition_is_skipped_not_failed(self) -> None: + client = _FakeSourcegraphClient( + src.GraphQLError( + "Sourcegraph GraphQL errors: [{'message': 'user not found: id=123'}]", + is_application_error=True, + ) + ) + counts = apply.apply_additions( + cast(src.SourcegraphClient, client), + [ + apply.PermissionAddition( + user_id="VXNlcjoxMjM=", + username="deleted-user", + repo_id=src.encode_repository_id(264), + repo_name="test-repo-0241", + ) + ], + parallelism=1, + ) + + self.assertEqual(0, counts.succeeded) + self.assertEqual(1, counts.skipped) + self.assertEqual(0, counts.failed) + self.assertEqual(0, counts.canceled) + self.assertEqual(1, len(client.calls)) + + def test_non_missing_graphql_error_is_failed(self) -> None: + client = _FakeSourcegraphClient( + src.GraphQLError( + "Sourcegraph GraphQL errors: [{'message': 'permission denied'}]", + is_application_error=True, + ) + ) + counts = apply.apply_username_overwrites( + cast(src.SourcegraphClient, client), + [ + permission_types.RepositoryUsernameOverwrite( + repository_id=src.encode_repository_id(264), + repository_name="test-repo-0241", + usernames=("alice",), + ) + ], + parallelism=1, + ) + + self.assertEqual(0, counts.succeeded) + self.assertEqual(0, counts.skipped) + self.assertEqual(1, counts.failed) + self.assertEqual(0, counts.canceled) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_backups.py b/tests/unit/test_backups.py index c0d2603..340f773 100644 --- a/tests/unit/test_backups.py +++ b/tests/unit/test_backups.py @@ -25,31 +25,19 @@ def test_endpoint_artifacts_directory_uses_current_directory(self) -> None: ), ) - def test_endpoint_artifact_path_scopes_relative_paths(self) -> None: - self.assertEqual( - Path.cwd() / backups.ARTIFACTS_DIR_NAME / "sourcegraph.example.com" / "maps.yaml", - backups.endpoint_artifact_path("https://sourcegraph.example.com", Path("maps.yaml")), - ) - self.assertEqual( - Path("/tmp/maps.yaml"), - backups.endpoint_artifact_path( - "https://sourcegraph.example.com", Path("/tmp/maps.yaml") - ), - ) - def test_backup_path_uses_safe_endpoint_source_command_and_state(self) -> None: self.assertEqual( Path.cwd() / backups.ARTIFACTS_DIR_NAME / "sourcegraph.example.com" / backups.RUNS_DIR_NAME - / "2026-05-23-set_user" + / "2026-05-23-set_users" / "before.json", backups.backup_path( "repo/1", "2026-05-23", "https://sourcegraph.example.com", - "set:user", + "set:users", "before", ), ) diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index fe6ef17..afc6262 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -2,6 +2,7 @@ import contextlib import io +import os import tempfile import unittest from concurrent.futures import ThreadPoolExecutor @@ -12,21 +13,22 @@ import src_py_lib as src from src_py_lib.utils import config as shared_config +import src_auth_perms_sync from src_auth_perms_sync import cli from src_auth_perms_sync.shared import backups -def make_config(**updates: object) -> cli.SrcAuthPermissionsSyncConfig: - base_config = cli.SrcAuthPermissionsSyncConfig( +def make_config(**updates: object) -> cli.Config: + base_config = cli.Config( src_endpoint="https://sourcegraph.example.com", src_access_token="secret", ) return base_config.model_copy(update=updates) -def load_config_from_env(**env: str) -> cli.SrcAuthPermissionsSyncConfig: +def load_config_from_env(**env: str) -> cli.Config: return shared_config.load_config( - cli.SrcAuthPermissionsSyncConfig, + cli.Config, env_file=None, env={ "SRC_ENDPOINT": "https://sourcegraph.example.com", @@ -38,40 +40,120 @@ def load_config_from_env(**env: str) -> cli.SrcAuthPermissionsSyncConfig: class CliConfigTests(unittest.TestCase): - def test_resolve_command_defaults_to_get(self) -> None: - command = cli.resolve_command(make_config()) + def test_resolve_command_uses_explicit_command_name(self) -> None: + command = cli.resolve_command("get", make_config()) self.assertEqual("get", command.name) self.assertEqual("get", command.log_name) self.assertEqual("get", command.artifact_name) - - def test_resolve_command_prefers_explicit_commands(self) -> None: self.assertEqual( - "set", cli.resolve_command(make_config(set_path=Path("maps.yaml"), full=True)).name + "set", + cli.resolve_command("set", make_config(maps_path=Path("maps.yaml"), full=True)).name, ) self.assertEqual( - "restore", cli.resolve_command(make_config(restore_path=Path("snapshot.json"))).name + "restore", + cli.resolve_command("restore", make_config(restore_path=Path("snapshot.json"))).name, ) self.assertEqual( - "sync_saml_orgs", cli.resolve_command(make_config(sync_saml_organizations=True)).name + "sync_saml_orgs", + cli.resolve_command("sync_saml_orgs", make_config()).name, ) + def test_maps_path_does_not_select_set_command(self) -> None: + command = cli.resolve_command("get", make_config(maps_path=Path("custom-maps.yaml"))) + + self.assertEqual("get", command.name) + + def test_load_cli_returns_command_and_config_options(self) -> None: + with ( + tempfile.TemporaryDirectory() as directory, + mock.patch.dict( + os.environ, + { + "SRC_ENDPOINT": "https://sourcegraph.example.com", + "SRC_ACCESS_TOKEN": "secret", + }, + clear=True, + ), + ): + env_file = Path(directory) / ".env" + env_file.write_text("") + cli_input = cli.load_cli( + [ + "set", + "--env-file", + str(env_file), + "--maps-path", + "custom-maps.yaml", + "--users", + "alice,bob@example.com", + ] + ) + + self.assertEqual("set", cli_input.command_name) + self.assertEqual(Path("custom-maps.yaml"), cli_input.config.maps_path) + self.assertEqual(("alice", "bob@example.com"), cli_input.config.users) + + def test_maps_path_is_none_until_defaulted_for_an_endpoint(self) -> None: + with ( + tempfile.TemporaryDirectory() as directory, + mock.patch.dict( + os.environ, + { + "SRC_ENDPOINT": "https://sourcegraph.example.com", + "SRC_ACCESS_TOKEN": "secret", + }, + clear=True, + ), + ): + env_file = Path(directory) / ".env" + env_file.write_text("") + cli_input = cli.load_cli(["set", "--env-file", str(env_file)]) + + self.assertEqual("set", cli_input.command_name) + self.assertIsNone(cli_input.config.maps_path) + + def test_load_cli_rejects_singular_user_option(self) -> None: + captured_stderr = io.StringIO() + + with ( + contextlib.redirect_stderr(captured_stderr), + self.assertRaises(SystemExit) as exit_context, + ): + cli.load_cli(["get", "--user", "alice"]) + + self.assertEqual(2, exit_context.exception.code) + self.assertIn("unrecognized arguments: --user alice", captured_stderr.getvalue()) + + def test_restore_path_config_loads_without_selecting_a_command(self) -> None: + config = load_config_from_env(SRC_AUTH_PERMS_SYNC_RESTORE_PATH="snapshot.json") + + self.assertEqual(Path.cwd() / "snapshot.json", config.restore_path) + + def test_users_config_loads_comma_delimited_values(self) -> None: + config = load_config_from_env(SRC_AUTH_PERMS_SYNC_USERS="alice, bob@example.com,,carol") + + self.assertEqual(("alice", "bob@example.com", "carol"), config.users) + def test_set_command_options_match_each_incremental_mode(self) -> None: self.assertEqual( - "full", cli.set_command_options(make_config(set_path=Path("maps.yaml"))).mode + "full", + cli.set_command_options(make_config(maps_path=Path("maps.yaml"))).mode, ) self.assertEqual( - ("user", "alice"), + ("users", ("alice", "bob@example.com")), ( - cli.set_command_options(make_config(set_path=Path("maps.yaml"), user="alice")).mode, cli.set_command_options( - make_config(set_path=Path("maps.yaml"), user="alice") - ).user_identifier, + make_config(maps_path=Path("maps.yaml"), users=("alice", "bob@example.com")) + ).mode, + cli.set_command_options( + make_config(maps_path=Path("maps.yaml"), users=("alice", "bob@example.com")) + ).user_identifiers, ), ) users_without_permissions = cli.set_command_options( make_config( - set_path=Path("maps.yaml"), + maps_path=Path("maps.yaml"), users_without_explicit_perms=True, created_after="2026-01-01", ) @@ -79,88 +161,121 @@ def test_set_command_options_match_each_incremental_mode(self) -> None: self.assertEqual("users_without_explicit_perms", users_without_permissions.mode) self.assertEqual("2026-01-01", users_without_permissions.user_created_after) filtered_full = cli.set_command_options( - make_config(set_path=Path("maps.yaml"), created_after="2026-01-01") + make_config(maps_path=Path("maps.yaml"), created_after="2026-01-01") ) self.assertEqual("full", filtered_full.mode) self.assertEqual("2026-01-01", filtered_full.user_created_after) def test_resolve_command_includes_set_mode_names(self) -> None: - user_command = cli.resolve_command( - make_config(set_path=Path("maps.yaml"), user="alice", apply=True) + users_command = cli.resolve_command( + "set", + make_config(maps_path=Path("maps.yaml"), users=("alice",), apply=True), ) - full_command = cli.resolve_command(make_config(set_path=Path("maps.yaml"))) + full_command = cli.resolve_command("set", make_config(maps_path=Path("maps.yaml"))) - self.assertEqual("set_user", user_command.log_name) - self.assertEqual("set-add-user-apply", user_command.artifact_name) - self.assertEqual("user", user_command.set_mode) + self.assertEqual("set_users", users_command.log_name) + self.assertEqual("set-add-users-apply", users_command.artifact_name) + self.assertEqual("users", users_command.set_mode) self.assertEqual("set_full", full_command.log_name) self.assertEqual("set-dry-run", full_command.artifact_name) - def test_resolve_command_includes_combined_sync_names(self) -> None: - get_command = cli.resolve_command(make_config(get=True, sync_saml_organizations=True)) + def test_resolve_command_includes_combined_set_sync_names(self) -> None: set_command = cli.resolve_command( - make_config(set_path=Path("maps.yaml"), apply=True, sync_saml_organizations=True) + "set", + make_config( + maps_path=Path("maps.yaml"), + apply=True, + sync_saml_organizations=True, + ), ) - self.assertEqual("get", get_command.name) - self.assertEqual("get_sync_saml_orgs", get_command.log_name) - self.assertEqual("get-sync-saml-orgs-dry-run", get_command.artifact_name) - self.assertTrue(get_command.sync_saml_organizations) self.assertEqual("set", set_command.name) self.assertEqual("set_full_sync_saml_orgs", set_command.log_name) self.assertEqual("set-sync-saml-orgs-apply", set_command.artifact_name) self.assertTrue(set_command.sync_saml_organizations) - def test_validate_config_rejects_multiple_commands(self) -> None: + def test_validate_config_allows_sync_saml_orgs_with_set(self) -> None: + cli.validate_config( + "set", + make_config(maps_path=Path("maps.yaml"), sync_saml_organizations=True), + ) + + def test_validate_config_rejects_sync_saml_orgs_without_set(self) -> None: + self.assert_config_error( + "get", + make_config(sync_saml_organizations=True), + "can only be combined with set", + ) + self.assert_config_error( + "restore", + make_config(restore_path=Path("snapshot.json"), sync_saml_organizations=True), + "can only be combined with set", + ) self.assert_config_error( - make_config(get=True, set_path=Path("maps.yaml"), full=True), - "choose only one", + "sync_saml_orgs", + make_config(sync_saml_organizations=True), + "can only be combined with set", ) - def test_validate_config_allows_sync_saml_orgs_with_get_or_set(self) -> None: - cli.validate_config(make_config(get=True, sync_saml_organizations=True)) - cli.validate_config(make_config(set_path=Path("maps.yaml"), sync_saml_organizations=True)) + def test_validate_config_rejects_mutating_options_with_get(self) -> None: + self.assert_config_error( + "get", + make_config(apply=True), + "--apply cannot be used with the read-only get command", + ) + self.assert_config_error( + "get", + make_config(no_backup=True), + "--no-backup cannot be used with the read-only get command", + ) + + def test_validate_config_rejects_restore_without_restore_path(self) -> None: + self.assert_config_error("restore", make_config(), "restore requires --restore-path") - def test_validate_config_rejects_sync_saml_orgs_with_restore(self) -> None: + def test_validate_config_rejects_restore_path_without_restore(self) -> None: self.assert_config_error( - make_config(restore_path=Path("snapshot.json"), sync_saml_organizations=True), - "with --get or --set", + "get", + make_config(restore_path=Path("snapshot.json")), + "--restore-path requires the restore command", ) def test_validate_config_rejects_set_modes_without_set(self) -> None: - self.assert_config_error(make_config(full=True), "requires --set") + self.assert_config_error("get", make_config(full=True), "requires the set command") def test_validate_config_allows_get_user_filters_without_set(self) -> None: - cli.validate_config(make_config(user="alice")) - cli.validate_config(make_config(users_without_explicit_perms=True)) - cli.validate_config(make_config(created_after="2026-01-01")) + cli.validate_config("get", make_config(users=("alice", "bob@example.com"))) + cli.validate_config("get", make_config(users_without_explicit_perms=True)) + cli.validate_config("get", make_config(created_after="2026-01-01")) def test_validate_config_rejects_get_user_filter_conflicts(self) -> None: self.assert_config_error( - make_config(user="alice", users_without_explicit_perms=True), - "choose only one of --user or --users-without-explicit-perms", + "get", + make_config(users=("alice",), users_without_explicit_perms=True), + "choose only one of --users or --users-without-explicit-perms", ) def test_validate_config_rejects_user_filters_on_non_get_set_commands(self) -> None: self.assert_config_error( - make_config(restore_path=Path("snapshot.json"), user="alice"), - "require --get or --set", + "restore", + make_config(restore_path=Path("snapshot.json"), users=("alice",)), + "require get or set", ) def test_validate_config_allows_set_without_explicit_mode(self) -> None: - cli.validate_config(make_config(set_path=Path("maps.yaml"))) + cli.validate_config("set", make_config(maps_path=Path("maps.yaml"))) def test_created_after_config_accepts_yyyy_mm_dd_date_arguments(self) -> None: config = load_config_from_env(SRC_AUTH_PERMS_SYNC_CREATED_AFTER="2026-01-01") self.assertEqual("2026-01-01", config.created_after) - cli.validate_config(make_config(get=True, created_after="2026-01-01")) + cli.validate_config("get", make_config(created_after="2026-01-01")) cli.validate_config( + "set", make_config( - set_path=Path("maps.yaml"), - user="alice", + maps_path=Path("maps.yaml"), + users=("alice",), created_after="2026-01-01", - ) + ), ) def test_created_after_config_rejects_values_outside_yyyy_mm_dd_shape(self) -> None: @@ -180,6 +295,15 @@ def test_explicit_permissions_batch_size_rejects_values_below_one(self) -> None: with self.assertRaisesRegex(shared_config.ConfigError, "greater than or equal to 1"): load_config_from_env(SRC_AUTH_PERMS_SYNC_EXPLICIT_PERMISSIONS_BATCH_SIZE="0") + def test_http_timeout_config_is_loaded_from_env(self) -> None: + config = load_config_from_env(SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS="90") + + self.assertEqual(90, config.http_timeout_seconds) + + def test_http_timeout_rejects_values_at_or_below_zero(self) -> None: + with self.assertRaisesRegex(shared_config.ConfigError, "greater than 0"): + load_config_from_env(SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS="0") + def test_trace_config_is_loaded_from_env(self) -> None: config = load_config_from_env(SRC_AUTH_PERMS_SYNC_TRACE="true") @@ -187,11 +311,11 @@ def test_trace_config_is_loaded_from_env(self) -> None: def test_run_with_client_enables_sourcegraph_trace_collection(self) -> None: configuration = make_config(trace=True) - command = cli.resolve_command(configuration) + command = cli.resolve_command("get", configuration) captured_clients: list[src.SourcegraphClient] = [] def capture_client( - _config: cli.SrcAuthPermissionsSyncConfig, + _config: cli.Config, _command: cli.ResolvedCommand, client: src.SourcegraphClient, _worker_pool: ThreadPoolExecutor, @@ -212,9 +336,37 @@ def capture_client( self.assertEqual(1, len(captured_clients)) self.assertTrue(captured_clients[0].trace) + def test_run_with_client_uses_configured_http_timeout(self) -> None: + configuration = make_config(http_timeout_seconds=75.0) + command = cli.resolve_command("get", configuration) + captured_clients: list[src.SourcegraphClient] = [] + + def capture_client( + _config: cli.Config, + _command: cli.ResolvedCommand, + client: src.SourcegraphClient, + _worker_pool: ThreadPoolExecutor, + ) -> None: + captured_clients.append(client) + + with ( + ThreadPoolExecutor(max_workers=1) as worker_pool, + mock.patch.object(cli, "run_command", side_effect=capture_client), + ): + cli.run_with_client( + configuration, + command, + "https://sourcegraph.example.com", + worker_pool, + ) + + self.assertEqual(1, len(captured_clients)) + self.assertEqual(75.0, captured_clients[0].http.timeout) + def test_validate_config_rejects_multiple_set_modes(self) -> None: self.assert_config_error( - make_config(set_path=Path("maps.yaml"), full=True, user="alice"), + "set", + make_config(maps_path=Path("maps.yaml"), full=True, users=("alice",)), "choose at most one", ) @@ -222,37 +374,57 @@ def test_require_set_input_file_reports_missing_maps_file(self) -> None: with tempfile.TemporaryDirectory() as directory: existing_path = Path(directory) / "maps.yaml" existing_path.write_text("maps: []\n") - cli.require_set_input_file(make_config(set_path=existing_path)) + cli.require_set_input_file(existing_path) with self.assertRaises(SystemExit) as exit_context: - cli.require_set_input_file(make_config(set_path=Path(directory) / "missing.yaml")) - self.assertIn("--set input file does not exist", str(exit_context.exception)) + cli.require_set_input_file(Path(directory) / "missing.yaml") + self.assertIn("set input file does not exist", str(exit_context.exception)) + + def test_config_with_default_paths_only_defaults_omitted_maps_path(self) -> None: + endpoint_directory = Path.cwd() / backups.ARTIFACTS_DIR_NAME / "sourcegraph.example.com" - def test_endpoint_scoped_config_rewrites_relative_artifact_paths(self) -> None: - scoped_config = cli.endpoint_scoped_config( - make_config(set_path=Path("maps.yaml"), restore_path=Path("snapshot.json")), + defaulted_set_config = cli.config_with_default_paths( + "set", + make_config(), "https://sourcegraph.example.com", ) - endpoint_directory = Path.cwd() / backups.ARTIFACTS_DIR_NAME / "sourcegraph.example.com" - self.assertEqual(endpoint_directory / "maps.yaml", scoped_config.set_path) - self.assertEqual(endpoint_directory / "snapshot.json", scoped_config.restore_path) + self.assertEqual(endpoint_directory / "maps.yaml", defaulted_set_config.maps_path) + + explicit_set_config = cli.config_with_default_paths( + "set", + make_config(maps_path=Path("maps.yaml")), + "https://sourcegraph.example.com", + ) + self.assertEqual(Path("maps.yaml"), explicit_set_config.maps_path) + + restore_config = cli.config_with_default_paths( + "restore", + make_config(restore_path=Path("snapshot.json")), + "https://sourcegraph.example.com", + ) + self.assertEqual(Path("snapshot.json"), restore_config.restore_path) def test_run_fields_include_concrete_command(self) -> None: - configuration = make_config(set_path=Path("maps.yaml"), user="alice", apply=True) - command = cli.resolve_command(configuration) + configuration = make_config( + maps_path=Path("maps.yaml"), + users=("alice",), + apply=True, + ) + command = cli.resolve_command("set", configuration) fields = cli.run_fields(configuration, command, "https://sourcegraph.example.com") - self.assertEqual("set_user", fields["cli_cmd"]) + self.assertEqual("set_users", fields["cli_cmd"]) self.assertEqual("set", fields["base_cmd"]) - self.assertEqual("user", fields["set_mode"]) + self.assertEqual("users", fields["set_mode"]) self.assertEqual(True, fields["apply_flag"]) self.assertEqual(25, fields["explicit_permissions_batch_size"]) self.assertEqual(False, fields["trace"]) + self.assertEqual(60.0, fields["http_timeout_seconds"]) - def test_run_command_passes_primary_data_to_combined_sync(self) -> None: - configuration = make_config(get=True, sync_saml_organizations=True) - command = cli.resolve_command(configuration) + def test_run_command_passes_set_data_to_combined_sync(self) -> None: + configuration = make_config(sync_saml_organizations=True) + command = cli.resolve_command("set", configuration) client = cast(src.SourcegraphClient, object()) sourcegraph_site_config = object() command_data = cli.run_context.CommandData() @@ -264,13 +436,14 @@ def test_run_command_passes_primary_data_to_combined_sync(self) -> None: "validate_site_config", return_value=sourcegraph_site_config, ), - mock.patch.object(cli, "run_get", return_value=command_data) as run_get, + mock.patch.object(cli, "run_set", return_value=command_data) as run_set, mock.patch.object(cli, "run_sync_saml_organizations") as run_sync_saml_orgs, ): cli.run_command(configuration, command, client, worker_pool) - run_get.assert_called_once_with( + run_set.assert_called_once_with( configuration, + command, client, sourcegraph_site_config, worker_pool, @@ -283,14 +456,71 @@ def test_run_command_passes_primary_data_to_combined_sync(self) -> None: worker_pool, ) + def test_package_exports_programmatic_runner_and_config(self) -> None: + self.assertIs(src_auth_perms_sync.Config, cli.Config) + self.assertIs(src_auth_perms_sync.Get, cli.Get) + self.assertIs(src_auth_perms_sync.Set, cli.Set) + self.assertIs(src_auth_perms_sync.Restore, cli.Restore) + self.assertIs(src_auth_perms_sync.SyncSamlOrgs, cli.SyncSamlOrgs) + self.assertEqual( + ["Config", "Get", "Restore", "Set", "SyncSamlOrgs"], + src_auth_perms_sync.__all__, + ) + + def test_programmatic_runner_uses_supplied_config(self) -> None: + configuration = make_config(parallelism=1, sample_interval=0) + captured: list[tuple[cli.Config, cli.ResolvedCommand, str]] = [] + + def capture_run( + scoped_config: cli.Config, + command: cli.ResolvedCommand, + endpoint: str, + _worker_pool: ThreadPoolExecutor, + ) -> None: + captured.append((scoped_config, command, endpoint)) + + with ( + mock.patch.object(cli, "run_with_client", side_effect=capture_run), + mock.patch.object( + cli.src, + "logging_settings_from_config", + return_value=object(), + ), + mock.patch.object(cli.src, "logging", return_value=contextlib.nullcontext(None)), + ): + self.assertTrue(src_auth_perms_sync.Get(configuration)) + + self.assertEqual(1, len(captured)) + scoped_config, command, endpoint = captured[0] + self.assertIs(configuration, scoped_config) + self.assertEqual("get", command.name) + self.assertEqual("https://sourcegraph.example.com", endpoint) + + def test_programmatic_runner_returns_false_on_failure(self) -> None: + configuration = make_config(parallelism=1, sample_interval=0) + + with ( + mock.patch.object(cli, "run_with_client", side_effect=SystemExit(1)), + mock.patch.object( + cli.src, + "logging_settings_from_config", + return_value=object(), + ), + mock.patch.object(cli.src, "logging", return_value=contextlib.nullcontext(None)), + ): + self.assertFalse(src_auth_perms_sync.Get(configuration)) + def assert_config_error( - self, config: cli.SrcAuthPermissionsSyncConfig, expected_message: str + self, + command_name: cli.CommandName, + config: cli.Config, + expected_message: str, ) -> None: captured_stderr = io.StringIO() with ( contextlib.redirect_stderr(captured_stderr), self.assertRaises(SystemExit) as exit_context, ): - cli.validate_config(config) + cli.validate_config(command_name, config) self.assertEqual(2, exit_context.exception.code) self.assertIn(expected_message, captured_stderr.getvalue()) diff --git a/tests/unit/test_maps.py b/tests/unit/test_maps.py index 34bb483..d0d74f3 100644 --- a/tests/unit/test_maps.py +++ b/tests/unit/test_maps.py @@ -1,12 +1,19 @@ from __future__ import annotations +import base64 +import itertools +import json import tempfile import unittest from pathlib import Path +from typing import cast import yaml -from src_auth_perms_sync.permissions import maps +from src_auth_perms_sync.permissions import full_set, mapping, maps +from src_auth_perms_sync.permissions import queries as permission_queries +from src_auth_perms_sync.permissions import types as permission_types +from src_auth_perms_sync.shared import queries as shared_queries from src_auth_perms_sync.shared import types as shared_types @@ -73,3 +80,461 @@ def test_count_users_per_provider_counts_each_user_once_per_provider(self) -> No self.assertEqual(1, counts[maps.BUILTIN_PROVIDER_KEY]) self.assertEqual(1, counts[("saml", "https://idp.example.com", "sourcegraph")]) self.assertEqual(1, counts[("github", "https://github.com/", "github-client")]) + + def test_external_service_to_yaml_lifts_username_without_config(self) -> None: + service: permission_types.ExternalService = { + "id": "RXh0ZXJuYWxTZXJ2aWNlOjE=", + "kind": "BITBUCKETSERVER", + "displayName": "Bitbucket LOB1", + "url": "https://bitbucket.example.com/", + "repoCount": 0, + "createdAt": "2026-05-30T00:00:00Z", + "updatedAt": "2026-05-30T00:00:00Z", + "lastSyncAt": None, + "nextSyncAt": None, + "lastSyncError": None, + "warning": None, + "unrestricted": False, + "suspended": False, + "hasConnectionCheck": False, + "supportsRepoExclusion": False, + "creator": None, + "lastUpdater": None, + "config": json.dumps({"username": "LOB1-SA1", "token": "REDACTED"}), + } + + rendered = maps.external_service_to_yaml(service) + + self.assertEqual("LOB1-SA1", rendered["username"]) + self.assertNotIn("config", rendered) + + +class MappingTests(unittest.TestCase): + def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: + rules_without_email_filters = cast( + list[permission_types.MappingRule], + [ + { + "name": "username only", + "users": {"usernames": ["alice"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + rules_with_email_filters = cast( + list[permission_types.MappingRule], + [ + { + "name": "email only", + "users": {"emails": ["alice@example.com"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + + self.assertFalse(mapping.mapping_rules_need_user_emails(rules_without_email_filters)) + self.assertTrue(mapping.mapping_rules_need_user_emails(rules_with_email_filters)) + + def test_user_filter_matchers_intersect_without_expanding_selection(self) -> None: + providers: list[shared_types.AuthProvider] = [ + { + "serviceType": "builtin", + "serviceID": "", + "clientID": "", + "displayName": "Builtin", + "isBuiltin": True, + "configID": "", + } + ] + users = [ + self.make_user("user-1", "alice", True, "alice@example.com", True), + self.make_user("user-2", "bob", True, "bob@example.com", True), + self.make_user("user-3", "carol", True, "carol@example.com", False), + self.make_user("user-4", "dana", False, "dana@example.com", True), + ] + user_fields: dict[str, object] = { + "authProvider": {"type": "builtin"}, + "emails": ["alice@example.com", "carol@example.com", "dana@example.com"], + "emailRegexes": [r"^(alice|bob|carol)@example\.com$"], + "usernames": ["alice", "bob", "carol"], + "usernameRegexes": [r"^(alice|dana)$"], + } + single_filter_usernames = { + name: self.usernames_for( + mapping.resolve_users( + cast(permission_types.UserSelector, {name: matcher}), users, providers + ), + ) + for name, matcher in user_fields.items() + } + + for filter_count in range(2, len(user_fields) + 1): + for filter_names in itertools.combinations(user_fields, filter_count): + matched_usernames = self.usernames_for( + mapping.resolve_users( + cast( + permission_types.UserSelector, + {name: user_fields[name] for name in filter_names}, + ), + users, + providers, + ) + ) + expected_usernames = self.intersection_for(filter_names, single_filter_usernames) + + self.assertEqual(expected_usernames, matched_usernames) + for name in filter_names: + self.assertLessEqual(matched_usernames, single_filter_usernames[name]) + + self.assertEqual( + {"alice"}, + self.usernames_for( + mapping.resolve_users( + cast(permission_types.UserSelector, user_fields), users, providers + ) + ), + ) + + def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> None: + sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") + example_private_repo = self.make_repo("repo-2", "github.com/example/private-repo") + gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") + example_public_repo = self.make_repo("repo-4", "github.com/example/public-repo") + all_repos = { + sourcegraph_repo["id"]: sourcegraph_repo, + example_private_repo["id"]: example_private_repo, + gitlab_repo["id"]: gitlab_repo, + example_public_repo["id"]: example_public_repo, + } + services_by_id = { + 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise", "enterprise-sync"), + 2: self.make_external_service(2, "GITHUB", "GitHub Cloud", "cloud-sync"), + } + repos_by_external_service_id = { + 1: [sourcegraph_repo, example_private_repo, gitlab_repo], + 2: [example_public_repo], + } + repository_fields: dict[str, object] = { + "codeHostConnection": {"username": "enterprise-sync"}, + "names": [ + "github.com/example/private-repo", + "gitlab.com/example/private-repo", + ], + "nameRegexes": [r"^github\.com/example/"], + } + single_filter_repo_names = { + name: self.repo_names_for( + mapping.resolve_repos( + cast(permission_types.RepositorySelector, {name: matcher}), + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ) + for name, matcher in repository_fields.items() + } + + for filter_count in range(2, len(repository_fields) + 1): + for filter_names in itertools.combinations(repository_fields, filter_count): + matched_repo_names = self.repo_names_for( + mapping.resolve_repos( + cast( + permission_types.RepositorySelector, + {name: repository_fields[name] for name in filter_names}, + ), + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ) + expected_repo_names = self.intersection_for(filter_names, single_filter_repo_names) + + self.assertEqual(expected_repo_names, matched_repo_names) + for name in filter_names: + self.assertLessEqual(matched_repo_names, single_filter_repo_names[name]) + + self.assertEqual( + {"github.com/example/private-repo"}, + self.repo_names_for( + mapping.resolve_repos( + cast(permission_types.RepositorySelector, repository_fields), + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ), + ) + + def test_validate_mapping_rules_accepts_flat_text_selector_lists(self) -> None: + mapping.validate_mapping_rules( + cast( + list[permission_types.MappingRule], + [ + { + "name": "flat selector lists", + "users": { + "emails": ["alice@example.com"], + "emailRegexes": [r"^team-.*@example\.com$"], + "usernames": ["alice"], + "usernameRegexes": [r"^team-.*"], + }, + "repos": { + "names": ["github.com/example/private-repo"], + "nameRegexes": [r"^github\.com/example/"], + }, + } + ], + ) + ) + + def test_repository_name_matches_any_pattern(self) -> None: + sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") + github_repo = self.make_repo("repo-2", "github.com/example/private-repo") + gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") + all_repos = { + sourcegraph_repo["id"]: sourcegraph_repo, + github_repo["id"]: github_repo, + gitlab_repo["id"]: gitlab_repo, + } + + matched_repos = mapping.resolve_repos( + { + "nameRegexes": [ + r"^github\.com/example/", + r"^gitlab\.com/example/", + ], + }, + {}, + {}, + all_repos, + ) + + self.assertEqual( + {"github.com/example/private-repo", "gitlab.com/example/private-repo"}, + self.repo_names_for(matched_repos), + ) + + def test_username_matches_any_pattern(self) -> None: + providers: list[shared_types.AuthProvider] = [] + users = [ + self.make_user("user-1", "alice", True, "alice@example.com", True), + self.make_user("user-2", "test_user_00001", True, "one@example.com", True), + self.make_user("user-3", "test_user_00100", True, "hundred@example.com", True), + self.make_user("user-4", "service-account", True, "service@example.com", True), + ] + + matched_users = mapping.resolve_users( + {"usernameRegexes": [r"^(alice|test_user_00[0-9]{3})$"]}, + users, + providers, + ) + + self.assertEqual( + {"alice", "test_user_00001", "test_user_00100"}, + self.usernames_for(matched_users), + ) + + def test_validate_mapping_rules_rejects_invalid_text_matchers(self) -> None: + with self.assertRaises(SystemExit) as raised: + mapping.validate_mapping_rules( + cast( + list[permission_types.MappingRule], + [ + { + "name": "invalid flat selector lists", + "users": { + "emails": "alice@example.com", + "usernames": [""], + }, + "repos": {"names": [123], "nameRegexes": ["["]}, + }, + { + "name": "invalid code host field", + "users": {"usernames": ["alice"]}, + "repos": { + "codeHostConnection": {"config": {"username": "old"}, "id": 1}, + "regex": r"^github\.com/example/", + }, + }, + { + "name": "invalid username regex", + "users": {"usernameRegexes": ["["]}, + "repos": {"names": ["github.com/example/private-repo"]}, + }, + { + "users": {"usernames": ["alice"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + }, + ], + ) + ) + + message = str(raised.exception) + self.assertIn("users.emails must be a list of strings", message) + self.assertIn("users.usernames[0] is an empty string", message) + self.assertIn("repos.names[0] must be a string", message) + self.assertIn("repos.nameRegexes[0] is not a valid Python regex", message) + self.assertIn("users.usernameRegexes[0] is not a valid Python regex", message) + self.assertIn("unknown repos field 'regex'", message) + self.assertIn("unknown repos.codeHostConnection field 'config'", message) + self.assertIn("unknown repos.codeHostConnection field 'id'", message) + self.assertIn("`name:` is missing", message) + + def make_user( + self, + user_id: str, + username: str, + builtin_auth: bool, + email: str, + verified: bool, + ) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": builtin_auth, + "emails": [{"email": email, "verified": verified}], + "externalAccounts": {"nodes": []}, + } + + def make_repo(self, repo_id: str, name: str) -> permission_types.Repository: + return {"id": repo_id, "name": name} + + def make_external_service( + self, + external_service_id: int, + kind: str, + display_name: str, + username: str | None = None, + ) -> permission_types.ExternalService: + graphql_id = base64.b64encode(f"ExternalService:{external_service_id}".encode()).decode() + return { + "id": graphql_id, + "kind": kind, + "displayName": display_name, + "url": f"https://code-host-{external_service_id}.example.com", + "repoCount": 0, + "createdAt": "2026-05-30T00:00:00Z", + "updatedAt": "2026-05-30T00:00:00Z", + "lastSyncAt": None, + "nextSyncAt": None, + "lastSyncError": None, + "warning": None, + "unrestricted": False, + "suspended": False, + "hasConnectionCheck": False, + "supportsRepoExclusion": False, + "creator": None, + "lastUpdater": None, + "config": json.dumps({"username": username} if username else {}), + } + + def usernames_for(self, users: list[shared_types.User]) -> set[str]: + return {user["username"] for user in users} + + def repo_names_for(self, repos: list[permission_types.Repository]) -> set[str]: + return {repo["name"] for repo in repos} + + def intersection_for( + self, names: tuple[str, ...], sets_by_name: dict[str, set[str]] + ) -> set[str]: + matched = set(sets_by_name[names[0]]) + for name in names[1:]: + matched &= sets_by_name[name] + return matched + + +class FullSetPlanningTests(unittest.TestCase): + def test_full_set_plan_reuses_user_tuple_for_non_overlapping_repos(self) -> None: + users = [self.make_user("user-1", "bob"), self.make_user("user-2", "alice")] + repositories = [ + self.make_repo("repo-1", "github.com/example/one"), + self.make_repo("repo-2", "github.com/example/two"), + ] + context = self.make_context( + [ + { + "name": "alice and bob get example repos", + "users": {"usernames": ["alice", "bob"]}, + "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, + } + ], + repositories, + ) + + plan = full_set.plan_full_set_permissions(context, users) + + self.assertEqual(("alice", "bob"), plan.expected_users["repo-1"]) + self.assertEqual(("alice", "bob"), plan.expected_users["repo-2"]) + self.assertIs(plan.expected_users["repo-1"], plan.expected_users["repo-2"]) + self.assertEqual(4, plan.total_grants) + + def test_full_set_plan_unions_only_overlapping_repos(self) -> None: + users = [ + self.make_user("user-1", "alice"), + self.make_user("user-2", "bob"), + self.make_user("user-3", "chris"), + ] + repositories = [ + self.make_repo("repo-1", "github.com/example/one"), + self.make_repo("repo-2", "github.com/example/two"), + self.make_repo("repo-3", "github.com/example/three"), + ] + context = self.make_context( + [ + { + "name": "alice and bob get first repos", + "users": {"usernames": ["alice", "bob"]}, + "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, + }, + { + "name": "bob and chris get second repos", + "users": {"usernames": ["bob", "chris"]}, + "repos": {"names": ["github.com/example/two", "github.com/example/three"]}, + }, + ], + repositories, + ) + + plan = full_set.plan_full_set_permissions(context, users) + + self.assertEqual(("alice", "bob"), plan.expected_users["repo-1"]) + self.assertEqual(("alice", "bob", "chris"), plan.expected_users["repo-2"]) + self.assertEqual(("bob", "chris"), plan.expected_users["repo-3"]) + self.assertEqual(7, plan.total_grants) + + def make_context( + self, + mapping_rules: list[permission_types.MappingRule], + repositories: list[permission_types.Repository], + ) -> permission_types.MappingContext: + return permission_types.MappingContext( + mapping_rules=mapping_rules, + providers=[], + saml_groups_attribute_names={}, + services_by_id={}, + repos_by_external_service_id={}, + all_repos_by_id={repository["id"]: repository for repository in repositories}, + ) + + def make_user(self, user_id: str, username: str) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": True, + "emails": [], + "externalAccounts": {"nodes": []}, + } + + def make_repo(self, repo_id: str, name: str) -> permission_types.Repository: + return {"id": repo_id, "name": name} + + +class QueryTests(unittest.TestCase): + def test_user_email_fields_are_opt_in(self) -> None: + self.assertNotIn("emails {", shared_queries.QUERY_USERS) + self.assertNotIn("emails {", shared_queries.query_users()) + self.assertIn("emails {", shared_queries.query_users(include_emails=True)) + + self.assertNotIn("emails {", permission_queries.QUERY_USER_BY_ID) + self.assertNotIn("emails {", permission_queries.query_user_by_id()) + self.assertIn("emails {", permission_queries.query_user_by_id(include_emails=True)) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 3525fc9..bcb72f9 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -177,22 +177,14 @@ def test_list_users_explicit_repos_batches_aliases_and_follows_pages(self) -> No ), ] - class FakeGraphQLClient: - def __init__(self, **_kwargs: object) -> None: - pass - - def execute( - self, - query: str, - variables: src.JSONDict, - *, - follow_pages: bool = True, - ) -> src.JSONDict: - calls.append((query, dict(variables), follow_pages)) - return responses.pop(0) - - def graphql(query: str, variables: object = None) -> src.JSONDict: - return FakeGraphQLClient().execute(query, cast(src.JSONDict, variables or {})) + def graphql( + query: str, + variables: object = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + calls.append((query, dict(cast(src.JSONDict, variables or {})), follow_pages)) + return responses.pop(0) client = cast( src.SourcegraphClient, @@ -203,12 +195,11 @@ def graphql(query: str, variables: object = None) -> src.JSONDict: graphql=graphql, ), ) - with patch.object(permissions_sourcegraph.src, "GraphQLClient", FakeGraphQLClient): - repos_by_user_id = permissions_sourcegraph.list_users_explicit_repos( - client, - ["user-1", "user-2"], - batch_size=2, - ) + repos_by_user_id = permissions_sourcegraph.list_users_explicit_repos( + client, + ["user-1", "user-2"], + batch_size=2, + ) self.assertEqual( { diff --git a/uv.lock b/uv.lock index 490b812..3feed68 100644 --- a/uv.lock +++ b/uv.lock @@ -318,7 +318,6 @@ wheels = [ [[package]] name = "src-auth-perms-sync" -version = "0.2.2" source = { editable = "." } dependencies = [ { name = "json5" },