diff --git a/.copyright.tmpl b/.copyright.tmpl new file mode 100644 index 0000000..e46736a --- /dev/null +++ b/.copyright.tmpl @@ -0,0 +1,2 @@ +SPDX-FileCopyrightText: 2026 AOT Technologies +SPDX-License-Identifier: Apache-2.0 diff --git a/.github/workflows/docker-policy.yml b/.github/workflows/docker-policy.yml new file mode 100644 index 0000000..4a0f53e --- /dev/null +++ b/.github/workflows/docker-policy.yml @@ -0,0 +1,43 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +# Enforce baseline container hygiene: digest-pinned base image and non-root USER. +name: Docker policy + +on: + push: + branches: [main, master] + paths: + - "Dockerfile" + - "docker/**" + pull_request: + paths: + - "Dockerfile" + - "docker/**" + +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Verify Dockerfiles use digest-pinned FROM and non-root USER + shell: bash + run: | + set -euo pipefail + shopt -s nullglob + files=(Dockerfile docker/*/Dockerfile) + for f in "${files[@]}"; do + echo "Checking $f" + if ! grep -qE '^FROM [^[:space:]]+@sha256:[a-f0-9]{64}' "$f"; then + echo "ERROR: $f must use FROM image@sha256:<64-hex-digest>" >&2 + exit 1 + fi + if ! grep -qE '^USER[[:space:]]' "$f"; then + echo "ERROR: $f must end with a non-root USER directive" >&2 + exit 1 + fi + done + echo "PASS: all Dockerfiles pinned and non-root" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..2c3ae51 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,54 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + + +name: Lint and Type Check + +on: + pull_request: + branches: [ "main" ] + +jobs: + ruff: + name: Ruff Linters + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run Ruff (Linting) + # --output-format=github automatically creates inline annotations on your PR code! + run: ruff check --output-format=github . + + - name: Run Ruff (Formatting Check) + run: ruff format --check . + + mypy: + name: Mypy Type Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run Mypy (Type checking) + run: mypy diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..74a2252 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,233 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +# Pinned action SHAs are immutable; update intentionally when upgrading. +# yamllint disable rule:line-length +name: Publish Node Wire package + +# Manual trigger: go to Actions → "Publish Node Wire package" → Run workflow. +# package_path must match the allowlist below (prevents confused-deputy path abuse). +# +# Examples: +# package_path: packages/runtime +# package_path: packages/connectors/fhir_epic +# package_path: packages/connectors/google_drive +on: + workflow_dispatch: + inputs: + package_path: + description: | + Relative path to the package directory (must be allowlisted). + Examples: packages/runtime | packages/connectors/fhir_epic + required: true + type: string + version: + description: "Semver version to publish (e.g. 0.2.0)" + required: true + type: string + +env: + PIP_AUDIT_VERSION: "2.7.3" + CYCLONEDX_BOM_VERSION: "4.6.1" + +jobs: + # ───────────────────────────────────────────────────────────────────────────── + # Build a binary wheel for each platform (Linux / macOS / Windows). + # cibuildwheel compiles Cython extensions and produces manylinux / macosx / + # win_amd64 wheels. The NoPyBuild override in setup.py ensures .py source + # files are excluded from all wheels. + # ───────────────────────────────────────────────────────────────────────────── + build-wheels: + name: Build (${{ matrix.os }}) + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + permissions: + contents: read + + steps: + - name: Validate package path (allowlist) + shell: python + run: | + import os + import sys + + raw = "${{ inputs.package_path }}".strip().replace("\\", "/") + norm = os.path.normpath(raw).replace("\\", "/") + # Reject traversal / absolute paths + if norm.startswith("..") or os.path.isabs(raw): + print("ERROR: invalid package_path", file=sys.stderr) + sys.exit(1) + allowed = { + "packages/runtime", + "packages/connectors/http_generic", + "packages/connectors/stripe", + "packages/connectors/smtp", + "packages/connectors/google_drive", + "packages/connectors/fhir_cerner", + "packages/connectors/fhir_epic", + } + if norm not in allowed: + print(f"ERROR: package_path {norm!r} is not allowlisted.", file=sys.stderr) + print("Allowed:", sorted(allowed), file=sys.stderr) + sys.exit(1) + print(f"PASS: package_path {norm!r} is allowlisted") + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: "3.11" + + - name: Install build tools + run: python -m pip install --upgrade pip "cython>=3.0" "cibuildwheel>=2.16.0" + + - name: Build platform wheel(s) + run: | + cd "${{ inputs.package_path }}" + python -m cibuildwheel --output-dir dist + env: + # Build only Python 3.11+ wheels (matches requires-python in pyproject.toml) + CIBW_BUILD: "cp311-* cp312-*" + # Skip 32-bit targets and PyPy — not supported + CIBW_SKIP: "*-win32 *-manylinux_i686 pp*" + + # ── Security gate: verify no .py source files leaked into any wheel ────── + - name: Verify binary-only wheel (no .py source) + shell: python + run: | + import glob, sys, zipfile + + wheels = glob.glob("${{ inputs.package_path }}/dist/*.whl") + if not wheels: + print("ERROR: No wheels produced", file=sys.stderr) + sys.exit(1) + + leaked: dict[str, list[str]] = {} + for whl in wheels: + with zipfile.ZipFile(whl) as zf: + bad = [n for n in zf.namelist() if n.endswith(".py")] + if bad: + leaked[whl] = bad + + if leaked: + print("SECURITY FAIL: .py files found in wheel(s):", file=sys.stderr) + for whl, files in leaked.items(): + print(f" {whl}:", file=sys.stderr) + for f in files: + print(f" {f}", file=sys.stderr) + sys.exit(1) + + print(f"PASS: {len(wheels)} wheel(s) verified — no .py source files") + + - name: Record wheel SHA256 (artifact integrity) + shell: python + run: | + import glob, hashlib, pathlib, sys + dist = pathlib.Path("${{ inputs.package_path }}") / "dist" + wheels = sorted(dist.glob("*.whl")) + if not wheels: + print("ERROR: no wheels to hash", file=sys.stderr) + sys.exit(1) + lines = [] + for w in wheels: + h = hashlib.sha256(w.read_bytes()).hexdigest() + line = f"{h} {w.name}" + print(line) + lines.append(line) + (dist / "sha256sums.txt").write_text("\n".join(lines) + "\n", encoding="utf-8") + + - name: Upload wheel artifacts + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: wheels-${{ matrix.os }} + path: ${{ inputs.package_path }}/dist/*.whl + if-no-files-found: error + + # ───────────────────────────────────────────────────────────────────────────── + # Collect all platform wheels, run supply-chain checks, publish to PyPI. + # Uses Trusted Publisher (OIDC) — no long-lived PyPI API tokens needed. + # Configure on PyPI: Settings → Publishing → Add Publisher (GitHub, this repo, + # workflow name = "Publish Node Wire package"). + # ───────────────────────────────────────────────────────────────────────────── + publish: + name: Publish to PyPI + needs: build-wheels + runs-on: ubuntu-latest + permissions: + id-token: write # Required for Trusted Publisher OIDC + contents: read + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: "3.11" + + - name: Download all wheel artifacts + uses: actions/download-artifact@v4 + with: + path: dist-all + + - name: Flatten into dist/ directory + run: | + mkdir -p dist + find dist-all -name "*.whl" -exec cp {} dist/ \; + echo "Wheels collected:" + ls dist/ + sha256sum dist/*.whl | tee dist/sha256sums.txt + + - name: Validate wheel version matches input + shell: python + run: | + import glob, sys + + expected = "${{ inputs.version }}" + wheels = glob.glob("dist/*.whl") + if not wheels: + print("ERROR: No wheels found for publish", file=sys.stderr) + sys.exit(1) + + bad = [w for w in wheels if f"-{expected}-" not in w and f"-{expected.replace('.', '_')}-" not in w] + if bad: + print(f"ERROR: Version mismatch. Expected {expected!r} in filename.", file=sys.stderr) + for w in bad: + print(f" {w}", file=sys.stderr) + sys.exit(1) + + print(f"PASS: {len(wheels)} wheel(s) match version {expected!r}") + + - name: Install built wheels for CVE scan + run: pip install dist/*.whl + + - name: Vulnerability scan (CVE gate — blocks publish on HIGH or higher) + run: | + pip install "pip-audit==${{ env.PIP_AUDIT_VERSION }}" + pip-audit --fail-on HIGH + + - name: Generate SBOM + run: | + pip install "cyclonedx-bom==${{ env.CYCLONEDX_BOM_VERSION }}" + cyclonedx-py environment -o sbom.json + echo "SBOM generated: sbom.json" + + - name: Upload SBOM as release artifact + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: sbom-${{ inputs.version }} + path: sbom.json + + - name: Publish to PyPI (Trusted Publisher / OIDC + Sigstore attestations) + uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 + with: + packages-dir: dist/ + # attestations: true generates a Sigstore attestation automatically. + # Clients can verify with: pip download && python -m pypi_attestation_viewer + attestations: true diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..21e2774 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,137 @@ + +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +name: CI – Pytest + +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + workflow_dispatch: + +jobs: + test: + name: Run pytest (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12"] + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: | + pyproject.toml + + - name: Install dependencies + run: uv sync --all-extras --dev + + - name: Run pytest + run: uv run pytest --cov --cov-report=html --cov-report=term-missing + + - name: Upload HTML coverage report + if: always() + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: coverage-html-py${{ matrix.python-version }} + path: htmlcov/ + if-no-files-found: ignore + + + playground-integration: + name: Playground integration tests + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: | + pyproject.toml + + - name: Install dependencies + run: uv sync --all-extras --dev + + - name: Install Playwright browsers + run: uv run python -m playwright install chromium --with-deps + + - name: Run playground integration tests + env: + + #gdrive + GOOGLE_DRIVE_SA_JSON: ${{ secrets.GOOGLE_DRIVE_SA_JSON }} + GOOGLE_DRIVE_FOLDER_ID: ${{ secrets.GOOGLE_DRIVE_FOLDER_ID }} + GDRIVE_TEST_RECIPIENT_EMAIL: ${{ secrets.GDRIVE_TEST_RECIPIENT_EMAIL }} + + #stripe + STRIPE_API_KEY: ${{ secrets.STRIPE_API_KEY }} + STRIPE_TEST_CUSTOMER_ID: ${{ secrets.STRIPE_TEST_CUSTOMER_ID }} + STRIPE_TEST_PRICE_ID: ${{ secrets.STRIPE_TEST_PRICE_ID }} + + #salesforce + SALESFORCE_INSTANCE_URL: ${{ secrets.SALESFORCE_INSTANCE_URL }} + SALESFORCE_TOKEN_URL: ${{ secrets.SALESFORCE_TOKEN_URL }} + SALESFORCE_CLIENT_ID: ${{ secrets.SALESFORCE_CLIENT_ID }} + SALESFORCE_CLIENT_SECRET: ${{ secrets.SALESFORCE_CLIENT_SECRET }} + SALESFORCE_REFRESH_TOKEN: ${{ secrets.SALESFORCE_REFRESH_TOKEN }} + + #slack + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} + SLACK_TEST_CHANNEL: ${{ secrets.SLACK_TEST_CHANNEL }} + SLACK_TEST_USER_ID: ${{ secrets.SLACK_TEST_USER_ID }} + SLACK_TEST_CHANNEL_ID: ${{ secrets.SLACK_TEST_CHANNEL_ID }} + + #epic fhir + EPIC_CLIENT_ID: ${{ secrets.EPIC_CLIENT_ID }} + EPIC_PRIVATE_KEY: ${{ secrets.EPIC_PRIVATE_KEY }} + EPIC_TOKEN_URL: ${{ secrets.EPIC_TOKEN_URL }} + EPIC_KID: ${{ secrets.EPIC_KID }} + EPIC_FHIR_BASE_URL: ${{ secrets.EPIC_FHIR_BASE_URL }} + + #cerner + CERNER_CLIENT_ID: ${{ secrets.CERNER_CLIENT_ID }} + CERNER_PRIVATE_KEY: ${{ secrets.CERNER_PRIVATE_KEY }} + CERNER_TOKEN_URL: ${{ secrets.CERNER_TOKEN_URL }} + CERNER_KID: ${{ secrets.CERNER_KID }} + CERNER_FHIR_BASE_URL: ${{ secrets.CERNER_FHIR_BASE_URL }} + CERNER_SCOPES: ${{ secrets.CERNER_SCOPES }} + + # Disable authentication and dotenv loading for playground tests, and restrict connectors + NW_REST_AUTH_DISABLED: "true" + NW_REST_LOAD_DOTENV: "false" + NW_ALLOWED_CONNECTORS: "google_drive,salesforce,stripe,slack,fhir_epic,fhir_cerner,http_generic" + run: uv run pytest tests/playground/ --no-cov -v + + - name: Upload Playwright traces on failure + if: failure() + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: playwright-traces + path: test-results/ + if-no-files-found: ignore diff --git a/.github/workflows/quality-gates.yml b/.github/workflows/quality-gates.yml new file mode 100644 index 0000000..524fc09 --- /dev/null +++ b/.github/workflows/quality-gates.yml @@ -0,0 +1,88 @@ + +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +name: Quality gates + +on: + pull_request: + push: + branches: [main, master] + +# This workflow enforces Bandit, tests/coverage, and SonarQube. + +jobs: + bandit: + name: Bandit security scan + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: "3.11" + - name: Install dependencies + run: python -m pip install --upgrade pip && pip install -e ".[dev,agents]" + # Bandit exits non-zero when *any* finding exists (incl. low/medium). The + # enforce step below gates on high only; use --exit-zero here so this step + # always produces the JSON artifact and the job can print a summary. + - name: Generate Bandit JSON report + run: bandit -c pyproject.toml -r src -f json -o bandit-report.json --exit-zero + - name: Bandit findings summary (log) + run: python scripts/bandit_report_summary.py bandit-report.json + - name: Upload Bandit report artifact + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: bandit-report + path: bandit-report.json + if-no-files-found: error + - name: Enforce high-severity Bandit gate + run: bandit -c pyproject.toml -r src --severity-level high + + test: + name: Tests and coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: "3.11" + - name: Install dependencies + run: python -m pip install --upgrade pip && pip install -e ".[dev,agents]" + - name: Run tests (coverage.xml generated via pyproject addopts) + run: pytest tests/ -v + - name: Upload coverage artifact + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: coverage-xml + path: coverage.xml + if-no-files-found: error + + sonar: + name: SonarQube analysis + runs-on: ubuntu-latest + needs: [bandit, test] + # Sonar scan requires repository secrets; skip gracefully when unavailable + # (e.g. PRs from forks where secrets are not exposed). + if: ${{ secrets.SONAR_TOKEN != '' && secrets.SONAR_HOST_URL != '' }} + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Download coverage artifact + uses: actions/download-artifact@v4 + with: + name: coverage-xml + path: . + - name: SonarQube scan (wait for quality gate) + uses: SonarSource/sonarqube-scan-action@v5.3.1 + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + SONAR_HOST_URL: ${{ secrets.SONAR_HOST_URL }} + with: + args: > + -Dsonar.qualitygate.wait=true + -Dsonar.qualitygate.timeout=300 diff --git a/.github/workflows/security-pr.yml b/.github/workflows/security-pr.yml new file mode 100644 index 0000000..264ef7f --- /dev/null +++ b/.github/workflows/security-pr.yml @@ -0,0 +1,69 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +### Continuous security checks for publishable Python packages on pull requests. + + +name: Python package security PR checks + +on: + pull_request: + paths: + - ".github/workflows/security-pr.yml" + - ".github/workflows/publish.yml" + - "pyproject.toml" + - "uv.lock" + - "packages/**" + - "src/**" + push: + branches: [main, master] + paths: + - ".github/workflows/security-pr.yml" + - ".github/workflows/publish.yml" + - "pyproject.toml" + - "uv.lock" + - "packages/**" + - "src/**" + schedule: + - cron: "17 3 * * *" + +env: + PIP_AUDIT_VERSION: "2.7.3" + +jobs: + vulnerability-scan: + name: Vulnerability scan (${{ matrix.package_path }}) + runs-on: ubuntu-latest + permissions: + contents: read + strategy: + fail-fast: false + matrix: + package_path: + - packages/runtime + - packages/connectors/http_generic + - packages/connectors/stripe + - packages/connectors/smtp + - packages/connectors/google_drive + - packages/connectors/fhir_cerner + - packages/connectors/fhir_epic + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@v5.3.0 + with: + python-version: "3.11" + + # Connector packages declare node-wire-runtime>=0.1.0 as a PyPI-style dep; install + # packages/runtime from the repo so pip resolves it without requiring a published wheel. + - name: Install package and audit tool + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade "setuptools>=78.1.1" + python -m pip install "packages/runtime" "${{ matrix.package_path }}" + python -m pip install "pip-audit==${{ env.PIP_AUDIT_VERSION }}" + + - name: Vulnerability scan (blocks any vulnerability) + run: pip-audit diff --git a/.gitignore b/.gitignore index 96bb6a5..534874b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 build/ .pytest_cache @@ -7,4 +10,33 @@ __pycache__/ dist/ .env .DS_Store -**/.DS_Store \ No newline at end of file +**/.DS_Store +.coverage +.coverage.* +htmlcov/ +*.gitattributes +# GCP / cloud credentials +connectorplatform-*.json +*-service-account.json +*credentials*.json +service_account.json + +# Grafana exports (auto-generated) +grafana/*.json + +# Python lock files +uv.lock + +# Temporary test/upload files +upload-test.txt +*-test.txt +upload-*.txt +coverage.xml +bandit-report.json +.venv*/ + +# debug +launch.json + +# log +mcp_debug.log diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e813c81 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.5 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + pass_filenames: false + additional_dependencies: ["pydantic", "fastapi"] + + - repo: https://github.com/PyCQA/bandit + rev: 1.8.6 + hooks: + - id: bandit + args: ["-c", "pyproject.toml", "-r", "src"] + pass_filenames: false diff --git a/DEPENDENCIES.md b/DEPENDENCIES.md new file mode 100644 index 0000000..dd83324 --- /dev/null +++ b/DEPENDENCIES.md @@ -0,0 +1,197 @@ +# Node Wire Open Source Dependencies + +This file is automatically generated and contains an inventory of all third-party dependencies used in the Node Wire project. + +## License Classification Criteria +To maintain open-source compliance, dependencies are evaluated against the following criteria: +* **? Safe (Permissive):** MIT, Apache-2.0, BSD, PSF. These licenses are universally safe for our Apache 2.0 open-source release and can be freely used, modified, and distributed. +* **?? Needs Review:** Custom or obscure licenses. These require manual review by the engineering team to ensure they don't impose conflicting obligations. +* **? Risky (Copyleft):** GPLv2, GPLv3, AGPL. These licenses are strictly prohibited in the runtime application as they force derivative works to adopt the same open-source license. They may only be used as isolated, non-distributed Development/Linting tools. + +--- + +| Name | Version | License | URL | +|---------------------------------------------------|-----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------| +| Cython | 3.2.4 | Apache Software License | https://cython.org/ | +| Deprecated | 1.3.1 | MIT License | https://github.com/laurent-laporte-pro/deprecated | +| Jinja2 | 3.1.6 | BSD License | https://github.com/pallets/jinja/ | +| MarkupSafe | 3.0.3 | BSD-3-Clause | https://github.com/pallets/markupsafe/ | +| PyJWT | 2.12.1 | MIT | https://github.com/jpadilla/pyjwt | +| PyYAML | 6.0.3 | MIT License | https://pyyaml.org/ | +| Pygments | 2.20.0 | BSD-2-Clause | https://pygments.org | +| aiohappyeyeballs | 2.6.1 | Python Software Foundation License | https://github.com/aio-libs/aiohappyeyeballs | +| aiohttp | 3.13.5 | Apache-2.0 AND MIT | https://github.com/aio-libs/aiohttp | +| aiosignal | 1.4.0 | Apache Software License | https://github.com/aio-libs/aiosignal | +| aiosmtplib | 5.1.0 | MIT | https://github.com/cole/aiosmtplib/blob/main/CHANGELOG.rst | +| annotated-doc | 0.0.4 | MIT | https://github.com/fastapi/annotated-doc | +| annotated-types | 0.7.0 | MIT License | https://github.com/annotated-types/annotated-types | +| anthropic | 0.89.0 | MIT License | https://github.com/anthropics/anthropic-sdk-python | +| anyio | 4.13.0 | MIT | https://anyio.readthedocs.io/en/stable/versionhistory.html | +| asgiref | 3.11.1 | BSD License | https://github.com/django/asgiref/ | +| attrs | 26.1.0 | MIT | https://www.attrs.org/en/stable/changelog.html | +| boolean.py | 5.0 | BSD-2-Clause | https://github.com/bastikr/boolean.py | +| build | 1.4.2 | MIT | https://build.pypa.io | +| certifi | 2026.2.25 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/certifi/python-certifi | +| cffi | 2.0.0 | MIT | https://cffi.readthedocs.io/en/latest/whatsnew.html | +| cfgv | 3.5.0 | MIT | https://github.com/asottile/cfgv | +| charset-normalizer | 3.4.7 | MIT | https://github.com/jawah/charset_normalizer/blob/master/CHANGELOG.md | +| click | 8.3.2 | BSD-3-Clause | https://github.com/pallets/click/ | +| colorama | 0.4.6 | BSD License | https://github.com/tartley/colorama | +| coverage | 7.13.5 | Apache-2.0 | https://github.com/coveragepy/coveragepy | +| cryptography | 46.0.6 | Apache-2.0 OR BSD-3-Clause | https://github.com/pyca/cryptography | +| cuid | 0.4 | Apache Software License | http://github.com/necaris/cuid.py | +| distlib | 0.4.0 | Python Software Foundation License | https://github.com/pypa/distlib | +| distro | 1.9.0 | Apache Software License | https://github.com/python-distro/distro | +| dnspython | 2.8.0 | ISC License (ISCL) | https://www.dnspython.org | +| docstring_parser | 0.17.0 | MIT License | https://github.com/rr-/docstring_parser | +| email-validator | 2.3.0 | The Unlicense (Unlicense) | https://github.com/JoshData/python-email-validator | +| fastapi | 0.135.3 | MIT | https://github.com/fastapi/fastapi | +| filelock | 3.25.2 | MIT | https://github.com/tox-dev/py-filelock | +| frozenlist | 1.8.0 | Apache-2.0 | https://github.com/aio-libs/frozenlist | +| google-ai-generativelanguage | 0.6.15 | Apache Software License | https://github.com/googleapis/google-cloud-python/tree/main/packages/google-ai-generativelanguage | +| google-api-core | 2.25.2 | Apache Software License | https://github.com/googleapis/python-api-core | +| google-api-python-client | 2.193.0 | Apache Software License | https://github.com/googleapis/google-api-python-client/ | +| google-auth | 2.49.1 | Apache Software License | https://github.com/googleapis/google-auth-library-python | +| google-auth-httplib2 | 0.3.1 | Apache Software License | https://github.com/googleapis/google-cloud-python/packages/google-auth-httplib2 | +| google-generativeai | 0.8.6 | Apache Software License | https://github.com/google/generative-ai-python | +| googleapis-common-protos | 1.74.0 | Apache Software License | https://github.com/googleapis/google-cloud-python/tree/main/packages/googleapis-common-protos | +| groq | 1.1.2 | Apache Software License | https://github.com/groq/groq-python | +| grpcio | 1.80.0 | Apache-2.0 | https://grpc.io | +| grpcio-status | 1.71.2 | Apache Software License | https://grpc.io | +| grpcio-tools | 1.71.2 | Apache Software License | https://grpc.io | +| h11 | 0.16.0 | MIT License | https://github.com/python-hyper/h11 | +| h2 | 4.3.0 | MIT License | https://github.com/python-hyper/h2/ | +| hpack | 4.1.0 | MIT License | https://github.com/python-hyper/hpack/ | +| httpcore | 1.0.9 | BSD-3-Clause | https://www.encode.io/httpcore/ | +| httplib2 | 0.31.2 | MIT License | https://github.com/httplib2/httplib2 | +| httptools | 0.7.1 | MIT | https://github.com/MagicStack/httptools | +| httpx | 0.27.2 | BSD-3-Clause | https://github.com/encode/httpx | +| httpx-sse | 0.4.3 | MIT | https://github.com/florimondmanca/httpx-sse | +| hyperframe | 6.1.0 | MIT License | https://github.com/python-hyper/hyperframe/ | +| identify | 2.6.18 | MIT | https://github.com/pre-commit/identify | +| idna | 3.11 | BSD-3-Clause | https://github.com/kjd/idna | +| importlib_metadata | 8.7.1 | Apache-2.0 | https://github.com/python/importlib_metadata | +| inflection | 0.5.1 | MIT License | https://github.com/jpvanhal/inflection | +| iniconfig | 2.3.0 | MIT | https://github.com/pytest-dev/iniconfig | +| jiter | 0.13.0 | MIT License | https://github.com/pydantic/jiter/ | +| jsonschema | 4.26.0 | MIT | https://github.com/python-jsonschema/jsonschema | +| jsonschema-specifications | 2025.9.1 | MIT | https://github.com/python-jsonschema/jsonschema-specifications | +| librt | 0.9.0 | MIT | https://github.com/mypyc/librt | +| license-expression | 30.4.4 | Apache-2.0 | https://github.com/aboutcode-org/license-expression | +| licenseheaders | 0.8.8 | MIT License | http://github.com/johann-petrak/licenseheaders | +| mcp | 1.27.0 | MIT License | https://modelcontextprotocol.io | +| multidict | 6.7.1 | Apache License 2.0 | https://github.com/aio-libs/multidict | +| mypy | 1.20.0 | MIT | https://www.mypy-lang.org/ | +| mypy_extensions | 1.1.0 | MIT | https://github.com/python/mypy_extensions | +| node-wire | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-fhir-cerner | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-fhir-epic | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-google-drive | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-http-generic | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-runtime | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-smtp | 0.1.0 | UNKNOWN | UNKNOWN | +| node-wire-stripe | 0.1.0 | UNKNOWN | UNKNOWN | +| nodeenv | 1.10.0 | BSD License | https://github.com/ekalinin/nodeenv | +| openai | 2.30.0 | Apache Software License | https://github.com/openai/openai-python | +| opentelemetry-api | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/opentelemetry-api | +| opentelemetry-exporter-otlp | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/exporter/opentelemetry-exporter-otlp | +| opentelemetry-exporter-otlp-proto-common | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/exporter/opentelemetry-exporter-otlp-proto-common | +| opentelemetry-exporter-otlp-proto-grpc | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/exporter/opentelemetry-exporter-otlp-proto-grpc | +| opentelemetry-exporter-otlp-proto-http | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/exporter/opentelemetry-exporter-otlp-proto-http | +| opentelemetry-instrumentation | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/opentelemetry-instrumentation | +| opentelemetry-instrumentation-agno | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-agno | +| opentelemetry-instrumentation-alephalpha | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-alephalpha | +| opentelemetry-instrumentation-anthropic | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-anthropic | +| opentelemetry-instrumentation-asgi | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-asgi | +| opentelemetry-instrumentation-bedrock | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-bedrock | +| opentelemetry-instrumentation-chromadb | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-chromadb | +| opentelemetry-instrumentation-cohere | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-cohere | +| opentelemetry-instrumentation-crewai | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-crewai | +| opentelemetry-instrumentation-fastapi | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-fastapi | +| opentelemetry-instrumentation-google-generativeai | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-google-generativeai | +| opentelemetry-instrumentation-groq | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-groq | +| opentelemetry-instrumentation-haystack | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-haystack | +| opentelemetry-instrumentation-lancedb | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-lancedb | +| opentelemetry-instrumentation-langchain | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-langchain | +| opentelemetry-instrumentation-llamaindex | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-llamaindex | +| opentelemetry-instrumentation-logging | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-logging | +| opentelemetry-instrumentation-marqo | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-marqo | +| opentelemetry-instrumentation-mcp | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-mcp | +| opentelemetry-instrumentation-milvus | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-milvus | +| opentelemetry-instrumentation-mistralai | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-mistralai | +| opentelemetry-instrumentation-ollama | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-ollama | +| opentelemetry-instrumentation-openai | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-openai | +| opentelemetry-instrumentation-openai-agents | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-openai-agents | +| opentelemetry-instrumentation-pinecone | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-pinecone | +| opentelemetry-instrumentation-qdrant | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-qdrant | +| opentelemetry-instrumentation-redis | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-redis | +| opentelemetry-instrumentation-replicate | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-replicate | +| opentelemetry-instrumentation-requests | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-requests | +| opentelemetry-instrumentation-sagemaker | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-sagemaker | +| opentelemetry-instrumentation-sqlalchemy | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-sqlalchemy | +| opentelemetry-instrumentation-threading | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/instrumentation/opentelemetry-instrumentation-threading | +| opentelemetry-instrumentation-together | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-together | +| opentelemetry-instrumentation-transformers | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-transformers | +| opentelemetry-instrumentation-urllib3 | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-urllib3 | +| opentelemetry-instrumentation-vertexai | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-vertexai | +| opentelemetry-instrumentation-voyageai | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-voyageai | +| opentelemetry-instrumentation-watsonx | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-watsonx | +| opentelemetry-instrumentation-weaviate | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-weaviate | +| opentelemetry-instrumentation-writer | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-writer | +| opentelemetry-proto | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/opentelemetry-proto | +| opentelemetry-sdk | 1.40.0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/opentelemetry-sdk | +| opentelemetry-semantic-conventions | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python/tree/main/opentelemetry-semantic-conventions | +| opentelemetry-semantic-conventions-ai | 0.5.1 | Apache-2.0 | UNKNOWN | +| opentelemetry-util-http | 0.61b0 | Apache-2.0 | https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/util/opentelemetry-util-http | +| packaging | 26.0 | Apache-2.0 OR BSD-2-Clause | https://github.com/pypa/packaging | +| pathspec | 1.0.4 | Mozilla Public License 2.0 (MPL 2.0) | https://python-path-specification.readthedocs.io/en/latest/index.html | +| platformdirs | 4.9.6 | MIT | https://github.com/tox-dev/platformdirs | +| pluggy | 1.6.0 | MIT License | UNKNOWN | +| pre_commit | 4.5.1 | MIT | https://github.com/pre-commit/pre-commit | +| propcache | 0.4.1 | Apache Software License | https://github.com/aio-libs/propcache | +| proto-plus | 1.27.2 | Apache Software License | https://github.com/googleapis/google-cloud-python/tree/main/packages/proto-plus | +| protobuf | 5.29.6 | 3-Clause BSD License | https://developers.google.com/protocol-buffers/ | +| pyasn1 | 0.6.3 | BSD-2-Clause | https://github.com/pyasn1/pyasn1 | +| pyasn1_modules | 0.4.2 | BSD License | https://github.com/pyasn1/pyasn1-modules | +| pybreaker | 1.4.1 | BSD License | http://github.com/danielfm/pybreaker | +| pycparser | 3.0 | BSD-3-Clause | https://github.com/eliben/pycparser | +| pydantic | 2.12.5 | MIT | https://github.com/pydantic/pydantic | +| pydantic-settings | 2.13.1 | MIT | https://github.com/pydantic/pydantic-settings | +| pydantic_core | 2.41.5 | MIT | https://github.com/pydantic/pydantic-core | +| pyparsing | 3.3.2 | MIT | https://github.com/pyparsing/pyparsing/ | +| pyproject_hooks | 1.2.0 | MIT License | https://github.com/pypa/pyproject-hooks | +| pytest | 9.0.2 | MIT | https://docs.pytest.org/en/latest/ | +| pytest-asyncio | 1.3.0 | Apache-2.0 | https://github.com/pytest-dev/pytest-asyncio | +| pytest-cov | 7.1.0 | MIT | https://pytest-cov.readthedocs.io/en/latest/changelog.html | +| python-debian | 1.1.0 | DFSG approved; GNU General Public License v2 or later (GPLv2+) | https://salsa.debian.org/python-debian-team/python-debian | +| python-discovery | 1.2.2 | MIT License | https://github.com/tox-dev/python-discovery | +| python-dotenv | 1.2.2 | BSD-3-Clause | https://github.com/theskumar/python-dotenv | +| python-magic | 0.4.27 | MIT License | http://github.com/ahupp/python-magic | +| python-multipart | 0.0.24 | Apache-2.0 | https://github.com/Kludex/python-multipart | +| pywin32 | 311 | Python Software Foundation License | https://github.com/mhammond/pywin32 | +| referencing | 0.37.0 | MIT | https://github.com/python-jsonschema/referencing | +| regex | 2026.4.4 | Apache-2.0 AND CNRI-Python | https://github.com/mrabarnett/mrab-regex | +| requests | 2.33.1 | Apache Software License | https://github.com/psf/requests | +| reuse | 6.2.0 | Apache Software License; CC0 1.0 Universal (CC0 1.0) Public Domain Dedication; DFSG approved; GNU General Public License v3 or later (GPLv3+); Other/Proprietary License | https://reuse.software/ | +| rpds-py | 0.30.0 | MIT | https://github.com/crate-py/rpds | +| ruff | 0.15.10 | MIT | https://docs.astral.sh/ruff | +| sniffio | 1.3.1 | Apache Software License; MIT License | https://github.com/python-trio/sniffio | +| sse-starlette | 3.3.4 | BSD-3-Clause | https://github.com/sysid/sse-starlette | +| starlette | 1.0.0 | BSD-3-Clause | https://github.com/Kludex/starlette | +| stripe | 15.0.1 | MIT License | https://stripe.com/ | +| tenacity | 9.1.4 | Apache Software License | https://github.com/jd/tenacity | +| tomlkit | 0.14.0 | MIT License | https://github.com/sdispater/tomlkit | +| tqdm | 4.67.3 | MPL-2.0 AND MIT | https://tqdm.github.io | +| traceloop-sdk | 0.57.0 | Apache-2.0 | https://github.com/traceloop/openllmetry | +| typing-inspection | 0.4.2 | MIT | https://github.com/pydantic/typing-inspection | +| typing_extensions | 4.15.0 | PSF-2.0 | https://github.com/python/typing_extensions | +| uritemplate | 4.2.0 | BSD 3-Clause OR Apache-2.0 | https://uritemplate.readthedocs.org | +| urllib3 | 2.6.3 | MIT | https://github.com/urllib3/urllib3/blob/main/CHANGES.rst | +| uvicorn | 0.43.0 | BSD-3-Clause | https://uvicorn.dev/ | +| virtualenv | 21.2.0 | MIT | https://github.com/pypa/virtualenv | +| watchfiles | 1.1.1 | MIT License | https://github.com/samuelcolvin/watchfiles | +| websockets | 16.0 | BSD-3-Clause | https://github.com/python-websockets/websockets | +| wrapt | 1.17.3 | BSD License | https://github.com/GrahamDumpleton/wrapt | +| yarl | 1.23.0 | Apache-2.0 | https://github.com/aio-libs/yarl | +| zipp | 3.23.0 | MIT | https://github.com/jaraco/zipp | diff --git a/Dockerfile b/Dockerfile index 4afee60..f94cafa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,11 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + # Node Wire — Docker Image # ======================== -# This image packages the connector platform as a FastMCP server. +# This image packages the connector platform as an MCP stdio server (manifest-driven). # ToolHive runs it as a container, injects secrets as env vars, # and proxies the stdio MCP transport to HTTP/SSE. # @@ -11,7 +16,8 @@ # thv run --name node-wire-connectors --transport stdio \ # --secret ... node-wire:latest -FROM python:3.12-slim +# Digest-pinned base (update when bumping tag). See .github/workflows/docker-policy.yml. +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 # Install system deps needed by some connector libs RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -21,19 +27,44 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app # Copy source (build context = repo root) -COPY pyproject.toml ./ COPY src/ ./src/ COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/http_generic/dist/*.whl /wheels/ +COPY packages/connectors/stripe/dist/*.whl /wheels/ +COPY packages/connectors/smtp/dist/*.whl /wheels/ +COPY packages/connectors/slack/dist/*.whl /wheels/ +COPY packages/connectors/google_drive/dist/*.whl /wheels/ +COPY packages/connectors/fhir_cerner/dist/*.whl /wheels/ +COPY packages/connectors/fhir_epic/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src + +# Install runtime + connector packages using local wheel artifacts +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime \ + node-wire-http-generic \ + node-wire-stripe \ + node-wire-smtp \ + node-wire-slack \ + node-wire-google-drive \ + node-wire-fhir-cerner \ + node-wire-fhir-epic \ + "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app -# Install platform + agents extras -RUN pip install --no-cache-dir -e ".[agents]" +USER app # Expose nothing — ToolHive manages the stdio proxy port internally # MCP_PORT / FASTMCP_PORT will be set by ToolHive if ever needed # Healthcheck: verify the package is importable HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.mcp_entrypoint import _make_server; print('ok')" || exit 1 + python -c "from agents.mcp_entrypoint import main; assert callable(main); print('ok')" || exit 1 -# Default entrypoint: run the FastMCP server on stdio +# Default entrypoint: run the MCP server on stdio CMD ["python", "-m", "agents.mcp_entrypoint"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2595a26 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 AOT Technologies + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 0000000..49cdcde --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright 2026 AOT Technologies + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..079bc51 --- /dev/null +++ b/NOTICE @@ -0,0 +1,9 @@ +Node Wire + +Copyright 2026 AOT Technologies +Developed by AOT Technologies Engineering Team + +Licensed under the Apache License, Version 2.0. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 \ No newline at end of file diff --git a/NOTICE.license b/NOTICE.license new file mode 100644 index 0000000..28e0dbe --- /dev/null +++ b/NOTICE.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: 2026 AOT Technologies + +SPDX-License-Identifier: Apache-2.0 diff --git a/README.md b/README.md index 66d68ab..3e0eb30 100644 --- a/README.md +++ b/README.md @@ -1,195 +1,214 @@ + + # Node Wire -This repository implements **Node Wire**: a three-layer Python platform that runs connector adapters over REST, gRPC, or MCP. Each connector talks to an external system (e.g. Google Drive, SMTP, Stripe); the runtime provides a consistent execution contract, error handling, and resilience. This is a POC—intended to validate the architecture and be understandable for developers new to the codebase. +Node Wire is a three-layer Python platform that runs connector adapters (Google Drive, SMTP, Stripe, FHIR, etc.) and exposes them over REST, gRPC, or MCP. It provides a consistent execution contract with built-in validation, resilience, and telemetry. -For dependency management use any tool that understands `pyproject.toml` (e.g. `uv`, `pip`, or `poetry`). +## Prerequisites ---- +Before getting started, make sure you have: -## Individual MCP servers +| Requirement | Version | Notes | +|---|---|---| +| Python | 3.11+ | Required to run the platform | +| `uv` or `pip` | Latest | `uv` is recommended for local development | +| Git | Any recent version | Required to clone the repository | +| Docker | Latest | Required for MCP server image builds and `docker-compose.mcp.yml` | +| Node.js | Any LTS | Only needed for MCP Inspector | -Each connector can run as its own independent MCP server (Docker image). +## Quick Start -| Image | Tool exposed | Docker image | -| ----------------------- | -------------------------- | -------------------------------- | -| `nw-google-drive` | `google_drive_upload_file` | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | `fhir_epic_read_patient` | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp_send_email` | `docker/smtp/Dockerfile` | +### 1. Install +```bash +git clone +cd node-wire +uv sync --extra agents +``` +*(Requires `uv`. Alternatively, use `pip install -e ".[agents]"`)* -See [docs/mcp-servers.md](docs/mcp-servers.md) for build, env config, docker-compose, and ToolHive registration. +### 2. Configure +Copy the sample environment file and add your `NW_ALLOWED_CONNECTORS`: +```bash +# Linux/macOS/PowerShell +cp sample.env .env ---- +# Windows (CMD) +copy sample.env .env +``` +*(Edit `.env` and set `NW_ALLOWED_CONNECTORS=http_generic` or others)* -## High-level architecture +### 3. Run Grafana/OpenTelemetry (optional) -The platform is split into three layers: +For telemetry visualization, start the Grafana stack before running the application: -- **Layer A – Runtime** (`runtime`): The engine that every connector runs inside. It defines the execution contract, a standard error taxonomy, retries and circuit breaking, and telemetry. -- **Layer B – Connectors** (`connectors`): Adapters that implement that contract and call external systems (HTTP Generic, SMTP, Stripe, Google Drive, FHIR Epic, FHIR Cerner). Each connector has its own input/output schema and business logic. -- **Layer C – Bindings** (`bindings`): How the platform is exposed to the outside world—REST API, gRPC server, MCP server—and how connectors are loaded from configuration (ConnectorFactory + `config/connectors.yaml`). +```bash +cd grafana && docker compose up -d +``` -**Data flow (simplified):** A request arrives via REST, gRPC, or MCP → the factory resolves the right connector → the runtime runs it (validate input → optional policy check → retry/circuit-breaker wrapper → execute) → the response is returned in a standard shape (`ConnectorResponse`). ---- +### 4. Run +**Bash (Linux/macOS):** +```bash +# Using uv (recommended) +MODE=API uv run node-wire -## Layer A – `runtime` +# Using python +MODE=API python -m bindings_entrypoint +``` -**Purpose:** Provide shared execution and reliability so every connector behaves in a consistent way (validation, errors, retries, telemetry) without each connector reimplementing the same plumbing. +**PowerShell (Windows):** +```powershell +# Using uv +$env:MODE="API"; uv run node-wire -**Location:** `src/runtime/` (base.py, models.py, errors.py, resilience.py, secrets.py, policy.py). +# Using python +$env:MODE="API"; python -m bindings_entrypoint +``` +*(Modes: `API`, `GRPC`, `MCP`)* -### Main pieces +Open [http://localhost:8000/docs](http://localhost:8000/docs) to see the Swagger UI. -- **BaseConnector** - Abstract base class for all connectors. Subclasses implement `internal_execute(...)`. The runtime’s `run()` method: - 1. Generates a trace ID and starts an OpenTelemetry span. - 2. Validates the raw request body with Pydantic (using the connector’s input model). - 3. Calls the optional policy hook (if configured). - 4. Wraps execution with retries and a circuit breaker (resilience). - 5. Maps any exception to the standard error taxonomy. - 6. Returns a `ConnectorResponse` (success + data, or error_code + error_category + message). +### 5. Playground -- **ConnectorResponse / ErrorCategory** - Every connector returns the same response shape: `success`, `data`, `error_code`, `error_category`, `message`, `trace_id`. Categories are `RETRYABLE`, `BUSINESS`, `AUTH`, `FATAL`. Bindings (e.g. REST) map these to HTTP status codes (e.g. BUSINESS → 400, AUTH → 401, RETRYABLE → 503, FATAL → 500). +The platform includes an interactive web playground at [http://localhost:8000/playground/](http://localhost:8000/playground/) (available when the REST API is running). -- **ErrorMapper** - A registry that maps exception types to a stable error code and category. Connectors register their own exception types in their Layer B `registration` module. Unmapped exceptions default to FATAL. +--- -- **Resilience** - A decorator (e.g. Tenacity for retries, PyBreaker for circuit breaker) wraps the actual execution. Transient failures are retried; after too many failures the circuit opens to avoid overloading the external system. +## Build Packages (Wheels) -- **SecretProvider** - Abstraction for fetching secrets (API keys, credentials). The POC uses environment variables via `EnvSecretProvider` in the factory. Connectors receive the provider and use it to resolve connector-specific keys (e.g. Google Drive’s service account JSON). +Before building Docker images, build the Python packages as binary wheels: -- **PolicyHook** - Optional hook to allow or deny execution (e.g. by principal or tenant). Not required for the POC; when present, the runtime calls it after validation and before execution. +```bash +bash scripts/build-packages.sh +``` -- **Telemetry** - OpenTelemetry span around `connector.run` with attributes such as connector id, action, trace id, tenant, principal. +See [docs/packaging.md](docs/packaging.md) for details on the wheel build lifecycle. --- -## Layer B – `connectors` +## Build MCP Server Images -**Purpose:** System adapters that talk to external services. Each connector defines input/output models and implements `internal_execute` (and optionally registers its own exceptions with the ErrorMapper). +Use this workflow when you want Docker images for the individual MCP servers such as Google Drive, SMTP, Stripe, Salesforce, or Slack. -**Location:** `src/connectors/`. Each connector lives in its own subpackage (e.g. `google_drive/`, `smtp/`, `stripe/`, `http_generic/`). +### Build prerequisites -### Common structure per connector +Before building images, make sure: -- **schema.py** – Pydantic models for request (input) and response (output). Some connectors use a single action (e.g. `execute`) with a discriminated union in the payload (e.g. Google Drive: `action: "files.list" | "files.get" | ...`). -- **logic.py** – Connector class (subclass of `BaseConnector`) and the actual calls to the external SDK or API inside `internal_execute`. -- **registration.py** – Registers connector-specific exception types with `ErrorMapper` (category and optional error code). Loaded at startup via `auto_register()`. -- **exceptions.py** (optional) – Connector-specific exception classes. +- Docker is installed and available on your shell path. +- You are running commands from the repository root. +- Local wheels have been built first. -### Connectors included +See [docs/local-packages-to-images.md](docs/local-packages-to-images.md) for the full package -> image workflow and required wheel artifacts per image. -| Connector | Description | REST action | Exposed via (from config) | -|----------------|--------------------------------------------------|---------------|-----------------------------| -| **http_generic** | Generic HTTP request (any URL, method, headers) | `request` | rest, grpc, mcp | -| **smtp** | Send email via SMTP | `send_email` | rest, grpc, mcp | -| **stripe** | Stripe charge | `charge` | grpc, mcp (no rest in config)| -| **google_drive**| Google Drive (list, create, get, update, upload, delete, permissions) | `execute` (payload discriminator) | rest, grpc, mcp | -| **fhir_epic** | FHIR R4 integration for Epic (multi-action) | `read_patient`, `search_encounter`, `create_document_reference`, `search_document_reference` | rest, grpc, mcp | -| **fhir_cerner** | FHIR R4 integration for Cerner (multi-action) | `read_patient`, `search_encounter`, `create_document_reference`, `search_document_reference` | rest, grpc, mcp | +### Build all MCP server images -### Connector-specific documentation +All MCP server images are built from the repository root using the automation script: -**Details for each connector**—operations, request/response bodies, examples, and error handling—**are documented in that connector’s folder.** +```bash +./scripts/build-mcp-images.sh +``` -Examples: Google Drive has a full doc at `src/connectors/google_drive/README.md`; FHIR connectors are documented at `src/connectors/fhir_epic/README.md` and `src/connectors/fhir_cerner/README.md`. Other connectors may have a similar `.md` in their folder or document behavior in code and docstrings; always check the connector’s folder for up-to-date details. +To tag with a specific version (defaults to the version in `pyproject.toml`): ---- +```bash +./scripts/build-mcp-images.sh --version 0.1.0 +``` -## Layer C – `bindings` +This produces images tagged as both `latest` and the version string: -**Purpose:** Expose connectors over different protocols and load them from configuration. No business logic lives here—only routing, config, and protocol translation. +| Image name | Tags | +|---|---| +| `nw-google-drive` | `nw-google-drive:latest`, `nw-google-drive:0.1.0` | +| `nw-smartonfhir-epic` | `nw-smartonfhir-epic:latest`, `nw-smartonfhir-epic:0.1.0` | +| `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner:latest`, `nw-smartonfhir-cerner:0.1.0` | +| `nw-smtp` | `nw-smtp:latest`, `nw-smtp:0.1.0` | +| `nw-stripe` | `nw-stripe:latest`, `nw-stripe:0.1.0` | +| `nw-salesforce` | `nw-salesforce:latest`, `nw-salesforce:0.1.0` | +| `nw-slack` | `nw-slack:latest`, `nw-slack:0.1.0` | -**Location:** `src/bindings/` (factory.py, rest_api/app.py, grpc_server/, mcp_server/), and the entrypoint `bindings_entrypoint.py` at the package root. +### Build one image manually -### ConnectorFactory +To build a single image manually from the repo root: -- Reads `config/connectors.yaml` (list of connectors with `enabled` and `exposed_via` per protocol: rest, grpc, mcp). -- Instantiates each enabled connector (with a shared `EnvSecretProvider`). -- `list_for_protocol("rest" | "grpc" | "mcp")` returns only connectors that are exposed for that protocol. Used by the REST app and MCP server to build routes or tool lists. +```bash +# Google Drive only +docker build -f docker/google-drive/Dockerfile -t nw-google-drive:latest . -### REST API (FastAPI) +# Epic FHIR only +docker build -f docker/fhir-epic/Dockerfile -t nw-smartonfhir-epic:latest . -- **GET /health** – Health check. -- **GET /docs** – Swagger UI. -- Routes are built dynamically: **POST /connectors/{connector_id}/{action}** (e.g. `POST /connectors/google_drive/execute`). The request body is JSON; the response is the standard `ConnectorResponse`. HTTP status is derived from `error_category` (e.g. BUSINESS → 400, AUTH → 401, RETRYABLE → 503, FATAL → 500). +# Cerner FHIR only +docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner:latest . -### gRPC / MCP +# SMTP only +docker build -f docker/smtp/Dockerfile -t nw-smtp:latest . -- **gRPC:** Started when `MODE=GRPC`; server listens on port 50051. -- **MCP:** Started when `MODE=MCP`; server exposes tools for discovery and invocation. +# Stripe only +docker build -f docker/stripe/Dockerfile -t nw-stripe:latest . -### Entrypoint +# Salesforce only +docker build -f docker/salesforce/Dockerfile -t nw-salesforce:latest . -- Run with `python -m bindings_entrypoint` (or the `node-wire` script after install). The **MODE** environment variable selects: - - **API** (default) – REST API on port 8000. - - **GRPC** – gRPC server on port 50051. - - **MCP** – MCP server. +# Slack only +docker build -f docker/slack/Dockerfile -t nw-slack:latest . +``` ---- +> **Note:** The build context must be the repository root (`.`) so the `COPY src/` and `COPY config/` instructions resolve correctly. -## Configuration +--- -- **config/connectors.yaml** - Lists each connector with: - - `enabled`: whether to load it. - - `exposed_via`: list of protocols (`rest`, `grpc`, `mcp`). Only listed protocols expose that connector. +## Run MCP Servers with Docker Compose -- **Secrets** - Supplied via environment variables. The factory uses `EnvSecretProvider`; keys are connector-specific (e.g. Google Drive expects a variable documented in `src/connectors/google_drive/README.md`). +### Compose prerequisites -### Google Drive service account setup (quick) +Before starting the MCP containers, make sure: -1. In Google Cloud Console, select your project and enable **Google Drive API** (`APIs & Services` -> `Library`). -2. Go to `APIs & Services` -> `Credentials` -> `Create Credentials` -> `Service Account`. -3. Create a JSON key for that service account (`Keys` -> `Add Key` -> `Create new key` -> `JSON`). -4. Save the key file securely (example name: `service_account.json`) and never commit it to Git. -5. Open the JSON and copy the `client_email` value. -6. In Google Drive, share the target folder with that service-account email (Editor permission). -7. Copy that folder ID and use it in demo requests / env as needed. +- The MCP server images have already been built locally. +- Your `.env` file is populated with the credentials needed by the connectors you want to run. -Set credential secret used by this platform (`GOOGLE_DRIVE_SA_JSON`): +`docker-compose.mcp.yml` starts all MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. +Each service pins `NW_ALLOWED_CONNECTORS` to its own connector so a broad value in `.env` does not make per-connector images import optional dependencies they do not contain. -- **Option A (recommended for local):** set it to the absolute path of the JSON file. -- **Option B:** set it to the full JSON content as a string. +```bash +# Ensure local wheels exist and your .env is populated, then: +docker compose -f docker-compose.mcp.yml up --build +``` -PowerShell example (load JSON content into env var for current shell): +To start only a specific server: -```powershell -$saPath = "C:\path\to\service_account.json" -$env:GOOGLE_DRIVE_SA_JSON = Get-Content -Path $saPath -Raw +```bash +docker compose -f docker-compose.mcp.yml up --build nw-smartonfhir-epic ``` --- -## Running the platform - -1. **Install** from the repo root: - ```bash - pip install . - ``` +## Documentation -2. **Start the REST API** (default): - - **Windows (cmd):** `set MODE=API && python -m bindings_entrypoint` - (Or omit `MODE`; API is the default.) - - **Windows (PowerShell):** `$env:MODE="API"; python -m bindings_entrypoint` - - **Linux/macOS:** `MODE=API python -m bindings_entrypoint` +For more detailed information, please refer to the following guides: - Then open: - - **Health:** http://localhost:8000/health - - **Swagger:** http://localhost:8000/docs +- **[Architecture](docs/architecture.md)** — Layered design and data flow. +- **[Installation](docs/installation.md)** — Detailed setup and prerequisites. +- **[Configuration](docs/configuration.md)** — Environment variables and `connectors.yaml`. +- **[Connectors Guide](docs/connectors.md)** — How to use and build connectors. +- **[MCP Integration](docs/mcp.md)** — Using Node Wire with AI agents. +- **[Troubleshooting](docs/troubleshooting.md)** — Common errors and fixes. +- **[MCP Servers & Docker](docs/mcp-servers.md)** — Deploying individual connectors as MCP servers. +- **[Packaging & Publishing](docs/packaging.md)** — Wheel builds and CI flow. +- **[Code Quality & Compliance](docs/code-quality-compliance.md)** — Ruff, Mypy, pre-commit, REUSE, and dependency compliance. +## Developer docs -3. **Start gRPC or MCP** - Set `MODE=GRPC` or `MODE=MCP` using your shell’s syntax (same as above for Windows). +- Individual connector MCP servers (ToolHive): [docs/mcp-servers.md](docs/mcp-servers.md) +- Creating a new connector: [docs/connectors.md](docs/connectors.md) +- Code quality/compliance (Ruff, Mypy, REUSE, pip-audit): [docs/code-quality-compliance.md](docs/code-quality-compliance.md) +- Quality/security gates (Bandit, SonarQube): [docs/quality-security-gates.md](docs/quality-security-gates.md) --- -## Dependencies +## License -All dependencies are declared in `pyproject.toml` (Python >=3.11). They include: pydantic, FastAPI, uvicorn, tenacity, pybreaker, OpenTelemetry, grpcio, and connector-specific libraries (httpx, aiosmtplib, stripe, google-auth, google-api-python-client, etc.). See `pyproject.toml` for the full list and versions. +This project is licensed under the Apache License 2.0. +See the LICENSE file for details. diff --git a/Setup.md b/Setup.md deleted file mode 100644 index 7be559f..0000000 --- a/Setup.md +++ /dev/null @@ -1,361 +0,0 @@ -# Node Wire — Setup Guide - -Node Wire is a Python framework that runs connector adapters (Google Drive, SMTP, FHIR, Stripe, and more) and exposes them over REST, gRPC, or MCP. It includes a built-in AI agent layer so LLMs can discover and orchestrate these connectors automatically. - ---- - -## Table of Contents - -- [Prerequisites](#prerequisites) -- [Installation](#installation) -- [Configuration](#configuration) -- [Running the Platform](#running-the-platform) -- [Connectors Overview](#connectors-overview) -- [Connector Setup](#connector-setup) -- [MCP Server & ToolHive](#mcp-server--toolhive) -- [Running Tests](#running-tests) -- [Playground UI](#playground-ui) - ---- - -## Prerequisites - - -| Requirement | Version | Notes | -| ----------- | ------- | --------------------------------------- | -| Python | 3.12+ | `python --version` to check | -| pip or uv | Latest | `pip install --upgrade pip` | -| Git | Any | To clone the repo | -| Docker | Latest | Only needed for ToolHive MCP deployment | - - ---- - -## Installation - -```bash -# 1. Clone the repository -git clone -cd connector-platform - -# 2. Install dependencies (recommended: uv) -uv sync --extra agents - -# 3. Verify the install -uv run node-wire --help -``` - -> **REST/gRPC only** (no AI agent features): `uv sync` without the extra is sufficient. -> -> **Alternative (pip):** If you’re not using `uv`, install editable deps with pip: -> -> - `pip install -e ".[agents]"` (includes MCP/LLM agent dependencies) -> - `pip install -e .` (REST/gRPC only, no agent dependencies) - ---- - -## Configuration - -All secrets and settings are loaded from environment variables. A template is provided at `sample.env`. - -```bash -# Copy the template -cp sample.env .env - -# Open and fill in the values you need -``` - -You only need to fill in the sections for the connectors you plan to use. The platform starts successfully even if some credentials are missing — those connectors will simply return an error when called. - -### Environment Variable Sections - - -| Section | Key Variables | When Needed | -| ---------------- | ------------------------------------------------------------------------------------------------------------------- | ---------------------- | -| **FHIR Epic** | `EPIC_FHIR_BASE_URL`, `EPIC_TOKEN_URL`, `EPIC_CLIENT_ID`, `EPIC_KID`, `EPIC_PRIVATE_KEY` | Epic EHR integration | -| **FHIR Cerner** | `CERNER_FHIR_BASE_URL`, `CERNER_TOKEN_URL`, `CERNER_CLIENT_ID`, `CERNER_KID`, `CERNER_PRIVATE_KEY`, `CERNER_SCOPES` | Cerner EHR integration | -| **Google Drive** | `google_drive_sa_json`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | -| **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | -| **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | -| **ToolHive** | `TOOLHIVE_MCP_URL` (single) or `TOOLHIVE_MCP_URLS` (comma-separated, multi-server) | ToolHive MCP proxy | - - -See `sample.env` for the full list with example values. - ---- - -## Running the Platform - -The platform supports three modes. Set the `MODE` environment variable to switch between them. - - -| Mode | Command | Default Port | Use Case | -| ---------------------- | --------------------------------- | ------------ | ----------------------------------- | -| **REST API** (default) | `uv run node-wire` | `8000` | HTTP clients, Swagger UI, curl | -| **gRPC** | `MODE=GRPC uv run node-wire` | `50051` | gRPC clients | -| **MCP (stdio)** | `python -m agents.mcp_entrypoint` | stdio | AI agents, ToolHive, Claude Desktop | - - -### REST API Quick Start - -```bash -# Default port 8000 -uv run node-wire - -# If port 8000 is in use, override with PORT -PORT=8001 uv run node-wire -``` - -Once running: - -- **Health check:** `GET http://localhost:8000/health` -- **Interactive docs (Swagger UI):** `http://localhost:8000/docs` -- **Call a connector:** `POST http://localhost:8000/connectors/{connector_id}/{action}` - -Example — send an HTTP request via the generic connector: - -```bash -curl -X POST http://localhost:8000/connectors/http_generic/request \ - -H "Content-Type: application/json" \ - -d '{"url": "https://httpbin.org/get", "method": "GET"}' -``` - -All responses use the same standard shape: - -```json -{ - "success": true, - "data": { "raw": { ... }, "description": "..." }, - "error_code": null, - "error_category": null, - "message": null, - "trace_id": "..." -} -``` - ---- - -## Connectors Overview - - -| Connector | What It Does | Credentials Needed | Setup Guide | -| ---------------- | ------------------------------------------ | -------------------------------------- | --------------------------------------------------------------------------------------------- | -| **http_generic** | Make HTTP requests to any URL | None | No setup needed | -| **smtp** | Send emails via SMTP | SMTP host/port/username/password | [SMTP Setup](#smtp) | -| **stripe** | Process Stripe payments | Stripe API key | [Stripe Setup](#stripe) | -| **google_drive** | List, upload, download, manage Drive files | GCP service account JSON | [Google Drive setup & API](docs/google_drive_connector.md#google-drive-service-account-setup) | -| **fhir_epic** | Read/write patient data from Epic EHR | Epic SMART credentials + private key | [FHIR Epic Setup](#fhir-epic) | -| **fhir_cerner** | Read/write patient data from Cerner EHR | Cerner SMART credentials + private key | [FHIR Cerner Setup](#fhir-cerner) | - - ---- - -## Connector Setup - -### HTTP Generic - -No credentials required. Works out of the box. - -```bash -curl -X POST http://localhost:8000/connectors/http_generic/request \ - -H "Content-Type: application/json" \ - -d '{ - "url": "https://api.example.com/data", - "method": "POST", - "headers": {"Authorization": "Bearer your-token"}, - "body": {"key": "value"} - }' -``` - ---- - -### SMTP - -Add these to your `.env`: - -```env -SMTP_HOST=smtp.gmail.com -SMTP_PORT=587 -SMTP_USERNAME=you@gmail.com -SMTP_PASSWORD=your-app-password -``` - -> **Gmail users:** You must use an [App Password](https://support.google.com/accounts/answer/185833), not your regular Gmail password. Enable 2-Factor Authentication on your Google account first, then generate an App Password under Security settings. - -Supported configurations: - -- Port `587` with STARTTLS (recommended for Gmail, most SMTP providers) -- Port `465` with implicit TLS - ---- - -### Stripe - -Add to your `.env`: - -```env -stripe_api_key=sk_test_your_key_here -``` - -Use a **test key** (`sk_test_...`) during development. Switch to a live key (`sk_live_...`) for production. - ---- - -### Google Drive - -The Google Drive connector uses a **service account** — a non-human Google account your application uses to authenticate with Google Drive APIs. - -**Full documentation:** [docs/google_drive_connector.md](docs/google_drive_connector.md) — service account setup, verification, and REST `execute` API (all seven operations). - -Quick summary of what you'll need: - -1. A Google Cloud project with the Drive API enabled -2. A service account with a downloaded JSON key file -3. A shared Drive folder (share it with the service account's email) - -Add to your `.env`: - -```env -google_drive_sa_json=/absolute/path/to/service-account.json -GOOGLE_DRIVE_FOLDER_ID=your-folder-id-from-drive-url -``` - ---- - -### FHIR Epic - -Epic EHR integration uses the SMART Backend Services OAuth2 flow with RS384 JWT authentication. - -Add to your `.env`: - -```env -EPIC_FHIR_BASE_URL=https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4 -EPIC_TOKEN_URL=https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token -EPIC_CLIENT_ID=your-epic-client-id -EPIC_KID=your-key-id -EPIC_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" -``` - -You obtain these credentials by registering a backend application in the [Epic App Orchard](https://appmarket.epic.com/) (or your organization's Epic sandbox). - -**Available actions:** `read_patient`, `search_encounter`, `create_document_reference`, `search_document_reference` - ---- - -### FHIR Cerner - -Cerner EHR integration also uses SMART Backend Services with `private_key_jwt` client authentication. - -Add to your `.env`: - -```env -CERNER_FHIR_BASE_URL=https://fhir-ehr-code.cerner.com/r4/your-tenant-id -CERNER_TOKEN_URL=https://authorization.cerner.com/tenants/your-tenant-id/protocols/oauth2/profiles/smart-v1/token -CERNER_CLIENT_ID=your-cerner-client-id -CERNER_KID=your-key-id -CERNER_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" -CERNER_SCOPES="system/Patient.read system/Encounter.read system/DocumentReference.read system/DocumentReference.write" -``` - -Register your application in the [Cerner Developer Portal](https://code.cerner.com/) to obtain these credentials. - -**Available actions:** `read_patient`, `search_encounter`, `create_document_reference`, `search_document_reference` - ---- - -## MCP Server & ToolHive - -The platform exposes connector tools for AI agents via the MCP (Model Context Protocol). There are two deployment modes: - -### Individual MCP servers (recommended) - -Each connector runs as its own independent MCP server. This is the preferred approach for modular, scalable deployments. - - -| Image | Tool exposed | Docker image | -| ----------------------- | -------------------------- | -------------------------------- | -| `nw-google-drive` | `google_drive_upload_file` | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | `fhir_epic_read_patient` | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp_send_email` | `docker/smtp/Dockerfile` | - - -**Full guide (build, env config, ToolHive registration, multi-server agent usage):** [docs/mcp-servers.md](docs/mcp-servers.md) - -Quick start: - -```bash -# Build all three images -./scripts/build-mcp-images.sh - -# Start all three locally -docker compose -f docker-compose.mcp.yml up -``` - -### Combined MCP server (all connectors in one) - -For simpler setups all connectors can be exposed from a single MCP server: - -```bash -python -m agents.mcp_entrypoint -``` - -**ToolHive** runs the MCP server inside a secure Docker container, manages secrets injection, and provides an HTTP proxy that any MCP-compatible client (Claude Desktop, Cursor, custom agents) can connect to. - -**See the full ToolHive workflow guide:** [docs/toolhive_agent_scenario.md](docs/toolhive_agent_scenario.md) - -### Quick Local Test (No ToolHive) - -```bash -# Inspect any individual server with MCP Inspector -npx @modelcontextprotocol/inspector python -m agents.fhir_epic_mcp -npx @modelcontextprotocol/inspector python -m agents.google_drive_mcp - -# Or test the combined server -npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint -``` - ---- - -## Running Tests - -```bash -# Install dev dependencies (if not already installed) -pip install -e ".[dev,agents]" - -# Run all tests -pytest tests/ -v - -# Run a specific connector's tests -pytest tests/test_google_drive.py -v -pytest tests/test_fhir_epic.py -v -pytest tests/test_toolhive_agent.py -v -``` - -Most tests are unit tests that run without real credentials. Integration tests that call live APIs are skipped unless the relevant environment variables are set. - ---- - -## Playground UI - -The repository includes an interactive web playground that showcases 5 orchestration scenarios: - -> **Note:** The UI is served under the `/playground/` path (not at the server root). - -```bash -# Start the REST API (if not already running) -uv run node-wire - -# Open in your browser -open http://localhost:8000/playground/ -``` - -Scenarios include: - -1. Epic FHIR patient lookup and clinical note upload -2. IT Ops automation via HTTP Generic -3. Cerner FHIR orchestration -4. Google Drive document archival -5. AI agent orchestration via MCP - -See `playground/README.md` for details on each scenario and how to configure them. \ No newline at end of file diff --git a/config/connectors.yaml b/config/connectors.yaml index 0c1682c..bc5d18e 100644 --- a/config/connectors.yaml +++ b/config/connectors.yaml @@ -1,25 +1,98 @@ -connectors: - http_generic: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - smtp: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - stripe: - enabled: true - exposed_via: ["grpc", "mcp"] - google_drive: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - fhir_epic: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" - oauth_token_url: "${EPIC_TOKEN_URL:https://fhir.epic.sandbox/api/FHIR/R4/oauth/token}" - fhir_cerner: - enabled: true - exposed_via: ["rest", "grpc", "mcp"] - base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" - oauth_token_url: "${CERNER_TOKEN_URL:https://authorization.cerner.com/tenants/your-tenant-id/protocols/oauth2/profiles/smart-v1/token}" - scopes: "${CERNER_SCOPES:system/Patient.read system/Encounter.read system/DocumentReference.read system/DocumentReference.write}" - +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +# connectors.yaml — Node Wire connector configuration +# +# REST API auth (not stored here; set in environment): +# NW_REST_API_KEY — required for /connectors, /playground, /scenarios unless NW_REST_AUTH_DISABLED=true +# +# SECURITY RULE: This file must never contain secrets. +# - Non-sensitive config (base_url, host, port) → safe in YAML +# - Secrets (client_id, private_key, api_key) → environment variables (or cloud backend) +# +connectors: + http_generic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + # auth: not set — defaults to NoAuthProvider + + smtp: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + host: "smtp.example.com" + port: 587 + from_email: "noreply@example.com" + auth: + provider: static_credentials + username_secret: SMTP_USERNAME + password_secret: SMTP_PASSWORD + + stripe: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] # keeping REST unless you explicitly want to remove it + auth: + provider: static_token + secret_key: stripe_api_key + header_name: Authorization + prefix: "" # Stripe expects raw key + + google_drive: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: service_account + sa_json_secret: GOOGLE_DRIVE_SA_JSON + scopes: + - https://www.googleapis.com/auth/drive + + fhir_epic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: EPIC_TOKEN_URL + client_id_secret: EPIC_CLIENT_ID + private_key_secret: EPIC_PRIVATE_KEY + kid_secret: EPIC_KID + algorithm: RS384 + + fhir_cerner: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: CERNER_TOKEN_URL + client_id_secret: CERNER_CLIENT_ID + private_key_secret: CERNER_PRIVATE_KEY + kid_secret: CERNER_KID + algorithm: RS384 + scopes_secret: CERNER_SCOPES + scopes: + - system/Patient.read + - system/Encounter.read + - system/DocumentReference.read + - system/DocumentReference.write + + salesforce: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: oauth2 + grant_method: refresh_token + token_url_secret: SALESFORCE_TOKEN_URL + client_id_secret: SALESFORCE_CLIENT_ID + client_secret_secret: SALESFORCE_CLIENT_SECRET + refresh_token_secret: SALESFORCE_REFRESH_TOKEN + + slack: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: static_token + secret_key: SLACK_BOT_TOKEN diff --git a/docker-compose.mcp.yml b/docker-compose.mcp.yml index 1004b3f..0af631d 100644 --- a/docker-compose.mcp.yml +++ b/docker-compose.mcp.yml @@ -1,28 +1,88 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## services: nw-google-drive: image: nw-google-drive:latest + build: + context: . + dockerfile: docker/google-drive/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: google_drive stdin_open: true tty: true restart: unless-stopped nw-smartonfhir-epic: image: nw-smartonfhir-epic:latest + build: + context: . + dockerfile: docker/fhir-epic/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: fhir_epic stdin_open: true tty: true restart: unless-stopped nw-smartonfhir-cerner: image: nw-smartonfhir-cerner:latest + build: + context: . + dockerfile: docker/fhir-cerner/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: fhir_cerner stdin_open: true tty: true restart: unless-stopped nw-smtp: image: nw-smtp:latest + build: + context: . + dockerfile: docker/smtp/Dockerfile env_file: .env + environment: + NW_ALLOWED_CONNECTORS: smtp + stdin_open: true + tty: true + restart: unless-stopped + + nw-stripe: + image: nw-stripe:latest + build: + context: . + dockerfile: docker/stripe/Dockerfile + env_file: .env + environment: + NW_ALLOWED_CONNECTORS: stripe + stdin_open: true + tty: true + restart: unless-stopped + + nw-salesforce: + image: nw-salesforce:latest + build: + context: . + dockerfile: docker/salesforce/Dockerfile + env_file: .env + environment: + NW_ALLOWED_CONNECTORS: salesforce + stdin_open: true + tty: true + restart: unless-stopped + + nw-slack: + image: nw-slack:latest + build: + context: . + dockerfile: docker/slack/Dockerfile + env_file: .env + environment: + NW_ALLOWED_CONNECTORS: slack stdin_open: true tty: true restart: unless-stopped diff --git a/docker/fhir-cerner/Dockerfile b/docker/fhir-cerner/Dockerfile index f53bb53..779f46b 100644 --- a/docker/fhir-cerner/Dockerfile +++ b/docker/fhir-cerner/Dockerfile @@ -1,8 +1,12 @@ -FROM python:3.12-slim +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 LABEL org.opencontainers.image.title="nw-smartonfhir-cerner" \ org.opencontainers.image.description="Node Wire — SMART on FHIR Cerner MCP server" \ - org.opencontainers.image.source="https://github.com/AOT-Technologies/connector-platform" + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" RUN apt-get update && apt-get install -y --no-install-recommends \ curl ca-certificates \ @@ -10,14 +14,25 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app -COPY pyproject.toml ./ COPY src/ ./src/ COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/fhir_cerner/dist/*.whl /wheels/ -RUN pip install --no-cache-dir -e ".[agents]" +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=fhir_cerner + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-fhir-cerner "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.fhir_cerner_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.fhir_cerner_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.fhir_cerner_mcp"] - diff --git a/docker/fhir-epic/Dockerfile b/docker/fhir-epic/Dockerfile index 633f031..7218cf6 100644 --- a/docker/fhir-epic/Dockerfile +++ b/docker/fhir-epic/Dockerfile @@ -1,8 +1,12 @@ -FROM python:3.12-slim +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 LABEL org.opencontainers.image.title="nw-smartonfhir-epic" \ org.opencontainers.image.description="Node Wire — SMART on FHIR Epic MCP server" \ - org.opencontainers.image.source="https://github.com/AOT-Technologies/connector-platform" + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" RUN apt-get update && apt-get install -y --no-install-recommends \ curl ca-certificates \ @@ -10,14 +14,25 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app -COPY pyproject.toml ./ COPY src/ ./src/ COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/fhir_epic/dist/*.whl /wheels/ -RUN pip install --no-cache-dir -e ".[agents]" +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=fhir_epic + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-fhir-epic "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.fhir_epic_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.fhir_epic_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.fhir_epic_mcp"] - diff --git a/docker/google-drive/Dockerfile b/docker/google-drive/Dockerfile index 43cbe2b..dca8ffb 100644 --- a/docker/google-drive/Dockerfile +++ b/docker/google-drive/Dockerfile @@ -1,8 +1,12 @@ -FROM python:3.12-slim +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 LABEL org.opencontainers.image.title="nw-google-drive" \ org.opencontainers.image.description="Node Wire — Google Drive MCP server" \ - org.opencontainers.image.source="https://github.com/AOT-Technologies/connector-platform" + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" RUN apt-get update && apt-get install -y --no-install-recommends \ curl ca-certificates \ @@ -10,14 +14,25 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app -COPY pyproject.toml ./ COPY src/ ./src/ COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/google_drive/dist/*.whl /wheels/ -RUN pip install --no-cache-dir -e ".[agents]" +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=google_drive + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-google-drive "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.google_drive_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.google_drive_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.google_drive_mcp"] - diff --git a/docker/salesforce/Dockerfile b/docker/salesforce/Dockerfile new file mode 100644 index 0000000..e255d6a --- /dev/null +++ b/docker/salesforce/Dockerfile @@ -0,0 +1,35 @@ +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 + +LABEL org.opencontainers.image.title="nw-salesforce" \ + org.opencontainers.image.description="Node Wire — Salesforce MCP server" \ + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY src/ ./src/ +COPY config/ ./config/ +# Wheels are optional for local dev builds; build-mcp-images.sh populates them +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/salesforce/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=salesforce + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-salesforce "mcp>=1.6.0" httpx \ + || pip install --no-cache-dir "mcp>=1.6.0" httpx # Fallback if wheels missing + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ + python -c "from agents.salesforce_mcp import main; assert callable(main); print('ok')" || exit 1 + +CMD ["python", "-m", "agents.salesforce_mcp"] diff --git a/docker/slack/Dockerfile b/docker/slack/Dockerfile new file mode 100644 index 0000000..8b4b3c4 --- /dev/null +++ b/docker/slack/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 + +LABEL org.opencontainers.image.title="nw-slack" \ + org.opencontainers.image.description="Node Wire — Slack MCP server" \ + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY src/ ./src/ +COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/slack/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=slack + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-slack "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ + python -c "from agents.slack_mcp import main; assert callable(main); print('ok')" || exit 1 + +CMD ["python", "-m", "agents.slack_mcp"] diff --git a/docker/smtp/Dockerfile b/docker/smtp/Dockerfile index c4d725b..ee39cfd 100644 --- a/docker/smtp/Dockerfile +++ b/docker/smtp/Dockerfile @@ -1,8 +1,12 @@ -FROM python:3.12-slim +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 LABEL org.opencontainers.image.title="nw-smtp" \ org.opencontainers.image.description="Node Wire — SMTP MCP server" \ - org.opencontainers.image.source="https://github.com/AOT-Technologies/connector-platform" + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" RUN apt-get update && apt-get install -y --no-install-recommends \ curl ca-certificates \ @@ -10,14 +14,25 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app -COPY pyproject.toml ./ COPY src/ ./src/ COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/smtp/dist/*.whl /wheels/ -RUN pip install --no-cache-dir -e ".[agents]" +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=smtp + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-smtp "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.smtp_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.smtp_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.smtp_mcp"] - diff --git a/docker/stripe/Dockerfile b/docker/stripe/Dockerfile new file mode 100644 index 0000000..552c46a --- /dev/null +++ b/docker/stripe/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.12-slim@sha256:3d5ed973e45820f5ba5e46bd065bd88b3a504ff0724d85980dcd05eab361fcf4 + +LABEL org.opencontainers.image.title="nw-stripe" \ + org.opencontainers.image.description="Node Wire — Stripe MCP server" \ + org.opencontainers.image.source="https://github.com/AOT-Technologies/node-wire" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY src/ ./src/ +COPY config/ ./config/ +COPY packages/runtime/dist/*.whl /wheels/ +COPY packages/connectors/stripe/dist/*.whl /wheels/ + +ENV PYTHONPATH=/app/src \ + NW_ALLOWED_CONNECTORS=stripe + +RUN pip install --no-cache-dir --find-links=/wheels \ + node-wire-runtime node-wire-stripe "mcp>=1.6.0" \ + && rm -rf /wheels + +RUN groupadd --system --gid 1000 app \ + && useradd --system --uid 1000 --gid app --home /app app \ + && chown -R app:app /app + +USER app + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ + python -c "from agents.stripe_mcp import main; assert callable(main); print('ok')" || exit 1 + +CMD ["python", "-m", "agents.stripe_mcp"] diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..781e2d2 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,69 @@ +# Node Wire Architecture + +The Node Wire platform is designed as a three-layer Python platform that runs connector adapters over REST, gRPC, or MCP. Each connector talks to an external system (e.g., Google Drive, SMTP, Stripe); the runtime provides a consistent execution contract, error handling, and resilience. + +## High-Level Architecture + +The platform is split into three layers: + +- **Layer A – Runtime** (`src/node_wire_runtime/`): The engine that every connector runs inside. It defines the execution contract, a standard error taxonomy, retries and circuit breaking, and telemetry. +- **Layer B – Connectors** (`src/node_wire_/`): Adapters that implement that contract and call external systems (HTTP Generic, SMTP, Stripe, Google Drive, FHIR Epic, FHIR Cerner, Salesforce, Slack). Each connector has its own input/output schema and business logic. +- **Layer C – Bindings** (`src/bindings/`): How the platform is exposed to the outside world—REST API, gRPC server, MCP server—and how connectors are loaded from configuration (ConnectorFactory + `config/connectors.yaml`). + +### Data Flow (Simplified) + +1. A request arrives via REST, gRPC, or MCP. +2. The `ConnectorFactory` resolves the right connector. +3. The runtime runs the connector: + - Validate input via Pydantic. + - Optional policy check. + - Retry/circuit-breaker wrapper (resilience). + - Execute internal logic. + - Map any exceptions to the standard error taxonomy. +4. The response is returned in a standard shape (`ConnectorResponse`). + +--- + +## Layer A – `runtime` + +**Purpose:** Provide shared execution and reliability so every connector behaves in a consistent way (validation, errors, retries, telemetry) without each connector reimplementing the same plumbing. + +**Location:** `src/node_wire_runtime/` + +### Main Components + +- **BaseConnector**: Abstract base class for all connectors. It handles the `run()` method pipeline. +- **ConnectorResponse / ErrorCategory**: Unified response shape and error categorization (`RETRYABLE`, `BUSINESS`, `AUTH`, `FATAL`). +- **ErrorMapper**: Maps exception types to stable error codes and categories. +- **Resilience**: Decorators for retries (Tenacity) and circuit breaking (PyBreaker). +- **SecretProvider**: Abstraction for fetching secrets (API keys, credentials). +- **PolicyHook**: Optional hook to allow or deny execution based on principal or tenant. +- **Telemetry**: OpenTelemetry integration for tracing. + +--- + +## Layer B – `connectors` + +**Purpose:** System adapters that talk to external services. Each connector defines input/output models and implements `internal_execute`. + +**Location:** `src/node_wire_/` + +### Common Structure + +- `schema.py`: Pydantic models for request and response. +- `logic.py`: Connector class and external service logic. +- `registration.py`: Registers connector-specific exceptions. + +--- + +## Layer C – `bindings` + +**Purpose:** Expose connectors over different protocols and load them from configuration. + +**Location:** `src/bindings/` + +### Bindings Offered + +- **REST API (FastAPI)**: Dynamic routes at `POST /connectors/{connector_id}/{action}`. +- **gRPC Server**: Protocol buffers based interface on port 50051. +- **MCP Server**: Model Context Protocol implementation for AI agents. diff --git a/docs/code-quality-compliance.md b/docs/code-quality-compliance.md new file mode 100644 index 0000000..b9e5c4f --- /dev/null +++ b/docs/code-quality-compliance.md @@ -0,0 +1,95 @@ + + +# Code Quality and Compliance + +This project uses **Ruff** for linting and formatting, **Mypy** for static type checking, **Bandit** for SAST, **pip-audit** for dependency vulnerability checks, and **REUSE** for open-source licensing compliance. + +Linting and type checks run automatically in CI on pull requests against the `main` branch via `.github/workflows/lint.yml`. Security and package compliance checks are additionally enforced through `.github/workflows/quality-gates.yml` and `.github/workflows/security-pr.yml`. + +## Manual usage for developers + +Install development dependencies first: + +```bash +pip install -e ".[dev]" +``` + +Then run the local quality checks: + +- **Check formatting and linting errors:** `ruff check .` +- **Auto-fix and format code:** `ruff check --fix . && ruff format .` +- **Run static type validation:** `mypy` + +`mypy` uses the default `files` target from `[tool.mypy]` in `pyproject.toml`, which is currently `src`. Avoid `mypy .`, because it can pull in packaging `setup.py` scripts under `packages/` and produce duplicate-module noise. To include tests explicitly, run: + +```bash +mypy src tests +``` + +## Pre-commit hooks + +You can attach `.pre-commit-config.yaml` so checks run before each commit: + +```bash +pre-commit install +``` + +To run all configured hooks across the repository: + +```bash +pre-commit run --all-files +``` + +The current pre-commit setup includes Ruff, Ruff formatting, Mypy, and Bandit. + +## Copyright headers and REUSE compliance + +This repository enforces open-source licensing compliance using [REUSE](https://reuse.software/). First-party files should contain the appropriate SPDX copyright and license headers. + +### Verify compliance + +```bash +uv pip install reuse +uv run reuse lint +``` + +### Add missing headers + +If `reuse lint` reports missing headers, you can apply the repository header template with: + +```bash +bash scripts/add-license-headers.sh +``` + +## Dependency inventory and compliance + +To maintain an open-source compliant dependency set, the repository tracks third-party packages and their licenses in `DEPENDENCIES.md`. + +### License classification criteria + +- **Safe (permissive):** MIT, Apache-2.0, BSD, PSF. Safe for the Apache-2.0 release. +- **Needs review:** Custom or uncommon licenses that require manual review. +- **Risky (copyleft):** GPLv2, GPLv3, AGPL. Not allowed in the runtime application. They may be acceptable only for isolated, non-distributed development tooling. + +### Update the inventory and run compliance checks + +When adding dependencies or preparing a release, run the unified compliance script: + +```bash +bash scripts/run-compliance-checks.sh +``` + +That script: + +1. Regenerates `DEPENDENCIES.md`. +2. Runs **Bandit** for static application security testing. +3. Runs **pip-audit** for dependency vulnerability scanning. + +## Related docs + +- [Quality and security gates](quality-security-gates.md) +- [Installation guide](installation.md) diff --git a/docs/compliance/hipaa-considerations.md b/docs/compliance/hipaa-considerations.md new file mode 100644 index 0000000..ce4be2c --- /dev/null +++ b/docs/compliance/hipaa-considerations.md @@ -0,0 +1,35 @@ +# HIPAA Compliance Considerations + +Node-Wire provides connectors for healthcare systems, including Epic and Cerner FHIR APIs. While Node-Wire is designed with security in mind, deploying Node-Wire in a healthcare environment to process Protected Health Information (PHI) requires careful consideration to maintain compliance with the Health Insurance Portability and Accountability Act (HIPAA). + +> [!WARNING] +> Node-Wire is a software framework, not a managed service. You are solely responsible for ensuring that your deployment, configuration, and infrastructure meet all applicable HIPAA requirements. + +## 1. Business Associate Agreements (BAAs) + +Since Node-Wire acts as a middleware layer routing data between various systems (e.g., your EHR, your LLM provider, and external services), you must have a Business Associate Agreement (BAA) in place with **every** third-party service provider that touches PHI. + +- **LLM Providers:** If you are using OpenAI, Anthropic, Google, or Groq to process PHI via Node-Wire agents, you must have a BAA signed with that provider and ensure you are using their HIPAA-eligible endpoints/models. +- **Hosting Infrastructure:** If you deploy Node-Wire on AWS, Azure, Google Cloud, or another cloud provider, you must have a BAA with the hosting provider. +- **External Connectors:** If you use connectors like SMTP or Google Drive to send or store PHI, those services must also be covered under a BAA. + +## 2. Data in Transit (Encryption) + +All network traffic involving PHI must be encrypted. +- **EHR Communication:** Node-Wire's FHIR connectors use HTTPS/TLS to communicate with Epic and Cerner APIs. +- **Client Communication:** When deploying the Node-Wire REST API or MCP Server, you must place it behind a reverse proxy (e.g., Nginx, Traefik) or API Gateway configured with strict TLS 1.2+ encryption. + +## 3. Data at Rest (Persistence) + +Node-Wire itself does not include a database and does not persistently store PHI. It processes data in memory during execution. However, consider the following: +- **Logs:** Ensure that your logging infrastructure does not capture PHI. Node-Wire's default `INFO` logging levels do not log payloads, but running in `DEBUG` or `TRACE` mode may expose PHI to logs. You must configure your logging systems to redact PHI or ensure the logging environment is HIPAA-compliant. +- **Caching:** If you implement caching layers on top of Node-Wire, ensure the cache is encrypted at rest. + +## 4. Authentication and Authorization + +- **API Keys & JWTs:** Node-Wire's REST API supports API keys and JWTs. Ensure these secrets are strong, rotated regularly, and never hardcoded in source control. +- **OAuth 2.0 / SMART on FHIR:** The FHIR connectors rely on the underlying authentication provided by the EHR. Ensure that the service accounts or client applications registered in Epic/Cerner are provisioned with the principle of least privilege, granting access only to the specific FHIR resources required by the agents. + +## 5. Safe Secret Management + +Do not store credentials (e.g., `client_secret`, API keys) in plain text environment files in production. Node-Wire supports Pluggable Secret Providers (e.g., HashiCorp Vault, Azure Key Vault, AWS Secrets Manager). You should use a secure secret management solution to inject credentials at runtime. diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 0000000..b27645a --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,93 @@ +# Configuration Guide + +Node Wire is configured primarily through environment variables and a YAML configuration file. + +## Environment Variables + +All secrets and settings are loaded from environment variables. A template is provided at `sample.env`. + +```bash +# Linux/macOS/PowerShell +cp sample.env .env + +# Windows (CMD) +copy sample.env .env +``` + +### Required Variables + +| Variable | Description | +|----------|-------------| +| `NW_ALLOWED_CONNECTORS` | **Required.** A comma-separated list of connector names to load (e.g., `fhir_epic,http_generic`). Node Wire defaults to a fail-closed policy. | + +### Connector Secrets + +| Section | Key Variables | When Needed | +|---------|---------------|-------------| +| **FHIR Epic** | `EPIC_FHIR_BASE_URL`, `EPIC_TOKEN_URL`, `EPIC_CLIENT_ID`, `EPIC_KID`, `EPIC_PRIVATE_KEY` | Epic EHR integration | +| **FHIR Cerner** | `CERNER_FHIR_BASE_URL`, `CERNER_TOKEN_URL`, `CERNER_CLIENT_ID`, `CERNER_KID`, `CERNER_PRIVATE_KEY`, `CERNER_SCOPES` | Cerner EHR integration | +| **Google Drive** | `GOOGLE_DRIVE_SA_JSON`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | +| **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | +| **Slack** | `SLACK_BOT_TOKEN` | Sending Slack messages | +| **Stripe** | `STRIPE_API_KEY` | Stripe payments | +| **Salesforce** | `SALESFORCE_INSTANCE_URL`, `SALESFORCE_TOKEN_URL`, `SALESFORCE_CLIENT_ID`, `SALESFORCE_CLIENT_SECRET`, `SALESFORCE_REFRESH_TOKEN` | Salesforce CRM integration | +| **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | + +### Transport & Binding Config + +| Variable | Description | Default | +|----------|-------------|---------| +| `MODE` | Execution mode (`API`, `GRPC`, `MCP`) | `API` | +| `PORT` | Port for the REST API | `8000` | +| `NW_MCP_TRANSPORT` | MCP transport mode (`stdio` or `streamable-http`) | `stdio` | +| `NW_MCP_PORT` | Port for streamable-http MCP | `8080` | +| `NW_MCP_AUTH_DISABLED` | Disable MCP authentication (local dev only) | `false` | +| `NW_REST_AUTH_DISABLED` | Disable REST API authentication (local dev only) | `false` | + +--- + +## Configuration File (`config/connectors.yaml`) + +This file determines which connectors are enabled and which protocols they are exposed through. + +```yaml +connectors: + google_drive: + enabled: true + exposed_via: + - rest + - grpc + - mcp +``` + +- **enabled**: Whether to load the connector at startup. +- **exposed_via**: List of protocols (`rest`, `grpc`, `mcp`). + +--- + +## Secrets Management + +The factory uses an `EnvSecretProvider` by default. It looks up keys exactly as provided, and then in uppercase (e.g., `my_key` then `MY_KEY`). + +### Google Drive Service Account (Local Example) + +For local development, you can set `GOOGLE_DRIVE_SA_JSON` to the absolute path of your service account JSON file. + +**PowerShell (Windows):** +```powershell +$saPath = "C:\path\to\service_account.json" +$env:GOOGLE_DRIVE_SA_JSON = Get-Content -Path $saPath -Raw +``` + +**Bash (Linux/macOS):** +```bash +export GOOGLE_DRIVE_SA_JSON=$(cat /path/to/service_account.json) +``` + +--- + +## Security Best Practices + +- **Production REST:** Set `NW_REST_API_KEY` and send `Authorization: Bearer ` or `X-API-Key: `. +- **Disable Dotenv:** Set `NW_REST_LOAD_DOTENV=false` in production to prevent loading from a `.env` file on disk. +- **Fail-Closed:** Always explicitly list allowed connectors in `NW_ALLOWED_CONNECTORS`. diff --git a/docs/connectors.md b/docs/connectors.md new file mode 100644 index 0000000..9259a14 --- /dev/null +++ b/docs/connectors.md @@ -0,0 +1,551 @@ + + +# Connectors guide (`src/node_wire_*`) + +This guide explains how **connectors** fit into Node Wire, how to build your own connector, and how the runtime and bindings wire everything together. Connector implementations live under `src/node_wire_/` (e.g. `src/node_wire_google_drive/`); the shared base class lives at **`src/node_wire_runtime/base_connector.py`**. + +## How connectors fit into the platform + +- **Layer B — Connectors** (`src/node_wire_/`): adapter packages (schemas, logic, optional `registration.py`). +- **Layer C — Bindings** (`src/bindings/`): REST, gRPC, and MCP servers plus `ConnectorFactory` loading from `config/connectors.yaml`. + +At startup, bindings call **`node_wire_runtime.connector_registry.auto_register()`**, which loads connector entry points, imports each connector’s `logic` module (registering the class), then imports optional `registration.py` for `ErrorMapper` side effects. **`ConnectorFactory`** resolves connectors from the registry — **do not add per-connector branches in `src/bindings/factory.py`.** + +--- + +## Package layout and registration + +Each connector is a **top-level package** under `src/` (e.g. `node_wire_fhir_epic`): + +| File | Role | +|------|------| +| `schema.py` | Pydantic input/output models. Each input model has an `action: Literal[...]` discriminator field (often combined into a discriminated union). | +| `logic.py` | Connector class: `BaseConnector` subclass — either explicit `@nw_action` methods, or **`action_specs`** plus an optional `_execute_action_spec` override for SDK dispatch. | +| `action_spec.py` (optional) | Declarative `SdkActionSpec` entries mapping validated models to vendor SDK calls (see Google Drive). | +| `registration.py` | Optional: registers connector-specific exceptions with `ErrorMapper`. | +| `exceptions.py` | Optional: custom exception types. | + +At startup, call **`node_wire_runtime.connector_registry.auto_register()`**: it loads entry points in group `node_wire.connectors`, imports each connector's `logic` module (triggering `BaseConnector.__init_subclass__` and `_CONNECTOR_REGISTRY`), then imports optional `registration.py` for `ErrorMapper` side effects. + +--- + +## The unified `BaseConnector` + +There is one base class for all connectors: **`BaseConnector`** (`src/node_wire_runtime/base_connector.py`). It handles: + +- Input validation via a Pydantic **discriminated union** (the `action` field selects the right model) +- Optional **policy hook** enforcement +- **Retries and circuit breaking** via `with_resilience` +- **Error mapping** via `ErrorMapper` +- OpenTelemetry **tracing** +- A standard **`ConnectorResponse`** envelope + +Actions are declared either with the **`@nw_action("name")`** decorator on async methods, or by listing them in **`action_specs`** (the runtime generates equivalent handlers). A connector can have **one or many** actions — there is no separate "single-action" type. + +``` +flowchart LR + yaml[connectors.yaml] + factory[ConnectorFactory.load] + inst[BaseConnector subclass] + run[connector.run] + exec[internal_execute → @nw_action dispatch] + resp[ConnectorResponse] + yaml --> factory --> inst --> run --> exec --> resp +``` + +--- + +## Building a connector (Google Drive SDK example) + +The production **Google Drive** connector (`src/node_wire_google_drive/`) is a good template for wrapping a **vendor Python SDK** (here `googleapiclient` / Drive API v3): service-account auth in `build_client()`, a discriminated union of operations in `schema.py`, and **`action_specs`** so each API surface becomes a manifest action without duplicating boilerplate. + +### Step 1 — Define your schemas (`schema.py`) + +Each operation is a Pydantic model with an **`action`** field whose type is a `Literal["…"]` unique to that operation. Those models are combined into a **discriminated union** (and often wrapped in `RootModel` for a single top-level validator), which the runtime uses to pick the correct handler. + +```python +# src/node_wire_google_drive/schema.py (conceptual excerpt) +from __future__ import annotations + +from typing import Annotated, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, RootModel + + +class BaseDriveOperation(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class FilesListOperation(BaseDriveOperation): + action: Literal["files.list"] + page_size: int = Field(10, ge=1, le=100) + query: Optional[str] = None + fields: Optional[str] = None + page_token: Optional[str] = None + + +class FilesUploadOperation(BaseDriveOperation): + action: Literal["files.upload"] + name: str + mime_type: str + parents: Optional[list[str]] = None + content: Optional[str] = None + content_base64: Optional[str] = None + + +# …other operations (files.create, files.get, …) — see the repo. + +_GoogleDriveOperationUnion = Annotated[ + Union[ + FilesListOperation, + FilesUploadOperation, + # … FilesCreateOperation, FilesGetOperation, … + ], + Field(discriminator="action"), +] + +GoogleDriveOperationInput = RootModel[_GoogleDriveOperationUnion] + + +class GoogleDriveOperationOutput(BaseModel): + raw: dict + description: str +``` + +When a connector only has **one** action, the `action` field is still required — the runtime always validates through the discriminated union. + +### Step 2 — Map operations to the SDK (`action_spec.py`) + +**`SdkActionSpec`** describes how to turn a validated model into a single SDK call: resource path (`resource_segments`), HTTP-style method name (`method_name`), and how to build `body` / keyword arguments from the model. The full Drive registry lives in [`src/node_wire_google_drive/action_spec.py`](../src/node_wire_google_drive/action_spec.py). + +```python +# src/node_wire_google_drive/action_spec.py (illustrative) +from node_wire_runtime.sdk_action_spec import SdkActionSpec + +from .schema import FilesCreateOperation, FilesListOperation + +# def _build_files_list_kwargs(drive, model): ... + +# Real module builds this dict via register helpers — see repo for uploads, permissions, etc. + +GOOGLE_DRIVE_ACTION_SPECS: dict[str, SdkActionSpec] = { + "files.list": SdkActionSpec( + resource_segments=("files",), + method_name="list", + build_kwargs=_build_files_list_kwargs, # optional: defaults, shared drives flags + input_model=FilesListOperation, + ), + "files.create": SdkActionSpec( + resource_segments=("files",), + method_name="create", + body_from_model={"name": "name", "mime_type": "mimeType", "parents": "parents"}, + constant_kwargs={"fields": "id, name, webViewLink", "supportsAllDrives": True}, + input_model=FilesCreateOperation, + ), +} +``` + +`googleapiclient` is **synchronous**. The shared helper **`execute_spec_in_thread`** runs the generated `.execute()` call in a thread pool so the connector’s public API stays async. + +### Step 3 — Implement the connector class (`logic.py`) + +Subclass `BaseConnector`, set **`connector_id`**, **`output_model`**, and **`action_specs`**. The base class **generates** one async `@nw_action` handler per spec. Override **`_execute_action_spec`** to add logging, thread offload, and translation of vendor exceptions (e.g. `HttpError` → your `error_map` types). + +```python +# src/node_wire_google_drive/logic.py (conceptual excerpt) +from __future__ import annotations + +import json +from typing import Any + +from google.oauth2 import service_account +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError + +from node_wire_runtime import BaseConnector +from node_wire_runtime.models import ErrorCategory +from node_wire_runtime.sdk_action_spec import execute_spec_in_thread + +from .action_spec import GOOGLE_DRIVE_ACTION_SPECS +from .exceptions import GoogleDriveAuthError, GoogleDriveRateLimitError # + other mapped types +from .schema import GoogleDriveOperationOutput + + +class GoogleDriveConnector(BaseConnector): + connector_id = "google_drive" + output_model = GoogleDriveOperationOutput + action_specs = GOOGLE_DRIVE_ACTION_SPECS + + error_map = { + GoogleDriveAuthError: (ErrorCategory.AUTH, "GDRIVE_AUTH"), + GoogleDriveRateLimitError: (ErrorCategory.RETRYABLE, "GDRIVE_RATE_LIMIT"), + # … + } + + def build_client(self) -> Any: + raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") + info = json.loads(raw_sa) # or path to a JSON file — see production code + creds = service_account.Credentials.from_service_account_info( + info, + scopes=["https://www.googleapis.com/auth/drive"], + ) + return build("drive", "v3", credentials=creds) + + async def _execute_action_spec( + self, + action_name: str, + params: Any, + *, + trace_id: str, + log_extra: dict[str, Any] | None = None, + ) -> GoogleDriveOperationOutput: + spec = GOOGLE_DRIVE_ACTION_SPECS[action_name] + drive = self.get_client() + try: + raw = await execute_spec_in_thread(drive, spec, params) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=raw, + description=f"Successfully executed {action_name}", + ) + + +## Connector Authentication + +Node Wire provides a shared **`AuthProvider`** abstraction (`src/node_wire_runtime/auth/`) that handles token acquisition, JWT construction (for SMART on FHIR), caching, and expiry. This ensures that connector logic (`logic.py`) does not need to handle raw credentials or IdP-specific handshake details. + +### Using Auth in a Connector + +To use authentication, call **`await self.get_auth_headers()`** (inherited from `BaseConnector`). This returns a dictionary of headers (e.g. `{"Authorization": "Bearer "}`) injected by the configured provider. + +```python +# logic.py usage +async def read_resource(self, params: In, *, trace_id: str) -> Out: + base_url = self._get_base_url() + headers = await self.get_auth_headers() # Fetched/cached by provider + + async with httpx.AsyncClient() as client: + resp = await client.get(f"{base_url}/resource", headers=headers) + resp.raise_for_status() + ... +``` + +### Supported Provider Types + +Choose a provider in your **`connectors.yaml`** via the `auth:` block: + +| Type | Description | +|------|-------------| +| **`none`** | (Default) No auth headers added. | +| **`static_token`** | Uses a fixed token from a secret (Bearer, Basic, or custom). Supports refresh. | +| **`oauth2`** | Full Client Credentials flow. Supports `private_key_jwt` (RS384) and `client_secret_post`. Handles caching and expiry automagically. | + +### Configuration (`connectors.yaml`) + +```yaml +connectors: + fhir_epic: + enabled: true + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: EPIC_TOKEN_URL + client_id_secret: EPIC_CLIENT_ID + private_key_secret: EPIC_PRIVATE_KEY + kid_secret: EPIC_KID + algorithm: RS384 + + stripe: + enabled: true + auth: + provider: static_token + secret_key: STRIPE_API_KEY +``` + +--- +``` + +Key points: +- **`connector_id`** — unique string; used for routing, config, and registry lookup. +- **`output_model`** — the Pydantic class returned by every action (Drive uses one shared envelope with `raw` + `description`). +- **`error_map`** — maps exception types to `(ErrorCategory, error_code)`. Entries are registered with `ErrorMapper` automatically at class definition time. +- **`build_client()`** — override to create the Google API client. `get_client()` caches the result in `self._client`. +- **`action_specs`** — each key becomes a manifest action (e.g. `files.list`). Do **not** also add a manual `@nw_action` with the same name. +- **`_execute_action_spec`** — **required** when using **`action_specs`**: each generated handler delegates here. Typically call **`execute_spec_in_thread`** for blocking SDKs (such as `googleapiclient`). Connectors that only use hand-written `@nw_action` methods do not implement this hook. + +**Adding a new Drive operation:** add a Pydantic variant and extend the union in `schema.py`, register a new `SdkActionSpec` in `action_spec.py`, and rely on auto-generated handlers (see [`src/node_wire_google_drive/README.md`](../src/node_wire_google_drive/README.md)). + +### Step 4 — Register in `config/connectors.yaml` + +```yaml +connectors: + google_drive: + enabled: true + exposed_via: + - rest + - grpc + - mcp +``` + +`exposed_via` controls which bindings surface the connector. Use any subset of **`rest`**, **`grpc`**, and **`mcp`** (omit protocols you do not need). + +### Step 5 — Auto-registration (nothing extra needed) + +`BaseConnector.__init_subclass__` adds your class to `_CONNECTOR_REGISTRY[connector_id]` as soon as `logic.py` is imported. **`node_wire_runtime.connector_registry.auto_register()`** performs those imports at startup. **No manual factory branch is required.** + +--- + +## Single-action connector example + +A connector with one action is identical in structure — just add one `@nw_action` method: + +```python +# src/node_wire_sms/schema.py +from __future__ import annotations +from typing import Literal +from pydantic import BaseModel + +class SmsSendInput(BaseModel): + action: Literal["send"] = "send" + to: str + message: str + +class SmsSendOutput(BaseModel): + message_sid: str + status: str +``` + +```python +# src/node_wire_sms/logic.py +from __future__ import annotations + +from node_wire_runtime import BaseConnector, nw_action +from .schema import SmsSendInput, SmsSendOutput + + +class SmsConnector(BaseConnector): + connector_id = "sms" + output_model = SmsSendOutput + + @nw_action("send") + async def send(self, params: SmsSendInput, *, trace_id: str) -> SmsSendOutput: + api_key = self.secret_provider.get_secret("sms_api_key") + # ... call SMS vendor API ... + return SmsSendOutput(message_sid="SM123", status="queued") +``` + +--- + +## Calling a connector directly (in-process) + +Use `connector.run(dict)` for the full pipeline (validation, policy, retries, error mapping): + +```python +from node_wire_runtime.connector_registry import auto_register +from bindings.factory import ConnectorFactory + +auto_register() +factory = ConnectorFactory() +factory.load() + +connector = factory.get_for_protocol("google_drive", "rest", action="files.list") +response = await connector.run( + {"action": "files.list", "page_size": 10, "query": "mimeType = 'application/vnd.google-apps.folder'"} +) + +if response.success: + print(response.data) # {"raw": {"files": [...], ...}, "description": "Successfully executed files.list"} +else: + print(response.error_code, response.message) +``` + +For composing actions within a connector, use **`self.call_action`**. It routes through **`connector.run`** so **policy hooks**, **resilience**, and the **`ConnectorResponse`** error path apply (including MCP scope policy). It returns the nested action’s **output model** on success (validated from `run()`’s `data`). On policy denial it raises **`PolicyDenied`**, which the outer `run()` maps like any other action failure. + +Optional keyword args `principal`, `tenant_id`, and `scopes` override the caller identity for the nested call. When omitted, **`call_action` inherits** identity from the outer `run()` (MCP/REST with JWT or scoped API key), so nested actions receive the same authorization as a direct tool call. + +```python +from node_wire_runtime import BaseConnector, nw_action + +@nw_action("upload_then_describe") +async def upload_then_describe( + self, params: MyInput, *, trace_id: str +) -> GoogleDriveOperationOutput: + created = await self.call_action( + "files.create", + {"action": "files.create", "name": params.name, "mime_type": params.mime_type}, + ) + file_id = created.raw["id"] + return await self.call_action( + "files.get", + {"action": "files.get", "file_id": file_id}, + ) +``` + +--- + +## Integrating with binding layers + +The factory and manifest drive all bindings. Once a connector is registered and `load()` is called, REST, gRPC, and MCP discover enabled connectors according to `exposed_via`. + +### Optional: MCP under `src/agents/` (ToolHive / stdio) + +The repo also ships **stdio MCP servers** for agents and ToolHive under `src/agents/` (e.g. `python -m agents.mcp_entrypoint`, per-connector modules). Those are separate from `MODE=MCP` on `node-wire`; see **[mcp-servers.md](mcp-servers.md)** for images, env, and registration. Wiring a connector in `config/connectors.yaml` does not by itself add a ToolHive image — follow **mcp-servers.md** when you need a dedicated MCP deployment. + +### REST binding + +`src/bindings/rest_api/app.py` calls `build_manifest(connectors)` and registers a `POST /connectors/{connector_id}/{action}` route for every manifest entry: + +``` +POST /connectors/google_drive/files.list +Content-Type: application/json + +{ "page_size": 10, "query": "name contains 'report'" } +``` + +The `action` field in the body is optional for REST — the binding injects it from the URL path (see `src/node_wire_runtime/ingress.py`). Per-action **argument normalizers** (`mcp_normalize` on each action) run on the JSON body the same way as MCP, so LLM-friendly aliases work for REST as well. If the body includes an `action` field, it **must** match the path segment; otherwise the API returns **400**. + +The runtime then performs full Pydantic validation and returns a `ConnectorResponse`. + +**Response envelope:** + +```json +{ + "success": true, + "data": { + "raw": { "files": [{ "id": "...", "name": "...", "mimeType": "..." }], "nextPageToken": null }, + "description": "Successfully executed files.list" + }, + "trace_id": "4f3a...", + "error_code": null, + "error_category": null, + "message": null +} +``` + +HTTP status codes are mapped from `ErrorCategory`: + +| `ErrorCategory` | HTTP status | +|-----------------|-------------| +| `BUSINESS` | 400 | +| `AUTH` | 401 | +| `RETRYABLE` | 503 | +| `FATAL` / other | 500 | + +### MCP binding + +`src/bindings/mcp_server/server.py` registers one **MCP tool** per manifest entry. Tool names follow the pattern `{connector_id}.{action}` (e.g. `google_drive.files.list`, `google_drive.files.upload`). + +The MCP server calls `connector.run(args_dict)` and serialises the `ConnectorResponse` as the tool result. + +The **tool name** (`.`) is authoritative: after normalizers run, the binding sets `action` from the tool name. A conflicting `action` in the payload is rejected (see `enforce_authoritative_action` in `src/node_wire_runtime/ingress.py`). + +Optional per-action **argument normalizers** (`mcp_normalize` on `@sdk_action` / `SdkActionSpec`) run before `connector.run` to map LLM aliases to canonical fields. Actions default to **strict** JSON Schema (`additionalProperties: false`); set `alias_tolerant=True` only where extra keys must pass MCP SDK validation before normalization. + +Published **`input_schema` omits the `action` property** (manifest contract v2+): clients must not rely on sending `action` inside tool arguments; the MCP tool name (or REST path) is authoritative. + +**FHIR `search_encounter` (Epic/Cerner):** normalizers map root-level `patient` / `patientId` to `patient_id`, and `sort` → `_sort` (via `search_params`). Encounter search **requires** a patient filter (`patient_id` or `patient` in `search_params`) before any outbound FHIR call. + +### Manifest + +`build_manifest(connectors)` is the single source of truth for both bindings (by default it strips `action` from each entry’s `input_schema`). It returns one entry per `@sdk_action`: + +```python +[ + { + "connector_id": "weather", + "action": "current_weather", + "input_schema": { ... }, # JSON Schema from CurrentWeatherInput (action not required) + "output_schema": { ... }, # ConnectorResponse envelope; data typed to the action output model (nullable on errors) + }, + { + "connector_id": "google_drive", + "action": "files.upload", + ... + } +] +``` + +--- + +## Connector inventory + +| Connector | Primary actions | +|-----------|-----------------| +| `http_generic` | `request` | +| `smtp` | `send_email` | +| `stripe` | `charge` | +| `salesforce` | `create_lead`, `read_lead`, `update_lead`, `delete_lead`, `create_contact`, `read_contact`, `update_contact`, `delete_contact` | +| `google_drive` | `files.list`, `files.upload`, … (see `action_specs`) | + +| `fhir_epic` | `read_patient`, `search_patients`, `search_encounter`, `create_document_reference`, `search_document_reference` | +| `fhir_cerner` | Same family as Epic with Cerner-specific schemas | +| `slack` | `post_message`, `send_direct_message`, `upload_file` | + +MCP tool names: **`.`** (e.g. `fhir_epic.read_patient`). See [`docs/mcp-servers.md`](mcp-servers.md). + +--- + +## Adding a new connector (checklist) + +3. In `logic.py`: subclass `BaseConnector`, set `connector_id` and `output_model`, then add `@nw_action` methods or wire `action_specs`. +4. **Authentication**: Delegate all header construction to **`self.get_auth_headers()`**. Do not hardcode secret lookups or IdP handshakes and ensure sensitive fields are removed from your `input_schema`. +5. For SDK-style connectors, add an `action_spec.py` (or similar) with `SdkActionSpec` entries and use **`execute_spec_in_thread`** when the vendor client is blocking. +6. Optionally add `error_map` and/or `registration.py` for custom exception handling. +7. Add the connector to **`config/connectors.yaml`** with `enabled: true`, the desired `exposed_via` protocols, and an **`auth:`** block. +8. That's it — `auto_register()` handles the rest. No factory branch required. + +--- + +## Configuration reference + +### `config/connectors.yaml` + +```yaml +connectors: + : + enabled: true # false → connector not instantiated + exposed_via: # controls which bindings surface this connector + - rest + - grpc + - mcp + # connector-specific keys passed via SecretProvider or connector __init__ +``` + +### `ConnectorFactory` API + +| Method | Description | +|--------|-------------| +| `load()` | Reads YAML, instantiates all enabled connectors from `_CONNECTOR_REGISTRY`. | +| `get_for_protocol(id, protocol, action=None)` | Returns connector if enabled and exposed for that protocol; `None` otherwise. | +| `list_for_protocol(protocol)` | All connectors exposed for a given protocol. | + +--- + +## Security (REST, plugins, secrets) + +**MCP (`bindings.mcp_server`)** — Configure **`NW_MCP_API_KEY_SCOPES`** (and optionally **`NW_MCP_ACTION_SCOPE_MAP_JSON`**) so `tools/list` and `tools/call` align with the same scope rules. Production baseline is **`NW_MCP_SCOPE_POLICY_DEFAULT=deny`**. Optional guardrail **`NW_MCP_SCOPE_POLICY_STRICT=true`** fails startup when scope policy would otherwise be effectively disabled (default allow + empty map). API key wildcard (`"*"`) is explicit and intentionally bypasses per-action scope restrictions; use only for deliberate super-user keys. JWTs continue to use claim `scopes` / `scope`. + +**REST API (`bindings.rest_api`)** — `GET /health` is unauthenticated. All other routes (`/connectors/...`, `/playground/...`, `/scenarios/...`, OpenAPI) require **`NW_REST_API_KEY`** via `Authorization: Bearer ` or `X-API-Key: `, optional **`NW_REST_JWT_SECRET`** for HS256 JWTs. API key scopes use **`NW_REST_API_KEY_SCOPES`** (same format as MCP). Set **`NW_REST_AUTH_DISABLED=true`** only for local development. Production: set **`NW_REST_LOAD_DOTENV=false`** so secrets are not read from a `.env` file on disk. + +**HTTP Generic outbound policy** — `http_generic.request` allows only `GET`, `POST`, `PUT`, `PATCH`, `DELETE`, and input methods are normalized to uppercase before validation. URLs targeting internal destinations are rejected (`localhost`, loopback, private/link-local IP ranges, metadata endpoints). Connector logs sanitize URL fields by dropping query strings and fragments so only scheme/host/path are retained. + +**Connector entry points** — Any installed distribution may register `node_wire.connectors`. For production, set **`NW_ALLOWED_CONNECTORS`** to a comma-separated list of entry point names (e.g. `fhir_epic,http_generic`). **`NW_CONNECTOR_MODULE_PREFIX`** defaults to `node_wire_`; modules not under that prefix are skipped. + +**Secrets** — `EnvSecretProvider` looks up the key **as given**, then **`key.upper()`** (e.g. `my_key` then `MY_KEY`). It raises **`SecretNotFoundError`** when a variable is missing (fail-closed). Set **`NW_ENV_SECRET_LEGACY_EMPTY=true`** only if you need legacy empty-string behaviour. **`NW_SECRET_BACKEND=aws_env`** with **`NW_AWS_SECRETS_MANAGER_SECRET_ID`** composes AWS Secrets Manager JSON + env fallback via `ChainedSecretProvider` (see `bindings.factory._build_secret_provider`). + +--- + +## Related documentation + +- [packaging.md](packaging.md) — Wheel build lifecycle, PyPI publish flow, client install model, secrets config, and pre-publish checklist. +- [mcp-servers.md](mcp-servers.md) — MCP images, ToolHive, env vars. +- [google_drive_connector.md](google_drive_connector.md) — Drive REST API and setup. +- [salesforce_connector.md](salesforce_connector.md) — Salesforce CRM operations and playground. +- [slack_connector.md](slack_connector.md) — Slack bot token and setup. +- Per-connector READMEs under `src/node_wire_*/README.md` where present. + diff --git a/docs/google_drive_connector.md b/docs/google_drive_connector.md index a66f116..08c9e9d 100644 --- a/docs/google_drive_connector.md +++ b/docs/google_drive_connector.md @@ -1,3 +1,9 @@ + + # Google Drive Connector This document covers the Google Drive connector under `connectors/google_drive` in two parts: @@ -5,7 +11,7 @@ This document covers the Google Drive connector under `connectors/google_drive` 1. **[Google Drive service account setup](#google-drive-service-account-setup)** — Create a GCP service account, enable the Drive API, configure `.env`, share a folder, and verify connectivity. 2. **[REST API reference](#rest-api-reference)** — The `execute` action, all seven operations, request/response shapes, and the platform error taxonomy. -For **MCP** (e.g. ToolHive), the connector is exposed as the `google_drive_upload_file` tool. End-to-end agent setup is documented in [docs/toolhive_agent_scenario.md](toolhive_agent_scenario.md). +For **MCP** (e.g. ToolHive), tools are named `google_drive.` from the connector manifest (e.g. `google_drive.files.upload`). End-to-end agent setup is documented in [docs/toolhive_agent_scenario.md](toolhive_agent_scenario.md). --- @@ -19,7 +25,7 @@ This guide walks you through creating a Google Cloud service account and connect - A Google account - Access to [Google Cloud Console](https://console.cloud.google.com/) -- Node Wire installed and configured (see [Setup.md](../Setup.md)) +- Node Wire installed and configured (see [Installation Guide](installation.md)) ### Step 1: Create or Select a GCP Project @@ -339,7 +345,7 @@ The service account must have edit permission on the file. #### files.upload -Create a new file with text content. +Create a new file with content (text or binary). Request body: @@ -353,14 +359,21 @@ Request body: } ``` +For **MCP** (`google_drive.files.upload`), omit `action` in the tool arguments object; the server injects `files.upload` from the tool name. The published `inputSchema` does not include an `action` property. + Fields: - `name` (string, required). - `mime_type` (string, required). - `parents` (array of string, optional). -- `content` (string, required): UTF-8 text content that will be uploaded. +- `content` (string, optional): UTF-8 text content that will be uploaded. +- `content_base64` (string, optional): base64-encoded binary content (e.g. PDFs, images). + +Exactly one of `content` or `content_base64` must be provided (enforced at validation). + +Content is uploaded using `MediaInMemoryUpload`; this is suitable for small payloads. -Content is uploaded using `MediaInMemoryUpload`; this is suitable for small text payloads. +> For MCP callers (e.g. ToolHive): use canonical fields (`content` / `content_base64`). Legacy `media` / `media_body` shapes are normalized when possible but are not part of the public schema. Legacy `action: "upload"` in the payload is deprecated; set `NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=reject` to hard-fail during rollout. #### files.delete diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..f2b8bff --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,110 @@ +# Installation Guide + +## Prerequisites + +| Requirement | Version | Notes | +|-------------|---------|-------| +| Python | 3.11+ | `python --version` to check | +| pip or uv | Latest | `pip install --upgrade pip` | +| Git | Any | To clone the repo | +| Docker | Latest | Only needed for ToolHive MCP deployment | +| Node.js | Any LTS | Only needed for MCP Inspector | + +--- + +## Installation Steps + +### 1. Clone the repository +```bash +git clone +cd +``` + +### 2. Configure +Copy the sample environment file and add your `NW_ALLOWED_CONNECTORS`: +```bash +# Linux/macOS/PowerShell +cp sample.env .env + +# Windows (CMD) +copy sample.env .env +``` +*(Edit `.env` and set `NW_ALLOWED_CONNECTORS=http_generic` or others)* + +Node Wire uses a fail-closed connector allowlist. If `NW_ALLOWED_CONNECTORS` is missing or empty, no connectors are loaded even when they are enabled in `config/connectors.yaml`. + +### 3. Install dependencies + +**Using `uv` (recommended):** +```bash +uv sync --extra agents +``` + +**Using `pip`:** +- Full install (including AI agents): `pip install -e ".[agents]"` +- Minimal install (REST/gRPC only): `pip install -e .` +- Dev install (linting/tests): `pip install -e ".[dev,agents]"` + +### 4. Verify the installation +```bash +uv run node-wire --help +``` + +--- + +## Running the Platform + +Node Wire supports REST, gRPC, and MCP entry modes: + +| Mode | Command | Default port / transport | Use case | +|------|---------|--------------------------|----------| +| REST API | `uv run node-wire` | `8000` | HTTP clients, Swagger UI, playground | +| gRPC | `MODE=GRPC uv run node-wire` | `50051` | gRPC clients | +| MCP | `python -m agents.mcp_entrypoint` | `stdio` or HTTP | AI agents, ToolHive, Inspector | + +### REST quick start + +```bash +# Local development only +export NW_REST_AUTH_DISABLED=true + +# Start the API +uv run node-wire +``` + +Once it is running: + +- Health check: `GET http://localhost:8000/health` +- Swagger UI: `http://localhost:8000/docs` +- Playground: `http://localhost:8000/playground/` + +### MCP notes + +For MCP transport modes, Inspector usage, and multi-server deployment: + +- See [mcp.md](mcp.md) for transport setup and local MCP usage. +- See [mcp-servers.md](mcp-servers.md) for per-connector images, ToolHive, and Docker-based MCP deployment. + +--- + +## Development Setup + +### Code Quality (Linting & Formatting) +We use **Ruff** for linting/formatting and **Mypy** for type checking. + +- **Check:** `ruff check .` +- **Fix:** `ruff check --fix . && ruff format .` +- **Types:** `mypy` + +`mypy` defaults to the `[tool.mypy].files` targets from `pyproject.toml`. To include tests explicitly, run `mypy src tests`. + +### Pre-commit Hooks +```bash +pre-commit install +``` + +### Running Tests +```bash +pytest tests/ -v +``` +Integration tests are skipped unless the relevant environment variables (secrets) are set. diff --git a/docs/local-packages-to-images.md b/docs/local-packages-to-images.md new file mode 100644 index 0000000..b66361a --- /dev/null +++ b/docs/local-packages-to-images.md @@ -0,0 +1,148 @@ + + +# Local package -> Docker image workflow + +This guide walks through building Node Wire packages locally (as wheels) and using those wheels to build the Docker images in `docker/`. + +The Dockerfiles in this repo install local wheel artifacts from `packages/**/dist/*.whl`, so **you must build wheels first**. + +--- + +## Prerequisites + +- Python 3.12 available in your shell +- Docker installed and running +- Build tooling installed: + +```bash +python -m pip install --upgrade build cython wheel +``` + +Run all commands from the repository root: + + + +--- + +## 1) Build wheel packages locally + +Build all runtime + connector wheels: + +```bash +bash scripts/build-packages.sh +``` + +Build only specific packages (faster when iterating): + +```bash +bash scripts/build-packages.sh \ + packages/runtime \ + packages/connectors/smtp \ + packages/connectors/stripe +``` + +The script (`scripts/build-packages.sh` in default mode, not `--all`): +- builds host wheels and Linux-compatible wheels (via Docker), +- writes artifacts under each package's `dist/` folder, +- fails if any `.py` source files leak into a wheel. + +For optional local `cibuildwheel` builds (broader wheel matrix on your host), see **Optional: broader wheels** in [docs/packaging.md](packaging.md). + +--- + +## 2) Confirm wheel artifacts exist + +Quick check (example for SMTP): + +```bash +ls packages/runtime/dist/*.whl +ls packages/connectors/smtp/dist/*.whl +ls packages/connectors/stripe/dist/*.whl +``` + +If `ls` fails, rebuild that package before continuing. + +--- + +## 3) Build Docker images from local wheels + +### Build all MCP connector images + +```bash +./scripts/build-mcp-images.sh +``` + +With explicit version tag: + +```bash +./scripts/build-mcp-images.sh --version 0.1.0 +``` + +This builds: +- `nw-google-drive` +- `nw-smartonfhir-epic` +- `nw-smartonfhir-cerner` +- `nw-smtp` +- `nw-stripe` + +### Build one image manually + +```bash +docker build -f docker/smtp/Dockerfile -t nw-smtp:local . +``` + +--- + +## Wheel requirements by image + +Each Dockerfile expects specific wheel files to exist in `dist/`: + +| Image | Required wheels | +|---|---| +| `docker/smtp/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/smtp/dist/*.whl` | +| `docker/google-drive/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/google_drive/dist/*.whl` | +| `docker/fhir-epic/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/fhir_epic/dist/*.whl` | +| `docker/fhir-cerner/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/fhir_cerner/dist/*.whl` | +| `docker/stripe/Dockerfile` | `packages/runtime/dist/*.whl`, `packages/connectors/stripe/dist/*.whl` | +| `Dockerfile` (unified MCP server) | runtime + all connector wheels (`http_generic`, `stripe`, `smtp`, `google_drive`, `fhir_epic`, `fhir_cerner`) | + +--- + +## Common failures and fixes + +### `COPY ... dist/*.whl` failed: no source files were specified + +A required wheel is missing. Re-run `scripts/build-packages.sh` for the missing package(s), then rebuild the image. + +### Docker build cannot find `src/` or `config/` + +Use repo root as build context (`.`): + +```bash +docker build -f docker/smtp/Dockerfile -t nw-smtp:local . +``` + +Do not run `docker build` from inside `docker//`. + +### Docker daemon not running + +Start Docker Desktop (or daemon) and retry package/image builds. + +--- + +## Recommended local loop + +```bash +# 1) Rebuild changed packages +bash scripts/build-packages.sh packages/runtime packages/connectors/smtp + +# 2) Build image(s) +docker build -f docker/smtp/Dockerfile -t nw-smtp:local . + +# 3) Verify image exists +docker images --filter reference=nw-smtp +``` diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 1dfe8de..03fe2c5 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -1,3 +1,9 @@ + + # Node Wire — Individual MCP Servers This document covers everything needed to build, run, configure, and integrate the per-connector MCP servers with ToolHive and the Agentic Workflow. @@ -8,6 +14,8 @@ This document covers everything needed to build, run, configure, and integrate t - [Architecture](#architecture) - [Naming conventions](#naming-conventions) +- [Shifting between transport modes](#shifting-between-transport-modes) +- [Testing with MCP Inspector](#testing-with-mcp-inspector) - [Environment configuration](#environment-configuration) - [Build images](#build-images) - [Run with docker-compose](#run-with-docker-compose) @@ -30,23 +38,218 @@ flowchart TD Epic[nw-smartonfhir-epic] Cerner[nw-smartonfhir-cerner] SMTP[nw-smtp] + Stripe[nw-stripe] + Salesforce[nw-salesforce] + Slack[nw-slack] end Agent -->|"TOOLHIVE_MCP_URLS"| GDrive Agent -->|"TOOLHIVE_MCP_URLS"| Epic Agent -->|"TOOLHIVE_MCP_URLS"| Cerner Agent -->|"TOOLHIVE_MCP_URLS"| SMTP + Agent -->|"TOOLHIVE_MCP_URLS"| Stripe + Agent -->|"TOOLHIVE_MCP_URLS"| Salesforce + Agent -->|"TOOLHIVE_MCP_URLS"| Slack ``` --- ## Naming conventions -| Connector | Python entrypoint | Docker image | ToolHive name | MCP tool(s) exposed | +| Connector | Python entrypoint | Docker image | ToolHive name | MCP tools exposed | |---|---|---|---|---| -| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_upload_file` | -| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient` | -| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | -| SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp_send_email` | +| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | All manifest actions for `google_drive` (names `google_drive.`, e.g. `google_drive.files.upload`) | +| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | All manifest actions for `fhir_epic` (e.g. `fhir_epic.read_patient`) | +| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | All manifest actions for `fhir_cerner` (e.g. `fhir_cerner.read_patient`) | +| SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp.send_email` | +| Stripe | `python -m agents.stripe_mcp` | `nw-stripe` | `nw-stripe` | All manifest actions for `stripe` (e.g., `stripe.charge`) | +| Salesforce | `python -m agents.salesforce_mcp` | `nw-salesforce` | `nw-salesforce` | All manifest actions for `salesforce` (e.g., `salesforce.create_lead`) | +| Slack | `python -m agents.slack_mcp` | `nw-slack` | `nw-slack` | All manifest actions for `slack` (e.g. `slack.post_message`) | + + +The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, `stripe.create_payment_intent`, `stripe.create_subscription`, `stripe.cancel_subscription`, `stripe.issue_refund`, plus the rows above). + +### Tool arguments and security + +- Tool name (`.`) determines the routed action; do not rely on a separate `action` field in the JSON body to select a different operation. +- Per-action normalizers in `src/node_wire_runtime/mcp_normalizers.py` map common LLM mistakes to canonical schema fields; see also `src/node_wire_runtime/ingress.py` for shared MCP/REST behavior. +- **`tools/list` input schemas** omit the `action` field (manifest contract v2+). Pass only the fields shown in `inputSchema`; the server injects `action` from the tool name. + +**Legacy rollout (Google Drive `google_drive.files.upload` only):** + +| Variable | Values | Purpose | +|----------|--------|---------| +| `NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD` | `warn` (default), `allow` (same mapping, no deprecation log), `reject` | Legacy payload `action: "upload"` for `google_drive.files.upload`. Use `reject` once all clients omit `action` or use canonical `files.upload` only in pre-invoke validation paths. | + +--- + +## Shifting between transport modes + +Node Wire now supports two ways to expose tools to AI agents. By default, it uses `stdio`, but you can easily shift to the native `streamable-http` mode for web-native deployments. + +### Comparison: stdio vs. streamable-http + +| Feature | stdio (Default) | streamable-http | +|---|---|---| +| **Best For** | ToolHive, local development, subprocess-based clients | Direct web integration, persistent servers, remote MCP clients | +| **Connectivity** | Standard input/output | HTTP POST plus server streaming support | +| **Port Management** | Not applicable | Requires an open port (default: 8081) | +| **Playground behavior** | Buffered agent response after the backend finishes | Tool steps appear as they complete; final answer streams into the UI | + +### How to configure and shift modes + +You can switch modes and ports instantly using environment variables. No code changes are required. + +#### 1. Running in stdio mode (Default) +No extra variables are needed. This is the mode expected by local stdio clients and ToolHive-style stdio wrapping. + +```bash +python -m agents.mcp_entrypoint +``` + +PowerShell: + +```powershell +$env:NW_MCP_TRANSPORT="stdio" +# Using uv +uv run node-wire + +# Using python +python -m bindings_entrypoint +``` + +#### 2. Shifting to native HTTP mode (Port 8081) +To run as a standalone HTTP server on port 8081: + +**PowerShell (Windows):** +```powershell +$env:NW_MCP_TRANSPORT="streamable-http" +$env:NW_MCP_HOST="127.0.0.1" +$env:NW_MCP_PORT="8081" +$env:NW_MCP_PATH="/mcp" +# Using uv +uv run node-wire + +# Using python +python -m bindings_entrypoint +``` + +**Bash (Linux/macOS):** +```bash +export NW_MCP_TRANSPORT="streamable-http" +export NW_MCP_HOST="127.0.0.1" +export NW_MCP_PORT="8081" +export NW_MCP_PATH="/mcp" +# Using uv +uv run node-wire + +# Using python +python -m bindings_entrypoint +``` + +The native HTTP endpoint will be: + +```text +http://127.0.0.1:8081/mcp +``` + +### Protocol-level requirements +When running in `streamable-http` mode, clients must comply with the strict MCP Streamable-HTTP specification: +- **Headers**: Clients must send `Accept: application/json, text/event-stream` on all requests. +- **Handshake**: The server will respond with a `Mcp-Session-Id` header which must be forwarded in all subsequent messages for that session. +- **Auth boundary**: Node Wire enforces MCP auth at the HTTP edge for the streamable endpoint (`/mcp`) before MCP handler dispatch. Missing/invalid credentials are rejected early with 401/403/503. + +### Production authz baseline (recommended) + +Use these settings for production-style posture: + +```env +NW_MCP_AUTH_DISABLED=false +NW_MCP_SCOPE_POLICY_DEFAULT=deny +# Optional guardrail: fail startup if scope policy would be disabled +NW_MCP_SCOPE_POLICY_STRICT=true +``` + +Notes: +- `NW_MCP_SCOPE_POLICY_DEFAULT=deny` enforces fallback scope `mcp:.` even when no explicit action map is present. +- Keep `NW_MCP_ACTION_SCOPE_MAP_JSON` for custom scope names across tools. +- API keys with `NW_MCP_API_KEY_SCOPES=*` are super-user keys by design and bypass per-action scope checks. + +### Playground transport indicator + +The browser playground reads `/scenarios/agent-transport` and displays the current mode in the Agentic Workflow panel: + +- `Transport: stdio`: chat uses the buffered `/scenarios/agent-chat` endpoint. Tool cards and the final answer appear after the backend agent run completes. +- `Transport: Streamable HTTP`: chat uses `/scenarios/agent-chat-stream`. Tool cards appear as each MCP tool finishes, and the final answer is appended to the assistant bubble as streamed chunks. + +If you switch `NW_MCP_TRANSPORT`, restart the API server and hard refresh the browser so the latest `app.js` is loaded. + +--- + +## Testing with MCP Inspector + +MCP Inspector is the official browser-based developer tool for testing and debugging MCP servers. It runs with `npx` and opens a local UI, usually at `http://localhost:6274`. + +### Inspect stdio mode + +Use stdio mode when you want Inspector to launch the Python MCP server process itself: + +```powershell +$env:NW_MCP_TRANSPORT="stdio" +npx @modelcontextprotocol/inspector uv run python -m agents.mcp_entrypoint +``` + +Per-connector examples: + +```powershell +npx @modelcontextprotocol/inspector uv run nw-google-drive +npx @modelcontextprotocol/inspector uv run nw-smartonfhir-epic +npx @modelcontextprotocol/inspector uv run nw-smartonfhir-cerner +npx @modelcontextprotocol/inspector uv run python -m agents.smtp_mcp +``` + +In the Inspector UI: + +1. Select `stdio` transport if it is not already selected. +2. Click `Connect`. +3. Open the `Tools` tab. +4. Click `List Tools`. +5. Pick a safe tool and run it with valid JSON arguments. + +### Inspect streamable-http mode + +Start the MCP server first: + +```powershell +$env:NW_MCP_TRANSPORT="streamable-http" +$env:NW_MCP_HOST="127.0.0.1" +$env:NW_MCP_PORT="8081" +$env:NW_MCP_PATH="/mcp" +python -m agents.mcp_entrypoint +``` + +Then start Inspector in another terminal: + +```powershell +npx @modelcontextprotocol/inspector +``` + +In the Inspector UI: + +1. Set transport type to `Streamable HTTP`. +2. Set URL to `http://127.0.0.1:8081/mcp`. +3. Click `Connect`. +4. Open `Tools`. +5. Click `List Tools`. +6. Run a tool call with valid arguments. + +For reusable client config, a streamable HTTP server entry should look like: + +```json +{ + "type": "streamable-http", + "url": "http://127.0.0.1:8081/mcp" +} +``` --- @@ -58,6 +261,12 @@ Copy `sample.env` to `.env` and fill in the sections for the servers you plan to cp sample.env .env ``` +### Shared Required Variables + +| Variable | Description | +|---|---| +| `NW_ALLOWED_CONNECTORS` | **Required.** Comma-separated list of allowed connector names (e.g. `fhir_epic,google_drive`). Individual servers still check this allowlist before loading. | + ### Per-server required variables #### `nw-google-drive` @@ -118,7 +327,7 @@ Register your application at the [Cerner Developer Portal](https://code.cerner.c #### `nw-smtp` -The SMTP MCP server exposes one tool: `smtp_send_email`. When running under ToolHive, inject these as secrets: +The SMTP MCP server exposes one tool: `smtp.send_email`. When running under ToolHive, inject these as secrets: | Variable | Description | |---|---| @@ -138,6 +347,47 @@ SMTP_PASSWORD=your-gmail-app-password FROM_EMAIL=your-email@gmail.com ``` +#### `nw-stripe` + +| Variable | Description | +|---|---| +| `STRIPE_API_KEY` | Your Stripe secret API key (starts with `sk_test_` or `sk_live_`) | + +```env +STRIPE_API_KEY=sk_test_4eC39HqLyjWDarjtT1zdp7dc +``` + +#### `nw-salesforce` + +| Variable | Description | +|---|---| +| `SALESFORCE_INSTANCE_URL` | Your Salesforce instance URL (e.g., `https://domain.my.salesforce.com`) | +| `SALESFORCE_TOKEN_URL` | OAuth2 token endpoint (usually `https://login.salesforce.com/services/oauth2/token`) | +| `SALESFORCE_CLIENT_ID` | Connected App Client ID | +| `SALESFORCE_CLIENT_SECRET` | Connected App Client Secret | +| `SALESFORCE_REFRESH_TOKEN` | Refresh token with `refresh_token` and `api` scopes | + +```env +SALESFORCE_INSTANCE_URL=https://nodenet.my.salesforce.com +SALESFORCE_TOKEN_URL=https://login.salesforce.com/services/oauth2/token +SALESFORCE_CLIENT_ID=your-client-id +SALESFORCE_CLIENT_SECRET=your-client-secret +SALESFORCE_REFRESH_TOKEN=your-refresh-token +``` + + +#### `nw-slack` + +| Variable | Description | +|---|---| +| `SLACK_BOT_TOKEN` | Slack Bot User OAuth Token (`xoxb-...`) | +| `NW_SLACK_ATTACHMENTS_DIR` | Optional: sandboxed directory for uploads (default: `/slack_attachments`) | + +```env +SLACK_BOT_TOKEN=xoxb-your-bot-token +NW_SLACK_ATTACHMENTS_DIR=/slack_attachments +``` + ### ToolHive / Agent settings | Variable | Description | @@ -165,7 +415,9 @@ GROQ_API_KEY=your-groq-api-key ## Build images -All four images are built from the repository root using the automation script: +Before building images, build local wheels first. See [docs/local-packages-to-images.md](local-packages-to-images.md) for the full package -> image workflow and required wheel artifacts per image. + +All MCP server images are built from the repository root using the automation script: ```bash ./scripts/build-mcp-images.sh @@ -185,6 +437,9 @@ This produces images tagged as both `latest` and the version string: | `nw-smartonfhir-epic` | `nw-smartonfhir-epic:latest`, `nw-smartonfhir-epic:0.1.0` | | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner:latest`, `nw-smartonfhir-cerner:0.1.0` | | `nw-smtp` | `nw-smtp:latest`, `nw-smtp:0.1.0` | +| `nw-stripe` | `nw-stripe:latest`, `nw-stripe:0.1.0` | +| `nw-salesforce` | `nw-salesforce:latest`, `nw-salesforce:0.1.0` | +| `nw-slack` | `nw-slack:latest`, `nw-slack:0.1.0` | To build a single image manually from the repo root: @@ -200,6 +455,15 @@ docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner:latest . # SMTP only docker build -f docker/smtp/Dockerfile -t nw-smtp:latest . + +# Stripe only +docker build -f docker/stripe/Dockerfile -t nw-stripe:latest . + +# Salesforce only +docker build -f docker/salesforce/Dockerfile -t nw-salesforce:latest . + +# Slack only +docker build -f docker/slack/Dockerfile -t nw-slack:latest . ``` > **Note:** The build context must be the repository root (`.`) so the `COPY src/` and `COPY config/` instructions resolve correctly. @@ -208,17 +472,18 @@ docker build -f docker/smtp/Dockerfile -t nw-smtp:latest . ## Run with docker-compose -`docker-compose.mcp.yml` starts all three MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. +`docker-compose.mcp.yml` starts all MCP servers as stdio containers in one command. This is useful for local validation before configuring ToolHive. +Each service pins `NW_ALLOWED_CONNECTORS` to its own connector so a broad value in `.env` does not make per-connector images import optional dependencies they do not contain. ```bash -# Ensure your .env is populated, then: -docker compose -f docker-compose.mcp.yml up +# Ensure local wheels exist and your .env is populated, then: +docker compose -f docker-compose.mcp.yml up --build ``` To start only a specific server: ```bash -docker compose -f docker-compose.mcp.yml up nw-smartonfhir-epic +docker compose -f docker-compose.mcp.yml up --build nw-smartonfhir-epic ``` --- @@ -262,6 +527,25 @@ thv run --name nw-smtp --transport stdio \ --secret SMTP_PASSWORD,target=SMTP_PASSWORD \ --secret FROM_EMAIL,target=FROM_EMAIL \ nw-smtp:latest + +# Stripe +thv run --name nw-stripe --transport stdio \ + --secret STRIPE_API_KEY,target=STRIPE_API_KEY \ + nw-stripe:latest + +# Salesforce +thv run --name nw-salesforce --transport stdio \ + --secret SALESFORCE_INSTANCE_URL,target=SALESFORCE_INSTANCE_URL \ + --secret SALESFORCE_CLIENT_ID,target=SALESFORCE_CLIENT_ID \ + --secret SALESFORCE_CLIENT_SECRET,target=SALESFORCE_CLIENT_SECRET \ + --secret SALESFORCE_USERNAME,target=SALESFORCE_USERNAME \ + --secret SALESFORCE_PASSWORD,target=SALESFORCE_PASSWORD \ + nw-salesforce:latest + +# Slack +thv run --name nw-slack --transport stdio \ + --secret SLACK_BOT_TOKEN,target=SLACK_BOT_TOKEN \ + nw-slack:latest ``` > **Google Drive + ToolHive:** Set `GOOGLE_DRIVE_SA_JSON` to the JSON *contents* (not a file path) when storing in ToolHive secrets, because ToolHive injects secrets as string values. @@ -342,3 +626,21 @@ python -m agents.toolhive --local --patient-id 12724066 --recipient-email you@ex | `fhir_cerner connector not configured` | Missing Cerner env vars | Ensure all `CERNER_*` variables are set and non-empty | | Docker build fails with `COPY src/ not found` | Wrong build context | Always run `docker build` from the **repository root**, not from `docker//` | | Image healthcheck fails | Import error at startup | Run `docker logs ` to see the Python traceback; usually a missing env var | + +## Rollout verification checklist + +Use this checklist when promoting streamable-http MCP to production: + +1. Confirm edge auth gate behavior: + - No token -> `401 MCP_AUTH_REQUIRED` + - Invalid token -> `403 MCP_AUTH_INVALID` + - Valid token -> request reaches MCP handlers +2. Confirm scope baseline: + - `NW_MCP_SCOPE_POLICY_DEFAULT=deny` is set in deployed env + - Optionally enforce `NW_MCP_SCOPE_POLICY_STRICT=true` +3. Confirm authorization telemetry: + - Track trends for 401/403 and `POLICY_DENIED` responses after rollout + - Verify expected tool visibility changes in `tools/list` for scoped identities +4. Confirm privileged-key controls: + - Any API key with wildcard scope (`*`) is documented and approved + - Non-admin API keys use minimal scopes only diff --git a/docs/mcp.md b/docs/mcp.md new file mode 100644 index 0000000..3efc06f --- /dev/null +++ b/docs/mcp.md @@ -0,0 +1,117 @@ +# Model Context Protocol (MCP) in Node Wire + +Node Wire integrates with the Model Context Protocol to allow AI agents (like Claude or custom LLM orchestrators) to discover and use connectors as tools. + +## Transport Modes + +Switch between transports using the `NW_MCP_TRANSPORT` environment variable. + +### 1. `stdio` (Default) +Communicates via standard I/O. Best for local development and subprocess-based clients. +**Bash (Linux/macOS):** +```bash +# Using uv +NW_MCP_TRANSPORT=stdio uv run python -m agents.mcp_entrypoint + +# Using python +NW_MCP_TRANSPORT=stdio python -m agents.mcp_entrypoint +``` + +**PowerShell (Windows):** +```powershell +# Using uv +$env:NW_MCP_TRANSPORT="stdio"; uv run python -m agents.mcp_entrypoint + +# Using python +$env:NW_MCP_TRANSPORT="stdio"; python -m agents.mcp_entrypoint +``` + +### 2. `streamable-http` +Native HTTP MCP server using SSE (Server-Sent Events). +**Bash (Linux/macOS):** +```bash +# Using uv +NW_MCP_TRANSPORT=streamable-http NW_MCP_HOST=127.0.0.1 NW_MCP_PORT=8081 NW_MCP_PATH=/mcp uv run python -m agents.mcp_entrypoint + +# Using python +NW_MCP_TRANSPORT=streamable-http NW_MCP_HOST=127.0.0.1 NW_MCP_PORT=8081 NW_MCP_PATH=/mcp python -m agents.mcp_entrypoint +``` + +**PowerShell (Windows):** +```powershell +# Using uv +$env:NW_MCP_TRANSPORT="streamable-http"; $env:NW_MCP_HOST="127.0.0.1"; $env:NW_MCP_PORT="8081"; $env:NW_MCP_PATH="/mcp"; uv run python -m agents.mcp_entrypoint + +# Using python +$env:NW_MCP_TRANSPORT="streamable-http"; $env:NW_MCP_HOST="127.0.0.1"; $env:NW_MCP_PORT="8081"; $env:NW_MCP_PATH="/mcp"; python -m agents.mcp_entrypoint +``` + +### Streaming Features +- **Configurable Buffering (`NW_STREAM_BUFFER_MS`)**: When streaming, output can be buffered to reduce event spam. Set to the duration in milliseconds (e.g., `2000` for a 2-second batching window). Default is `0` (no buffering). +- **Completion Signals**: The core runtime emits structured "done" signals (`stream_completion_log`) via Python logging when streaming ends, allowing package consumers to easily detect when a stream finishes. + +--- + +## Testing with MCP Inspector + +The [MCP Inspector](https://github.com/modelcontextprotocol/inspector) is the best way to validate your MCP tools locally. + +### Testing stdio +```bash +npx @modelcontextprotocol/inspector uv run python -m agents.mcp_entrypoint + +# Using python +npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint +``` + +### Testing streamable-http +1. Start the server (as shown above). +2. Run the inspector: +```bash +npx @modelcontextprotocol/inspector +``` +3. In the UI, select **Streamable HTTP** and connect to `http://127.0.0.1:8081/mcp`. + +--- + +## Deployment Modes + +Node Wire supports two ways to expose tools via MCP: + +### 1. Combined MCP Server +All connectors enabled for MCP in `config/connectors.yaml` are exposed from a single process. +```bash +# Using uv +uv run python -m agents.mcp_entrypoint + +# Using python +python -m agents.mcp_entrypoint +``` + +### 2. Individual MCP Servers +Each connector runs as its own independent MCP server (often in a dedicated Docker container). This is preferred for modular, scalable deployments. +- **Full Guide:** [Individual MCP Servers (Docker)](mcp-servers.md) + +--- + +## FHIR Tool Arguments (Cerner / Epic) + +Tool names follow the pattern `fhir_cerner.` and `fhir_epic.`. The MCP server normalizes common LLM aliases (e.g., `patientId` → `resource_id`). + +| Action | When to use | Example arguments | +|--------|-------------|-------------------| +| `read_patient` | You have a Patient ID | `{"resource_id": "12724066"}` | +| `search_patients` | No ID, or name-based search | `{"given_name": "Nancy", "family_name": "Smart"}` | +| `search_encounter` | Find medical visits | `{"patient_id": "12724066"}` | + +--- + +## Connector Manifests + +Each connector defines a manifest that MCP uses to understand available tools. +- Tool names follow the pattern: `.` (e.g., `google_drive.files.list`). +- The runtime handles argument normalization, so LLM-friendly aliases often work automatically. + +## Related Docs +- [Individual MCP Servers (Docker)](mcp-servers.md) +- [ToolHive Agent Scenario](toolhive_agent_scenario.md) diff --git a/docs/packaging.md b/docs/packaging.md new file mode 100644 index 0000000..a9c176d --- /dev/null +++ b/docs/packaging.md @@ -0,0 +1,221 @@ + + +# Packaging & Publishing + +Node Wire ships as **seven independent PyPI packages** built from a single monorepo. All wheels are binary-only (Cython-compiled `.so`/`.pyd` files) — no `.py` source is included in any published wheel. + +--- + +## Package inventory + +| PyPI name | Source path | Entry-point key | +|---|---|---| +| `node-wire-runtime` | `src/node_wire_runtime/` | — (no entry point; this is the runtime) | +| `node-wire-fhir-cerner` | `src/node_wire_fhir_cerner/` | `fhir_cerner` | +| `node-wire-fhir-epic` | `src/node_wire_fhir_epic/` | `fhir_epic` | +| `node-wire-google-drive` | `src/node_wire_google_drive/` | `google_drive` | +| `node-wire-http-generic` | `src/node_wire_http_generic/` | `http_generic` | +| `node-wire-smtp` | `src/node_wire_smtp/` | `smtp` | +| `node-wire-stripe` | `src/node_wire_stripe/` | `stripe` | + +Each connector's `pyproject.toml` lives at `packages/connectors//pyproject.toml`; the runtime's is at `packages/runtime/pyproject.toml`. + +--- + +## Python package build lifecycle + +Prerequisites: `pip install build cython wheel` (and a usable `python` on the host). Run `bash scripts/build-packages.sh --help` for usage. + +### Build all packages (default) + +```bash +bash scripts/build-packages.sh +``` + +Default mode builds each of the **seven** known package paths (see inventory above): `python -m build --wheel` on the **host**, then again inside **Docker** (`python:3.12-slim`) so you get Linux-tagged wheels suitable for containers. **Docker must be installed and the daemon running.** After each package, the script scans every produced wheel and fails if any `.py` file appears inside the archive. + + +### Artifact layout and safe command usage + +`scripts/build-packages.sh` writes wheels per package under `packages/**/dist/` (there is no single repo-root `dist/` output). + +Before using wildcard wheel commands, clear old wheel artifacts so commands do not accidentally match stale versions: + +```bash +rm -f packages/runtime/dist/*.whl +rm -f packages/connectors/stripe/dist/*.whl +``` + +### Build a single package + +```bash +bash scripts/build-packages.sh packages/connectors/stripe +``` + +### Optional: broader wheels with cibuildwheel (`--all`) + +For additional platform wheels from your **current machine** (whatever `cibuildwheel` can target there), install it and use the same script: + +```bash +python -m pip install 'cibuildwheel>=2.16.0' +bash scripts/build-packages.sh --all +bash scripts/build-packages.sh --all packages/runtime +``` + +`CIBW_BUILD` / `CIBW_SKIP` default to the same patterns as `.github/workflows/publish.yml` unless you override them in the environment. Full Linux + macOS + Windows coverage is still best done in CI, not guaranteed from one laptop. + +### Inspect wheel contents + +After building, confirm no source leaks: + +```bash +unzip -l packages/connectors/stripe/dist/node_wire_stripe-*.whl +# Must show .so/.pyd files only — no .py files +``` + +### Install from wheels and verify entry points + +```bash +# Install into an active (clean) virtual env +pip install \ + packages/runtime/dist/node_wire_runtime-*.whl \ + packages/connectors/stripe/dist/node_wire_stripe-*.whl + +# Confirm entry points registered +python -c " +from importlib.metadata import entry_points +print(list(entry_points(group='node_wire.connectors'))) +" +``` + +### Verify connector loading + +```bash +python -c " +from node_wire_runtime.connector_registry import auto_register +loaded = auto_register() +print('Loaded:', loaded) +" +``` + +--- + +## Client consumption model + +A downstream client installs only what it needs: + +```bash +pip install node-wire-runtime node-wire-stripe node-wire-fhir-epic +``` + +At startup, `auto_register()` discovers all installed connectors via the `node_wire.connectors` [entry-point group](https://packaging.python.org/en/latest/specifications/entry-points/) — no explicit import list required. + +### Runtime loading knobs + +| Env var | Default | Purpose | +|---|---|---| +| `NW_ALLOWED_CONNECTORS` | _(all discovered)_ | Comma-separated allowlist of entry-point names (e.g. `stripe,fhir_epic`). In development, leave unset to load everything. In production, set explicitly. | +| `NW_CONNECTOR_MODULE_PREFIX` | `node_wire_` | Connectors whose target module doesn't start with this prefix are skipped with a warning. Set to `""` to disable the check. | + +--- + +## `connectors.yaml` and secrets + +### Minimal `connectors.yaml` + +```yaml +connectors: + stripe: + enabled: true + exposed_via: ["mcp"] + fhir_epic: + enabled: false + exposed_via: [] +``` + +`enabled` gates whether the connector is instantiated. `exposed_via` controls which protocols (`rest`, `grpc`, `mcp`) surface it. A connector that is installed but `enabled: false` will not run. + +See `config/connectors.yaml` for the full working example and `src/node_wire_runtime/connectors.yaml.sample` for a commented template with all supported fields. + +For per-connector detail (operations, env vars, request/response shapes) see `docs/connectors.md` and each connector's `README.md` under `src/node_wire_/`. + +### Secret backend (`NW_SECRET_BACKEND`) + +| Value | Behavior | +|---|---| +| `env` _(default)_ | Reads from process environment. Raises `SecretNotFoundError` for absent keys (fail-closed). | +| `aws_env` | Tries AWS Secrets Manager JSON bundle first; falls back to env on `SecretNotFoundError`. Propagates `SecretProviderError` immediately (broken provider is never silently swallowed). | + +Required env vars for `aws_env`: + +- `NW_AWS_SECRETS_MANAGER_SECRET_ID` — secret name or ARN (required) +- `AWS_REGION` — defaults to `us-east-1` + +**Legacy flag:** `NW_ENV_SECRET_LEGACY_EMPTY=true` returns `""` for missing keys instead of raising. This exists for backwards compatibility only — do not use in production. + +Additional cloud backends (`vault`, `azure`, `gcp`) ship as optional extras in `node-wire-runtime` but are not currently wired into the factory: + +```bash +pip install "node-wire-runtime[aws]" # boto3 +pip install "node-wire-runtime[vault]" # hvac +pip install "node-wire-runtime[azure]" # azure-keyvault-secrets +pip install "node-wire-runtime[gcp]" # google-cloud-secret-manager +``` + +--- + +## CI publish flow (Trusted Publisher) + +**Workflow:** `.github/workflows/publish.yml` — manual `workflow_dispatch`. + +**Required inputs:** + +| Input | Example | Notes | +|---|---|---| +| `package_path` | `packages/connectors/stripe` | Must match an entry in the workflow's allowlist | +| `version` | `0.1.0` | Must match `[project].version` in the package's `pyproject.toml` | + +**Pipeline steps:** + +1. Validate `package_path` against allowlist (prevents path traversal) +2. Matrix-build wheels on Ubuntu, macOS, Windows via `cibuildwheel` (Python 3.11, 3.12; Linux manylinux + aarch64, macOS x86_64 + arm64, Windows amd64) +3. Post-build gate: verify zero `.py` files per wheel; record SHA256 checksums +4. Merge artifacts; `pip-audit --fail-on HIGH` CVE gate +5. Generate SBOM via `cyclonedx-py` +6. Publish to PyPI via OIDC Trusted Publisher with Sigstore attestations (all action SHAs pinned for immutability) + +--- + +## Docker demo images + +The `docker/*/Dockerfile` images are **demonstration templates** for packaging a single connector as a standalone MCP server. They are not production orchestration artefacts. + +For a local end-to-end walkthrough (build wheels first, then build Docker images that consume those wheels), see [docs/local-packages-to-images.md](local-packages-to-images.md). + +```bash +docker build -f docker/smtp/Dockerfile -t nw-smtp . +docker build -f docker/google-drive/Dockerfile -t nw-google-drive . +docker build -f docker/fhir-epic/Dockerfile -t nw-smartonfhir-epic . +docker build -f docker/fhir-cerner/Dockerfile -t nw-smartonfhir-cerner . +docker build -f docker/stripe/Dockerfile -t nw-stripe . +``` + +For compose and ToolHive registration see `docs/mcp-servers.md`. + +--- + +## Pre-PyPI local validation checklist + +Run these gates before triggering the CI publish workflow (default `build-packages.sh` is enough; `--all` is optional for broader local wheels): + +- [ ] `bash scripts/build-packages.sh` exits 0 +- [ ] `unzip -l packages//dist/*.whl` shows no `.py` files +- [ ] Install wheels into a clean venv; confirm entry points resolve +- [ ] `auto_register()` loads expected connectors +- [ ] `pytest tests/test_connector_registry.py tests/test_connectors_basic.py` passes +- [ ] Wheel SHA256 checksums recorded and match expected values +- [ ] `package_path` and `version` inputs match the allowlist and `pyproject.toml` version before dispatching the workflow diff --git a/docs/privacy.md b/docs/privacy.md new file mode 100644 index 0000000..bfe883e --- /dev/null +++ b/docs/privacy.md @@ -0,0 +1,34 @@ +# Privacy Policy and Compliance + +The Node-Wire project is committed to ensuring privacy and secure data handling out-of-the-box. As a framework facilitating the orchestration of integrations between Large Language Models (LLMs) and various enterprise/healthcare systems, Node-Wire adheres to strict principles to prevent inadvertent data exposure. + +## Core Privacy Principles + +1. **No Telemetry or Phone Home:** + The Node-Wire open-source framework does not collect, transmit, or store any usage data, telemetry, or analytics. It operates entirely within the infrastructure where it is deployed. + +2. **No Data Persistence by Default:** + Node-Wire acts as an orchestration and routing layer. It does not contain a built-in database for persistent storage of transaction data, logs, or payloads. Any data persistence must be explicitly configured by the user via connectors (e.g., storing a file in Google Drive). + +3. **Zero PII/PHI in Source Control:** + The repository is routinely audited to ensure no Personally Identifiable Information (PII) or Protected Health Information (PHI) is committed to source control. + +## Testing and Dummy Data + +All unit tests, integration tests, and example scenarios within the `tests/` and `playground/` directories strictly utilize fabricated placeholder data. + +- **Dummy Emails:** `doc@example.com`, `patient@example.com`, `noreply@node-wire.local` +- **Dummy Patient IDs:** `12724066`, `eXYZ123` +- **Dummy Credentials:** Credentials in tests use explicit `dummy` or `test` prefixes (e.g., `sk_test_dummy`). + +If you are contributing to Node-Wire, you **must** ensure that no real data from your environment is included in your commits. + +## Logging + +By default, Node-Wire logging is configured to provide operational visibility without exposing sensitive payloads. However, when running the MCP Server or REST API in `DEBUG` mode, certain raw HTTP requests and responses may be logged for troubleshooting. + +**Guidance:** Do not run Node-Wire in `DEBUG` logging mode in production environments to prevent the accidental leakage of sensitive data into system logs. + +## Security Disclosures + +If you discover a potential privacy or security vulnerability within Node-Wire, please do not disclose it publicly. Refer to our [Security Policy](security-gap-report.md) for instructions on how to securely report issues to the maintainers. diff --git a/docs/quality-security-gates.md b/docs/quality-security-gates.md new file mode 100644 index 0000000..3858c9a --- /dev/null +++ b/docs/quality-security-gates.md @@ -0,0 +1,179 @@ +# Quality and security gates + +This document defines how Node Wire enforces security scanning and SonarQube analysis in CI, plus the SonarQube Community Edition setup required for centralized reporting. + +This repository enforces security gates at both PR time and publish time. + +## CI quality gates + +Workflow: `.github/workflows/quality-gates.yml` + +Runs on every pull request and on pushes to `main`/`master`. + +Required jobs: + +- `bandit`: writes `bandit-report.json` (with `--exit-zero` so low/medium findings do not fail the job before the gate), prints a log summary, uploads the artifact, then fails only on **high**-severity findings in the enforce step. +- `test`: runs `pytest` and produces `coverage.xml`. +- `sonar`: runs SonarQube scan and waits for quality gate result (runs after `bandit` and `test`). + +Required checks to add in branch protection: + +- `Quality gates / Bandit security scan` +- `Quality gates / Tests and coverage` +- `Python package security PR checks / Vulnerability scan (packages/runtime)` +- `Python package security PR checks / Vulnerability scan (packages/connectors/http_generic)` +- `Python package security PR checks / Vulnerability scan (packages/connectors/stripe)` +- `Python package security PR checks / Vulnerability scan (packages/connectors/smtp)` +- `Python package security PR checks / Vulnerability scan (packages/connectors/google_drive)` +- `Python package security PR checks / Vulnerability scan (packages/connectors/fhir_cerner)` +- `Python package security PR checks / Vulnerability scan (packages/connectors/fhir_epic)` + +Configure branch protection so pull requests cannot merge unless all required checks pass. + +## CVE scanning policy + +- PR and push-to-main scanning runs in `.github/workflows/security-pr.yml`. +- Release-time scanning remains in `.github/workflows/publish.yml` as defense in depth. +- `pip-audit --fail-on HIGH` is the vulnerability gate threshold. +- Scheduled scans catch newly disclosed CVEs even when code does not change. + +**Monorepo install note:** Connector packages under `packages/connectors/*` declare `node-wire-runtime>=0.1.0` as a normal PyPI dependency name. The security workflow installs `packages/runtime` from the checkout **together with** each matrix package (`pip install packages/runtime ""`) so `pip` can resolve `node-wire-runtime` without requiring a published wheel on PyPI. Locally, mirror that when auditing a single connector: `pip install packages/runtime packages/connectors/`. + +## Run checks locally + +```bash +# Install dev tools +pip install -e ".[dev,agents]" + +# Security gate (matches CI failure threshold) +bandit -c pyproject.toml -r src --severity-level high + +# Optional: JSON report + same summary as CI logs +bandit -c pyproject.toml -r src -f json -o bandit-report.json --exit-zero +python scripts/bandit_report_summary.py bandit-report.json + +# Tests + coverage.xml (required by SonarQube) +pytest tests/ -v +``` + +## Deterministic pytest environment + +To keep pytest collection and REST app startup deterministic, `tests/conftest.py` sets a fixed environment before imports: + +- `NW_REST_LOAD_DOTENV=false` so REST startup does not merge a repo-root `.env` over test variables. +- `NW_CONFIG_PATH=tests/fixtures/connectors_for_tests.yaml` so optional connectors outside the pytest allowlist remain `enabled: false` (for example `slack` and `salesforce`). +- `NW_ALLOWED_CONNECTORS=http_generic,smtp,stripe,google_drive,fhir_epic,fhir_cerner` so only the supported test connector set is loaded during collection. + +Do not rely on `.env` values during pytest collection. The test harness intentionally overrides them so local developer state does not affect CI or test outcomes. + +### Pre-commit + +```bash +pre-commit install +pre-commit run --all-files +``` + +## Local Sonar scan with Docker + +After generating `coverage.xml`, run scanner from the repository root: + +```bash +docker run --rm \ + -e SONAR_TOKEN=YOUR_TOKEN \ + -v "G:\SPACE\node-wire:/usr/src" \ + -w /usr/src \ + sonarsource/sonar-scanner-cli \ + -Dsonar.host.url=http://host.docker.internal:9000 \ + -Dsonar.token=YOUR_TOKEN +``` + +## SonarQube configuration + +The repository includes `sonar-project.properties` and CI expects these GitHub secrets: + +- `SONAR_HOST_URL` (example: `https://sonarqube.company.internal`) +- `SONAR_TOKEN` (project analysis token) + +For server setup and quality gate policy details, see this document's [SonarQube Community Edition setup](#sonarqube-community-edition-setup) section. + +## Bandit policy + +Bandit is configured in `pyproject.toml` under `[tool.bandit]`. + +### Exit codes and CI behavior + +By default, **Bandit exits with a non-zero status whenever it reports any finding**, including low and medium severity. That affects `-f json -o ...` the same as text output. + +CI splits responsibilities: + +1. **JSON artifact + log summary** — `bandit ... -f json -o bandit-report.json --exit-zero` so the workflow always produces the report and runs `scripts/bandit_report_summary.py` for readable logs. Low/medium issues are visible here and in Sonar/import without failing the job. +2. **Enforcement** — `bandit ... --severity-level high` fails the job only on high-severity findings (matches branch-protection intent). + +Locally, mirror CI with the commands in [Run checks locally](#run-checks-locally). + +### Scope + +Policy: + +- Scan target: `src/` (runtime, bindings, in-tree connector implementations installed via the root package). +- Exclude: `.venv`, `venv`, `tests`, `playground`, `dist`, `htmlcov`. +- CI enforcement threshold: `--severity-level high`. +- **Packages tree:** connector distributions under `packages/connectors/*` are audited for CVEs in `.github/workflows/security-pr.yml` (`pip-audit`). Run Bandit against those paths separately if you need SAST on a standalone checkout. + +If legacy findings block adoption, create a baseline once and track deltas: + +```bash +bandit -c pyproject.toml -r src -f json -o bandit-baseline.json --exit-zero +bandit -c pyproject.toml -r src --baseline bandit-baseline.json --severity-level high +``` + +## SonarQube Community Edition setup + +### 1) Run SonarQube CE (example Docker) + +```bash +docker volume create sonarqube_data +docker volume create sonarqube_logs +docker volume create sonarqube_extensions + +docker run -d --name sonarqube \ + -p 9000:9000 \ + -v sonarqube_data:/opt/sonarqube/data \ + -v sonarqube_logs:/opt/sonarqube/logs \ + -v sonarqube_extensions:/opt/sonarqube/extensions \ + sonarqube:lts-community +``` + +For production, place SonarQube behind HTTPS/reverse proxy and persistent backup strategy. + +### 2) Create project and token + +1. Open SonarQube UI (`http://:9000`). +2. Create project key `node-wire` (or update `sonar-project.properties` if using a different key). +3. Generate project analysis token. + +### 3) Configure GitHub secrets + +In repository settings, add: + +- `SONAR_HOST_URL` +- `SONAR_TOKEN` + +### 4) Configure quality gate + +Create or update a quality gate to enforce at minimum: + +- No new blocker issues. +- No new critical vulnerabilities. +- Coverage on new code >= 80%. + +Attach the gate to the Node Wire project. + +## Acceptance criteria mapping + +- Security scan runs on every PR: enforced by `quality-gates.yml` (Bandit). +- Builds fail on high-severity Bandit findings: Bandit gate in CI. +- SonarQube dashboard visible: SonarQube CE project + scanner upload from CI. +- Coverage visible in SonarQube: `pytest-cov` generates `coverage.xml`, scanner consumes it via `sonar.python.coverage.reportPaths`. +- Developers run checks locally: documented commands and pre-commit (Bandit). +- Config version-controlled: `pyproject.toml`, `.pre-commit-config.yaml`, `sonar-project.properties`, workflow file. diff --git a/docs/salesforce_connector.md b/docs/salesforce_connector.md new file mode 100644 index 0000000..264c107 --- /dev/null +++ b/docs/salesforce_connector.md @@ -0,0 +1,95 @@ +# Salesforce Connector (`src/node_wire_salesforce`) + +The Salesforce connector provides a secure, asynchronous interface for managing CRM records (Leads and Contacts). It leverages Node Wire's `OAuth2AuthProvider` to handle token refresh automatically, allowing for seamless integration into agentic workflows and medical-to-CRM pipelines. + +## Capabilities + +The connector exposes full CRUD (Create, Read, Update, Delete) operations for the two most common Salesforce objects used in healthcare and enterprise outreach: + +| Action | Description | +|---|---| +| `create_lead` | Create a new Lead record. Requires `LastName` and `Company`. | +| `read_lead` | Fetch a single Lead record by ID. | +| `update_lead` | Update specific fields on an existing Lead. | +| `delete_lead` | Remove a Lead record. | +| `create_contact` | Create a new Contact record. Requires `LastName`. | +| `read_contact` | Fetch a single Contact record by ID. | +| `update_contact` | Update specific fields on an existing Contact. | +| `delete_contact` | Remove a Contact record. | + +## Configuration + +Add the following to your `config/connectors.yaml`: + +```yaml +connectors: + salesforce: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: oauth2 + grant_method: refresh_token + token_url_secret: SALESFORCE_TOKEN_URL + client_id_secret: SALESFORCE_CLIENT_ID + client_secret_secret: SALESFORCE_CLIENT_SECRET + refresh_token_secret: SALESFORCE_REFRESH_TOKEN +``` + +## Environment Variables + +The following secrets must be provided (e.g., in `.env` or via your secret manager): + +| Variable | Example | +|---|---| +| `SALESFORCE_INSTANCE_URL` | `https://your-domain.my.salesforce.com` | +| `SALESFORCE_TOKEN_URL` | `https://login.salesforce.com/services/oauth2/token` | +| `SALESFORCE_CLIENT_ID` | `3MVG9...` | +| `SALESFORCE_CLIENT_SECRET` | `A1B2...` | +| `SALESFORCE_REFRESH_TOKEN` | `5Aep...` | + +## Example Usage + +### REST API + +```bash +curl -X POST http://localhost:8000/connectors/salesforce/create_lead \ + -H "X-API-Key: your-key" \ + -H "Content-Type: application/json" \ + -d '{ + "LastName": "Doe", + "Company": "Acme Corp", + "Email": "john.doe@example.com", + "Status": "Open - Not Contacted" + }' +``` + +### Agentic (MCP) + +If registered via MCP, the agent can call `salesforce.create_lead` with the following arguments: + +```json +{ + "LastName": "Smith", + "Company": "HealthTech", + "Email": "jane@smith.com" +} +``` + +## Playground Interface + +The Node Wire playground includes a **CRM Synchronization** panel specifically for Salesforce. This interface allows you to: + +1. **Toggle between Lead and Contact management**: Use the action dropdown to switch contexts. +2. **Execute full CRUD operations**: The form dynamically adjusts based on whether you are creating, reading, updating, or deleting a record. +3. **Real-time Pipeline Visualization**: Watch the synchronization steps (Authentication → Fetch/Update → Verification) in real-time. +4. **Instant Record Validation**: See the exact Salesforce resource IDs and data returned by the API. + +Access the playground at `http://localhost:8000/playground` (when running locally). + +## Security Note + +- **OAuth2**: Tokens are never stored in plain text in logs. Node Wire's `AuthProvider` handles encryption and secure memory storage. +- **Refresh Token Support**: The connector is configured to use `grant_method: refresh_token`, ensuring it can stay authenticated for long-running agentic tasks. +- **Traceability**: All actions are logged with a `trace_id` for auditing and idempotency tracking. +- **PII Protection**: Ensure your logging levels are set correctly; by default, the connector logs the metadata of the transaction but not the full PII payload. + diff --git a/docs/security-gap-report.md b/docs/security-gap-report.md new file mode 100644 index 0000000..e9ee129 --- /dev/null +++ b/docs/security-gap-report.md @@ -0,0 +1,365 @@ + + +# Security & Architecture Gap Report — Node Wire MCP Platform + +> **Perspective:** Secure MCP server platform integrator reviewing the runtime and connectors for production readiness. +> **Date:** 2026-04-01 +> **Branch:** `feature/python-packages` + +--- + +## Executive Summary + +The platform has a solid foundation: clean layered architecture (runtime → connectors → bindings), Pydantic-enforced input validation, OpenTelemetry observability, and resilience patterns (circuit breaker, retry). However, **critical security gaps must be addressed before production use**, particularly around authentication, credential management, PHI/PII logging, and network security. + +**Finding counts:** 5 Critical · 10 High · 7 Medium · 5 Low + +--- + +## Severity Definitions + +| Severity | Meaning | +|----------|---------| +| **CRITICAL** | Exploit path exists now; immediate remediation required | +| **HIGH** | Significant attack surface; address before production | +| **MEDIUM** | Increases risk; address in next sprint | +| **LOW** | Best-practice gap; backlog item | + +--- + +## CRITICAL Findings + +### C1 — Credentials in `.env` Committed to Repository + +- **Location:** `.env` (repository root) +- **What's exposed:** `EPIC_PRIVATE_KEY` (RSA private key), `CERNER_PRIVATE_KEY`, `EPIC_CLIENT_ID`, `GROQ_API_KEY`, `SMTP_PASSWORD`, path to `connectorplatform-*.json` service account file +- **Impact:** Anyone with read access to the repo can impersonate the Epic/Cerner OAuth client, read Google Drive, send email as the platform, and call Groq +- **Fix:** + 1. Revoke all exposed credentials immediately and rotate + 2. Add `.env` and `connectorplatform-*.json` to `.gitignore` + 3. Move secrets to a secrets manager (HashiCorp Vault, AWS Secrets Manager, K8s Secrets) + +--- + +### C2 — PHI Logged in Error Paths (HIPAA Violation) + +- **Location:** `src/connectors/fhir_epic/logic.py` (~line 485), `src/connectors/fhir_cerner/logic.py` (~line 592) +- **What's logged:** Full FHIR `DocumentReference` payload on failure — contains patient names, birthdates, MRNs, diagnoses +- **Code pattern:** + ```python + logger.error("... sent_payload=%s", json.dumps(doc_ref)) + ``` +- **Impact:** Violates HIPAA § 164.312(b) audit controls; PHI written to log aggregation systems in plaintext +- **Fix:** Log only resource type, resource ID, and HTTP status code. Implement a `PHIScrubber` log filter for all healthcare connectors + +--- + +### C3 — No Authentication on REST API or gRPC Binding + +- **Location:** `src/bindings/rest_api/app.py`, `src/bindings/grpc_server/server.py` +- **What's missing:** Zero authentication or authorization on any endpoint +- **gRPC uses an insecure port:** + ```python + server.add_insecure_port(f"[::]:{port}") # no TLS, no mTLS + ``` +- **Impact:** Any network-adjacent caller can invoke any connector action with no audit trail +- **Fix:** + - REST: Add API key or OAuth2 bearer token middleware to FastAPI + - gRPC: Switch to `add_secure_port` with TLS credentials; enforce mTLS for service-to-service + +--- + +### C4 — SSRF via HTTP Generic Connector + +- **Location:** `src/connectors/http_generic/schema.py`, `src/connectors/http_generic/logic.py` +- **What's missing:** `HttpUrl` validates URL format but not destination host +- **Attack path:** + ```json + { "url": "http://169.254.169.254/latest/meta-data", "method": "GET" } + ``` + → Returns AWS instance metadata including IAM credentials +- **Fix:** Block RFC-1918, loopback (`127.0.0.0/8`), and link-local (`169.254.0.0/16`) address ranges at the schema validator level; optionally implement an egress allowlist + +--- + +### C5 — Configurable Secret Key Names in SMTP Connector + +- **Location:** `src/connectors/smtp/schema.py` (fields `username_secret_key`, `password_secret_key`) +- **Attack path:** Caller provides `"username_secret_key": "STRIPE_API_KEY"` → connector fetches the Stripe key and uses it as an SMTP credential; SMTP auth error may reveal whether the key value exists or its format +- **Impact:** Secret enumeration and partial exfiltration via error side-channel +- **Fix:** Hardcode secret key names inside the connector; remove `username_secret_key` and `password_secret_key` from the public input schema entirely + +--- + +## HIGH Findings + +### H1 — OAuth Error Response Body Logged + +- **Location:** `src/connectors/fhir_epic/logic.py` (~line 130), `src/connectors/fhir_cerner/logic.py` +- **What's logged:** Full `token_response.text` on OAuth failure — may include client credential reflections, token hints, or infrastructure error details +- **Fix:** Log only the `error` and `error_description` fields from the JSON response + +--- + +### H2 — Unvalidated HTTP Method in Generic Connector + +- **Location:** `src/connectors/http_generic/schema.py` +- **Current:** `method: str` — accepts any string value +- **Risk:** Arbitrary or non-standard HTTP methods forwarded to target servers; undefined server behavior +- **Fix:** + ```python + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] + ``` + +--- + +### H3 — Stripe Global API Key (Race Condition) + +- **Location:** `src/connectors/stripe/logic.py` (~line 48) +- **Pattern:** `stripe.api_key = api_key` mutates global module state +- **Risk:** Concurrent requests clobber each other's API key; Stripe exceptions may include the key value in tracebacks +- **Fix:** Pass `api_key=` explicitly to each Stripe API call rather than setting global state + +--- + +### H4 — Auto-Discovery Loads All Connector Modules Without Allowlist + +- **Location:** `src/connectors/__init__.py` (`auto_register()`) +- **Risk:** Any file placed at `src/connectors/*/logic.py` is imported and executed automatically on startup — no explicit allowlist +- **Fix:** Validate each discovered connector ID against an explicit allowlist from `connectors.yaml`; skip unknown modules with a logged warning + +--- + +### H5 — No Rate Limiting on Any Binding + +- **Location:** All binding layers (REST, gRPC, MCP) +- **Risk:** Unlimited invocation rate; no protection against DoS, credential-stuffing, or API quota exhaustion at upstream services +- **Fix:** Add `slowapi` (or equivalent) rate-limiting middleware; define per-tenant quotas in connector configuration + +--- + +### H6 — Circuit Breaker Shared Across Tenants + +- **Location:** `src/runtime/base_connector.py` (constructor — `self._breaker = CircuitBreaker(...)`) +- **Risk:** One tenant triggering repeated failures opens the circuit breaker for all tenants — classic noisy-neighbor DoS +- **Fix:** Key the circuit breaker on `(connector_id, tenant_id)` rather than connector class alone + +--- + +### H7 — Unvalidated Base64 Content in FHIR Connectors + +- **Location:** `src/connectors/fhir_cerner/logic.py`, `src/connectors/fhir_epic/logic.py` +- **Pattern:** + ```python + attachment["data"] = params.data # no format check, no size limit + ``` +- **Risk:** Malformed base64 forwarded to EHR; unbounded payload size enables memory exhaustion +- **Fix:** Add a Pydantic `field_validator` that calls `base64.b64decode(v, validate=True)` and enforces a max size (e.g., 10 MB) + +--- + +### H8 — No Dependency Vulnerability Scanning + +- **Location:** `pyproject.toml`, `uv.lock` +- **Gap:** No `pip-audit`, `safety`, or Dependabot configured; several deps use unbounded `>=VERSION` (e.g., `tenacity>=8.2.0`, `fastapi>=0.111.0`, `uvicorn>=0.30.0`) +- **Fix:** Add `pip-audit` to CI pipeline; add upper bounds to all runtime dependencies; add `detect-secrets` as a pre-commit hook + +--- + +### H9 — No Structured Audit Trail for Policy Hook Decisions + +- **Location:** `src/runtime/base_connector.py` (policy hook execution block) +- **Gap:** Policy denials are logged at `WARNING` level but not emitted as structured security events queryable by a SIEM +- **Fix:** Emit a structured `POLICY_DENIED` event containing `principal`, `tenant_id`, `connector_id`, `action`, and `reason` to a dedicated audit log sink + +--- + +### H10 — Zero Security Test Coverage + +- **Location:** `tests/` (all files) +- **Gaps:** No tests exist for authentication failures, SSRF attempts, malformed payloads, credential leakage in error messages, multi-tenant isolation, rate limit enforcement, or concurrent state safety +- **Fix:** Add a `tests/security/` suite covering at minimum: + - SSRF via HTTP Generic (`127.0.0.1`, `169.254.169.254`) + - Secret enumeration via SMTP `username_secret_key` + - PHI absence in FHIR error log output + - Circuit breaker isolation across tenants + +--- + +## MEDIUM Findings + +### M1 — Stripe Input Has No Validation Bounds + +- **Location:** `src/connectors/stripe/schema.py` +- `amount: int` — no minimum (1 cent) or maximum cap +- `currency: str` — no ISO 4217 pattern check +- **Fix:** + ```python + amount: int = Field(..., ge=1, le=999_999_999) + currency: str = Field(..., pattern=r'^[A-Z]{3}$') + ``` + +--- + +### M2 — Config Variable Substitution (`${VAR:default}`) Not Implemented + +- **Location:** `src/bindings/factory.py`; `config/connectors.yaml` uses `${VAR:default}` syntax in comments and values +- **Current behavior:** Variables are loaded as literal strings — the `${...}` is never expanded +- **Fix:** Implement regex-based substitution in the YAML loader; raise at startup if a required variable is unset + +--- + +### M3 — Factory Returns `None` on Missing Connector (Silent Failure) + +- **Location:** `src/bindings/factory.py` (`get_for_protocol()`) +- **Risk:** A misconfigured connector silently returns `None`; failures surface at request time rather than startup +- **Fix:** Validate all enabled connectors during factory initialization and raise immediately on any misconfiguration + +--- + +### M4 — Hardcoded Timeouts and Circuit Breaker Parameters + +- **Location:** `src/connectors/http_generic/logic.py` (`timeout=30.0`), `src/runtime/base_connector.py` (`fail_max=5, reset_timeout=30`) +- **Risk:** A 30-second timeout is inappropriate for large FHIR document uploads; a 5-failure threshold may be too sensitive for high-traffic deployments +- **Fix:** Expose these as per-connector configuration keys in `connectors.yaml` + +--- + +### M5 — OpenTelemetry Trace Data May Export PHI + +- **Location:** `src/runtime/observability.py` +- **Risk:** Span attributes populated by FHIR connectors may include patient identifiers that flow unfiltered to the OTLP collector +- **Fix:** Add a `SpanSanitizer` processor that removes or hashes known PHI field names before export + +--- + +### M6 — Google Drive `query` Parameter Accepts Arbitrary String + +- **Location:** `src/connectors/google_drive/schema.py` (`query: Optional[str]`) +- **Risk:** No client-side validation; the platform relies entirely on Google's server-side handling of malformed or adversarial query strings +- **Fix:** Document and enforce the allowed query syntax subset; reject queries that don't match a safe pattern + +--- + +### M7 — Service Account File Path Not Sandboxed + +- **Location:** `src/connectors/google_drive/logic.py` +- **Pattern:** `Credentials.from_service_account_file(raw_sa.strip())` — path is fully controlled by the env var +- **Risk:** Path traversal if the environment variable is tampered with +- **Fix:** Resolve the path with `Path.resolve()` and assert it falls within the application directory before opening + +--- + +## LOW Findings + +### L1 — SMTP Connector Logs Recipient Email Addresses + +- **Location:** `src/connectors/smtp/logic.py` +- `"from_email": str(params.from_email)` written to structured log output +- **Risk:** Email addresses are PII; log aggregators retain them indefinitely, creating a compliance liability +- **Fix:** Log only recipient count and domain (e.g., `example.com`), never full addresses + +--- + +### L2 — MCP Manifest Lacks Security Metadata + +- **Location:** `src/connectors/manifest.py` +- **Missing:** Required OAuth scopes, auth requirements, per-action rate limits, deprecation status +- **Impact:** LLM clients have no way to determine required permissions before invoking a tool +- **Fix:** Add an optional `security` block to each manifest entry describing required scopes and auth type + +--- + +### L3 — No MCP Prompt Templates Defined + +- **Location:** `src/bindings/mcp_server/server.py` +- **Gap:** The MCP spec supports pre-built prompt templates to guide safe, correct tool use +- **Risk:** Without templates, LLM clients must independently discover correct multi-step usage patterns (e.g., FHIR patient lookup → document create) +- **Fix:** Define prompt templates for common connector flows + +--- + +### L4 — No Sampling or Pagination Limits in MCP Binding + +- **Location:** `src/bindings/mcp_server/server.py` +- **Gap:** A single `files.list` or FHIR search with a large page size could return megabytes of data in one tool response +- **Fix:** Enforce maximum page sizes at the MCP binding layer; add streaming for large result sets + +--- + +### L5 — PEM Key Reconstruction Is Brittle + +- **Location:** `src/connectors/fhir_cerner/logic.py`, `src/connectors/fhir_epic/logic.py` +- **Pattern:** `private_key_str.replace("\\n", "\n")` to reconstruct a PEM key from env var +- **Risk:** Silently produces an invalid key if the env var format is wrong; error only surfaces at JWT signing time +- **Fix:** Parse and validate the key with the `cryptography` library at connector startup; reject the connector if the key is unparseable + +--- + +## Summary Table + +| ID | Category | Issue | Severity | +|----|----------|-------|----------| +| C1 | Credentials | `.env` with real secrets committed to repo | CRITICAL | +| C2 | Privacy | PHI logged in FHIR error paths | CRITICAL | +| C3 | AuthN/AuthZ | No authentication on REST or gRPC bindings | CRITICAL | +| C4 | Network | SSRF via HTTP Generic connector | CRITICAL | +| C5 | AuthN | Configurable secret key names in SMTP | CRITICAL | +| H1 | Privacy | OAuth error response body logged | HIGH | +| H2 | Validation | Unvalidated HTTP method in generic connector | HIGH | +| H3 | Concurrency | Stripe global API key mutation (race condition) | HIGH | +| H4 | Supply Chain | All connector modules auto-loaded without allowlist | HIGH | +| H5 | DoS | No rate limiting on any binding | HIGH | +| H6 | Isolation | Circuit breaker shared across all tenants | HIGH | +| H7 | Validation | Unvalidated base64 content in FHIR connectors | HIGH | +| H8 | Dependencies | No CVE scanning; unbounded version ranges | HIGH | +| H9 | Audit | Policy hook denials not structured/auditable | HIGH | +| H10 | Testing | Zero security test coverage | HIGH | +| M1 | Validation | Stripe amount/currency unbounded | MEDIUM | +| M2 | Config | `${VAR}` substitution not implemented | MEDIUM | +| M3 | Reliability | Silent `None` on missing connector config | MEDIUM | +| M4 | Config | Hardcoded timeouts and circuit breaker params | MEDIUM | +| M5 | Privacy | OTel traces may export PHI to collector | MEDIUM | +| M6 | Validation | Drive query accepts arbitrary string | MEDIUM | +| M7 | Path Safety | Service account file path not sandboxed | MEDIUM | +| L1 | Privacy | SMTP logs full recipient email addresses | LOW | +| L2 | MCP | Manifest lacks security metadata (scopes, auth) | LOW | +| L3 | MCP | No MCP prompt templates defined | LOW | +| L4 | MCP | No sampling/pagination limits in MCP binding | LOW | +| L5 | Reliability | Brittle PEM key reconstruction from env var | LOW | + +--- + +## Recommended Remediation Order + +### Immediate — before any external network access + +1. Revoke all credentials in `.env`; rotate FHIR private keys, Groq key, SMTP password (C1) +2. Remove PHI from FHIR error log lines (C2) +3. Add API key middleware to REST binding (C3) +4. Block RFC-1918 / loopback hosts in HTTP Generic URL validator (C4) +5. Hardcode SMTP secret key names; remove from input schema (C5) + +### Before production + +6. Add TLS + mTLS to gRPC server (C3 continuation) +7. Add connector allowlist validation in `auto_register()` (H4) +8. Add rate limiting middleware to all bindings (H5) +9. Add `pip-audit` to CI and pin dependency upper bounds (H8) +10. Write `tests/security/` suite (H10) + +### Next sprint + +11. Implement per-tenant circuit breakers (H6) +12. Add OTLP `SpanSanitizer` for PHI fields (M5) +13. Implement `${VAR:default}` config substitution (M2) +14. Add Stripe `amount`/`currency` field validators (M1) +15. Add manifest `security` metadata block (L2) + +--- + +*Generated: 2026-04-01 | Branch: feature/python-packages* diff --git a/docs/slack_connector.md b/docs/slack_connector.md new file mode 100644 index 0000000..e69c4a5 --- /dev/null +++ b/docs/slack_connector.md @@ -0,0 +1,153 @@ +# Slack Connector + +This document covers the Slack connector under `src/node_wire_slack` in two parts: + +1. **[Slack Bot Setup](#slack-bot-setup)** — Create a Slack app, configure OAuth scopes, and obtain your bot token. +2. **[REST API Reference](#rest-api-reference)** — Connector actions, request/response shapes, and flexible channel resolution. + +For **MCP** (e.g. ToolHive), tools are named `slack.` from the connector manifest (e.g. `slack.post_message`). + +--- + +## Slack Bot Setup + +The Slack connector uses a **Bot User OAuth Token** to interact with your workspace. + +### Prerequisites + +- A Slack workspace where you have permission to install apps. +- [Slack API Dashboard](https://api.slack.com/apps) access. + +### Step 1: Create a Slack App + +1. Go to [api.slack.com/apps](https://api.slack.com/apps) and click **Create New App**. +2. Select **From scratch**. +3. Give your app a name (e.g., `Node-Wire Connector`) and select your workspace. +4. Click **Create App**. + +### Step 2: Configure Scopes + +1. In the left sidebar, go to **OAuth & Permissions**. +2. Scroll down to **Scopes > Bot Token Scopes**. +3. Add the following scopes: + - `chat:write` — Send messages to channels and DMs. + - `files:write` — Upload and share files. + - `im:write` — Start direct messages with users. + - `groups:read` (optional) — If you need to post to private channels the bot is invited to. + - `channels:read` (optional) — If you need to resolve channel names. + +### Step 3: Install and Get Token + +1. Scroll back up to the top of the **OAuth & Permissions** page. +2. Click **Install to Workspace**. +3. Click **Allow** to authorize the bot. +4. Copy the **Bot User OAuth Token** (it starts with `xoxb-`). + +### Step 4: Configure the Connector + +Add the token to your `.env` file: + +```env +SLACK_BOT_TOKEN=xoxb-your-token-here +``` + +### Step 5: Invite the Bot (Important) + +Slack bots cannot "see" private channels unless they are explicitly invited. + +1. Go to the Slack channel you want the bot to use. +2. Type `/invite @YourAppName` and press Enter. + +--- + +## REST API Reference + +The connector exposes actions as standard REST endpoints. Channel identifiers are flexible and automatically resolved. + +### Operations overview + +- Connector ID: `slack` +- Base REST path: `POST /connectors/slack/{action}` + +### Actions + +#### `post_message` + +Send a message to a channel, group, or user. + +**Request body:** + +```json +{ + "channel": "#general", + "message": "Clinical alert: Patient summary available.", + "blocks": [ + { + "type": "section", + "text": { "type": "mrkdwn", "text": "*Emergency Update*: BP 180/110" } + } + ] +} +``` + +**Channel Resolution:** +- **Channel Name**: Starts with `#` (e.g., `#general`). +- **Channel ID**: Starts with `C` or `G` (e.g., `C12345`). +- **User ID**: Starts with `U` or `W` (e.g., `U12345`). Automatically resolved to a DM channel. + +#### `send_direct_message` + +A specialized action for DMs. If targeted at a User ID, the connector ensures the DM channel is open before posting. + +**Request body:** + +```json +{ + "channel": "U12345678", + "message": "You have a new lab result to review." +} +``` + +#### `upload_file` + +Uploads a file to a Slack channel or DM. + +**Request body (Base64):** + +```json +{ + "channel": "C12345678", + "filename": "labs.pdf", + "content_base64": "JVBER...", + "initial_comment": "Here is the PDF summary." +} +``` + +**Request body (Filesystem):** + +```json +{ + "channel": "U12345678", + "filename": "summary.pdf", + "filepath": "/slack_attachments/p_123.pdf" +} +``` + +> **Note:** `filepath` must be within the directory defined by `NW_SLACK_ATTACHMENTS_DIR` (default `/slack_attachments`). + +### Error Taxonomy + +| Category | Platform Code | Cause | +|---|---|---| +| `AUTH` | `SLACK_AUTH_ERROR` | Invalid or revoked token | +| `AUTH` | `SLACK_PERMISSION_ERROR` | Missing OAuth scope | +| `RETRYABLE` | `SLACK_RATE_LIMIT` | Slack rate limit (429) | +| `BUSINESS` | `SLACK_MESSAGE_ERROR` | Channel not found or invalid payload | +| `BUSINESS` | `SLACK_UPLOAD_ERROR` | File too large or bad content | + +--- + +### Related + +- Individual MCP Servers: [docs/mcp-servers.md](mcp-servers.md) +- Connector Architecture: [docs/connectors.md](connectors.md) diff --git a/docs/toolhive_agent_scenario.md b/docs/toolhive_agent_scenario.md index 666c39b..7b2e58b 100644 --- a/docs/toolhive_agent_scenario.md +++ b/docs/toolhive_agent_scenario.md @@ -1,3 +1,9 @@ + + # ToolHive Agent Scenario: FHIR → Google Drive → Email > **End-to-end guide for running Node Wire as an MCP server on ToolHive, and connecting an AI agent to orchestrate healthcare and enterprise workflows.** @@ -36,10 +42,11 @@ This guide walks you through running the platform as an MCP server using ToolHiv ``` ToolHive UI ────────────────────────────────────────────────────── │ MCP Server (Docker): node-wire │ -│ ├── Tool: fhir_cerner_read_patient ← fetch patient from Cerner │ -│ ├── Tool: fhir_epic_read_patient ← fetch patient from Epic │ -│ ├── Tool: google_drive_upload_file ← write file to Drive │ -│ └── Tool: smtp_send_email ← email the summary │ +│ ├── Tool: fhir_cerner.read_patient ← fetch patient from Cerner │ +│ ├── Tool: fhir_epic.read_patient ← fetch patient from Epic │ +│ ├── Tool: google_drive.files.upload ← write file to Drive │ +│ ├── Tool: stripe.charge ← process payment │ +│ └── Tool: smtp.send_email ← email the summary │ │ ↕ stdio → HTTP proxy │ ────────────────────────────────────────────────────────────────── ↕ MCP JSON-RPC over HTTP @@ -62,6 +69,7 @@ For modular deployments, each connector can be run as an independent MCP server - `nw-google-drive` (Google Drive) - `nw-smartonfhir-epic` (Epic SMART on FHIR) - `nw-smartonfhir-cerner` (Cerner SMART on FHIR) +- `nw-smtp` (SMTP email) When running multiple MCP servers, configure the agent with **`TOOLHIVE_MCP_URLS`** (comma-separated list of ToolHive proxy URLs). The agent will merge tools across servers. @@ -82,14 +90,15 @@ You can think of it as a local "MCP server manager" — you register your server ## What does the Node Wire MCP server expose? -When running as an MCP server, the platform exposes 4 tools that AI agents can discover and call: +When running **this scenario’s** minimal multi-connector stack (one MCP server per connector image registered in ToolHive), agents typically see **four** tools (Cerner read patient, Epic read patient, Drive upload, SMTP send). The **unified** MCP server (`python -m agents.mcp_entrypoint`) exposes **all** manifest actions for every connector enabled for MCP in `config/connectors.yaml` (often 18+ tools). This section describes the **four-tool** happy path; see [mcp-servers.md](mcp-servers.md) for the full surface. | Tool | Description | |---|---| -| `fhir_cerner_read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | -| `fhir_epic_read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | -| `google_drive_upload_file` | Create and upload a text file to Google Drive | -| `smtp_send_email` | Send an email via SMTP | +| `fhir_cerner.read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | +| `fhir_epic.read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | +| `google_drive.files.upload` | Create and upload a text file to Google Drive | +| `stripe.charge` | Process a payment | +| `smtp.send_email` | Send an email via SMTP | The agent uses an LLM's tool-calling capability to decide which tools to call, in what order, and with what parameters. @@ -135,7 +144,7 @@ Below is the full set of environment variables used by the connector platform an | `GROQ_API_KEY` | LLM (Groq) | Your Groq API key | | `GROQ_MODEL` | LLM | Example: `openai/gpt-oss-120b` | | `MCP_TRANSPORT` | ToolHive / local | `stdio` when running in ToolHive container | -| `PYTHONPATH` | Runtime | e.g. `/app/src` for container; `d:\connector-platform\src` locally | +| `PYTHONPATH` | Runtime | e.g. `/app/src` for container; `**/node-wire/src` locally | | `SMTP_HOST` | SMTP connector | Example: `sandbox.smtp.mailtrap.io` | | `SMTP_PORT` | SMTP connector | Example: `2525` | | `SMTP_USERNAME` | SMTP connector | Mailtrap / SMTP user | @@ -160,7 +169,7 @@ Option A — Recommended: ToolHive UI (no code) Option B — Local quick run (Windows PowerShell) -Prerequisite: Install Python 3.10+ and Git. If you cannot install, ask an administrator to run Option A. +Prerequisite: Install Python 3.11+ and Git. If you cannot install, ask an administrator to run Option A. 1. Open PowerShell and clone or navigate to the project folder. 2. Create a simple `.env` file in the project root (replace placeholder values): @@ -204,8 +213,6 @@ Notes for non-developers: From the root of the repository: ```bash -cd connector-platform - docker build -t node-wire:latest . ``` @@ -317,7 +324,7 @@ In the ToolHive UI under **Installed**, you should see: |---|---| | Name | `node-wire-connectors` | | Status | `Running` | -| Tools | `fhir_cerner_read_patient`, `fhir_epic_read_patient`, `google_drive_upload_file`, `smtp_send_email` | +| Tools | Manifest-driven `.` (e.g. `fhir_cerner.read_patient`, `fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`; unified server also lists Stripe, HTTP generic, and other MCP-enabled connectors) | | Endpoint | `http://localhost:/sse` | --- @@ -404,11 +411,11 @@ I have completed all three steps: 3. Sent a summary email to your-email@example.com with a link to the file. Steps executed (3): - ✓ Step 1: fhir_cerner_read_patient + ✓ Step 1: fhir_cerner.read_patient result : {"patient_id": "123*****", "full_name": "Nancy Smart", ...} - ✓ Step 2: google_drive_upload_file + ✓ Step 2: google_drive.files.upload result : {"file_id": "1XYZ...", "web_view_link": "https://docs.google.com/..."} - ✓ Step 3: smtp_send_email + ✓ Step 3: smtp.send_email result : {"sent": true} ``` @@ -504,7 +511,7 @@ In Cursor's MCP settings, add the same endpoint URL. The tools will appear in th | `google_drive connector: authentication failed` | `GOOGLE_DRIVE_SA_JSON` is a file path, not JSON content | For ToolHive, paste the actual JSON *contents* of the file (not the file path) as the secret value; for local `.env`, use an absolute path to the JSON file per [Google Drive service account setup](google_drive_connector.md#google-drive-service-account-setup) | | `SMTP authentication failed` | Wrong username or password | For Gmail, use an App Password not your regular password; confirm `SMTP_USERNAME` includes `@` | | `groq SDK not installed` | Missing optional dependency | `pip install -e ".[agents]"` | -| Agent loops forever without completing | LLM reasoning issue | Try increasing `--max-steps`; try a different LLM provider; check that all four tools are visible in ToolHive | +| Agent loops forever without completing | LLM reasoning issue | Try increasing `--max-steps`; try a different LLM provider; check that the expected tools are visible in ToolHive (`tools/list`); refresh after MCP image upgrades | | `docker: Cannot connect to the Docker daemon` | Docker not running | Start Docker Desktop | | Container starts but shows 0 tools | MCP server failed to start | Check container logs: `docker logs `; verify the image built successfully | @@ -519,33 +526,21 @@ pip install -e ".[dev,agents]" pytest tests/test_toolhive_agent.py -v ``` -Expected output (all 9 tests passing): - -``` -tests/test_toolhive_agent.py::test_llm_factory_groq_created PASSED -tests/test_toolhive_agent.py::test_llm_factory_openai_created PASSED -tests/test_toolhive_agent.py::test_llm_factory_unknown_raises PASSED -tests/test_toolhive_agent.py::test_llm_factory_case_insensitive PASSED -tests/test_toolhive_agent.py::test_agent_runs_three_tool_sequence PASSED -tests/test_toolhive_agent.py::test_agent_respects_max_steps PASSED -tests/test_toolhive_agent.py::test_agent_handles_tool_error_graceful PASSED -tests/test_toolhive_agent.py::test_agent_fails_when_mcp_unreachable PASSED -tests/test_toolhive_agent.py::test_mcp_entrypoint_registers_three_to PASSED -``` +Expect every test collected from `tests/test_toolhive_agent.py` to pass (names and count change as the suite evolves). If a test fails, re-run with `-v` and compare against the current definitions in that file. --- ## File layout (`agents`) ``` -connector-platform/ +node-wire/ ├── Dockerfile ← Docker image for ToolHive ├── pyproject.toml ← [agents] extras added ├── sample.env ← env var reference └── src/ └── agents/ ├── __init__.py - ├── mcp_entrypoint.py ← FastMCP server (4 tools) + ├── mcp_entrypoint.py ← MCP stdio server (manifest; all MCP connectors) ├── toolhive.py ← ReAct agent + CLI ├── llm_factory.py ← Provider factory └── providers/ @@ -559,5 +554,6 @@ connector-platform/ ## Related documentation -- [Setup.md](../Setup.md) — Full platform setup guide +- [Installation Guide](installation.md) — Full platform setup guide +- [Configuration Guide](configuration.md) — Environment variables and settings - [google_drive_connector.md](google_drive_connector.md) — Google Drive service account setup and REST API reference diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 0000000..b864f5d --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,24 @@ +# Troubleshooting Guide + +## Common Errors & Fixes + +| Problem | Likely Cause | Fix | +|---------|--------------|-----| +| **Port 8000 in use** | Another process is using the default REST port. | Set `PORT=8001` (or any free port) before starting the platform. | +| **Connector "not configured"** | Connector is disabled or not exposed. | Confirm `enabled: true` and `exposed_via` include your protocol in `config/connectors.yaml`. | +| **Auth Failure (Google Drive)** | Incorrect credential format. | In ToolHive, `GOOGLE_DRIVE_SA_JSON` must be the JSON **contents**. Locally, it can be an absolute path. | +| **"Invalid port: PORT"** | Environment variable not parsed correctly. | Ensure `PORT` or `NW_MCP_PORT` is set to a valid integer (e.g., `8081`). | +| **No connectors loaded** | `NW_ALLOWED_CONNECTORS` is missing. | **Required.** Set `NW_ALLOWED_CONNECTORS` to a comma-separated list of connectors to enable. | + +--- + +## Logging & Debugging + +### REST API +Check the console output where `uv run node-wire` is running. It logs incoming requests and standard error taxonomy mappings. + +### MCP (stdio) +In `stdio` mode, the server communicates over standard I/O. Any `print()` statements in the code will break the protocol. Use Python's `logging` module to log to `stderr` or a file. + +### OpenTelemetry +If configured, check your OpenTelemetry collector (e.g., Jaeger) for traces with `trace_id` from the `ConnectorResponse`. diff --git a/grafana/Connector Logs & Status - Updated-1773917850709.json b/grafana/Connector Logs & Status - Updated-1773917850709.json deleted file mode 100644 index 08b2a3a..0000000 --- a/grafana/Connector Logs & Status - Updated-1773917850709.json +++ /dev/null @@ -1,304 +0,0 @@ -{ - "annotations": { - "list": [ - { - "builtIn": 1, - "datasource": { - "type": "grafana", - "uid": "-- Grafana --" - }, - "enable": true, - "hide": true, - "iconColor": "rgba(0, 211, 255, 1)", - "name": "Annotations & Alerts", - "type": "dashboard" - } - ] - }, - "description": "Real-time log monitoring with success rate and status distribution", - "editable": true, - "fiscalYearStartMonth": 0, - "graphTooltip": 0, - "links": [], - "panels": [ - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "fieldConfig": { - "defaults": { - "color": { - "mode": "thresholds" - }, - "mappings": [], - "thresholds": { - "mode": "absolute", - "steps": [ - { - "color": "red", - "value": 0 - }, - { - "color": "orange", - "value": 80 - }, - { - "color": "green", - "value": 95 - } - ] - }, - "unit": "percent" - }, - "overrides": [] - }, - "gridPos": { - "h": 6, - "w": 12, - "x": 0, - "y": 0 - }, - "id": 1, - "options": { - "colorMode": "value", - "graphMode": "none", - "justifyMode": "auto", - "orientation": "auto", - "percentChangeColorMode": "standard", - "reduceOptions": { - "calcs": [ - "lastNotNull" - ], - "fields": "", - "values": false - }, - "showPercentChange": false, - "textMode": "auto", - "wideLayout": true - }, - "pluginVersion": "12.4.1", - "targets": [ - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "direction": "backward", - "editorMode": "code", - "expr": "sum(count_over_time({service_name=\"aot-connector-platform\"} | logfmt | connector_id =~ \"$connector_type.*\" |= \"completed successfully\" [$__range])) / sum(count_over_time({service_name=\"aot-connector-platform\"} | logfmt | connector_id =~ \"$connector_type.*\" |= \"Starting connector execution\" [$__range])) * 100", - "queryType": "range", - "refId": "A" - } - ], - "title": "Success Rate", - "type": "stat" - }, - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "fieldConfig": { - "defaults": { - "color": { - "mode": "palette-classic" - }, - "custom": { - "hideFrom": { - "legend": false, - "tooltip": false, - "viz": false - } - }, - "mappings": [] - }, - "overrides": [ - { - "matcher": { - "id": "byName", - "options": "Success" - }, - "properties": [ - { - "id": "color", - "value": { - "fixedColor": "green", - "mode": "fixed" - } - } - ] - }, - { - "matcher": { - "id": "byName", - "options": "Error" - }, - "properties": [ - { - "id": "color", - "value": { - "fixedColor": "red", - "mode": "fixed" - } - } - ] - } - ] - }, - "gridPos": { - "h": 6, - "w": 12, - "x": 12, - "y": 0 - }, - "id": 2, - "options": { - "displayLabels": [ - "percent" - ], - "legend": { - "displayMode": "list", - "placement": "bottom", - "showLegend": true, - "values": [ - "value", - "percent" - ] - }, - "pieType": "donut", - "reduceOptions": { - "calcs": [ - "lastNotNull" - ], - "fields": "", - "values": false - }, - "sort": "desc", - "tooltip": { - "hideZeros": false, - "mode": "single", - "sort": "none" - } - }, - "pluginVersion": "12.4.1", - "targets": [ - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "direction": "backward", - "editorMode": "code", - "expr": "sum(count_over_time({service_name=\"aot-connector-platform\"} | logfmt | connector_id =~ \"$connector_type.*\" |~ \"(?i)completed successfully\" [$__range]))", - "legendFormat": "Success", - "queryType": "range", - "refId": "A" - }, - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "direction": "backward", - "editorMode": "code", - "expr": "sum(count_over_time({service_name=\"aot-connector-platform\"} |~ \"(?i)error\" [$__range]))", - "legendFormat": "Error", - "queryType": "range", - "refId": "B" - } - ], - "title": "Success vs Error Rate", - "type": "piechart" - }, - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "description": "Search logs using Ctrl+F or click on log lines for details", - "fieldConfig": { - "defaults": {}, - "overrides": [] - }, - "gridPos": { - "h": 12, - "w": 24, - "x": 0, - "y": 6 - }, - "id": 3, - "options": { - "dedupStrategy": "none", - "enableInfiniteScrolling": true, - "enableLogDetails": true, - "prettifyLogMessage": true, - "showControls": true, - "showLabels": true, - "showTime": true, - "sortOrder": "Descending", - "syntaxHighlighting": true, - "unwrappedColumns": false, - "wrapLogMessage": true - }, - "pluginVersion": "12.4.1", - "targets": [ - { - "datasource": { - "type": "loki", - "uid": "loki" - }, - "direction": "backward", - "editorMode": "code", - "expr": "{service_name=\"aot-connector-platform\"} | logfmt | connector_id =~ \"$connector_type.*\"", - "queryType": "range", - "refId": "A" - } - ], - "title": "All Connector Logs", - "type": "logs" - } - ], - "preload": false, - "refresh": "30s", - "schemaVersion": 42, - "tags": [ - "connectors", - "logs", - "monitoring" - ], - "templating": { - "list": [ - { - "allValue": ".*", - "allowCustomValue": false, - "current": { - "text": [ - "$__all" - ], - "value": [ - "$__all" - ] - }, - "includeAll": true, - "label": "Connector Type", - "multi": true, - "name": "connector_type", - "options": [], - "query": "fhir,google_drive", - "type": "custom", - "valuesFormat": "csv" - } - ] - }, - "time": { - "from": "now-30m", - "to": "now" - }, - "timepicker": {}, - "timezone": "browser", - "title": "Connector Logs & Status - Updated", - "uid": "connector_logs_fixed", - "version": 17, - "weekStart": "" -} \ No newline at end of file diff --git a/grafana/README.md b/grafana/README.md index d446754..4ebcae4 100644 --- a/grafana/README.md +++ b/grafana/README.md @@ -1,3 +1,9 @@ + + # Grafana Quick Guide ## What is included diff --git a/grafana/docker-compose.yml b/grafana/docker-compose.yml index 9fa73f9..2e2e20a 100644 --- a/grafana/docker-compose.yml +++ b/grafana/docker-compose.yml @@ -1,3 +1,7 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## services: otel-lgtm: image: grafana/otel-lgtm @@ -6,4 +10,4 @@ services: - "4317:4317" - "4318:4318" stdin_open: true - tty: true \ No newline at end of file + tty: true diff --git a/packages/connectors/fhir_cerner/__init__.py b/packages/connectors/fhir_cerner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/connectors/fhir_cerner/pyproject.toml b/packages/connectors/fhir_cerner/pyproject.toml new file mode 100644 index 0000000..53d715a --- /dev/null +++ b/packages/connectors/fhir_cerner/pyproject.toml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-fhir-cerner" +version = "0.1.0" +description = "Node Wire connector — Cerner FHIR R4 (read/search patients and encounters)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx[http2]>=0.27.0,<0.28.0", + "PyJWT[crypto]>=2.8.0", +] + +[project.entry-points."node_wire.connectors"] +fhir_cerner = "node_wire_fhir_cerner.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_fhir_cerner*"] diff --git a/packages/connectors/fhir_cerner/setup.py b/packages/connectors/fhir_cerner/setup.py new file mode 100644 index 0000000..0196aaa --- /dev/null +++ b/packages/connectors/fhir_cerner/setup.py @@ -0,0 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../src/node_wire_fhir_cerner") +) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/fhir_epic/__init__.py b/packages/connectors/fhir_epic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/connectors/fhir_epic/pyproject.toml b/packages/connectors/fhir_epic/pyproject.toml new file mode 100644 index 0000000..db0987b --- /dev/null +++ b/packages/connectors/fhir_epic/pyproject.toml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-fhir-epic" +version = "0.1.0" +description = "Node Wire connector — Epic FHIR R4 (read/search patients and encounters)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx[http2]>=0.27.0,<0.28.0", + "PyJWT[crypto]>=2.8.0", +] + +[project.entry-points."node_wire.connectors"] +fhir_epic = "node_wire_fhir_epic.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_fhir_epic*"] diff --git a/packages/connectors/fhir_epic/setup.py b/packages/connectors/fhir_epic/setup.py new file mode 100644 index 0000000..4ed6d53 --- /dev/null +++ b/packages/connectors/fhir_epic/setup.py @@ -0,0 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../src/node_wire_fhir_epic") +) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/google_drive/__init__.py b/packages/connectors/google_drive/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/connectors/google_drive/pyproject.toml b/packages/connectors/google_drive/pyproject.toml new file mode 100644 index 0000000..c356768 --- /dev/null +++ b/packages/connectors/google_drive/pyproject.toml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-google-drive" +version = "0.1.0" +description = "Node Wire connector — Google Drive API v3 (files and permissions)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "google-auth>=2.0.0", + "google-api-python-client>=2.100.0", +] + +[project.entry-points."node_wire.connectors"] +google_drive = "node_wire_google_drive.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_google_drive*"] diff --git a/packages/connectors/google_drive/setup.py b/packages/connectors/google_drive/setup.py new file mode 100644 index 0000000..077bbb6 --- /dev/null +++ b/packages/connectors/google_drive/setup.py @@ -0,0 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../src/node_wire_google_drive") +) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/http_generic/__init__.py b/packages/connectors/http_generic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/connectors/http_generic/pyproject.toml b/packages/connectors/http_generic/pyproject.toml new file mode 100644 index 0000000..0786852 --- /dev/null +++ b/packages/connectors/http_generic/pyproject.toml @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-http-generic" +version = "0.1.0" +description = "Node Wire connector — generic HTTP REST client (GET/POST/PUT/DELETE/PATCH)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx[http2]>=0.27.0,<0.28.0", +] + +[project.entry-points."node_wire.connectors"] +http_generic = "node_wire_http_generic.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_http_generic*"] diff --git a/packages/connectors/http_generic/setup.py b/packages/connectors/http_generic/setup.py new file mode 100644 index 0000000..e68e736 --- /dev/null +++ b/packages/connectors/http_generic/setup.py @@ -0,0 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../src/node_wire_http_generic") +) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/salesforce/pyproject.toml b/packages/connectors/salesforce/pyproject.toml new file mode 100644 index 0000000..d13034b --- /dev/null +++ b/packages/connectors/salesforce/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "node-wire-salesforce" +version = "0.1.0" +description = "Node Wire connector — Salesforce CRM" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx>=0.27.0", +] + +[project.entry-points."node_wire.connectors"] +salesforce = "node_wire_salesforce.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_salesforce*"] diff --git a/packages/connectors/salesforce/setup.py b/packages/connectors/salesforce/setup.py new file mode 100644 index 0000000..a6ba329 --- /dev/null +++ b/packages/connectors/salesforce/setup.py @@ -0,0 +1,21 @@ +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../src/node_wire_salesforce") +) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/slack/pyproject.toml b/packages/connectors/slack/pyproject.toml new file mode 100644 index 0000000..199859e --- /dev/null +++ b/packages/connectors/slack/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "node-wire-slack" +version = "0.1.0" +description = "Node Wire connector — Slack API (messaging and file uploads)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "httpx>=0.27.0", +] + +[project.entry-points."node_wire.connectors"] +slack = "node_wire_slack.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_slack*"] diff --git a/packages/connectors/slack/setup.py b/packages/connectors/slack/setup.py new file mode 100644 index 0000000..cbbd861 --- /dev/null +++ b/packages/connectors/slack/setup.py @@ -0,0 +1,19 @@ +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../src/node_wire_slack")) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/smtp/__init__.py b/packages/connectors/smtp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/connectors/smtp/pyproject.toml b/packages/connectors/smtp/pyproject.toml new file mode 100644 index 0000000..390557c --- /dev/null +++ b/packages/connectors/smtp/pyproject.toml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-smtp" +version = "0.1.0" +description = "Node Wire connector — SMTP email sending (async)" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "aiosmtplib>=3.0.1", + "email-validator>=2.0.0", +] + +[project.entry-points."node_wire.connectors"] +smtp = "node_wire_smtp.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_smtp*"] diff --git a/packages/connectors/smtp/setup.py b/packages/connectors/smtp/setup.py new file mode 100644 index 0000000..d5b7f65 --- /dev/null +++ b/packages/connectors/smtp/setup.py @@ -0,0 +1,23 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../src/node_wire_smtp")) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/connectors/stripe/__init__.py b/packages/connectors/stripe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/connectors/stripe/pyproject.toml b/packages/connectors/stripe/pyproject.toml new file mode 100644 index 0000000..d511635 --- /dev/null +++ b/packages/connectors/stripe/pyproject.toml @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-stripe" +version = "0.1.0" +description = "Node Wire connector — Stripe payments" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] + +dependencies = [ + "node-wire-runtime>=0.1.0", + "stripe>=10.0.0", +] + +[project.entry-points."node_wire.connectors"] +stripe = "node_wire_stripe.logic" + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../../src"] +include = ["node_wire_stripe*"] diff --git a/packages/connectors/stripe/setup.py b/packages/connectors/stripe/setup.py new file mode 100644 index 0000000..5657bfe --- /dev/null +++ b/packages/connectors/stripe/setup.py @@ -0,0 +1,23 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +import glob +import os +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../src/node_wire_stripe")) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize(py_files, compiler_directives={"language_level": "3"}, build_dir="build"), +) diff --git a/packages/runtime/pyproject.toml b/packages/runtime/pyproject.toml new file mode 100644 index 0000000..cf477ff --- /dev/null +++ b/packages/runtime/pyproject.toml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "node-wire-runtime" +version = "0.1.0" +description = "Node Wire runtime — connector framework, resilience, observability, and pluggable secrets" +requires-python = ">=3.11" +authors = [{ name = "AOT", email = "dev@aot.local" }] +readme = "README.md" + +dependencies = [ + "pydantic>=2.6.0,<3.0.0", + "tenacity>=8.2.0", + "pybreaker>=1.0.0", + "opentelemetry-api>=1.24.0", + "opentelemetry-sdk>=1.24.0", + "opentelemetry-exporter-otlp>=1.24.0", + "traceloop-sdk>=0.53.0", + "pyyaml>=6.0.1", + "python-dotenv>=1.0.0", +] + +[project.optional-dependencies] +# Cloud secret backends — install only what your deployment needs +aws = ["boto3>=1.34.0"] +vault = ["hvac>=2.1.0"] +azure = ["azure-keyvault-secrets>=4.8.0", "azure-identity>=1.16.0"] +gcp = ["google-cloud-secret-manager>=2.20.0"] + +[build-system] +requires = ["setuptools>=69.0.0", "cython>=3.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["../../src"] +include = ["node_wire_runtime*"] + +# Ship connectors.yaml.sample as a non-Python package resource +[tool.setuptools.package-data] +node_wire_runtime = ["connectors.yaml.sample"] diff --git a/packages/runtime/setup.py b/packages/runtime/setup.py new file mode 100644 index 0000000..15a8dcc --- /dev/null +++ b/packages/runtime/setup.py @@ -0,0 +1,50 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +Cython build for node-wire-runtime. + +Compiles all .py files to .so/.pyd extensions and overrides build_py +so that source .py files are NOT copied into the wheel — only the compiled +binary extensions are included. + +Build with: + python -m build --wheel --no-isolation + +Verify no .py files leaked: + unzip -l dist/node_wire_runtime-*.whl | grep '\.py$' +""" + +import glob +import os + +from Cython.Build import cythonize +from setuptools import setup +from setuptools.command.build_py import build_py as _BuildPy + + +class NoPyBuild(_BuildPy): + """Override that skips copying .py source files into the build tree. + + Setuptools would normally copy every .py file into the wheel alongside + the compiled extension. Returning [] here ensures the wheel contains + only .so/.pyd binaries. + """ + + def find_package_modules(self, package, package_dir): + return [] + + +src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src/node_wire_runtime")) +py_files = glob.glob(os.path.join(src_root, "**", "*.py"), recursive=True) + +setup( + cmdclass={"build_py": NoPyBuild}, + ext_modules=cythonize( + py_files, + compiler_directives={"language_level": "3"}, + build_dir="build", + annotate=False, + ), +) diff --git a/playground/README.md b/playground/README.md index 1ba1290..6cbbf49 100644 --- a/playground/README.md +++ b/playground/README.md @@ -1,3 +1,9 @@ + + # Node Wire Playground This folder contains a fully functional playground for **Node Wire**, showcasing how it orchestrates complex workflows across disparate systems like Electronic Health Records (EHR) and IT Service Management (ITSM) tools. @@ -9,7 +15,7 @@ The demo provides a modern, interactive web interface to trigger, monitor, and v ### Core Technologies - **Frontend**: Vanilla HTML5, CSS3 (Glassmorphism), and Javascript. - **Backend API**: FastAPI (Python) serving orchestration logic via `playground/scenarios.py`. -- **Connector Layer**: Integrated with `connectors` using the `fhir_epic`, `fhir_cerner`, and `http_generic` bindings. +- **Connector layer**: Uses Node Wire connectors (`fhir_epic`, `fhir_cerner`, `http_generic`, and others) via the platform REST API and `ConnectorFactory` (see [docs/connectors.md](docs/connectors.md)). --- @@ -62,9 +68,75 @@ This scenario demonstrates the platform's highest level of abstraction: an auton 1. **Autonomous Reasoning**: The agent parses user intent (e.g., "Get Nancy Smart's record and email it to her") using a Large Language Model (LLM). 2. **Dynamic Tool Selection**: Automatically selects and sequences tools from the **Node Wire MCP Server**, including Cerner FHIR, Google Drive, and SMTP. 3. **Guardrailed Execution**: Follows strict healthcare-specific guardrails, asking for missing patient IDs or confirmation before performing sensitive actions. - 4. **Real-time Interaction**: Provides a chat interface with live step-by-step visibility into the agent's thought process and tool execution. + 4. **Transport-aware Interaction**: Shows the active MCP transport in the chat panel and adjusts rendering behavior to match it. * **Implementation**: Leverages the `agents` module, providing a unified interface for LLMs to interact with any connector in the platform via a standard MCP bridge. +### Scenario 6: External Patient Viewer (Read-Only Retrieval) +This scenario loads a source EHR chart on demand for target viewer workflows without duplicating chart data or creating new FHIR resources. + +* **Logic Flow**: + 1. **Patient Resolution**: Uses a direct FHIR Patient ID when available, or resolves identity with given name, family name, and optional birthdate. + 2. **Demographics Retrieval**: Calls `read_patient` against the selected source EHR and displays the resolved patient identity. + 3. **Encounter Retrieval**: Calls `search_encounter` for the resolved patient with a configurable result limit. + 4. **Document Metadata Retrieval**: Calls `search_document_reference` for available document metadata. When no `DocumentReference` records are returned, the workflow presents encounters as lightweight fallback document rows. + 5. **Chart Assembly**: Produces a unified external chart view containing demographics, encounters, documents, source system, trace ID, and read-only status. +* **Implementation**: Uses the existing Epic and Cerner FHIR connectors through `playground/scenarios.py` and the input schema in `playground/ext_patient_viewer/schema.py`. The workflow calls only read/search actions and reports `0 Writes` in the UI. +* **Endpoint**: `POST /scenarios/external-patient-viewer` +* **Supported Sources**: Epic FHIR R4 and Cerner FHIR R4. + +#### MCP transport behavior in the playground + +The Agentic Workflow panel displays the active transport as a pill: + +- `Transport: stdio`: the browser calls `/scenarios/agent-chat`. The UI shows the loader while the backend agent completes, then renders tool cards and the final response together. +- `Transport: Streamable HTTP`: the browser calls `/scenarios/agent-chat-stream`. Tool cards appear as each MCP tool finishes, and the final answer is appended progressively as streamed chunks arrive. + +Set the mode before starting the REST API: + +```powershell +# Buffered stdio mode +$env:NW_MCP_TRANSPORT="stdio" +uv run node-wire +``` + +```powershell +# Streamable HTTP mode +$env:NW_MCP_TRANSPORT="streamable-http" +$env:NW_MCP_HOST="127.0.0.1" +$env:NW_MCP_PORT="8081" +$env:NW_MCP_PATH="/mcp" +uv run node-wire +``` + +After changing `NW_MCP_TRANSPORT`, restart the backend and hard refresh the browser so the latest `app.js` and transport status are loaded. + +#### Testing the MCP server with Inspector + +Use MCP Inspector to validate tools outside the playground: + +```powershell +npx @modelcontextprotocol/inspector +``` + +For stdio inspection: + +```powershell +$env:NW_MCP_TRANSPORT="stdio" +npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint +``` + +For streamable HTTP inspection, start the MCP server first: + +```powershell +$env:NW_MCP_TRANSPORT="streamable-http" +$env:NW_MCP_HOST="127.0.0.1" +$env:NW_MCP_PORT="8081" +$env:NW_MCP_PATH="/mcp" +python -m agents.mcp_entrypoint +``` + +Then open Inspector, select `Streamable HTTP`, connect to `http://127.0.0.1:8081/mcp`, run `List Tools`, and call a safe tool with valid JSON arguments. + --- ## 🛠️ Advanced Platform Features @@ -90,15 +162,15 @@ The demo is pre-configured with mock/sandbox endpoints for immediate use. To tes ### Testing Real Epic/Cerner (EHR) 1. **Update Config**: Modify `config/connectors.yaml` to point to a real Epic/Cerner Sandbox or Production URL. 2. **Auth**: Ensure you have valid `CLIENT_ID` and `PRIVATE_KEY` for the EHR's Backend System OAuth2 flow (SMART on FHIR). -3. **Data**: Use real Patient IDs and Encounter IDs from your target environment. +3. **Data**: Use real Patient IDs and Encounter IDs from your target environment. - **Cerner Note**: Ensure you use numeric Practitioner IDs (e.g., `593923`) and valid CodeSet 72 codes. ### Testing Google Drive Vault (Manual End-to-End) To test the Google Drive integration manually, follow these specialized setup steps: 1. **Service Account**: Create a Service Account in the Google Cloud Console with the **Google Drive API** enabled. Download the JSON key. 2. **Secret Configuration**: - * Place the JSON key file in your project directory (e.g., `D:\connector-platform\service_account.json`). - * Update your `.env` file: `GOOGLE_DRIVE_SA_JSON=D:\connector-platform\service_account.json`. + * Place the JSON key file somewhere safe on your machine (e.g., `/path/to/service_account.json`). + * Update your `.env` file: `GOOGLE_DRIVE_SA_JSON=/path/to/service_account.json`. * *Note: The platform now supports direct file paths for easier local configuration.* 3. **Permissions**: If using a specific **Vault Folder ID**, ensure that folder is shared with the Service Account's email address (found in the JSON) with "Editor" or "Manager" permissions. 4. **Workflow Verification**: @@ -110,8 +182,9 @@ To test the Google Drive integration manually, follow these specialized setup st To enable the AI Agent chat, you need to configure an LLM provider: 1. **Select Provider**: Set `LLM_PROVIDER` to `groq` (default) or `openai` in your `.env`. 2. **Add API Key**: Provide the corresponding key, e.g., `GROQ_API_KEY=your_key_here`. -3. **SMTP Setup**: (Optional) Add SMTP credentials (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`) to enable the agent to send emails. -4. **MCP URL**: (Optional) If running the MCP server in a separate container, set `TOOLHIVE_MCP_URL` to point to the MCP proxy. +3. **SMTP Setup**: (Optional) Add SMTP credentials (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD`) to enable the agent to send emails. +4. **MCP URL**: In `streamable-http` mode, set `TOOLHIVE_MCP_URL` or `TOOLHIVE_MCP_URLS` to the HTTP MCP endpoint(s). In `stdio` mode, the playground ignores those URLs and uses local stdio. +5. **Allowed Connectors**: Ensure `NW_ALLOWED_CONNECTORS` in your `.env` includes the connectors used by the agent (e.g. `fhir_cerner,google_drive,smtp`). --- @@ -119,8 +192,14 @@ To enable the AI Agent chat, you need to configure an LLM provider: 1. Navigate to the project root. 2. Start the FastAPI server: - ```bash - set MODE=API&& python -m bindings_entrypoint - ``` + +```bash +# Using uv (recommended) +uv run node-wire + +# Using python +python -m bindings_entrypoint +``` + 3. Open your browser to `http://localhost:8000/playground/` (or the configured port). -4. Switch between **EHR**, **IT Ops**, **Cerner**, **Google Drive Vault**, and **AI Agent** tabs to explore the different workflows. +4. Switch between **EHR**, **IT Ops**, **Cerner**, **Google Drive Vault**, **AI Agent**, **Slack**, **Stripe**, **Salesforce** and **External Patient Viewer** cards to explore the different workflows. diff --git a/playground/__init__.py b/playground/__init__.py index e69de29..39bdade 100644 --- a/playground/__init__.py +++ b/playground/__init__.py @@ -0,0 +1,4 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/playground/app.js b/playground/app.js index 6be74be..b2c4952 100644 --- a/playground/app.js +++ b/playground/app.js @@ -1,3 +1,7 @@ +// SPDX-FileCopyrightText: 2026 AOT Technologies +// +// SPDX-License-Identifier: Apache-2.0 + document.addEventListener('DOMContentLoaded', () => { const escapeHTML = (str) => { if (str == null) return ''; @@ -45,13 +49,68 @@ document.addEventListener('DOMContentLoaded', () => { const gdriveUploadOnly = document.getElementById('gdrive-upload-only'); const gdriveListOnly = document.getElementById('gdrive-list-only'); const gdriveSubNav = document.getElementById('gdrive-sub-nav'); + const slackForm = document.getElementById('slack-form'); + const slackRunBtn = document.getElementById('slack-run-btn'); + const slackSpinner = slackRunBtn?.querySelector('.loading-spinner'); + const slackBtnText = slackRunBtn?.querySelector('.btn-lbl'); + const slackPanel = document.getElementById('slack-panel'); + const slackActionSelect = document.getElementById('slack-action-select'); + const slackMessageSection = document.getElementById('slack-message-section'); + const slackFileSection = document.getElementById('slack-file-section'); + const slackFileInput = document.getElementById('slack-file'); + const slackFileDropZone = document.getElementById('slack-file-drop-zone'); + const slackFileChosenPreview = document.getElementById('slack-file-chosen-preview'); + const slackPreviewName = slackFileChosenPreview?.querySelector('.preview-name'); + const slackRemoveFileBtn = slackFileChosenPreview?.querySelector('.remove-file-btn'); const fileDropZone = document.getElementById('file-drop-zone'); const fileChosenPreview = document.getElementById('file-chosen-preview'); const previewName = fileChosenPreview?.querySelector('.preview-name'); const removeFileBtn = fileChosenPreview?.querySelector('.remove-file-btn'); + const stripeForm = document.getElementById('stripe-form'); + const stripeRunBtn = document.getElementById('stripe-run-btn'); + const stripeSpinner = stripeRunBtn.querySelector('.loading-spinner'); + const stripeBtnText = stripeRunBtn.querySelector('.btn-lbl'); + const stripePanel = document.getElementById('stripe-panel'); + + const stripeActionSelect = document.getElementById('stripe-action-select'); + const stripeSections = { + charge: document.getElementById('stripe-section-charge'), + payment_intent: document.getElementById('stripe-section-pi'), + subscription: document.getElementById('stripe-section-sub'), + cancel_subscription: document.getElementById('stripe-section-cancel'), + refund: document.getElementById('stripe-section-refund') + }; + + const salesforceForm = document.getElementById('salesforce-form'); + const salesforceRunBtn = document.getElementById('salesforce-run-btn'); + const salesforceSpinner = salesforceRunBtn.querySelector('.loading-spinner'); + const salesforceBtnText = salesforceRunBtn.querySelector('.btn-lbl'); + const salesforcePanel = document.getElementById('salesforce-panel'); + const salesforceActionSelect = document.getElementById('salesforce-action-select'); + const salesforceSections = { + create_lead: document.getElementById('salesforce-section-lead'), + update_lead: document.getElementById('salesforce-section-lead'), + create_contact: document.getElementById('salesforce-section-contact'), + update_contact: document.getElementById('salesforce-section-contact'), + read_lead: document.getElementById('salesforce-section-id-only'), + delete_lead: document.getElementById('salesforce-section-id-only'), + read_contact: document.getElementById('salesforce-section-id-only'), + delete_contact: document.getElementById('salesforce-section-id-only') + }; + + // External Patient Viewer + const extViewerForm = document.getElementById('ext-viewer-form'); + const extViewerRunBtn = document.getElementById('ext-viewer-run-btn'); + const extViewerSpinner = extViewerRunBtn ? extViewerRunBtn.querySelector('.loading-spinner') : null; + const extViewerBtnText = extViewerRunBtn ? extViewerRunBtn.querySelector('.btn-lbl') : null; + const extViewerPanel = document.getElementById('ext-patient-viewer-panel'); + let currentSubMode = 'file'; + let currentStripeSubMode = 'charge'; + let currentSalesforceSubMode = 'create_lead'; const connectorStatus = document.getElementById('connector-status'); + const brandLabel = document.querySelector('.brand-text h1 span.accent'); const tagline = document.querySelector('.tagline'); const layoutMain = document.querySelector('.layout-main'); @@ -63,8 +122,10 @@ document.addEventListener('DOMContentLoaded', () => { const agentInput = document.getElementById('agent-input'); const agentSendBtn = document.getElementById('agent-send-btn'); const agentTyping = document.getElementById('agent-typing'); + const agentTransportStatus = document.getElementById('agent-transport-status'); let agentConversationHistory = []; let agentBusy = false; + let agentTransportMode = 'stdio'; const pipelineLabels = { ehr: [ @@ -102,7 +163,69 @@ document.addEventListener('DOMContentLoaded', () => { "Apply file update", "Verify file metadata", "Complete update" - ] + ], + slack: [ + "Format Slack Payload", + "Dispatch to Slack API", + "Verify Acknowledgment", + "Update Audit Trail", + ], + stripe_charge: [ + "Initialize Payment", + "Process Charge", + "Verify Transaction", + ], + stripe_payment_intent: [ + "Initialize Session", + "Create Payment Intent", + "Verify Allocation", + ], + stripe_subscription: [ + "Validate Customer", + "Create Subscription", + "Verify Provisioning", + ], + stripe_cancel_subscription: [ + "Locate Resource", + "Cancel Subscription", + "Verify Termination", + ], + stripe_refund: [ + "Validate Charge", + "Process Refund", + "Verify Refund" + ], + salesforce_create_lead: [ + "Initialize CRM Sync", + "Create Lead Record", + "Verify Lead Status" + ], + salesforce_create_contact: [ + "Initialize CRM Sync", + "Create Contact Record", + "Verify Contact Status" + ], + salesforce_read: [ + "Authenticate CRM", + "Fetch Record Metadata", + "Verify Data Integrity" + ], + salesforce_update: [ + "Authenticate CRM", + "Apply Partial Update", + "Verify State Change" + ], + salesforce_delete: [ + "Authenticate CRM", + "Execute Soft Delete", + "Verify Termination" + ], + ext_patient_viewer: [ + "Resolve Patient Identity", + "Retrieve Encounter History", + "Retrieve Document Metadata", + "Assemble External Chart View" + ], }; const nodes = [ @@ -168,7 +291,7 @@ document.addEventListener('DOMContentLoaded', () => { } const rootSelectionView = document.getElementById('root-selection-view'); - const selectionCards = document.querySelectorAll('.selection-card'); + const selectionCards = document.querySelectorAll('.selection-card, .app-card'); const rootTabContainer = document.querySelector('.root-tab-container'); const backToHomeBtn = document.getElementById('back-to-home'); @@ -188,28 +311,79 @@ document.addEventListener('DOMContentLoaded', () => { selectionCards.forEach(card => { card.addEventListener('click', () => { const view = card.dataset.target; - rootSelectionView.classList.add('hidden'); - layoutMain.classList.remove('hidden'); - headerActions.classList.remove('hidden'); - + if (view === 'agent') { + rootSelectionView.classList.add('hidden'); + layoutMain.classList.remove('hidden'); + headerActions.classList.remove('hidden'); agentPanel.classList.remove('hidden'); connectorsView.classList.add('hidden'); layoutMain.classList.add('agent-mode'); connectorStatus.textContent = 'AI Agent Online'; tagline.textContent = 'Autonomous Healthcare Assistant'; document.documentElement.style.setProperty('--brand-accent', '#8b5cf6'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Workspace + `; + } log('Switched to AI Agent mode (MCP + LLM)', 'system'); + } else if (view === 'connector-apps-menu') { + rootSelectionView.classList.add('hidden'); + document.getElementById('connector-apps-selection-view').classList.remove('hidden'); + layoutMain.classList.add('hidden'); + headerActions.classList.remove('hidden'); + connectorStatus.textContent = 'Apps Marketplace'; + tagline.textContent = 'Ready-to-use experiences built on top of connectors'; + document.documentElement.style.setProperty('--brand-accent', '#0d9488'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Workspace + `; + } + log('Opened Connector Apps menu', 'system'); + } else if (view === 'ext-patient-viewer') { + document.getElementById('connector-apps-selection-view').classList.add('hidden'); + rootSelectionView.classList.add('hidden'); + layoutMain.classList.remove('hidden'); + headerActions.classList.remove('hidden'); + agentPanel.classList.add('hidden'); + connectorsView.classList.remove('hidden'); + layoutMain.classList.remove('agent-mode'); + connectorsListPanel.classList.add('hidden'); + playgroundView.classList.remove('hidden'); + if (backToConnectorsBtn) backToConnectorsBtn.classList.add('hidden'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Apps + `; + } + setMode('ext-patient-viewer'); } else { + rootSelectionView.classList.add('hidden'); + layoutMain.classList.remove('hidden'); + headerActions.classList.remove('hidden'); agentPanel.classList.add('hidden'); connectorsView.classList.remove('hidden'); layoutMain.classList.remove('agent-mode'); connectorsListPanel.classList.remove('hidden'); playgroundView.classList.add('hidden'); - connectorStatus.textContent = 'Connectors Ready'; tagline.textContent = 'Enterprise Integration Suite'; document.documentElement.style.setProperty('--brand-accent', '#2563eb'); + if (backSelectionBtn) { + backSelectionBtn.classList.remove('hidden'); + backSelectionBtn.innerHTML = ` + + Back to Workspace + `; + } log('Switched to Connectors view', 'system'); } }); @@ -218,13 +392,38 @@ document.addEventListener('DOMContentLoaded', () => { const returnToHome = (e) => { if (e) e.preventDefault(); rootSelectionView.classList.remove('hidden'); + document.getElementById('connector-apps-selection-view').classList.add('hidden'); layoutMain.classList.add('hidden'); headerActions.classList.add('hidden'); + tagline.textContent = 'Autonomous Connector Orchestration Platform'; log('Returned to main selection screen', 'system'); }; backToHomeBtn.addEventListener('click', returnToHome); - if (backSelectionBtn) backSelectionBtn.addEventListener('click', returnToHome); + + if (backSelectionBtn) { + backSelectionBtn.addEventListener('click', (e) => { + if (e) e.preventDefault(); + const btnText = backSelectionBtn.textContent.trim(); + if (btnText.includes('Back to Apps')) { + // Return from patient viewer to apps marketplace directory + layoutMain.classList.add('hidden'); + document.getElementById('connector-apps-selection-view').classList.remove('hidden'); + connectorStatus.textContent = 'Apps Marketplace'; + tagline.textContent = 'Ready-to-use experiences built on top of connectors'; + document.documentElement.style.setProperty('--brand-accent', '#0d9488'); + log('Returned to Apps Marketplace directory', 'system'); + } else { + // Return from other direct pages to workspace selection home + returnToHome(e); + } + }); + } + + const appsBackBtn = document.getElementById('apps-back-btn'); + if (appsBackBtn) { + appsBackBtn.addEventListener('click', returnToHome); + } const newChatBtn = document.getElementById('new-chat-btn'); newChatBtn.addEventListener('click', () => { @@ -329,16 +528,98 @@ document.addEventListener('DOMContentLoaded', () => { } } + function stripePipelineLabelOverride() { + if (currentStripeSubMode === 'charge') return pipelineLabels.stripe_charge; + if (currentStripeSubMode === 'payment_intent') return pipelineLabels.stripe_payment_intent; + if (currentStripeSubMode === 'subscription') return pipelineLabels.stripe_subscription; + if (currentStripeSubMode === 'cancel_subscription') return pipelineLabels.stripe_cancel_subscription; + if (currentStripeSubMode === 'refund') return pipelineLabels.stripe_refund; + return pipelineLabels.stripe_charge; + } + + function salesforcePipelineLabelOverride() { + if (currentSalesforceSubMode.startsWith('create')) return pipelineLabels.salesforce_create_lead; + if (currentSalesforceSubMode.startsWith('read')) return pipelineLabels.salesforce_read; + if (currentSalesforceSubMode.startsWith('update')) return pipelineLabels.salesforce_update; + if (currentSalesforceSubMode.startsWith('delete')) return pipelineLabels.salesforce_delete; + return pipelineLabels.salesforce_create_lead; + } + + function syncSalesforceActionForm() { + Object.values(salesforceSections).forEach(sec => { + if (sec) sec.classList.add('hidden'); + }); + const activeSec = salesforceSections[currentSalesforceSubMode] || salesforceSections['create_lead']; + if (activeSec) activeSec.classList.remove('hidden'); + + // Handle record ID field visibility in Lead/Contact sections + const idFields = document.querySelectorAll('#salesforce-form .id-field'); + idFields.forEach(f => { + if (currentSalesforceSubMode.startsWith('update')) { + f.classList.remove('hidden'); + } else { + f.classList.add('hidden'); + } + }); + + // Handle generic ID label for read/delete + const idLabel = document.getElementById('sf-resource-id-label'); + if (idLabel) { + if (currentSalesforceSubMode.includes('lead')) { + idLabel.textContent = 'Lead Record ID'; + } else { + idLabel.textContent = 'Contact Record ID'; + } + } + + if (salesforceActionSelect) { + salesforceActionSelect.value = currentSalesforceSubMode; + } + } + + + + function syncStripeActionForm() { + Object.values(stripeSections).forEach(sec => { + if (sec) sec.classList.add('hidden'); + }); + const activeSec = stripeSections[currentStripeSubMode] || stripeSections['charge']; + if (activeSec) activeSec.classList.remove('hidden'); + + if (stripeActionSelect) { + stripeActionSelect.value = currentStripeSubMode; + } + } + function setMode(mode) { currentMode = mode; - + + if (backToConnectorsBtn) { + if (mode === 'ext-patient-viewer') { + backToConnectorsBtn.innerHTML = ` + + Back to Workspace + `; + } else { + backToConnectorsBtn.innerHTML = ` + + Back to All Connectors + `; + } + } + // Hide all panels first ehrPanel.classList.add('hidden'); itopsPanel.classList.add('hidden'); cernerPanel.classList.add('hidden'); gdrivePanel.classList.add('hidden'); + stripePanel.classList.add('hidden'); + salesforcePanel.classList.add('hidden'); + if (slackPanel) slackPanel.classList.add('hidden'); + if (extViewerPanel) extViewerPanel.classList.add('hidden'); if (mode === 'ehr') { + ehrPanel.classList.remove('hidden'); connectorStatus.textContent = 'Epic R4 Online'; tagline.textContent = 'Enterprise EHR Orchestration'; @@ -362,13 +643,46 @@ document.addEventListener('DOMContentLoaded', () => { tagline.textContent = 'Secure Vault Orchestration'; document.documentElement.style.setProperty('--brand-accent', '#10b981'); log('Switched to Secure Document Archival mode (Google Drive)', 'system'); + } else if (mode === 'stripe') { + stripePanel.classList.remove('hidden'); + connectorStatus.textContent = 'Stripe Online'; + tagline.textContent = 'Financial Infrastructure'; + document.documentElement.style.setProperty('--brand-accent', '#635bff'); + log('Switched to Stripe Payment Orchestration mode', 'system'); + } else if (mode === 'salesforce') { + salesforcePanel.classList.remove('hidden'); + connectorStatus.textContent = 'Salesforce Online'; + tagline.textContent = 'CRM Orchestration'; + document.documentElement.style.setProperty('--brand-accent', '#00A1E0'); + log('Switched to Salesforce CRM Orchestration mode', 'system'); + } else if (mode === 'slack') { + if (slackPanel) slackPanel.classList.remove('hidden'); + connectorStatus.textContent = 'Slack Online'; + tagline.textContent = 'Team Collaboration & Notifications'; + document.documentElement.style.setProperty('--brand-accent', '#4A154B'); + log('Switched to Slack Operations mode', 'system'); + } else if (mode === 'ext-patient-viewer') { + if (extViewerPanel) extViewerPanel.classList.remove('hidden'); + connectorStatus.textContent = 'EHR Source ─ Read-Only'; + tagline.textContent = 'External Chart Viewer'; + document.documentElement.style.setProperty('--brand-accent', '#0d9488'); + log('Switched to External Patient Viewer mode (read-only)', 'system'); } if (mode === 'gdrive') { syncGdriveActionForm(); resetUI(gdrivePipelineLabelOverride()); + } else if (mode === 'stripe') { + syncStripeActionForm(); + resetUI(stripePipelineLabelOverride()); + } else if (mode === 'salesforce') { + syncSalesforceActionForm(); + resetUI(salesforcePipelineLabelOverride()); + } else if (mode === 'ext-patient-viewer') { + resetUI(pipelineLabels.ext_patient_viewer); } else { resetUI(); } + } // Root Tab Switching (MCP Orchestration vs Connectors) @@ -392,7 +706,7 @@ document.addEventListener('DOMContentLoaded', () => { // By default show the list if we just switched to connectors tab connectorsListPanel.classList.remove('hidden'); playgroundView.classList.add('hidden'); - + connectorStatus.textContent = 'Connectors Ready'; tagline.textContent = 'Enterprise Integration Suite'; document.documentElement.style.setProperty('--brand-accent', '#2563eb'); @@ -408,19 +722,24 @@ document.addEventListener('DOMContentLoaded', () => { connectorsListPanel.classList.add('hidden'); playgroundView.classList.remove('hidden'); if (backSelectionBtn) backSelectionBtn.classList.add('hidden'); + if (backToConnectorsBtn) backToConnectorsBtn.classList.remove('hidden'); setMode(mode); }); }); // Back to Connectors List backToConnectorsBtn.addEventListener('click', () => { - playgroundView.classList.add('hidden'); - connectorsListPanel.classList.remove('hidden'); - if (backSelectionBtn) backSelectionBtn.classList.remove('hidden'); - connectorStatus.textContent = 'Connectors Ready'; - tagline.textContent = 'Enterprise Integration Suite'; - document.documentElement.style.setProperty('--brand-accent', '#2563eb'); - log('Returned to Connectors list', 'system'); + if (currentMode === 'ext-patient-viewer') { + returnToHome(); + } else { + playgroundView.classList.add('hidden'); + connectorsListPanel.classList.remove('hidden'); + if (backSelectionBtn) backSelectionBtn.classList.remove('hidden'); + connectorStatus.textContent = 'Connectors Ready'; + tagline.textContent = 'Enterprise Integration Suite'; + document.documentElement.style.setProperty('--brand-accent', '#2563eb'); + log('Returned to Connectors list', 'system'); + } }); // Google Drive Sub-mode Switching @@ -517,7 +836,7 @@ document.addEventListener('DOMContentLoaded', () => { if (step.data.beautiful_data && node.querySelector('.beautiful-response')) { const bData = step.data.beautiful_data; const bDiv = node.querySelector('.beautiful-response'); - + bDiv.innerHTML = `
@@ -559,7 +878,7 @@ document.addEventListener('DOMContentLoaded', () => {
`; - + if (step.data.raw) { responseDiv.textContent = JSON.stringify(step.data.raw, null, 2); responseBtn.classList.remove('hidden'); @@ -643,6 +962,138 @@ document.addEventListener('DOMContentLoaded', () => { await handleSubmission(payload, '/scenarios/cerner-post-consultation', cernerRunBtn, cernerBtnText, cernerSpinner, 'Sync to Cerner Chart'); }); + stripeForm.addEventListener('submit', async (e) => { + e.preventDefault(); + const formData = new FormData(stripeForm); + const payload = Object.fromEntries(formData.entries()); + + let endpoint = '/scenarios/stripe-charge'; + let submitPayload = {}; + + if (currentStripeSubMode === 'charge' || !currentStripeSubMode) { + submitPayload = { + amount: parseInt(payload.charge_amount, 10), + currency: payload.charge_currency, + description: payload.charge_description + }; + endpoint = '/scenarios/stripe-charge'; + } else if (currentStripeSubMode === 'payment_intent') { + submitPayload = { + amount: parseInt(payload.pi_amount, 10), + currency: payload.pi_currency, + customer_id: payload.pi_customer || undefined, + payment_method: payload.pi_payment_method || undefined, + confirm: payload.pi_confirm === 'on' + }; + endpoint = '/scenarios/stripe-payment-intent'; + } else if (currentStripeSubMode === 'subscription') { + submitPayload = { + customer_id: payload.sub_customer, + price_id: payload.sub_price, + card_token: payload.sub_token || undefined + }; + endpoint = '/scenarios/stripe-subscription'; + } else if (currentStripeSubMode === 'cancel_subscription') { + submitPayload = { + subscription_id: payload.cancel_sub_id + }; + endpoint = '/scenarios/stripe-cancel-subscription'; + } else if (currentStripeSubMode === 'refund') { + const isPI = payload.refund_target_id.startsWith('pi_'); + submitPayload = { + charge_id: !isPI && payload.refund_target_id ? payload.refund_target_id : undefined, + payment_intent_id: isPI ? payload.refund_target_id : undefined, + amount: payload.refund_amount ? parseInt(payload.refund_amount, 10) : undefined + }; + endpoint = '/scenarios/stripe-refund'; + } + + await handleSubmission(submitPayload, endpoint, stripeRunBtn, stripeBtnText, stripeSpinner, 'Process Action'); + }); + + salesforceForm.addEventListener('submit', async (e) => { + e.preventDefault(); + const formData = new FormData(salesforceForm); + const payload = Object.fromEntries(formData.entries()); + + let endpoint = '/scenarios/salesforce-create-lead'; + let submitPayload = {}; + + if (currentSalesforceSubMode === 'create_lead') { + submitPayload = { + first_name: payload.lead_first_name || undefined, + last_name: payload.lead_last_name, + company: payload.lead_company, + email: payload.lead_email || undefined + }; + endpoint = '/scenarios/salesforce-create-lead'; + } else if (currentSalesforceSubMode === 'update_lead') { + submitPayload = { + record_id: payload.lead_id, + first_name: payload.lead_first_name || undefined, + last_name: payload.lead_last_name || undefined, + company: payload.lead_company || undefined, + email: payload.lead_email || undefined + }; + endpoint = '/scenarios/salesforce-update-lead'; + } else if (currentSalesforceSubMode === 'read_lead') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-read-lead'; + } else if (currentSalesforceSubMode === 'delete_lead') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-delete-lead'; + } else if (currentSalesforceSubMode === 'create_contact') { + submitPayload = { + first_name: payload.contact_first_name || undefined, + last_name: payload.contact_last_name, + email: payload.contact_email || undefined, + account_id: payload.contact_account_id || undefined + }; + endpoint = '/scenarios/salesforce-create-contact'; + } else if (currentSalesforceSubMode === 'update_contact') { + submitPayload = { + record_id: payload.contact_id, + first_name: payload.contact_first_name || undefined, + last_name: payload.contact_last_name || undefined, + email: payload.contact_email || undefined, + account_id: payload.contact_account_id || undefined + }; + endpoint = '/scenarios/salesforce-update-contact'; + } else if (currentSalesforceSubMode === 'read_contact') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-read-contact'; + } else if (currentSalesforceSubMode === 'delete_contact') { + submitPayload = { record_id: payload.generic_record_id }; + endpoint = '/scenarios/salesforce-delete-contact'; + } + + await handleSubmission(submitPayload, endpoint, salesforceRunBtn, salesforceBtnText, salesforceSpinner, 'Execute Action'); + }); + + + if (salesforceActionSelect) { + salesforceActionSelect.addEventListener('change', (e) => { + const mode = e.target.value; + if (mode === currentSalesforceSubMode) return; + currentSalesforceSubMode = mode; + syncSalesforceActionForm(); + resetUI(salesforcePipelineLabelOverride()); + log(`Switched to Salesforce mode [${currentSalesforceSubMode}]`); + }); + } + + + if (stripeActionSelect) { + stripeActionSelect.addEventListener('change', (e) => { + const mode = e.target.value; + if (mode === currentStripeSubMode) return; + currentStripeSubMode = mode; + syncStripeActionForm(); + resetUI(stripePipelineLabelOverride()); + log(`Switched to Stripe mode [${currentStripeSubMode}]`); + }); + } + // File Preview Logic if (gdriveFileInput && fileChosenPreview && previewName && fileDropZone) { gdriveFileInput.addEventListener('change', () => { @@ -691,7 +1142,7 @@ document.addEventListener('DOMContentLoaded', () => { e.preventDefault(); const formData = new FormData(gdriveForm); const payload = Object.fromEntries(formData.entries()); - + const fileInput = document.getElementById('gdrive-file'); if (payload.action === 'files.list') { @@ -760,22 +1211,22 @@ document.addEventListener('DOMContentLoaded', () => { if (currentSubMode === 'file' && fileInput.files.length > 0) { const file = fileInput.files[0]; const reader = new FileReader(); - + // Re-use the UI update logic outside to show "Encrypting" immediately resetUI(); gdriveRunBtn.disabled = true; gdriveSpinner.classList.remove('hidden'); gdriveBtnText.textContent = 'Encrypting File...'; - + reader.onload = async (event) => { try { const base64Data = event.target.result.split(',')[1]; payload.file_base64 = base64Data; payload.file_mime_type = file.type || 'application/octet-stream'; - + // Auto-update document name to the real file name if sending a binary file payload.document_name = file.name; - + await handleSubmission(payload, '/scenarios/gdrive-archival', gdriveRunBtn, gdriveBtnText, gdriveSpinner, 'Encrypt & Archive'); } catch (error) { log(`File parsing error: ${error.message}`, 'error'); @@ -784,14 +1235,14 @@ document.addEventListener('DOMContentLoaded', () => { gdriveSpinner.classList.add('hidden'); } }; - + reader.onerror = () => { log('Failed to read binary file from memory.', 'error'); gdriveBtnText.textContent = 'System Error'; gdriveRunBtn.disabled = false; gdriveSpinner.classList.add('hidden'); }; - + reader.readAsDataURL(file); } else { // Standard text submission @@ -799,6 +1250,107 @@ document.addEventListener('DOMContentLoaded', () => { } }); + if (slackActionSelect) { + slackActionSelect.addEventListener('change', () => { + const action = slackActionSelect.value; + if (action === 'upload_file') { + if (slackMessageSection) slackMessageSection.classList.add('hidden'); + if (slackFileSection) slackFileSection.classList.remove('hidden'); + } else { + if (slackMessageSection) slackMessageSection.classList.remove('hidden'); + if (slackFileSection) slackFileSection.classList.add('hidden'); + } + }); + } + + if (slackFileInput && slackFileChosenPreview && slackPreviewName && slackFileDropZone) { + slackFileInput.addEventListener('change', () => { + if (slackFileInput.files.length > 0) { + const fileName = slackFileInput.files[0].name; + slackPreviewName.textContent = fileName; + slackFileChosenPreview.classList.remove('hidden'); + slackFileDropZone.classList.add('hidden'); + } + }); + } + + if (slackRemoveFileBtn && slackFileInput && slackFileChosenPreview && slackFileDropZone) { + slackRemoveFileBtn.addEventListener('click', (e) => { + e.stopPropagation(); + slackFileInput.value = ''; + slackFileChosenPreview.classList.add('hidden'); + slackFileDropZone.classList.remove('hidden'); + }); + } + + if (slackFileDropZone) { + slackFileDropZone.addEventListener('dragover', (e) => { + e.preventDefault(); + slackFileDropZone.style.borderColor = 'var(--brand-accent)'; + slackFileDropZone.style.background = 'rgba(255, 255, 255, 0.08)'; + }); + + slackFileDropZone.addEventListener('dragleave', () => { + slackFileDropZone.style.borderColor = ''; + slackFileDropZone.style.background = ''; + }); + + slackFileDropZone.addEventListener('drop', (e) => { + e.preventDefault(); + slackFileDropZone.style.borderColor = ''; + slackFileDropZone.style.background = ''; + if (slackFileInput && e.dataTransfer.files.length > 0) { + slackFileInput.files = e.dataTransfer.files; + slackFileInput.dispatchEvent(new Event('change')); + } + }); + } + + if (slackForm) { + slackForm.addEventListener('submit', async (e) => { + e.preventDefault(); + const formData = new FormData(slackForm); + const payload = Object.fromEntries(formData.entries()); + + if (payload.action === 'upload_file' && slackFileInput && slackFileInput.files.length > 0) { + const file = slackFileInput.files[0]; + const reader = new FileReader(); + + resetUI(); + if (slackRunBtn) slackRunBtn.disabled = true; + if (slackSpinner) slackSpinner.classList.remove('hidden'); + if (slackBtnText) slackBtnText.textContent = 'Formatting payload...'; + + reader.onload = async (event) => { + try { + const base64Data = event.target.result.split(',')[1]; + payload.content_base64 = base64Data; + // Always override filename with actual file name if uploaded directly + payload.filename = file.name; + + await handleSubmission(payload, '/scenarios/slack-messaging', slackRunBtn, slackBtnText, slackSpinner, 'Send to Slack'); + } catch (error) { + log(`File parsing error: ${error.message}`, 'error'); + if (slackBtnText) slackBtnText.textContent = 'System Error'; + if (slackRunBtn) slackRunBtn.disabled = false; + if (slackSpinner) slackSpinner.classList.add('hidden'); + } + }; + + reader.onerror = () => { + log('Failed to read binary file from memory.', 'error'); + if (slackBtnText) slackBtnText.textContent = 'System Error'; + if (slackRunBtn) slackRunBtn.disabled = false; + if (slackSpinner) slackSpinner.classList.add('hidden'); + }; + + reader.readAsDataURL(file); + } else { + await handleSubmission(payload, '/scenarios/slack-messaging', slackRunBtn, slackBtnText, slackSpinner, 'Send to Slack'); + } + }); + } + // ====================================================== // AI Agent Chat Logic // ====================================================== @@ -810,6 +1362,100 @@ document.addEventListener('DOMContentLoaded', () => { bubble.innerHTML = `
${escapeHTML(roleLabel)}

${escapeHTML(content)}

`; agentChatHistory.appendChild(bubble); agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + return bubble; + } + + function appendStreamingBubble(label = 'Agent Streaming') { + const bubble = document.createElement('div'); + bubble.className = 'chat-bubble assistant streaming-bubble'; + bubble.innerHTML = ` +
+ ${escapeHTML(label)} +

+
+ + + + Streaming response... 0.0s +
+
+ `; + agentChatHistory.appendChild(bubble); + agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + return { + bubble, + text: bubble.querySelector('.streaming-text'), + loader: bubble.querySelector('.stream-tail-loader'), + timer: bubble.querySelector('.stream-running-timer') + }; + } + + function appendTraceBadge(traceId, transportLabel = '') { + if (!traceId) return; + const badge = document.createElement('div'); + badge.className = 'chat-trace-badge'; + const suffix = transportLabel ? ` | ${transportLabel}` : ''; + badge.textContent = `TRC-${traceId.toUpperCase().slice(0, 8)}${suffix}`; + agentChatHistory.appendChild(badge); + agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + } + + function appendStreamEndMessage(message, success = true, finalTime = null) { + const end = document.createElement('div'); + end.className = `stream-end-message ${success ? 'success' : 'error'}`; + let displayMessage = message || (success ? 'Streaming completed.' : 'Streaming ended with an error.'); + if (finalTime) { + displayMessage += ` (Total Time: ${finalTime}s)`; + } + end.textContent = displayMessage; + agentChatHistory.appendChild(end); + agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + } + + function updateAgentTransportStatus() { + if (!agentTransportStatus) return; + const label = agentTransportMode === 'streamable-http' ? 'Streamable HTTP' : 'stdio'; + agentTransportStatus.querySelector('.transport-status-label').textContent = `Transport: ${label}`; + } + + async function loadAgentTransportMode() { + try { + const response = await fetch('/scenarios/agent-transport'); + if (!response.ok) throw new Error(`Server returned ${response.status}`); + const data = await response.json(); + agentTransportMode = data.transport === 'streamable-http' ? 'streamable-http' : 'stdio'; + } catch (error) { + agentTransportMode = 'stdio'; + log(`Transport status unavailable; using stdio UI mode (${error.message})`, 'system'); + } + updateAgentTransportStatus(); + } + + async function readNdjsonStream(response, handlers) { + if (!response.body) throw new Error('Browser did not expose a readable response stream'); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let pending = ''; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + pending += decoder.decode(value, { stream: true }); + const lines = pending.split('\n'); + pending = lines.pop() || ''; + + for (const line of lines) { + if (!line.trim()) continue; + const event = JSON.parse(line); + if (handlers[event.type]) handlers[event.type](event); + } + } + + if (pending.trim()) { + const event = JSON.parse(pending); + if (handlers[event.type]) handlers[event.type](event); + } } function appendStepCard(step) { @@ -834,10 +1480,44 @@ document.addEventListener('DOMContentLoaded', () => {
${escapeHTML(argsStr)}
${resultPreview ? `
${resultIcon} ${escapeHTML(resultPreview)}
` : ''} `; - agentChatHistory.appendChild(card); + + const streamingBubble = agentChatHistory.querySelector('.streaming-bubble'); + if (streamingBubble) { + agentChatHistory.insertBefore(card, streamingBubble); + } else { + agentChatHistory.appendChild(card); + } agentChatHistory.scrollTop = agentChatHistory.scrollHeight; } + async function readNdjsonStream(response, handlers) { + if (!response.body) throw new Error('Browser did not expose a readable response stream'); + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let pending = ''; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + pending += decoder.decode(value, { stream: true }); + const lines = pending.split('\n'); + pending = lines.pop() || ''; + + for (const line of lines) { + if (!line.trim()) continue; + const event = JSON.parse(line); + if (handlers[event.type]) handlers[event.type](event); + } + } + + if (pending.trim()) { + const event = JSON.parse(pending); + if (handlers[event.type]) handlers[event.type](event); + } + } + async function sendAgentMessage() { const message = agentInput.value.trim(); if (!message || agentBusy) return; @@ -856,7 +1536,107 @@ document.addEventListener('DOMContentLoaded', () => { log(`Agent Chat: Sending message...`, 'system'); + let timerInterval = null; + let streamView = null; + const startTime = Date.now(); + try { + if (agentTransportMode === 'streamable-http') { + // Instantly display the streaming bubble and start the active timer + streamView = appendStreamingBubble(); + agentTyping.classList.add('hidden'); // Hide generic typing dot loader + + function startRunningTimer() { + if (timerInterval) return; + timerInterval = setInterval(() => { + if (streamView && streamView.timer) { + const elapsed = ((Date.now() - startTime) / 1000).toFixed(1); + streamView.timer.textContent = `${elapsed}s`; + } + }, 100); + } + startRunningTimer(); + + const response = await fetch('/scenarios/agent-chat-stream', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + message: message, + history: agentConversationHistory.slice(0, -1) + }) + }); + + if (!response.ok) throw new Error(`Server returned ${response.status}`); + + let finalText = ''; + let traceId = ''; + let success = true; + let doneMessage = ''; + + await readNdjsonStream(response, { + meta: (event) => { + traceId = event.trace_id || traceId; + }, + status: (event) => { + log(`Agent Stream: ${event.message}`, 'system'); + }, + step: (event) => { + agentTyping.classList.add('hidden'); + appendStepCard({ + tool: event.tool, + args: event.args || {}, + result: event.result || '' + }); + }, + final_chunk: (event) => { + agentTyping.classList.add('hidden'); + finalText += event.content || ''; + streamView.text.textContent = finalText; + agentChatHistory.scrollTop = agentChatHistory.scrollHeight; + }, + error: (event) => { + success = false; + agentTyping.classList.add('hidden'); + finalText += event.message || ''; + streamView.text.textContent = finalText; + }, + done: (event) => { + traceId = event.trace_id || traceId; + success = Boolean(event.success); + doneMessage = event.message || `Streaming ${success ? 'completed' : 'failed'}. trace_id=${traceId}`; + + if (timerInterval) { + clearInterval(timerInterval); + timerInterval = null; + } + const finalElapsed = ((Date.now() - startTime) / 1000).toFixed(2); + streamView.loader.classList.add('hidden'); + appendStreamEndMessage(doneMessage, success, finalElapsed); + } + }); + + agentTyping.classList.add('hidden'); + if (!doneMessage) { + if (timerInterval) { + clearInterval(timerInterval); + timerInterval = null; + } + const finalElapsed = ((Date.now() - startTime) / 1000).toFixed(2); + streamView.loader.classList.add('hidden'); + doneMessage = `Streaming connection closed before done event. trace_id=${traceId || 'unknown'}`; + appendStreamEndMessage(doneMessage, false, finalElapsed); + success = false; + } + if (!finalText) { + finalText = success ? 'Completed.' : 'The stream ended before a final answer was returned.'; + if (streamView) streamView.text.textContent = finalText; + } + agentConversationHistory.push({ role: 'assistant', content: finalText }); + appendTraceBadge(traceId, 'streamable-http'); + log(`Agent Chat: ${success ? 'Stream complete' : 'Stream failed'} | ${doneMessage}`, success ? 'success' : 'error'); + return; + } + const response = await fetch('/scenarios/agent-chat', { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -885,16 +1665,13 @@ document.addEventListener('DOMContentLoaded', () => { agentConversationHistory.push({ role: 'assistant', content: data.reply }); // Add trace badge - if (data.trace_id) { - const badge = document.createElement('div'); - badge.className = 'chat-trace-badge'; - badge.textContent = `TRC-${data.trace_id.toUpperCase().slice(0, 8)}`; - agentChatHistory.appendChild(badge); - } + appendTraceBadge(data.trace_id); log(`Agent Chat: ${data.success ? 'Success' : 'Responded'} | steps=${data.steps ? data.steps.length : 0}`, data.success ? 'success' : 'system'); } catch (error) { + if (timerInterval) clearInterval(timerInterval); + if (streamView && streamView.loader) streamView.loader.classList.add('hidden'); agentTyping.classList.add('hidden'); appendChatBubble('assistant', `Sorry, I couldn't reach the server: ${error.message}. Please check that the backend is running.`); log(`Agent Chat Error: ${error.message}`, 'error'); @@ -906,6 +1683,7 @@ document.addEventListener('DOMContentLoaded', () => { } // Event listeners for chat + loadAgentTransportMode(); agentSendBtn.addEventListener('click', sendAgentMessage); agentInput.addEventListener('keydown', (e) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -914,6 +1692,47 @@ document.addEventListener('DOMContentLoaded', () => { } }); + // ─── External Patient Viewer — form submission ─────────────────────────── + if (extViewerForm) { + extViewerForm.addEventListener('submit', async (e) => { + e.preventDefault(); + + const fd = new FormData(extViewerForm); + const patientId = (fd.get('patient_id') || '').trim(); + const givenName = (fd.get('patient_given') || '').trim(); + const familyName = (fd.get('patient_family') || '').trim(); + const birthdate = (fd.get('patient_birthdate')|| '').trim(); + const sourceSystem = fd.get('source_system') || 'epic'; + const maxEnc = parseInt(fd.get('max_encounters') || '5', 10); + const maxDocs = parseInt(fd.get('max_documents') || '10', 10); + + // Client-side guard: need at least one identity field + if (!patientId && !givenName && !familyName) { + log('Viewer: provide a Patient ID or at least a given/family name.', 'error'); + return; + } + + const payload = { + source_system: sourceSystem, + max_encounters: maxEnc, + max_documents: maxDocs, + }; + if (patientId) payload.patient_id = patientId; + if (givenName) payload.patient_given = givenName; + if (familyName) payload.patient_family = familyName; + if (birthdate) payload.patient_birthdate = birthdate; + + await handleSubmission( + payload, + '/scenarios/external-patient-viewer', + extViewerRunBtn, + extViewerBtnText, + extViewerSpinner, + 'Load External Chart', + pipelineLabels.ext_patient_viewer + ); + }); + } // Initial Load UI State resetUI(); diff --git a/playground/ext_patient_viewer/__init__.py b/playground/ext_patient_viewer/__init__.py new file mode 100644 index 0000000..77a99ae --- /dev/null +++ b/playground/ext_patient_viewer/__init__.py @@ -0,0 +1 @@ +# External Patient Viewer (Read-Only Retrieval) - playground sub-module diff --git a/playground/ext_patient_viewer/schema.py b/playground/ext_patient_viewer/schema.py new file mode 100644 index 0000000..6e9be22 --- /dev/null +++ b/playground/ext_patient_viewer/schema.py @@ -0,0 +1,45 @@ +""" +External Patient Viewer — Pydantic input/output schemas. + +Read-only: no writes occur in any scenario using these models. +""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel + + +class ExternalPatientViewerInput(BaseModel): + """ + Payload for the 'Load External Chart' workflow. + + The viewer resolves the patient via direct ID (preferred) or falls back + to identity-layer search using name + birthdate. Both paths remain + read-only; no FHIR resources are created or mutated. + """ + + # --- Identity Resolution --- + patient_id: Optional[str] = None + """Direct FHIR Patient resource ID. When supplied, name/DOB are ignored.""" + + patient_family: Optional[str] = None + """Family (last) name used for identity-layer search when patient_id is absent.""" + + patient_given: Optional[str] = None + """Given (first) name used for identity-layer search when patient_id is absent.""" + + patient_birthdate: Optional[str] = None + """ISO-8601 birthdate (YYYY-MM-DD) for identity-layer search disambiguation.""" + + # --- Source System --- + source_system: str = "epic" + """EHR source to query: 'epic' (default) or 'cerner'.""" + + # --- Retrieval Scope --- + max_encounters: int = 5 + """Maximum number of recent encounters to retrieve (1–20).""" + + max_documents: int = 10 + """Maximum number of document references to retrieve (1–50).""" diff --git a/playground/index.html b/playground/index.html index 46978e8..0882d12 100644 --- a/playground/index.html +++ b/playground/index.html @@ -1,10 +1,16 @@ + + - Node-wire Playground + node-wire Playground @@ -28,7 +34,7 @@
-

Node-Wire

+

node-wire

Autonomous Connector Orchestration Platform

@@ -51,7 +57,7 @@

Node-Wire

-
+
@@ -82,6 +88,43 @@

Connectors

+ +
+
+
+ +
+
+

Connector Apps

+

Ready-to-use experiences built on top of connectors

+
+
+ +
+
+
+
+
+ + @@ -93,7 +136,7 @@

Connectors

-

Node-Wire MCP via ToolHive

+

node-wire MCP via ToolHive

MCP Agent — Guardrailed
+
+
+ + Transport: stdio +
+
+
@@ -138,7 +188,7 @@

Node-Wire MCP via ToolHive

System Connectors

Pre-built Enterprise Integrations
- +
@@ -198,10 +248,48 @@

Google Drive

Secure clinical document archival with IAM-governed access and encryption.

-
+ +
+
+ + + +
+
+

Slack

+

Intelligent Team Notifications & File Uploads.

+
+
+ +
+
+ + + + +
+
+

Stripe

+

Financial transaction and subscription management infrastructure.

+
+
+ +
+
+ + + + +
+
+

Salesforce

+

Lead and contact management for CRM-driven enterprise workflows.

+
+
+
- - +
- +
@@ -305,20 +393,20 @@
- + - + - +
@@ -355,25 +443,25 @@ Plan: Increase Metformin to 1000mg twice daily. Refer to dietitian. Follow-up in 4 weeks for labs.
- + - + - +
@@ -494,15 +582,380 @@
- + + + + + + + + + +

Smart Pipeline

@@ -565,7 +1018,7 @@

Smart Pipeline

- + - +

Technical Audit

@@ -599,4 +1052,4 @@

Technical Audit

- \ No newline at end of file + diff --git a/playground/scenario_post_visit.py b/playground/scenario_post_visit.py index 9581c66..434e671 100644 --- a/playground/scenario_post_visit.py +++ b/playground/scenario_post_visit.py @@ -1,57 +1,74 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# import asyncio import base64 import logging import os import sys from datetime import datetime, timezone -from typing import Any, Dict, Optional +from typing import Dict # Add 'src' to sys.path so we can import core components if running from root sys.path.append(os.path.join(os.getcwd(), "src")) -from connectors.fhir_epic.logic import FhirEpicConnector -from connectors.fhir_epic.schema import ( +from node_wire_fhir_epic.logic import FhirEpicConnector +from node_wire_fhir_epic.schema import ( FhirPatientReadInput, FhirEncounterSearchInput, - FhirDocumentReferenceCreateInput + FhirDocumentReferenceCreateInput, ) -from runtime import SecretProvider +from node_wire_runtime import SecretProvider # Set up basic logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger("scenario_post_visit") + class EnvSecretProvider(SecretProvider): """Simple provider that could pull from env or a dict.""" + def __init__(self, secrets: Dict[str, str]): self.secrets = secrets + def get_secret(self, key: str) -> str: return self.secrets.get(key, "") + async def run_scenario(): """ Real-world Scenario: Post-Consultation Clinical Note Upload. - + Workflow: 1. Search for a Patient by demographics. 2. Find the most recent 'finished' Encounter for that patient. 3. Upload a clinical note (DocumentReference) tied to that Encounter. """ - + # Load env vars for secrets from dotenv import load_dotenv + load_dotenv() secrets = { - "epic_fhir_base_url": os.getenv("EPIC_FHIR_BASE_URL", "https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4"), - "epic_token_url": os.getenv("EPIC_TOKEN_URL", "https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token"), + "epic_fhir_base_url": os.getenv( + "EPIC_FHIR_BASE_URL", "https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4" + ), + "epic_token_url": os.getenv( + "EPIC_TOKEN_URL", "https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token" + ), "epic_client_id": os.getenv("EPIC_CLIENT_ID", "CLIENT_ID_HERE"), "epic_kid": os.getenv("EPIC_KID", "KID_HERE"), - "epic_private_key": os.getenv("EPIC_PRIVATE_KEY", "PRIVATE_KEY_HERE") + "epic_private_key": os.getenv("EPIC_PRIVATE_KEY", "PRIVATE_KEY_HERE"), } - + if "CLIENT_ID_HERE" in secrets.values(): - logger.warning("Using placeholder secrets. Ensure .env is populated with real Epic Sandbox credentials.") + logger.warning( + "Using placeholder secrets. Ensure .env is populated with real Epic Sandbox credentials." + ) connector = FhirEpicConnector(secret_provider=EnvSecretProvider(secrets)) trace_id = "scenario-trace-123" @@ -59,12 +76,11 @@ async def run_scenario(): print("\n=== STEP 1: Patient Discovery ===") patient_search_params = {"family": "Smith", "given": "Jason", "birthdate": "1985-01-01"} logger.info(f"Searching for patient: {patient_search_params}") - + try: - patient_action = connector.get_action("read_patient") - patient_result = await patient_action.internal_execute( - FhirPatientReadInput(search_params=patient_search_params), - trace_id=trace_id + patient_result = await connector.internal_execute( + FhirPatientReadInput(action="read_patient", search_params=patient_search_params), + trace_id=trace_id, ) patient_id = patient_result.resource.get("id") logger.info(f"Found Patient ID: {patient_id}") @@ -74,31 +90,31 @@ async def run_scenario(): print("\n=== STEP 2: Encounter Identification ===") today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_params = { - "patient": patient_id, - "status": "finished", - "date": today - } + encounter_params = {"patient": patient_id, "status": "finished", "date": today} logger.info(f"Finding encounter for patient {patient_id} on {today}") - + try: - encounter_action = connector.get_action("search_encounter") - enc_result = await encounter_action.internal_execute( - FhirEncounterSearchInput(search_params=encounter_params), - trace_id=trace_id + enc_result = await connector.internal_execute( + FhirEncounterSearchInput(action="search_encounter", search_params=encounter_params), + trace_id=trace_id, ) - + if not enc_result.resources: - logger.warning("No encounters found for this patient today. Falling back to most recent.") - enc_result = await encounter_action.internal_execute( - FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished"}), - trace_id=trace_id + logger.warning( + "No encounters found for this patient today. Falling back to most recent." + ) + enc_result = await connector.internal_execute( + FhirEncounterSearchInput( + action="search_encounter", + search_params={"patient": patient_id, "status": "finished"}, + ), + trace_id=trace_id, ) - + if not enc_result.resources: logger.error("No finished encounters found for this patient.") return - + encounter_id = enc_result.resources[0].get("id") logger.info(f"Selected Encounter ID: {encounter_id}") except Exception as e: @@ -107,30 +123,33 @@ async def run_scenario(): print("\n=== STEP 3: Clinical Note Upload ===") note_content = "Patient Jason Smith presented for follow-up. Vital signs stable. Plan: Continue current medication." - encoded_note = base64.b64encode(note_content.encode('utf-8')).decode('utf-8') - + encoded_note = base64.b64encode(note_content.encode("utf-8")).decode("utf-8") + doc_input = FhirDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": f"NOTE-{datetime.now().timestamp()}"}], status="current", - type={"coding": [{"system": "http://loinc.org", "code": "11506-3", "display": "Progress Note"}]}, + type={ + "coding": [ + {"system": "http://loinc.org", "code": "11506-3", "display": "Progress Note"} + ] + }, subject=f"Patient/{patient_id}", data=encoded_note, content_type="text/plain", - author=[{"reference": "Practitioner/ebmR9M-H9f6.dummy", "display": "Dr. Automated"}], + author=[{"reference": "Practitioner/ebmR9M-H9f6.dummy", "display": "Dr. Automated"}], description="Automated Post-Consultation Note Demo", - context={ - "encounter": [{"reference": f"Encounter/{encounter_id}"}] - } + context={"encounter": [{"reference": f"Encounter/{encounter_id}"}]}, ) - + logger.info(f"Uploading clinical note for Encounter {encounter_id}") try: - doc_action = connector.get_action("create_document_reference") - doc_result = await doc_action.internal_execute(doc_input, trace_id=trace_id) + doc_result = await connector.internal_execute(doc_input, trace_id=trace_id) logger.info(f"SUCCESS! Created DocumentReference: {doc_result.resource_id}") print(f"\nWorkflow Complete. Resource Created: {doc_result.resource_id}") except Exception as e: logger.error(f"Document upload failed: {e}") + if __name__ == "__main__": asyncio.run(run_scenario()) diff --git a/playground/scenarios.py b/playground/scenarios.py index 89ded8a..883c7d3 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -1,48 +1,81 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import base64 +import json import logging import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel, ValidationError, model_validator from dotenv import load_dotenv import os - -load_dotenv() - -from runtime.errors import ErrorMapper -from runtime.models import ErrorCategory - -ErrorMapper.register(ValidationError, ErrorCategory.BUSINESS, code="UNSUPPORTED_OPERATION") - -from connectors.fhir_epic.logic import FhirEpicConnector -from connectors.fhir_epic.schema import ( +import asyncio +from node_wire_runtime.errors import ErrorMapper +from node_wire_runtime.models import ErrorCategory +from node_wire_fhir_epic.logic import FhirEpicConnector +from node_wire_fhir_epic.schema import ( FhirDocumentReferenceCreateInput, FhirDocumentReferenceSearchInput, FhirEncounterSearchInput, FhirPatientReadInput, ) -from connectors.fhir_cerner.schema import ( +from node_wire_fhir_cerner.schema import ( FhirCernerDocumentReferenceCreateInput, FhirCernerDocumentReferenceSearchInput, FhirCernerEncounterSearchInput, FhirCernerPatientReadInput, ) -from connectors.google_drive.schema import ( +from node_wire_google_drive.schema import ( GoogleDriveOperationInput, - FilesUploadOperation, PermissionsCreateOperation, FilesGetOperation, FilesListOperation, FilesUpdateOperation, ) +from node_wire_salesforce.logic import SalesforceConnector +from node_wire_salesforce.schema import ( + CreateLeadInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + CreateContactInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, +) + + +from node_wire_slack.schema import ( + SlackPostMessageInput, + SlackSendDirectMessageInput, + SlackUploadFileInput, +) +from .ext_patient_viewer.schema import ExternalPatientViewerInput + +load_dotenv() + + +ErrorMapper.register(ValidationError, ErrorCategory.BUSINESS, code="UNSUPPORTED_OPERATION") +ErrorMapper.register(ValueError, ErrorCategory.BUSINESS, code="VALIDATION_ERROR") + + +load_dotenv() + + +ErrorMapper.register(ValidationError, ErrorCategory.BUSINESS, code="UNSUPPORTED_OPERATION") + logger = logging.getLogger("playground.scenarios") router = APIRouter(prefix="/scenarios", tags=["scenarios"]) + class PostConsultationInput(BaseModel): patient_id: Optional[str] = None patient_family: Optional[str] = None @@ -52,6 +85,7 @@ class PostConsultationInput(BaseModel): note_text: str visit_date: Optional[str] = None + class IncidentReportInput(BaseModel): title: str severity: str @@ -59,6 +93,38 @@ class IncidentReportInput(BaseModel): description: str reported_by: str = "Demo User" + +class StripeChargeInput(BaseModel): + amount: int + currency: str + description: Optional[str] = None + source: str = "tok_visa" + + +class StripePaymentIntentInputPlayground(BaseModel): + amount: int + currency: str + customer_id: Optional[str] = None + payment_method: Optional[str] = None + confirm: bool = False + + +class StripeSubscriptionInputPlayground(BaseModel): + customer_id: str + price_id: str + card_token: Optional[str] = None + + +class StripeCancelSubscriptionInputPlayground(BaseModel): + subscription_id: str + + +class StripeRefundInputPlayground(BaseModel): + charge_id: Optional[str] = None + payment_intent_id: Optional[str] = None + amount: Optional[int] = None + + class CernerPostConsultationInput(BaseModel): patient_id: Optional[str] = None patient_family: Optional[str] = None @@ -68,6 +134,7 @@ class CernerPostConsultationInput(BaseModel): note_text: str visit_date: Optional[str] = None + class GoogleDriveArchivalInput(BaseModel): document_name: Optional[str] = None recipient_email: Optional[str] = None @@ -112,62 +179,122 @@ def require_upload_fields_when_not_list(self) -> "GoogleDriveArchivalInput": dn = (self.document_name or "").strip() em = (self.recipient_email or "").strip() if not dn or not em: - raise ValueError("document_name and recipient_email are required for archival upload actions") + raise ValueError( + "document_name and recipient_email are required for archival upload actions" + ) return self + +class SalesforceLeadInputPlayground(BaseModel): + last_name: str + company: str + first_name: Optional[str] = None + email: Optional[str] = None + status: str = "Open - Not Contacted" + + +class SalesforceContactInputPlayground(BaseModel): + last_name: str + first_name: Optional[str] = None + email: Optional[str] = None + account_id: Optional[str] = None + + +class SalesforceGenericIdInputPlayground(BaseModel): + record_id: str + + +class SalesforceUpdateLeadInputPlayground(BaseModel): + record_id: str + first_name: Optional[str] = None + last_name: Optional[str] = None + company: Optional[str] = None + email: Optional[str] = None + + +class SalesforceUpdateContactInputPlayground(BaseModel): + record_id: str + first_name: Optional[str] = None + last_name: Optional[str] = None + email: Optional[str] = None + account_id: Optional[str] = None + + +class SlackPlaygroundInput(BaseModel): + action: str = "post_message" + channel: str = "" + message: Optional[str] = None + filename: Optional[str] = None + initial_comment: Optional[str] = None + content_base64: Optional[str] = None + + class ScenarioStep(BaseModel): name: str status: str # "pending", "success", "error" details: Optional[str] = None - display_name: Optional[str] = None # For "Plain English" UI labels + display_name: Optional[str] = None # For "Plain English" UI labels data: Optional[Any] = None retries: int = 0 + class ScenarioResponse(BaseModel): success: bool steps: List[ScenarioStep] final_resource_id: Optional[str] = None - human_summary: Optional[str] = None # Business-value summary + human_summary: Optional[str] = None # Business-value summary error_message: Optional[str] = None trace_id: str -def _safe_error_return(e: Exception, steps: List[ScenarioStep], trace_id: str, step_msg: str) -> ScenarioResponse: - from runtime.errors import ErrorMapper - from runtime.models import ErrorCategory +def _safe_error_return( + e: Exception, steps: List[ScenarioStep], trace_id: str, step_msg: str +) -> ScenarioResponse: + from node_wire_runtime.errors import ErrorMapper + from node_wire_runtime.models import ErrorCategory import logging - import asyncio + log = logging.getLogger("playground.scenarios") - + mapped_err = ErrorMapper.resolve(e) - safe_msg = str(e) if mapped_err.category != ErrorCategory.FATAL else "An internal system error occurred." - + safe_msg = ( + str(e) + if mapped_err.category != ErrorCategory.FATAL + else "An internal system error occurred." + ) + if hasattr(e, "errors") and callable(getattr(e, "errors", None)): try: safe_msg = e.errors()[0].get("msg", "Schema validation failed") except Exception: pass - + steps[-1].status = "error" steps[-1].details = f"[{mapped_err.category.value}] {safe_msg}" - + # Provide structured error data steps[-1].data = { - "error_code": mapped_err.code, + "error_code": mapped_err.code, "error_category": mapped_err.category.value, - "raw": {"error": safe_msg} + "raw": {"error": safe_msg}, } - + if mapped_err.category == ErrorCategory.BUSINESS: log.warning(f"{step_msg}: {safe_msg}") else: log.error(f"{step_msg}: {e}", exc_info=True) - + return ScenarioResponse(success=False, steps=steps, trace_id=trace_id, error_message=step_msg) -import asyncio -async def execute_with_retry(action: Any, input_data: Any, trace_id: str, step: ScenarioStep, max_retries: int = 3, base_delay: float = 1.0) -> Any: +async def execute_with_retry( + action: Any, + input_data: Any, + trace_id: str, + step: ScenarioStep, + max_retries: int = 3, + base_delay: float = 1.0, +) -> Any: last_exception = None delay = base_delay for attempt in range(max_retries + 1): @@ -176,7 +303,9 @@ async def execute_with_retry(action: Any, input_data: Any, trace_id: str, step: except Exception as e: last_exception = e if attempt < max_retries: - logger.warning(f"Action failed (attempt {attempt+1}/{max_retries+1}): {e}. Retrying in {delay}s...") + logger.warning( + f"Action failed (attempt {attempt + 1}/{max_retries + 1}): {e}. Retrying in {delay}s..." + ) step.retries += 1 await asyncio.sleep(delay) delay *= 2 @@ -185,105 +314,140 @@ async def execute_with_retry(action: Any, input_data: Any, trace_id: str, step: raise last_exception +# Single shared factory for playground scenarios (matches REST: enabled + exposed_via includes "rest"). +_playground_factory: Optional[Any] = None + + +def get_playground_factory() -> Any: + """Lazily load connector config once; same pattern as bindings REST `get_factory`.""" + global _playground_factory + if _playground_factory is None: + from bindings.factory import ConnectorFactory + from node_wire_runtime.connector_registry import auto_register + + _playground_factory = ConnectorFactory() + auto_register() + _playground_factory.load() + return _playground_factory + + +def resolve_connector(connector_id: str, action: Optional[str] = None) -> Any: + """Resolve a connector via public factory API (protocol-aware).""" + factory = get_playground_factory() + return factory.get_for_protocol(connector_id, "rest", action=action) + + def get_fhir_connector() -> FhirEpicConnector: - # Use global accessor instead of circular import - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - - connector = factory._connectors.get("fhir_epic") + connector = resolve_connector("fhir_epic") if not connector: raise HTTPException(status_code=500, detail="FHIR Epic connector not configured") - return connector + return connector # type: ignore[return-value] + def get_http_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - - connector = factory._connectors.get("http_generic") + # Manifest action for http_generic is "request"; pass it for parity with REST routing. + connector = resolve_connector("http_generic", action="request") if not connector: raise HTTPException(status_code=500, detail="Generic HTTP connector not configured") return connector -def get_cerner_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - connector = factory._connectors.get("fhir_cerner") +def get_cerner_connector(): + connector = resolve_connector("fhir_cerner") if not connector: raise HTTPException(status_code=500, detail="FHIR Cerner connector not configured") return connector -def get_google_drive_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - connector = factory._connectors.get("google_drive") +def get_google_drive_connector(): + connector = resolve_connector("google_drive") if not connector: raise HTTPException(status_code=500, detail="Google Drive connector not configured") return connector +def get_slack_connector(): + connector = resolve_connector("slack") + if not connector: + raise HTTPException(status_code=500, detail="Slack connector not configured") + return connector + + +def get_stripe_connector(): + connector = resolve_connector("stripe") + if not connector: + raise HTTPException(status_code=500, detail="Stripe connector not configured") + return connector + + +def get_salesforce_connector(): + connector = resolve_connector("salesforce") + if not connector: + raise HTTPException(status_code=500, detail="Salesforce connector not configured") + return connector + + @router.post("/post-consultation", response_model=ScenarioResponse) async def post_consultation_scenario( - payload: PostConsultationInput, - connector: FhirEpicConnector = Depends(get_fhir_connector) + payload: PostConsultationInput, connector: FhirEpicConnector = Depends(get_fhir_connector) ) -> ScenarioResponse: trace_id = str(uuid.uuid4()) steps: List[ScenarioStep] = [] - + # helper to add steps - def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): - steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) # STEP 1: Patient Discovery add_step("Patient Discovery", "pending", display_name="Identify Patient") try: - patient_action = connector.get_action("read_patient") - if payload.patient_id: logger.info(f"Performing direct Patient ID lookup: {payload.patient_id}") p_res = await execute_with_retry( - patient_action, - FhirPatientReadInput(resource_id=payload.patient_id), - trace_id, - steps[-1] + connector, FhirPatientReadInput(resource_id=payload.patient_id), trace_id, steps[-1] ) patient_id = payload.patient_id else: patient_search_params = { - "family": payload.patient_family, - "given": payload.patient_given, - "birthdate": payload.patient_birthdate + k: v + for k, v in { + "family": payload.patient_family, + "given": payload.patient_given, + "birthdate": payload.patient_birthdate, + }.items() + if v is not None } logger.info(f"Searching for patient: {patient_search_params}") p_res = await execute_with_retry( - patient_action, + connector, FhirPatientReadInput(search_params=patient_search_params), trace_id, - steps[-1] + steps[-1], ) patient_id = p_res.resource.get("id") if not patient_id: raise ValueError("Patient not found") - - patient_display = f"{payload.patient_given} {payload.patient_family}" if payload.patient_family else patient_id + + patient_display = ( + f"{payload.patient_given} {payload.patient_family}" + if payload.patient_family + else patient_id + ) steps[-1].status = "success" steps[-1].details = f"Verified: {patient_display}" steps[-1].display_name = f"Identity Verified: {patient_display}" - steps[-1].data = {"patient_id": patient_id, "display_name": patient_display, "raw": p_res.resource} + steps[-1].data = { + "patient_id": patient_id, + "display_name": patient_display, + "raw": p_res.resource, + } except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 1 failed") @@ -291,124 +455,176 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", add_step("Encounter Identification", "pending", display_name="Locate Medical Visit") try: if payload.encounter_id: - logger.info(f"Using manual Encounter ID: {payload.encounter_id}", extra={"trace_id": trace_id}) + logger.info( + f"Using manual Encounter ID: {payload.encounter_id}", extra={"trace_id": trace_id} + ) encounter_id = payload.encounter_id enc_type = "Manual" enc_status = "verified" else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_action = connector.get_action("search_encounter") - logger.info(f"Searching for encounter... patient={patient_id}, date={visit_date}", extra={"trace_id": trace_id}) + logger.info( + f"Searching for encounter... patient={patient_id}, date={visit_date}", + extra={"trace_id": trace_id}, + ) enc_res = await execute_with_retry( - encounter_action, - FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished", "date": visit_date}), + connector, + FhirEncounterSearchInput( + search_params={"patient": patient_id, "status": "finished", "date": visit_date} + ), trace_id, - steps[-1] + steps[-1], ) - + resources = enc_res.resources if not resources: # Fallback to any finished encounter enc_res = await execute_with_retry( - encounter_action, - FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished"}), + connector, + FhirEncounterSearchInput( + search_params={"patient": patient_id, "status": "finished"} + ), trace_id, - steps[-1] + steps[-1], ) resources = enc_res.resources if not resources: raise ValueError("No finished encounters found for this patient") - + selected_enc = resources[0] encounter_id = selected_enc.get("id") enc_type = selected_enc.get("type", [{}])[0].get("text", "Unknown") enc_status = selected_enc.get("status", "Unknown") - + if not encounter_id: - logger.error(f"Encounter found but missing 'id' field: {selected_enc}", extra={"trace_id": trace_id}) + logger.error( + f"Encounter found but missing 'id' field: {selected_enc}", + extra={"trace_id": trace_id}, + ) raise ValueError("The found Encounter resource is missing a valid FHIR ID.") - - logger.info(f"Selected Encounter: ID={encounter_id}, Type={enc_type}, Status={enc_status}", extra={"trace_id": trace_id}) - + + logger.info( + f"Selected Encounter: ID={encounter_id}, Type={enc_type}, Status={enc_status}", + extra={"trace_id": trace_id}, + ) + steps[-1].status = "success" steps[-1].details = f"Linked to {enc_type} Encounter: {encounter_id}" steps[-1].display_name = f"Visit Found: {enc_type} ({encounter_id})" - steps[-1].data = {"encounter_id": encounter_id, "type": enc_type, "status": enc_status, "raw": selected_enc if not payload.encounter_id else {"id": encounter_id, "note": "Manual ID used"}} + steps[-1].data = { + "encounter_id": encounter_id, + "type": enc_type, + "status": enc_status, + "raw": selected_enc + if not payload.encounter_id + else {"id": encounter_id, "note": "Manual ID used"}, + } except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 2 failed") # STEP 3: Clinical Note Upload add_step("Clinical Note Upload", "pending", display_name="Secure Sync to EHR") try: - encoded_note = base64.b64encode(payload.note_text.encode('utf-8')).decode('utf-8') + encoded_note = base64.b64encode(payload.note_text.encode("utf-8")).decode("utf-8") doc_input = FhirDocumentReferenceCreateInput( - identifier=[{"system": "urn:oid:1.2.3", "value": f"DEMO-{int(datetime.now().timestamp())}"}], + identifier=[ + {"system": "urn:oid:1.2.3", "value": f"DEMO-{int(datetime.now().timestamp())}"} + ], status="current", - type={"coding": [{"system": "http://loinc.org", "code": "11506-3", "display": "Progress Note"}]}, - category=[{"coding": [{"system": "http://hl7.org/fhir/us/core/CodeSystem/us-core-documentreference-category", "code": "clinical-note", "display": "Clinical Note"}]}], + type={ + "coding": [ + {"system": "http://loinc.org", "code": "11506-3", "display": "Progress Note"} + ] + }, + category=[ + { + "coding": [ + { + "system": "http://hl7.org/fhir/us/core/CodeSystem/us-core-documentreference-category", + "code": "clinical-note", + "display": "Clinical Note", + } + ] + } + ], subject=f"Patient/{patient_id}", data=encoded_note, content_type="text/plain", author=[{"reference": "Practitioner/ebmR9M-H9f6", "display": "Dr. Automated"}], description="Professional Demo Upload", - context={"encounter": [{"reference": f"Encounter/{encounter_id}"}]} + context={"encounter": [{"reference": f"Encounter/{encounter_id}"}]}, ) - - doc_action = connector.get_action("create_document_reference") - doc_res = await execute_with_retry(doc_action, doc_input, trace_id, steps[-1]) - + + doc_res = await execute_with_retry(connector, doc_input, trace_id, steps[-1]) + steps[-1].status = "success" steps[-1].details = f"EHR Updated. ID: {doc_res.resource_id}" steps[-1].display_name = "Note Synced Successfully" - steps[-1].data = {"resource_id": doc_res.resource_id, "raw": doc_res.resource if (hasattr(doc_res, 'resource') and doc_res.resource) else {"id": doc_res.resource_id, "status": "created", "note": "Resource payload not returned by Epic integration."}} - + steps[-1].data = { + "resource_id": doc_res.resource_id, + "raw": doc_res.resource + if (hasattr(doc_res, "resource") and doc_res.resource) + else { + "id": doc_res.resource_id, + "status": "created", + "note": "Resource payload not returned by Epic integration.", + }, + } + # STEP 4: Verification / Visualization add_step("Document Verification", "pending", display_name="Verify EHR Update") try: - doc_search_action = connector.get_action("search_document_reference") verify_res = await execute_with_retry( - doc_search_action, - FhirDocumentReferenceSearchInput(search_params={"patient": patient_id, "_id": doc_res.resource_id}), + connector, + FhirDocumentReferenceSearchInput( + search_params={"patient": patient_id, "_id": doc_res.resource_id} + ), trace_id, - steps[-1] + steps[-1], ) - + resources = verify_res.resources if not resources: - raise ValueError("Document was created but could not be verified in the EHR.") - + raise ValueError("Document was created but could not be verified in the EHR.") + verified_doc = resources[0] - + # Extract beautiful presentation data doc_date = verified_doc.get("date", "Unknown Date") doc_type_text = verified_doc.get("type", {}).get("text", "Clinical Note") if not doc_type_text and verified_doc.get("type", {}).get("coding"): - doc_type_text = verified_doc.get("type", {}).get("coding")[0].get("display", "Clinical Note") - + doc_type_text = ( + verified_doc.get("type", {}).get("coding")[0].get("display", "Clinical Note") + ) + doc_author = "Unknown Author" if verified_doc.get("author"): doc_author = verified_doc.get("author")[0].get("display", "System Orchestrator") - + doc_status = verified_doc.get("status", "current") - + # Extract more beautiful presentation data doc_category = "Clinical Note" if verified_doc.get("category") and verified_doc["category"][0].get("coding"): - doc_category = verified_doc["category"][0]["coding"][0].get("display", "Clinical Note") - + doc_category = verified_doc["category"][0]["coding"][0].get( + "display", "Clinical Note" + ) + doc_description = verified_doc.get("description", "Automated Clinical Note") doc_identifier = verified_doc.get("identifier", [{}])[0].get("value", "Unknown ID") # Decode base64 data for better display in beautiful view ONLY decoded_text = "No content available." try: - if verified_doc.get("content") and verified_doc["content"][0].get("attachment", {}).get("data"): + if verified_doc.get("content") and verified_doc["content"][0].get( + "attachment", {} + ).get("data"): b64_data = verified_doc["content"][0]["attachment"]["data"] decoded_text = base64.b64decode(b64_data).decode("utf-8") except Exception as e: logger.warning(f"Failed to decode base64 document content: {e}") - + beautiful_data = { "id": doc_res.resource_id, "identifier": doc_identifier, @@ -420,14 +636,14 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", "status": doc_status, "patient_name": patient_display, "encounter_id": encounter_id, - "content_text": decoded_text + "content_text": decoded_text, } - + steps[-1].status = "success" - steps[-1].details = f"Verified in Patient Chart" + steps[-1].details = "Verified in Patient Chart" steps[-1].display_name = f"Verified: {doc_type_text}" steps[-1].data = {"raw": verified_doc, "beautiful_data": beautiful_data} - + except Exception as e: logger.error(f"Verification Step 4 failed: {e}", extra={"trace_id": trace_id}) # We don't fail the whole scenario if verification fails, just mark the step @@ -436,25 +652,31 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", steps[-1].data = {"raw": {"error": str(e)}} return ScenarioResponse( - success=True, - steps=steps, - final_resource_id=doc_res.resource_id, + success=True, + steps=steps, + final_resource_id=doc_res.resource_id, human_summary="Medical record successfully updated in Epic. 15 minutes of manual entry automated in 2 seconds.", - trace_id=trace_id + trace_id=trace_id, ) except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 3 failed") + @router.post("/report-incident", response_model=ScenarioResponse) async def report_incident_scenario( - payload: IncidentReportInput, - connector: Any = Depends(get_http_connector) + payload: IncidentReportInput, connector: Any = Depends(get_http_connector) ) -> ScenarioResponse: trace_id = str(uuid.uuid4()) steps: List[ScenarioStep] = [] - - def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): - steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) # STEP 1: Format Payload add_step("Payload Formatting", "pending", display_name="Format Incident Payload") @@ -467,13 +689,13 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", "priority": payload.severity.lower(), "custom_fields": [ {"id": 12345, "value": payload.component}, - {"id": 67890, "value": ts} + {"id": 67890, "value": ts}, ], - "requester": {"name": payload.reported_by} + "requester": {"name": payload.reported_by}, } } steps[-1].status = "success" - steps[-1].details = f"Standard ITSM schema generated." + steps[-1].details = "Standard ITSM schema generated." steps[-1].display_name = "Payload Ready" steps[-1].data = {"raw": ticket_payload} except Exception as e: @@ -482,22 +704,23 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 2: Dispatch Webhook add_step("Dispatch Webhook", "pending", display_name="Dispatch Webhook") try: - from connectors.http_generic.schema import HttpRequestInput - + from node_wire_http_generic.schema import HttpRequestInput + # Using httpbin.org to simulate a real REST endpoint request_input = HttpRequestInput( url="https://httpbin.org/post", method="POST", headers={"X-Demo-Source": "node-wire"}, - body=ticket_payload + body=ticket_payload, ) - + http_action = connector response = await execute_with_retry(http_action, request_input, trace_id, steps[-1]) - + import json + resp_body = json.loads(response.body) - + steps[-1].status = "success" steps[-1].details = f"HTTP {response.status_code} Success" steps[-1].display_name = "Webhook Dispatched" @@ -510,23 +733,26 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", try: # httpbin echoes back our data in 'json' field incident_id = f"INC-{uuid.uuid4().hex[:8].upper()}" - + beautiful_data = { "id": incident_id, "type": "IT Service Incident", "date": datetime.now().isoformat(), "status": "OPEN", - "patient_name": payload.reported_by, + "patient_name": payload.reported_by, "author": "AOT-Automator", "category": payload.component, "description": payload.title, - "content_text": f"Incident documented and routed to Level 2 Support. Ref: {incident_id}\n\nDescription: {payload.description}" + "content_text": f"Incident documented and routed to Level 2 Support. Ref: {incident_id}\n\nDescription: {payload.description}", } - + steps[-1].status = "success" steps[-1].details = f"Incident {incident_id} Active" steps[-1].display_name = "Ticket Verified" - steps[-1].data = {"raw": {"incident_id": incident_id, "upstream_status": "accepted"}, "beautiful_data": beautiful_data} + steps[-1].data = { + "raw": {"incident_id": incident_id, "upstream_status": "accepted"}, + "beautiful_data": beautiful_data, + } except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 3 failed") @@ -535,18 +761,19 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", try: # Simulate background task import asyncio + await asyncio.sleep(0.4) - + steps[-1].status = "success" steps[-1].details = "System Audit Recorded" steps[-1].display_name = "Audit Log Updated" - + return ScenarioResponse( success=True, steps=steps, final_resource_id=incident_id, human_summary=f"IT Incident {incident_id} has been successfully created, routed, and audited.", - trace_id=trace_id + trace_id=trace_id, ) except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 4 failed") @@ -554,42 +781,49 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", @router.post("/cerner-post-consultation", response_model=ScenarioResponse) async def cerner_post_consultation_scenario( - payload: CernerPostConsultationInput, - connector: Any = Depends(get_cerner_connector) + payload: CernerPostConsultationInput, connector: Any = Depends(get_cerner_connector) ) -> ScenarioResponse: """4-step Cerner FHIR R4 post-consultation clinical note sync demo.""" trace_id = str(uuid.uuid4()) steps: List[ScenarioStep] = [] - def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): - steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) # STEP 1: Patient Discovery add_step("Patient Discovery", "pending", display_name="Identify Patient") try: - patient_action = connector.get_action("read_patient") - if payload.patient_id: logger.info(f"Cerner: direct Patient ID lookup: {payload.patient_id}") p_res = await execute_with_retry( - patient_action, + connector, FhirCernerPatientReadInput(resource_id=payload.patient_id), trace_id, - steps[-1] + steps[-1], ) patient_id = payload.patient_id else: - search_params = {k: v for k, v in { - "family": payload.patient_family, - "given": payload.patient_given, - "birthdate": payload.patient_birthdate, - }.items() if v} + search_params = { + k: v + for k, v in { + "family": payload.patient_family, + "given": payload.patient_given, + "birthdate": payload.patient_birthdate, + }.items() + if v + } logger.info(f"Cerner: searching for patient: {search_params}") p_res = await execute_with_retry( - patient_action, + connector, FhirCernerPatientReadInput(search_params=search_params), trace_id, - steps[-1] + steps[-1], ) patient_id = p_res.resource.get("id") @@ -598,12 +832,17 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", patient_display = ( f"{payload.patient_given} {payload.patient_family}" - if payload.patient_family else patient_id + if payload.patient_family + else patient_id ) steps[-1].status = "success" steps[-1].details = f"Verified: {patient_display}" steps[-1].display_name = f"Identity Verified: {patient_display}" - steps[-1].data = {"patient_id": patient_id, "display_name": patient_display, "raw": p_res.resource} + steps[-1].data = { + "patient_id": patient_id, + "display_name": patient_display, + "raw": p_res.resource, + } except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 1 failed") @@ -617,26 +856,25 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", selected_enc = {"id": encounter_id, "note": "Manual ID used"} else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_action = connector.get_action("search_encounter") enc_res = await execute_with_retry( - encounter_action, + connector, FhirCernerEncounterSearchInput( search_params={"patient": patient_id, "status": "finished", "date": visit_date} ), trace_id, - steps[-1] + steps[-1], ) resources = enc_res.resources if not resources: # Fallback: any finished encounter for this patient enc_res = await execute_with_retry( - encounter_action, + connector, FhirCernerEncounterSearchInput( search_params={"patient": patient_id, "status": "finished"} ), trace_id, - steps[-1] + steps[-1], ) resources = enc_res.resources @@ -654,7 +892,12 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", steps[-1].status = "success" steps[-1].details = f"Linked to {enc_type} Encounter: {encounter_id}" steps[-1].display_name = f"Visit Found: {enc_type} ({encounter_id})" - steps[-1].data = {"encounter_id": encounter_id, "type": enc_type, "status": enc_status, "raw": selected_enc} + steps[-1].data = { + "encounter_id": encounter_id, + "type": enc_type, + "status": enc_status, + "raw": selected_enc, + } except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 2 failed") @@ -668,7 +911,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # Cerner requires CodeSet 72 proprietary system — NOT a raw LOINC system URL. # The tenant ID is embedded in the connector's FHIR base URL path segment. try: - base_url_secret = connector._secret_provider.get_secret("cerner_fhir_base_url") + base_url_secret = connector.secret_provider.get_secret("cerner_fhir_base_url") # Extract tenant from URL: .../r4/{tenant_id} or similar parts = [p for p in base_url_secret.rstrip("/").split("/") if p] tenant_id = parts[-1] if parts else "your-tenant-id" @@ -686,12 +929,14 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", status="current", doc_status="final", type={ - "coding": [{ - "system": codeset72_system, - "code": "2820507", # Admission Note Physician in Cerner CodeSet 72 - "display": "Admission Note Physician", - "userSelected": True, - }], + "coding": [ + { + "system": codeset72_system, + "code": "2820507", # Admission Note Physician in Cerner CodeSet 72 + "display": "Admission Note Physician", + "userSelected": True, + } + ], "text": "Admission Note Physician", }, subject=f"Patient/{patient_id}", @@ -704,45 +949,47 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", custodian={"reference": "Organization/675844"}, context={ "encounter": [{"reference": "Encounter/97957281"}], - "period": {"start": period_start, "end": period_end} + "period": {"start": period_start, "end": period_end}, }, ) - doc_action = connector.get_action("create_document_reference") - doc_res = await execute_with_retry(doc_action, doc_input, trace_id, steps[-1]) + doc_res = await execute_with_retry(connector, doc_input, trace_id, steps[-1]) steps[-1].status = "success" steps[-1].details = f"Cerner EHR Updated. ID: {doc_res.resource_id}" steps[-1].display_name = "Note Synced to Cerner" steps[-1].data = { "resource_id": doc_res.resource_id, - "raw": doc_res.resource if (hasattr(doc_res, "resource") and doc_res.resource) - else {"id": doc_res.resource_id, "status": "created", "note": "Location header only — Cerner does not return body on create."}, + "raw": doc_res.resource + if (hasattr(doc_res, "resource") and doc_res.resource) + else { + "id": doc_res.resource_id, + "status": "created", + "note": "Location header only — Cerner does not return body on create.", + }, } # STEP 4: Verification add_step("Document Verification", "pending", display_name="Verify EHR Update") try: - doc_search_action = connector.get_action("search_document_reference") verify_res = await execute_with_retry( - doc_search_action, - FhirCernerDocumentReferenceSearchInput( - search_params={"_id": doc_res.resource_id} - ), + connector, + FhirCernerDocumentReferenceSearchInput(search_params={"_id": doc_res.resource_id}), trace_id, - steps[-1] + steps[-1], ) resources = verify_res.resources if not resources: - raise ValueError("Document created but could not be verified in Cerner. Indexing may be delayed.") + raise ValueError( + "Document created but could not be verified in Cerner. Indexing may be delayed." + ) verified_doc = resources[0] doc_date = verified_doc.get("date", now_iso) - doc_type_text = ( - verified_doc.get("type", {}).get("text") - or (verified_doc.get("type", {}).get("coding", [{}])[0].get("display", "Progress Note")) + doc_type_text = verified_doc.get("type", {}).get("text") or ( + verified_doc.get("type", {}).get("coding", [{}])[0].get("display", "Progress Note") ) doc_author = "Unknown Author" if verified_doc.get("author"): @@ -750,14 +997,18 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", doc_status = verified_doc.get("status", "current") doc_category = "Clinical Note" if verified_doc.get("category") and verified_doc["category"][0].get("coding"): - doc_category = verified_doc["category"][0]["coding"][0].get("display", "Clinical Note") + doc_category = verified_doc["category"][0]["coding"][0].get( + "display", "Clinical Note" + ) # Decode attachment content for display decoded_text = "No content available." try: content = verified_doc.get("content", []) if content and content[0].get("attachment", {}).get("data"): - decoded_text = base64.b64decode(content[0]["attachment"]["data"]).decode("utf-8") + decoded_text = base64.b64decode(content[0]["attachment"]["data"]).decode( + "utf-8" + ) except Exception: pass @@ -794,107 +1045,408 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", f"Clinical progress note successfully written to Cerner EHR for {patient_display}. " "15 minutes of manual chart entry automated in under 3 seconds." ), - trace_id=trace_id + trace_id=trace_id, ) except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 3 failed") -@router.post("/gdrive-archival", response_model=ScenarioResponse) -async def gdrive_archival_scenario( - payload: GoogleDriveArchivalInput, - connector: Any = Depends(get_google_drive_connector) + +@router.post("/stripe-charge", response_model=ScenarioResponse) +async def stripe_charge_scenario( + payload: StripeChargeInput, connector: Any = Depends(get_stripe_connector) ) -> ScenarioResponse: - """4-step Google Drive archival and sharing demo.""" trace_id = str(uuid.uuid4()) steps: List[ScenarioStep] = [] - def add_step(name: str, status: str, details: str = "", display_name: str = "", data: Any = None): - steps.append(ScenarioStep(name=name, status=status, details=details, display_name=display_name, data=data)) - - if payload.action == "files.list": - add_step("Drive List", "pending", display_name="List Drive Files") - try: - raw_ps = payload.list_page_size - page_size = 10 if raw_ps is None else int(raw_ps) - page_size = max(1, min(100, page_size)) - q = (payload.list_query or "").strip() or None - fields = (payload.list_fields or "").strip() or None - list_op = FilesListOperation( - action="files.list", - page_size=page_size, - query=q, - fields=fields, - ) - list_input = GoogleDriveOperationInput.model_validate(list_op.model_dump(exclude_none=True)) - res = await execute_with_retry(connector, list_input, trace_id, steps[-1]) - n = len(res.raw.get("files") or []) - steps[-1].status = "success" - steps[-1].details = f"Retrieved {n} file(s) (page_size={page_size})" - steps[-1].display_name = "Files Listed" - steps[-1].data = {"raw": res.raw} - return ScenarioResponse( - success=True, - steps=steps, - final_resource_id=None, - human_summary=f"Listed {n} file(s) from Google Drive (page size {page_size}).", - trace_id=trace_id, + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data ) - except Exception as e: - return _safe_error_return(e, steps, trace_id, "List failed") + ) - if payload.action == "files.get": - add_step("Drive Get", "pending", display_name="Get file metadata") - try: - fid = (payload.get_file_id or "").strip() - if not fid: - raise ValueError("get_file_id is required") - gf = (payload.get_fields or "").strip() or None - get_op = FilesGetOperation( - action="files.get", - file_id=fid, - fields=gf, - ) - get_input = GoogleDriveOperationInput.model_validate(get_op.model_dump(exclude_none=True)) - res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) - got_id = res.raw.get("id") or fid - name = res.raw.get("name", "") - steps[-1].status = "success" - steps[-1].details = f"Retrieved metadata for file id {got_id}" - steps[-1].display_name = "File metadata retrieved" - steps[-1].data = {"raw": res.raw} - return ScenarioResponse( - success=True, - steps=steps, - final_resource_id=got_id if isinstance(got_id, str) else str(got_id), - human_summary=f"Fetched Google Drive file metadata{f' ({name})' if name else ''}.", - trace_id=trace_id, - ) - except Exception as e: - return _safe_error_return(e, steps, trace_id, "files.get failed") + # STEP 1: Process Payment Intent + add_step("Process Payment Intent", "pending", display_name="Initialize Payment") + try: + steps[-1].status = "success" + steps[-1].details = "Payment initialization verified." + steps[-1].display_name = "Payment Initialized" + steps[-1].data = {"amount": payload.amount, "currency": payload.currency} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") - if payload.action == "files.update": - fid = (payload.update_file_id or "").strip() - if not fid: - raise ValueError("update_file_id is required") - add_ids = [ - x.strip() - for x in (payload.update_add_parents or "").split(",") - if x.strip() - ] or None - remove_ids = [ - x.strip() - for x in (payload.update_remove_parents or "").split(",") - if x.strip() - ] or None - new_name = (payload.update_name or "").strip() or None - new_mime = (payload.update_mime_type or "").strip() or None + # STEP 2: Confirm Charge + add_step("Confirm Charge", "pending", display_name="Process Charge") + try: + from node_wire_stripe.schema import ChargeInput - add_step("Update Prepare", "pending", display_name="Prepare update request") - preview = { - "file_id": fid, - "name": new_name, - "mime_type": new_mime, - "add_parents": add_ids, + charge_input = ChargeInput( + amount=payload.amount, + currency=payload.currency, + source=payload.source, + description=payload.description, + ) + + charge_res = await execute_with_retry(connector, charge_input, trace_id, steps[-1]) + + steps[-1].status = "success" + steps[-1].details = f"Charge Processed: {charge_res.charge_id}" + steps[-1].display_name = "Charge Successful" + steps[-1].data = {"raw": charge_res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + # STEP 3: Verify Transaction + add_step("Verify Transaction", "pending", display_name="Verify Receipt") + try: + beautiful_data = { + "id": charge_res.charge_id, + "type": "Payment Receipt", + "date": datetime.now().isoformat(), + "status": charge_res.status, + "patient_name": "Demo User", + "author": "Stripe Gateway", + "category": "Financial", + "description": payload.description or "No description", + "content_text": f"Charge of {payload.amount / 100:.2f} {payload.currency.upper()} processed successfully. Receipt: {charge_res.receipt_url or 'N/A'}", + } + steps[-1].status = "success" + steps[-1].details = "Transaction Verified" + steps[-1].display_name = "Transaction Verified" + steps[-1].data = {"beautiful_data": beautiful_data, "raw": {"status": "Verified"}} + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=charge_res.charge_id, + human_summary=f"Successfully processed {payload.amount / 100:.2f} {payload.currency.upper()} charge.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + + +@router.post("/stripe-payment-intent", response_model=ScenarioResponse) +async def stripe_payment_intent_scenario( + payload: StripePaymentIntentInputPlayground, connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + add_step("Initialize Session", "pending", display_name="Initialize PI") + try: + steps[-1].status = "success" + steps[-1].details = f"Initialized PI session for {payload.amount} {payload.currency}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Create Payment Intent", "pending", display_name="Create Intent") + try: + from node_wire_stripe.schema import CreatePaymentIntentInput + + pi_input = CreatePaymentIntentInput( + amount=payload.amount, + currency=payload.currency, + customer_id=payload.customer_id, + payment_method=payload.payment_method, + confirm=payload.confirm, + ) + res = await execute_with_retry(connector, pi_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Created Intent: {res.payment_intent_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Allocation", "pending", display_name="Verify Allocation") + try: + steps[-1].status = "success" + steps[-1].details = "Allocation verified" + steps[-1].display_name = "Allocation Verified" + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.payment_intent_id, + human_summary=f"Successfully created payment intent {res.payment_intent_id}.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + + +@router.post("/stripe-subscription", response_model=ScenarioResponse) +async def stripe_subscription_scenario( + payload: StripeSubscriptionInputPlayground, connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + add_step("Validate Customer", "pending", display_name="Validate Params") + try: + steps[-1].status = "success" + steps[-1].details = f"Validated inputs for Customer: {payload.customer_id}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Create Subscription", "pending", display_name="Create Sub") + try: + from node_wire_stripe.schema import CreateSubscriptionInput + + sub_input = CreateSubscriptionInput( + customer_id=payload.customer_id, + price_id=payload.price_id, + card_token=payload.card_token, + ) + res = await execute_with_retry(connector, sub_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Subscription Created: {res.subscription_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Provisioning", "pending", display_name="Verify Sub") + try: + steps[-1].status = "success" + steps[-1].details = f"Subscription {res.subscription_id} is {res.status}" + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.subscription_id, + human_summary="Successfully provisioned subscription for customer.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + + +@router.post("/stripe-cancel-subscription", response_model=ScenarioResponse) +async def stripe_cancel_subscription_scenario( + payload: StripeCancelSubscriptionInputPlayground, connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + add_step("Locate Resource", "pending", display_name="Locate Sub") + try: + steps[-1].status = "success" + steps[-1].details = f"Targeting subscription: {payload.subscription_id}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Cancel Subscription", "pending", display_name="Cancel Sub") + try: + from node_wire_stripe.schema import CancelSubscriptionInput + + can_input = CancelSubscriptionInput(subscription_id=payload.subscription_id) + res = await execute_with_retry(connector, can_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Cancelled Sub: {res.subscription_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Termination", "pending", display_name="Verify Cancel") + try: + steps[-1].status = "success" + steps[-1].details = f"Cancellation verified. Status: {res.status}" + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.subscription_id, + human_summary="Successfully canceled subscription.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + + +@router.post("/stripe-refund", response_model=ScenarioResponse) +async def stripe_refund_scenario( + payload: StripeRefundInputPlayground, connector: Any = Depends(get_stripe_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + add_step("Validate Charge", "pending", display_name="Validate Params") + try: + steps[-1].status = "success" + steps[ + -1 + ].details = f"Refund targeted for ID: {payload.charge_id or payload.payment_intent_id}" + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Process Refund", "pending", display_name="Issue Refund") + try: + from node_wire_stripe.schema import IssueRefundInput + + ref_input = IssueRefundInput( + charge_id=payload.charge_id, + payment_intent_id=payload.payment_intent_id, + amount=payload.amount, + ) + res = await execute_with_retry(connector, ref_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = f"Refund Processed: {res.refund_id}" + steps[-1].data = {"raw": res.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Refund", "pending", display_name="Verify Receipt") + try: + steps[-1].status = "success" + steps[-1].details = f"Refund recorded properly. Status: {res.status}" + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=res.refund_id, + human_summary="Successfully issued refund.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + + +@router.post("/gdrive-archival", response_model=ScenarioResponse) +async def gdrive_archival_scenario( + payload: GoogleDriveArchivalInput, connector: Any = Depends(get_google_drive_connector) +) -> ScenarioResponse: + """4-step Google Drive archival and sharing demo.""" + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + if payload.action == "files.list": + add_step("Drive List", "pending", display_name="List Drive Files") + try: + raw_ps = payload.list_page_size + page_size = 10 if raw_ps is None else raw_ps + page_size = max(1, min(100, page_size)) + q = (payload.list_query or "").strip() or None + fields = (payload.list_fields or "").strip() or None + list_op = FilesListOperation( + action="files.list", + page_size=page_size, + query=q, + fields=fields, + ) + list_input = GoogleDriveOperationInput.model_validate( + list_op.model_dump(exclude_none=True) + ) + res = await execute_with_retry(connector, list_input, trace_id, steps[-1]) + n = len(res.raw.get("files") or []) + steps[-1].status = "success" + steps[-1].details = f"Retrieved {n} file(s) (page_size={page_size})" + steps[-1].display_name = "Files Listed" + steps[-1].data = {"raw": res.raw} + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=None, + human_summary=f"Listed {n} file(s) from Google Drive (page size {page_size}).", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "List failed") + + if payload.action == "files.get": + add_step("Drive Get", "pending", display_name="Get file metadata") + try: + fid = (payload.get_file_id or "").strip() + if not fid: + raise ValueError("get_file_id is required") + gf = (payload.get_fields or "").strip() or None + get_op = FilesGetOperation( + action="files.get", + file_id=fid, + fields=gf, + ) + get_input = GoogleDriveOperationInput.model_validate( + get_op.model_dump(exclude_none=True) + ) + res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) + got_id = res.raw.get("id") or fid + name = res.raw.get("name", "") + steps[-1].status = "success" + steps[-1].details = f"Retrieved metadata for file id {got_id}" + steps[-1].display_name = "File metadata retrieved" + steps[-1].data = {"raw": res.raw} + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=got_id if isinstance(got_id, str) else str(got_id), + human_summary=f"Fetched Google Drive file metadata{f' ({name})' if name else ''}.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "files.get failed") + + if payload.action == "files.update": + fid = (payload.update_file_id or "").strip() + if not fid: + raise ValueError("update_file_id is required") + add_ids = [ + x.strip() for x in (payload.update_add_parents or "").split(",") if x.strip() + ] or None + remove_ids = [ + x.strip() for x in (payload.update_remove_parents or "").split(",") if x.strip() + ] or None + new_name = (payload.update_name or "").strip() or None + new_mime = (payload.update_mime_type or "").strip() or None + + add_step("Update Prepare", "pending", display_name="Prepare update request") + preview = { + "file_id": fid, + "name": new_name, + "mime_type": new_mime, + "add_parents": add_ids, "remove_parents": remove_ids, } steps[-1].status = "success" @@ -956,9 +1508,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", success=True, steps=steps, final_resource_id=rid if isinstance(rid, str) else str(rid), - human_summary=( - f"Updated Google Drive file{f' ({fname})' if fname else f' ({rid})'}." - ), + human_summary=(f"Updated Google Drive file{f' ({fname})' if fname else f' ({rid})'}."), trace_id=trace_id, ) @@ -973,7 +1523,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", "archived_at": ts, "recipient": payload.recipient_email, "folder_id": payload.folder_id, - "has_binary_payload": bool(payload.file_base64) + "has_binary_payload": bool(payload.file_base64), } steps[-1].status = "success" steps[-1].details = f"Archival schema generated for {payload.document_name}" @@ -1002,7 +1552,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", res = await execute_with_retry(connector, upload_input, trace_id, steps[-1]) file_id = res.raw.get("id") - + if not file_id: raise ValueError("File upload failed, no ID returned") @@ -1022,11 +1572,11 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", file_id=file_id, role="reader", email_address=payload.recipient_email, - type="user" + type="user", ) ) perm_res = await execute_with_retry(connector, perm_input, trace_id, steps[-1]) - + steps[-1].status = "success" steps[-1].details = f"Read access granted to {payload.recipient_email}" steps[-1].display_name = "Access Control Applied" @@ -1041,22 +1591,24 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", FilesGetOperation( action="files.get", file_id=file_id, - fields="id, name, mimeType, webViewLink, size, createdTime, owners" + fields="id, name, mimeType, webViewLink, size, createdTime, owners", ) ) - get_res = await connector.internal_execute(get_input, trace_id=trace_id) + get_res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) file_metadata = get_res.raw - + beautiful_data = { "id": file_id, "type": "Secure Archived Document", "date": file_metadata.get("createdTime", datetime.now().isoformat()), "status": "SECURED", - "patient_name": payload.recipient_email, # Mimicking patient name for UI schema - "author": file_metadata.get("owners", [{}])[0].get("displayName", "Service Account") if file_metadata.get("owners") else "Service Account", + "patient_name": payload.recipient_email, # Mimicking patient name for UI schema + "author": file_metadata.get("owners", [{}])[0].get("displayName", "Service Account") + if file_metadata.get("owners") + else "Service Account", "category": file_metadata.get("mimeType", "text/plain"), "description": file_metadata.get("name"), - "content_text": f"Document successfully archived and shared.\n\nWeb Link: {file_metadata.get('webViewLink')}\nSize: {file_metadata.get('size')} bytes" + "content_text": f"Document successfully archived and shared.\n\nWeb Link: {file_metadata.get('webViewLink')}\nSize: {file_metadata.get('size')} bytes", } steps[-1].status = "success" @@ -1069,29 +1621,131 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", steps=steps, final_resource_id=file_id, human_summary=f"Success! Document '{payload.document_name}' archived to Google Drive and shared with {payload.recipient_email}.", - trace_id=trace_id + trace_id=trace_id, ) except Exception as e: return _safe_error_return(e, steps, trace_id, "Step 4 failed") -# --------------------------------------------------------------------------- -# AI Agent Chat endpoint -# --------------------------------------------------------------------------- +@router.post("/slack-messaging", response_model=ScenarioResponse) +async def slack_scenario( + payload: SlackPlaygroundInput, connector: Any = Depends(get_slack_connector) +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] -class AgentChatMessage(BaseModel): - role: str # "user" or "assistant" - content: str + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) -class AgentChatInput(BaseModel): - message: str - history: List[Dict[str, str]] = [] # [{"role": "user/assistant", "content": "..."}] + add_step("Format Slack Payload", "pending", display_name="Format Slack Payload") + try: + if payload.action == "upload_file": + input_model = SlackUploadFileInput( + action="upload_file", + channel=payload.channel, + filename=payload.filename or "file.txt", + initial_comment=payload.initial_comment or "", + content_base64=payload.content_base64 or "", + ) + elif payload.action == "send_direct_message": + input_model = SlackSendDirectMessageInput( + action="send_direct_message", channel=payload.channel, message=payload.message or "" + ) + else: + input_model = SlackPostMessageInput( + action="post_message", channel=payload.channel, message=payload.message or "" + ) -class AgentChatStepResponse(BaseModel): + steps[-1].status = "success" + steps[-1].details = "Payload structured correctly" + steps[-1].display_name = "Payload Ready" + steps[-1].data = {"raw": input_model.model_dump()} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 failed") + + add_step("Dispatch to Slack API", "pending", display_name="Dispatch to Slack API") + try: + slack_res = await execute_with_retry(connector, input_model, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "API Accepted Request" + steps[-1].display_name = "Dispatched via API" + steps[-1].data = {"raw": slack_res.raw} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 failed") + + add_step("Verify Acknowledgment", "pending", display_name="Verify Acknowledgment") + try: + ref_id = ( + slack_res.ts + if hasattr(slack_res, "ts") and slack_res.ts + else getattr(slack_res, "file_id", "unknown") + ) + + beautiful_data = { + "id": ref_id, + "type": "Slack Notification", + "date": datetime.now().isoformat(), + "status": "DELIVERED", + "patient_name": payload.channel, + "author": "Slack Connector", + "category": payload.action, + "description": payload.filename if payload.action == "upload_file" else "Slack Message", + "content_text": payload.message if payload.message else f"Uploaded {payload.filename}", + } + + steps[-1].status = "success" + steps[-1].details = f"Acknowledged by Slack (Ref: {ref_id})" + steps[-1].display_name = "Verified Delivered" + steps[-1].data = {"raw": {"reference_id": ref_id}, "beautiful_data": beautiful_data} + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 failed") + + add_step("Update Audit Trail", "pending", display_name="Update Audit Trail") + try: + # Simulate latency + await asyncio.sleep(0.3) + steps[-1].status = "success" + steps[-1].details = "Audit logged securely" + steps[-1].display_name = "Audit Complete" + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=ref_id, + human_summary=f"Successfully sent Slack ({payload.action}) to {payload.channel}.", + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 4 failed") + + +# --------------------------------------------------------------------------- +# AI Agent Chat endpoint +# --------------------------------------------------------------------------- + + +class AgentChatMessage(BaseModel): + role: str # "user" or "assistant" + content: str + + +class AgentChatInput(BaseModel): + message: str + history: List[Dict[str, str]] = [] # [{"role": "user/assistant", "content": "..."}] + + +class AgentChatStepResponse(BaseModel): tool: str args: Dict[str, Any] result: Optional[str] = None + class AgentChatResponse(BaseModel): reply: str steps: List[AgentChatStepResponse] = [] @@ -1099,16 +1753,52 @@ class AgentChatResponse(BaseModel): success: bool +def _current_agent_transport() -> str: + transport = os.environ.get("NW_MCP_TRANSPORT", "stdio").strip().lower() or "stdio" + return transport if transport in {"stdio", "streamable-http"} else "stdio" + + +def _build_agent_chat_task(payload: AgentChatInput) -> str: + history_text_parts = [] + for msg in payload.history: + role = msg.get("role", "user") + content = msg.get("content", "") + history_text_parts.append(f"{role.upper()}: {content}") + + if history_text_parts: + return ( + "Previous conversation:\n" + + "\n".join(history_text_parts) + + f"\n\nUSER (latest): {payload.message}" + ) + return payload.message + + +@router.get("/agent-transport") +async def agent_transport() -> Dict[str, str]: + transport = _current_agent_transport() + return { + "transport": transport, + "label": "Streamable HTTP" if transport == "streamable-http" else "stdio", + } + + AGENT_GUARDRAIL_PROMPT = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " - "emails via SMTP.\n\n" + "emails via SMTP.\n" + "Tool names are `.` (e.g. `fhir_cerner.read_patient`, " + "`fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`). " + "Use exactly the names and JSON-schema arguments from tools/list.\n\n" "WORKFLOW (MUST EXECUTE SEQUENTIALLY, ONE STRICT STEP AT A TIME):\n" "When asked to 'Send patient summaries via email' or similar tasks, you MUST follow this exact flow in order. DO NOT parallelize these steps:\n" - " 1. First turn: Search for the patient. (If you have a Patient ID, you DO NOT need their name or birthdate).\n" - " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call the search tool with a guessed or hallucinated ID like '12345'.\n" + " 1. First turn: Obtain patient demographics from the EHR.\n" + ' - If the user gave a Patient ID: call `fhir_cerner.read_patient` or `fhir_epic.read_patient` with JSON `{"resource_id": ""}` (use Epic when the ID starts with \'e\'). Do NOT use search_patients for a known ID.\n' + " - If there is NO Patient ID but there IS a name: use name fields or `search_patients` per tools/list schema (e.g. `given_name`, `family_name`, `birthdate`, or valid `search_params`).\n" + " - Use `search_patients` only when you have no ID, or after `read_patient` failed and you need a fallback.\n" + " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call tools with a guessed or hallucinated ID like '12345'.\n" " 2. Second turn: Once you have the patient data from step 1, create a file on Google Drive containing the masked patient summary. Do NOT use placeholder content.\n" - " 3. Third turn: Once step 2 returns a 'web_view_link', send an email with that exact link. Do NOT call the email tool until you have the link.\n" + " 3. Third turn: Once step 2 returns a shareable Drive URL (see `data.raw.webViewLink` from tool `google_drive.files.upload`), send an email with that exact link. Do NOT call the email tool until you have the link.\n" " CRITICAL: You MUST ask the user for the recipient email address if they haven't provided it. DO NOT guess email addresses like 'recipient_email@example.com'.\n" " CRITICAL: In the email body, you MUST insert the actual URL string returned from step 2 (e.g. 'https://drive.google.com/...'). Do NOT literally write the text ''.\n\n" "DATA PRIVACY & MASKING — follow these strictly:\n" @@ -1118,7 +1808,7 @@ class AgentChatResponse(BaseModel): " - NEVER use the placeholder values ('1990-05-12', '12724066', or 'Name') in your reports - always use the real patient data masked accordingly.\n" "- EMAIL WORKFLOW: When sending patient details to an email recipient:\n" " 1. ALWAYS upload the masked patient summary to Google Drive first.\n" - " 2. Use the 'web_view_link' returned by the google_drive_upload_file tool.\n" + " 2. Use `data.raw.webViewLink` from the `google_drive.files.upload` tool result.\n" " 3. In the email body, provide that link instead of the actual data.\n" " 4. The email body should be professional: 'Patient data summary from the EHR is available at the following secure link: [Link]'\n\n" "GUARDRAILS:\n" @@ -1134,6 +1824,25 @@ class AgentChatResponse(BaseModel): ) +def _build_agent_chat_task(payload: AgentChatInput) -> str: + history_text_parts = [] + for msg in payload.history: + role = msg.get("role", "user") + content = msg.get("content", "") + history_text_parts.append(f"{role.upper()}: {content}") + + if history_text_parts: + return ( + "Previous conversation:\n" + + "\n".join(history_text_parts) + + f"\n\nUSER (latest): {payload.message}" + ) + return payload.message + + +def _current_agent_transport() -> str: + transport = os.environ.get("NW_MCP_TRANSPORT", "stdio").strip().lower() or "stdio" + return transport if transport in {"stdio", "streamable-http"} else "stdio" @router.post("/agent-chat", response_model=AgentChatResponse) @@ -1147,8 +1856,11 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: import sys trace_id = str(uuid.uuid4()) - logger.info("Agent Chat request | trace_id=%s | provider=%s", - trace_id, os.environ.get("LLM_PROVIDER", "groq")) + logger.info( + "Agent Chat request | trace_id=%s | provider=%s", + trace_id, + os.environ.get("LLM_PROVIDER", "groq"), + ) if not payload.message.strip(): return AgentChatResponse( @@ -1159,39 +1871,29 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: ) try: - from agents.llm_factory import LLMProviderFactory, LLMMessage + from agents.llm_factory import LLMProviderFactory from agents.toolhive import ( MultiMcpClient, ToolHiveAgent, ToolHiveMcpClient, StdioMcpClient, resolve_mcp_urls, + resolve_max_tool_failures, ) provider_name = os.environ.get("LLM_PROVIDER", "groq") logger.info("Agent Chat | creating LLM provider: %s", provider_name) llm_provider = LLMProviderFactory.create_from_env() - # Build the task from the conversation history + current message - # The agent will see the full context - history_text_parts = [] - for msg in payload.history: - role = msg.get("role", "user") - content = msg.get("content", "") - history_text_parts.append(f"{role.upper()}: {content}") - - if history_text_parts: - task = ( - "Previous conversation:\n" - + "\n".join(history_text_parts) - + f"\n\nUSER (latest): {payload.message}" - ) - else: - task = payload.message + task = _build_agent_chat_task(payload) # Determine MCP transport — try proxy first, fallback to local stdio - urls = resolve_mcp_urls() + transport = _current_agent_transport() + urls = resolve_mcp_urls() if transport == "streamable-http" else [] run_result = None + fallback_to_stdio = ( + os.environ.get("PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO", "false").lower() == "true" + ) if urls: logger.info("Agent Chat | trying ToolHive proxy URL(s): %s", ",".join(urls)) @@ -1200,7 +1902,12 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: mcp_client = ToolHiveMcpClient(urls[0]) else: mcp_client = MultiMcpClient([ToolHiveMcpClient(u) for u in urls]) - agent = ToolHiveAgent(mcp_client, llm_provider, max_steps=10) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) agent._system_prompt = AGENT_GUARDRAIL_PROMPT run_result = await agent.run(task) # Fallback to local stdio if: @@ -1208,38 +1915,71 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: # (b) agent "succeeded" but called zero tools (LLM gave up because # only a subset of tools was discoverable via the proxy) proxy_incomplete = ( - (not run_result.success and run_result.error and ( + not run_result.success + and run_result.error + and ( "Failed to list MCP tools" in run_result.error or "not in request.tools" in run_result.error - )) - or (run_result.success and not run_result.steps) + ) ) if proxy_incomplete: - logger.warning("Agent Chat | proxy incomplete, falling back to local stdio") - run_result = None + if fallback_to_stdio: + logger.warning("Agent Chat | proxy incomplete, falling back to local stdio") + run_result = None + else: + logger.warning( + "Agent Chat | proxy incomplete, returning proxy error to UI " + "(set PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=true to fallback)" + ) except Exception as proxy_err: - logger.warning("Agent Chat | proxy error: %s — falling back to local stdio", proxy_err) - run_result = None + if fallback_to_stdio: + logger.warning( + "Agent Chat | proxy error: %s — falling back to local stdio", proxy_err + ) + run_result = None + else: + logger.warning( + "Agent Chat | proxy error: %s — returning error to UI " + "(set PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=true to fallback)", + proxy_err, + ) + return AgentChatResponse( + reply=f"MCP proxy error: {proxy_err}", + steps=[], + trace_id=trace_id, + success=False, + ) if run_result is None: # Use local stdio transport logger.info("Agent Chat | using local stdio MCP transport") cmd = [sys.executable, "-m", "agents.mcp_entrypoint"] async with StdioMcpClient(cmd) as mcp_client: - agent = ToolHiveAgent(mcp_client, llm_provider, max_steps=10) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) agent._system_prompt = AGENT_GUARDRAIL_PROMPT run_result = await agent.run(task) # Map agent steps to response format chat_steps = [] for s in run_result.steps: - chat_steps.append(AgentChatStepResponse( - tool=s.tool_called or "unknown", - args=s.tool_args, - result=s.tool_result, - )) + chat_steps.append( + AgentChatStepResponse( + tool=s.tool_called or "unknown", + args=s.tool_args, + result=s.tool_result, + ) + ) - reply = run_result.final_answer or run_result.error or "I encountered an issue. Please try again." + reply = ( + run_result.final_answer + or run_result.error + or "I encountered an issue. Please try again." + ) return AgentChatResponse( reply=reply, @@ -1256,3 +1996,780 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: trace_id=trace_id, success=False, ) + + +@router.post("/agent-chat-stream") +async def agent_chat_stream(payload: AgentChatInput) -> Any: + """ + Stream agent progress and final-answer chunks to web clients. + + The terminal ``done`` event includes ``trace_id`` and ``message``. Clients + should stop their streaming loader only when that event arrives. + """ + + async def stream_events(): + try: + import sys + + from agents.llm_factory import LLMProviderFactory + from agents.toolhive import ( + MultiMcpClient, + StdioMcpClient, + ToolHiveAgent, + ToolHiveMcpClient, + resolve_mcp_urls, + resolve_max_tool_failures, + ) + + if not payload.message.strip(): + trace_id = str(uuid.uuid4()) + yield ( + json.dumps( + { + "type": "final_chunk", + "content": "Please type a message to get started.", + } + ) + + "\n" + ) + yield ( + json.dumps( + { + "type": "done", + "trace_id": trace_id, + "success": False, + "message": f"Streaming failed. trace_id={trace_id}", + } + ) + + "\n" + ) + return + + llm_provider = LLMProviderFactory.create_from_env() + task = _build_agent_chat_task(payload) + transport = _current_agent_transport() + urls = resolve_mcp_urls() if transport == "streamable-http" else [] + + if urls: + if len(urls) == 1: + mcp_client = ToolHiveMcpClient(urls[0]) + else: + mcp_client = MultiMcpClient([ToolHiveMcpClient(u) for u in urls]) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) + agent._system_prompt = AGENT_GUARDRAIL_PROMPT + async for event in agent.run_events(task): + yield json.dumps(event) + "\n" + return + + cmd = [sys.executable, "-m", "agents.mcp_entrypoint"] + async with StdioMcpClient(cmd) as mcp_client: + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) + agent._system_prompt = AGENT_GUARDRAIL_PROMPT + async for event in agent.run_events(task): + yield json.dumps(event) + "\n" + + except Exception as exc: + logger.error("Agent Chat stream failed: %s", exc, exc_info=True) + trace_id = str(uuid.uuid4()) + yield ( + json.dumps( + { + "type": "final_chunk", + "content": f"Sorry, I encountered an error: {exc}. Please check the server configuration and try again.", + } + ) + + "\n" + ) + yield ( + json.dumps( + { + "type": "done", + "trace_id": trace_id, + "success": False, + "message": f"Streaming failed. trace_id={trace_id}", + } + ) + + "\n" + ) + + return StreamingResponse(stream_events(), media_type="application/x-ndjson") + + +@router.post("/salesforce-create-lead", response_model=ScenarioResponse) +async def salesforce_create_lead_scenario( + payload: SalesforceLeadInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + add_step("Create Lead", "pending", display_name="Create Salesforce Lead") + + sf_input = CreateLeadInput( + LastName=payload.last_name, + Company=payload.company, + FirstName=payload.first_name, + Email=payload.email, + Status=payload.status, + ) + + try: + res = await execute_with_retry(connector, sf_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Lead record created" + steps[-1].data = {"resource_id": res.resource_id, "raw": res.data} + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=res.resource_id, + human_summary=f"Salesforce Lead created successfully with ID: {res.resource_id}", + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Lead creation failed") + + +@router.post("/salesforce-create-contact", response_model=ScenarioResponse) +async def salesforce_create_contact_scenario( + payload: SalesforceContactInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + add_step("Create Contact", "pending", display_name="Create Salesforce Contact") + + sf_input = CreateContactInput( + LastName=payload.last_name, + FirstName=payload.first_name, + Email=payload.email, + AccountId=payload.account_id, + ) + + try: + res = await execute_with_retry(connector, sf_input, trace_id, steps[-1]) + steps[-1].status = "success" + steps[-1].details = "Contact record created" + steps[-1].data = {"resource_id": res.resource_id, "raw": res.data} + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=res.resource_id, + human_summary=f"Salesforce Contact created successfully with ID: {res.resource_id}", + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Contact creation failed") + + +@router.post("/salesforce-read-lead", response_model=ScenarioResponse) +async def salesforce_read_lead_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + + add_step("Read Lead", "pending", "Fetching Lead Details") + try: + res = await execute_with_retry( + connector, ReadLeadInput(record_id=payload.record_id), trace_id, steps[-1] + ) + steps[-1].status = "success" + steps[-1].details = "Lead data retrieved" + steps[-1].data = res.data + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + human_summary=f"Lead data retrieved for {payload.record_id}", + final_resource_id=payload.record_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Read failed") + + +@router.post("/salesforce-update-lead", response_model=ScenarioResponse) +async def salesforce_update_lead_scenario( + payload: SalesforceUpdateLeadInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + + add_step("Update Lead", "pending", "Updating Lead Record") + fields = {k: v for k, v in payload.model_dump().items() if v is not None and k != "record_id"} + # Map to SF internal names + sf_fields = {} + if "first_name" in fields: + sf_fields["FirstName"] = fields["first_name"] + if "last_name" in fields: + sf_fields["LastName"] = fields["last_name"] + if "company" in fields: + sf_fields["Company"] = fields["company"] + if "email" in fields: + sf_fields["Email"] = fields["email"] + + try: + res = await execute_with_retry( + connector, + UpdateLeadInput(record_id=payload.record_id, fields=sf_fields), + trace_id, + steps[-1], + ) + steps[-1].status = "success" + steps[-1].details = "Lead updated" + # Salesforce PATCH returns 204 No Content, so we show the sent fields as confirmation + steps[-1].data = { + "record_id": payload.record_id, + "updated_fields": sf_fields, + "raw": res.data, + } + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Lead {payload.record_id} updated successfully.", + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Update failed") + + +@router.post("/salesforce-delete-lead", response_model=ScenarioResponse) +async def salesforce_delete_lead_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + + add_step("Delete Lead", "pending", "Removing Lead Record") + try: + await execute_with_retry( + connector, DeleteLeadInput(record_id=payload.record_id), trace_id, steps[-1] + ) + steps[-1].status = "success" + steps[-1].details = "Lead deleted" + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Lead {payload.record_id} deleted.", + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Delete failed") + + +@router.post("/salesforce-read-contact", response_model=ScenarioResponse) +async def salesforce_read_contact_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + + add_step("Read Contact", "pending", "Fetching Contact Details") + try: + res = await execute_with_retry( + connector, ReadContactInput(record_id=payload.record_id), trace_id, steps[-1] + ) + steps[-1].status = "success" + steps[-1].details = "Contact data retrieved" + steps[-1].data = res.data + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + human_summary=f"Contact data retrieved for {payload.record_id}", + final_resource_id=payload.record_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Read failed") + + +@router.post("/salesforce-update-contact", response_model=ScenarioResponse) +async def salesforce_update_contact_scenario( + payload: SalesforceUpdateContactInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + + add_step("Update Contact", "pending", "Updating Contact Record") + fields = {k: v for k, v in payload.model_dump().items() if v is not None and k != "record_id"} + sf_fields = {} + if "first_name" in fields: + sf_fields["FirstName"] = fields["first_name"] + if "last_name" in fields: + sf_fields["LastName"] = fields["last_name"] + if "email" in fields: + sf_fields["Email"] = fields["email"] + if "account_id" in fields: + sf_fields["AccountId"] = fields["account_id"] + + try: + res = await execute_with_retry( + connector, + UpdateContactInput(record_id=payload.record_id, fields=sf_fields), + trace_id, + steps[-1], + ) + steps[-1].status = "success" + steps[-1].details = "Contact updated" + # Salesforce PATCH returns 204 No Content, so we show the sent fields as confirmation + steps[-1].data = { + "record_id": payload.record_id, + "updated_fields": sf_fields, + "raw": res.data, + } + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Contact {payload.record_id} updated successfully.", + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Update failed") + + +@router.post("/salesforce-delete-contact", response_model=ScenarioResponse) +async def salesforce_delete_contact_scenario( + payload: SalesforceGenericIdInputPlayground, + connector: SalesforceConnector = Depends(get_salesforce_connector), +) -> ScenarioResponse: + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + + def add_step(name, status, display_name): + steps.append(ScenarioStep(name=name, status=status, display_name=display_name)) + + add_step("Delete Contact", "pending", "Removing Contact Record") + try: + await execute_with_retry( + connector, DeleteContactInput(record_id=payload.record_id), trace_id, steps[-1] + ) + steps[-1].status = "success" + steps[-1].details = "Contact deleted" + return ScenarioResponse( + success=True, + trace_id=trace_id, + steps=steps, + final_resource_id=payload.record_id, + human_summary=f"Contact {payload.record_id} deleted.", + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Delete failed") + + +# --------------------------------------------------------------------------- +# External Patient Viewer — Read-Only Retrieval +# --------------------------------------------------------------------------- + + +def _get_viewer_connector(source_system: str) -> Any: + """Return the correct FHIR connector based on source_system string.""" + if source_system.lower() == "cerner": + return get_cerner_connector() + return get_fhir_connector() + + +@router.post("/external-patient-viewer", response_model=ScenarioResponse) +async def external_patient_viewer_scenario( + payload: ExternalPatientViewerInput, +) -> ScenarioResponse: + """ + 4-step read-only workflow: resolve patient identity, retrieve demographics, + retrieve encounter history, retrieve document metadata. + + No FHIR resource is created or mutated during this workflow. + Encounter-as-document fallback is applied when document_references are absent. + """ + trace_id = str(uuid.uuid4()) + steps: List[ScenarioStep] = [] + is_epic = payload.source_system.lower() != "cerner" + + def add_step( + name: str, status: str, details: str = "", display_name: str = "", data: Any = None + ): + steps.append( + ScenarioStep( + name=name, status=status, details=details, display_name=display_name, data=data + ) + ) + + connector = _get_viewer_connector(payload.source_system) + system_label = "Epic FHIR R4" if is_epic else "Cerner FHIR R4" + + # ── STEP 1: Patient Resolution ────────────────────────────────────────── + add_step("Patient Resolution", "pending", display_name="Resolve Patient Identity") + try: + if payload.patient_id: + logger.info( + "[ExtViewer] Direct Patient ID lookup: %s on %s", + payload.patient_id, + system_label, + extra={"trace_id": trace_id}, + ) + if is_epic: + p_res = await execute_with_retry( + connector, + FhirPatientReadInput(resource_id=payload.patient_id), + trace_id, + steps[-1], + ) + else: + p_res = await execute_with_retry( + connector, + FhirCernerPatientReadInput(resource_id=payload.patient_id), + trace_id, + steps[-1], + ) + patient_id = payload.patient_id + patient_resource = p_res.resource or {} + else: + # Identity-layer search: resolve via name + birthdate + if not (payload.patient_family or payload.patient_given): + raise ValueError( + "Provide either patient_id or at least one name field (given/family) " + "to resolve patient identity." + ) + search_params = { + k: v + for k, v in { + "family": payload.patient_family, + "given": payload.patient_given, + "birthdate": payload.patient_birthdate, + }.items() + if v + } + logger.info( + "[ExtViewer] Identity-layer search: %s on %s", + search_params, + system_label, + extra={"trace_id": trace_id}, + ) + if is_epic: + p_res = await execute_with_retry( + connector, + FhirPatientReadInput(search_params=search_params), + trace_id, + steps[-1], + ) + else: + p_res = await execute_with_retry( + connector, + FhirCernerPatientReadInput(search_params=search_params), + trace_id, + steps[-1], + ) + patient_resource = p_res.resource or {} + patient_id = patient_resource.get("id") + + if not patient_id: + raise ValueError("Patient could not be resolved. No matching record found.") + + # Extract display name from FHIR resource + name_obj = patient_resource.get("name", [{}]) + if name_obj and isinstance(name_obj, list): + official = next((n for n in name_obj if n.get("use") == "official"), name_obj[0]) + else: + official = {} + given_parts = official.get("given", []) + family_part = official.get("family", "") + patient_display = f"{' '.join(given_parts)} {family_part}".strip() or ( + f"{payload.patient_given or ''} {payload.patient_family or ''}".strip() or patient_id + ) + patient_dob = patient_resource.get("birthDate", "Unknown") + patient_gender = patient_resource.get("gender", "Unknown") + + steps[-1].status = "success" + steps[-1].details = f"Resolved: {patient_display} (ID: {patient_id})" + steps[-1].display_name = f"Identity Resolved: {patient_display}" + steps[-1].data = { + "patient_id": patient_id, + "display_name": patient_display, + "dob": patient_dob, + "gender": patient_gender, + "source_system": system_label, + "raw": patient_resource, + } + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 1 — Patient Resolution failed") + + # ── STEP 2: Encounter History ──────────────────────────────────────────── + add_step("Encounter History", "pending", display_name="Retrieve Encounter History") + encounters: List[Any] = [] + try: + max_enc = max(1, min(20, payload.max_encounters)) + + if is_epic: + enc_res = await execute_with_retry( + connector, + FhirEncounterSearchInput( + search_params={"patient": patient_id, "_count": str(max_enc)} + ), + trace_id, + steps[-1], + ) + else: + enc_res = await execute_with_retry( + connector, + FhirCernerEncounterSearchInput( + search_params={"patient": patient_id, "_count": str(max_enc)} + ), + trace_id, + steps[-1], + ) + + encounters = enc_res.resources or [] + enc_count = len(encounters) + + most_recent_enc: dict = {} + if encounters: + most_recent_enc = encounters[0] + recent_enc_type = ( + (most_recent_enc.get("type") or [{}])[0].get("text", "Encounter") + if most_recent_enc + else "None" + ) + recent_enc_date = ( + most_recent_enc.get("period", {}).get("start", "Unknown date") + if most_recent_enc + else "N/A" + ) + + steps[-1].status = "success" + steps[-1].details = ( + f"Retrieved {enc_count} encounter(s). " + f"Most recent: {recent_enc_type} on {recent_enc_date}" + if enc_count + else "No encounters found for this patient." + ) + steps[-1].display_name = ( + f"Encounter History: {enc_count} record(s)" if enc_count else "No Encounters Found" + ) + steps[-1].data = { + "encounter_count": enc_count, + "encounters": encounters, + "raw": {"total": enc_count, "entries": encounters}, + } + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 2 — Encounter Retrieval failed") + + # ── STEP 3: Document Metadata ──────────────────────────────────────────── + add_step("Document Metadata", "pending", display_name="Retrieve Document Metadata") + documents: List[Any] = [] + doc_source = "fhir" + try: + max_docs = max(1, min(50, payload.max_documents)) + + if is_epic: + doc_res = await execute_with_retry( + connector, + FhirDocumentReferenceSearchInput( + search_params={"patient": patient_id, "_count": str(max_docs)} + ), + trace_id, + steps[-1], + ) + else: + doc_res = await execute_with_retry( + connector, + FhirCernerDocumentReferenceSearchInput( + search_params={"patient": patient_id, "_count": str(max_docs)} + ), + trace_id, + steps[-1], + ) + + documents = doc_res.resources or [] + + # Encounter-as-document fallback: when no DocumentReference exists, + # synthesise a lightweight document record from each encounter entry. + if not documents and encounters: + doc_source = "encounter_fallback" + for enc in encounters[:max_docs]: + enc_id = enc.get("id", "unknown") + enc_type_text = (enc.get("type") or [{}])[0].get("text", "Clinical Encounter") + enc_date = enc.get("period", {}).get("start", "Unknown") + enc_status = enc.get("status", "unknown") + documents.append( + { + "id": f"ENC-{enc_id}", + "resourceType": "EncounterFallback", + "status": enc_status, + "type": {"text": enc_type_text}, + "date": enc_date, + "description": "Encounter summary (no DocumentReference found)", + "subject": {"reference": f"Patient/{patient_id}"}, + "_synthetic": True, + } + ) + logger.info( + "[ExtViewer] No DocumentReferences found; using %d encounter fallback record(s)", + len(documents), + extra={"trace_id": trace_id}, + ) + + doc_count = len(documents) + fallback_note = " (encounter-fallback)" if doc_source == "encounter_fallback" else "" + + steps[-1].status = "success" + steps[-1].details = ( + f"Retrieved {doc_count} document(s){fallback_note}." + if doc_count + else "No documents or encounters available for this patient." + ) + steps[-1].display_name = ( + f"Documents: {doc_count} record(s){fallback_note}" + if doc_count + else "No Documents Found" + ) + steps[-1].data = { + "document_count": doc_count, + "source": doc_source, + "documents": documents, + "raw": {"total": doc_count, "source": doc_source, "entries": documents}, + } + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 3 — Document Retrieval failed") + + # ── STEP 4: Viewer Assembly ────────────────────────────────────────────── + add_step("Chart Assembly", "pending", display_name="Assemble External Chart View") + try: + enc_lines = [] + for i, enc in enumerate(encounters[:5]): + enc_type = (enc.get("type") or [{}])[0].get("text", "Encounter") + enc_date = enc.get("period", {}).get("start", "Unknown") + enc_status = enc.get("status", "?") + enc_lines.append(f" [{i + 1}] {enc_type} | {enc_date} | Status: {enc_status}") + + doc_lines = [] + for i, doc in enumerate(documents[:5]): + d_type = doc.get("type", {}).get("text") or doc.get("description", "Document") + d_date = doc.get("date") or doc.get("period", {}).get("start", "Unknown") + d_status = doc.get("status", "?") + d_synth = " [enc-fallback]" if doc.get("_synthetic") else "" + doc_lines.append(f" [{i + 1}] {d_type}{d_synth} | {d_date} | Status: {d_status}") + + content_lines = ( + [ + f"=== External Patient Chart ({system_label}) ===", + f"Patient : {patient_display}", + f"FHIR ID : {patient_id}", + f"DOB : {patient_dob}", + f"Gender : {patient_gender}", + "", + f"--- Encounter History ({len(encounters)} record(s)) ---", + ] + + (enc_lines if enc_lines else [" No encounters found."]) + + [ + "", + f"--- Documents ({len(documents)} record(s)" + + (" — encounter fallback" if doc_source == "encounter_fallback" else "") + + ") ---", + ] + + (doc_lines if doc_lines else [" No documents found."]) + + [ + "", + "[READ-ONLY] No data was written to the source system.", + ] + ) + + beautiful_data = { + "id": f"CHART-{patient_id}", + "type": "External Patient Chart", + "date": datetime.now(tz=timezone.utc).isoformat(), + "status": "READ-ONLY", + "patient_name": patient_display, + "author": system_label, + "category": "Clinical Chart View", + "description": ( + f"{len(encounters)} Encounter(s) · " + f"{len(documents)} Document(s)" + + (" [enc-fallback]" if doc_source == "encounter_fallback" else "") + ), + "content_text": "\n".join(content_lines), + } + + steps[-1].status = "success" + steps[-1].details = ( + f"Chart assembled. {len(encounters)} encounter(s), " + f"{len(documents)} document(s). Read-only — 0 writes." + ) + steps[-1].display_name = "Chart Ready (Read-Only)" + steps[-1].data = { + "patient_id": patient_id, + "encounter_count": len(encounters), + "document_count": len(documents), + "document_source": doc_source, + "read_only": True, + "raw": { + "patient_id": patient_id, + "source_system": system_label, + "encounters": len(encounters), + "documents": len(documents), + "document_source": doc_source, + }, + "beautiful_data": beautiful_data, + } + + return ScenarioResponse( + success=True, + steps=steps, + final_resource_id=patient_id, + human_summary=( + f"External chart loaded for {patient_display} from {system_label}. " + f"{len(encounters)} encounter(s) and {len(documents)} document(s) retrieved. " + "No data was written to the source system." + ), + trace_id=trace_id, + ) + except Exception as e: + return _safe_error_return(e, steps, trace_id, "Step 4 — Chart Assembly failed") diff --git a/playground/style.css b/playground/style.css index 28f669b..934f763 100644 --- a/playground/style.css +++ b/playground/style.css @@ -1,3 +1,9 @@ +/* + * SPDX-FileCopyrightText: 2026 AOT Technologies + * + * SPDX-License-Identifier: Apache-2.0 + */ + :root { --brand-primary: #0f172a; --brand-accent: #2563eb; @@ -647,13 +653,14 @@ textarea:focus { .completion-icon { width: 2rem; height: 2rem; - background: rgba(16, 185, 129, 0.1); - color: var(--success); + background: var(--success); + color: white; border-radius: 50%; display: flex; align-items: center; justify-content: center; flex-shrink: 0; + box-shadow: 0 6px 14px rgba(16, 185, 129, 0.24); } .completion-icon svg { @@ -688,8 +695,9 @@ textarea:focus { } .completion-card.error-toast .completion-icon { - background: rgba(244, 63, 94, 0.1); - color: var(--error); + background: var(--error); + color: white; + box-shadow: 0 6px 14px rgba(244, 63, 94, 0.24); } /* Logs */ @@ -915,7 +923,7 @@ textarea:focus { margin-top: 1.5rem; } -.connector-card { +.connector-card, .app-card { background: var(--card-bg); backdrop-filter: blur(12px); border: 1px solid var(--border); @@ -931,7 +939,7 @@ textarea:focus { overflow: hidden; } -.connector-card:hover { +.connector-card:hover, .app-card:hover { transform: translateY(-5px); border-color: var(--brand-accent); box-shadow: 0 15px 30px rgba(0, 0, 0, 0.06); @@ -974,6 +982,109 @@ textarea:focus { border: 1px solid rgba(0,0,0,0.07); } +/* Slack: purple brand */ +.bg-slack { + background: #4A154B; +} + +/* Stripe: indigo-purple brand */ +.bg-stripe { + background: #635BFF; +} + +/* Salesforce: sky blue brand */ +.bg-salesforce { + background: #00A1E0; +} + +/* External Patient Viewer: teal read-only indicator */ +.bg-ext-viewer { + background: linear-gradient(135deg, #0d9488, #0891b2); +} + +/* Read-only badge used inside the chart viewer */ +.readonly-badge { + display: inline-flex; + align-items: center; + gap: 0.35rem; + background: rgba(13, 148, 136, 0.12); + color: #0d9488; + border: 1px solid rgba(13, 148, 136, 0.25); + padding: 0.2rem 0.6rem; + border-radius: 999px; + font-size: 0.7rem; + font-weight: 700; + text-transform: uppercase; + letter-spacing: 0.04em; +} + +/* Viewer scope controls (range slider row) */ +.viewer-scope-row { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1.5rem; + margin-bottom: 1.5rem; +} + +.scope-field { + display: flex; + flex-direction: column; + gap: 0.4rem; +} + +.scope-field label { + font-size: 0.875rem; + font-weight: 600; + color: var(--text-main); + display: flex; + justify-content: space-between; + align-items: center; +} + +.scope-field label span.scope-val { + font-size: 0.8rem; + color: #0d9488; + font-weight: 700; +} + +input[type="range"] { + -webkit-appearance: none; + appearance: none; + width: 100%; + height: 6px; + background: linear-gradient(to right, #0d9488 0%, #0d9488 var(--pct, 25%), #e2e8f0 var(--pct, 25%), #e2e8f0 100%); + border-radius: 999px; + outline: none; + border: none; + padding: 0; + cursor: pointer; +} + +input[type="range"]::-webkit-slider-thumb { + -webkit-appearance: none; + width: 18px; + height: 18px; + border-radius: 50%; + background: #0d9488; + cursor: pointer; + box-shadow: 0 0 0 3px rgba(13,148,136,0.2); + transition: box-shadow 0.2s; +} + +input[type="range"]::-webkit-slider-thumb:hover { + box-shadow: 0 0 0 5px rgba(13,148,136,0.3); +} + +/* Viewer action button teal variant */ +.action-btn.btn-viewer { + background: linear-gradient(135deg, #0d9488, #0891b2); +} + +.action-btn.btn-viewer:hover { + background: #0f172a; + box-shadow: 0 15px 30px -10px rgba(13, 148, 136, 0.4); +} + .connector-details h3 { font-family: 'Outfit', sans-serif; font-size: 1.25rem; @@ -1269,6 +1380,88 @@ textarea:focus { margin-top: 0.5rem; } +.transport-status-bar { + display: flex; + justify-content: flex-start; + align-items: center; + margin: -0.75rem 0 1rem; + padding: 0.7rem 0.75rem; + background: #f8fafc; + border: 1px solid #e2e8f0; + border-radius: 0.875rem; +} + +.transport-status-pill { + display: flex; + align-items: center; + gap: 0.55rem; + color: var(--brand-accent); + background: white; + border: 1px solid #e2e8f0; + border-radius: 999px; + padding: 0.5rem 0.85rem; + font-weight: 700; + font-size: 0.78rem; + box-shadow: 0 8px 18px rgba(15, 23, 42, 0.04); +} + +.transport-status-dot { + width: 0.55rem; + height: 0.55rem; + border-radius: 999px; + background: var(--success); + box-shadow: 0 0 0 4px rgba(16, 185, 129, 0.14); +} + +.streaming-bubble p { + white-space: pre-wrap; +} + +.stream-tail-loader { + display: inline-flex; + align-items: center; + gap: 0.45rem; + margin-top: 0.65rem; + color: var(--text-muted); + font-size: 0.78rem; + font-weight: 600; +} + +.stream-running-timer { + font-family: monospace; + font-size: 0.8rem; + font-weight: 700; + color: var(--brand-accent); + background: rgba(37, 99, 235, 0.08); + padding: 0.1rem 0.4rem; + border-radius: 4px; + margin-left: 0.25rem; +} + +.stream-end-message { + align-self: flex-start; + max-width: 85%; + font-family: monospace; + font-size: 0.68rem; + padding: 0.35rem 0.55rem; + border-radius: 0.45rem; + border: 1px solid rgba(59, 130, 246, 0.12); + background: rgba(59, 130, 246, 0.06); + color: var(--text-muted); +} + +.stream-end-message.success { + border-color: rgba(16, 185, 129, 0.18); + background: rgba(16, 185, 129, 0.08); + color: #047857; +} + +.stream-end-message.error { + border-color: rgba(244, 63, 94, 0.18); + background: rgba(244, 63, 94, 0.08); + color: var(--error); +} + /* Final Responsive Overrides */ @media (max-width: 1100px) { .playground-layout { @@ -1341,7 +1534,11 @@ textarea:focus { .bg-agent { background: linear-gradient(135deg, #8B5CF6, #7C3AED); } .bg-connectors { background: linear-gradient(135deg, #2563EB, #1D4ED8); } - +.bg-epic { background: #2B5BE0; } +.bg-itops { background: #6366F1; } +.bg-cerner { background: #C74634; } +.bg-gdrive { background: #ffffff; border: 1px solid rgba(0,0,0,0.07); } +.bg-slack { background: #4A154B; } .selection-details h3 { font-family: "Outfit", sans-serif; font-size: 1.75rem; @@ -1450,13 +1647,22 @@ textarea:focus { flex-direction: column; align-items: center; justify-content: center; - min-height: 80vh; - padding: 2rem; - gap: 4rem; + min-height: calc(100vh - 120px); + padding: 1.5rem 0 3rem; + gap: 2.5rem; /* background: radial-gradient(circle at 10% 20%, rgba(139, 92, 246, 0.05) 0%, transparent 40%), */ /* radial-gradient(circle at 90% 80%, rgba(37, 99, 235, 0.05) 0%, transparent 40%); */ } +.apps-selection-view { + justify-content: flex-start; + padding-top: 4rem; +} + +.apps-selection-view .selection-grid { + justify-content: flex-start; +} + .selection-welcome h1 { font-size: 3.5rem; font-weight: 700; @@ -1467,22 +1673,25 @@ textarea:focus { .selection-grid { display: grid; - grid-template-columns: repeat(2, 1fr); - gap: 3rem; + grid-template-columns: repeat(3, minmax(280px, 1fr)); + align-items: stretch; + gap: 2rem; width: 100%; - max-width: 900px; + max-width: 1240px; } .selection-card { + width: 100%; background: white; border: 1px solid #e2e8f0; - border-radius: 20px; + border-radius: 1.8rem; cursor: pointer; transition: all 0.5s cubic-bezier(0.4, 0, 0.2, 1); - box-shadow: 0 20px 50px rgba(0, 0, 0, 0.04); + box-shadow: 0 18px 48px rgba(148, 163, 184, 0.18); position: relative; overflow: hidden; display: flex; + min-height: 356px; } .card-inner { @@ -1490,53 +1699,72 @@ textarea:focus { display: flex; flex-direction: column; align-items: center; - padding: 3rem 2rem 0; + justify-content: center; + padding: 2.2rem 2rem 2rem; } .selection-card:hover { transform: translateY(-8px); - box-shadow: 0 40px 80px rgba(0, 0, 0, 0.08); + box-shadow: 0 30px 70px rgba(148, 163, 184, 0.28); } .selection-icon { - width: 80px; - height: 80px; + width: 64px; + height: 64px; border-radius: 50%; display: flex; align-items: center; justify-content: center; - margin-bottom: 2rem; + margin-bottom: 1.25rem; transition: all 0.3s; } -.card-mcp .selection-icon { background: #f5f3ff; color: #8b5cf6; } -.card-connectors .selection-icon { background: #eff6ff; color: #2563eb; } +.card-mcp .selection-icon { background: #f3e8ff; color: #7c3aed; } +.card-connectors .selection-icon { background: #e0f2fe; color: #0284c7; } +.card-ext-viewer .selection-icon, +.card-apps-directory .selection-icon { background: #ccfbf1; color: #0d9488; } + +.selection-details { + text-align: center; + width: 100%; + display: flex; + flex-direction: column; + align-items: center; +} .selection-details h3 { - font-size: 1.75rem; + font-size: 1.5rem; font-family: "Outfit", sans-serif; color: #1e293b; - margin-bottom: 0.75rem; + margin-bottom: 0.7rem; + line-height: 1.2; } .selection-details p { - font-size: 1rem; + font-size: 0.92rem; color: #64748b; - margin-bottom: 2.5rem; /* Space before action bar */ + margin-bottom: 1.4rem; + line-height: 1.4; + max-width: 310px; + min-height: 2.6em; + padding: 0 0.5rem; } .action-bar { - width: calc(100% + 4rem); - margin: 0 -2rem; - height: 80px; + width: 100%; + max-width: 240px; + height: 60px; display: flex; align-items: center; justify-content: center; transition: all 0.3s; + margin-top: auto; } -.card-mcp .action-bar { background: #f5f3ff; color: #8b5cf6; } -.card-connectors .action-bar { background: #eff6ff; color: #2563eb; } +.card-mcp .action-bar { background: #f3e8ff; color: #7c3aed; } +.card-connectors .action-bar { background: #e0f2fe; color: #0284c7; } +.card-ext-viewer .action-bar, +.card-apps-directory .action-bar { background: #ccfbf1; color: #0d9488; } .selection-card:hover .action-bar { filter: brightness(0.95); @@ -1556,9 +1784,26 @@ textarea:focus { .card-mcp:hover { border-color: #8b5cf6; } .card-connectors { border: 2px solid transparent; } .card-connectors:hover { border-color: #2563eb; } +.card-ext-viewer, +.card-apps-directory { border: 2px solid transparent; } +.card-ext-viewer:hover, +.card-apps-directory:hover { border-color: #0d9488; } @media (max-width: 900px) { - .selection-grid { grid-template-columns: 1fr; } + .root-selection-view { + min-height: auto; + padding-top: 1rem; + } + + .selection-grid { + grid-template-columns: 1fr; + max-width: 420px; + } + + .selection-card { + min-height: 332px; + } + .selection-welcome h1 { font-size: 2.5rem; } } @@ -1600,4 +1845,53 @@ textarea:focus { .chat-reset-btn svg { width: 16px; height: 16px; -} \ No newline at end of file +} + +.transport-status-bar { + display: flex; + justify-content: flex-start; + align-items: center; + gap: 1rem; + margin: -0.75rem 0 1rem; + padding: 0.7rem 0.75rem; + background: #f8fafc; + border: 1px solid #e2e8f0; + border-radius: 0.875rem; +} + +.transport-status-pill { + display: flex; + align-items: center; + gap: 0.55rem; + color: var(--brand-accent); + background: white; + border: 1px solid #e2e8f0; + border-radius: 999px; + padding: 0.5rem 0.85rem; + font-weight: 700; + font-size: 0.78rem; + box-shadow: 0 8px 18px rgba(15, 23, 42, 0.04); +} + +.transport-status-dot { + width: 0.55rem; + height: 0.55rem; + border-radius: 999px; + background: var(--success); + box-shadow: 0 0 0 4px rgba(16, 185, 129, 0.14); +} + +.streaming-bubble p { + white-space: pre-wrap; +} + +@media (max-width: 768px) { + .transport-status-bar { + align-items: stretch; + flex-direction: column; + } + + .transport-status-pill { + justify-content: center; + } +} diff --git a/pyproject.toml b/pyproject.toml index c275864..2784685 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +# Copyright 2026 AOT Technologies +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 [project] name = "node-wire" version = "0.1.0" @@ -42,6 +46,11 @@ nw-smartonfhir-cerner = "agents.fhir_cerner_mcp:main" dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", + "pytest-cov>=5.0.0", + "ruff>=0.3.5", + "mypy>=1.9.0", + "bandit[toml]>=1.7.9", + "pre-commit>=4.0.0", ] agents = [ "mcp>=1.6.0", # Official Python MCP SDK (includes FastMCP) @@ -51,6 +60,20 @@ agents = [ "anthropic>=0.28.0", # Claude SDK ] +# Development entry points — register all connectors so auto_register() works +# in the dev editable install (pip install -e .). +# Each published connector package also declares its own entry point in its +# packages/connectors/{name}/pyproject.toml. +[project.entry-points."node_wire.connectors"] +http_generic = "node_wire_http_generic.logic" +smtp = "node_wire_smtp.logic" +stripe = "node_wire_stripe.logic" +google_drive = "node_wire_google_drive.logic" +fhir_epic = "node_wire_fhir_epic.logic" +fhir_cerner = "node_wire_fhir_cerner.logic" +salesforce = "node_wire_salesforce.logic" +slack = "node_wire_slack.logic" + [tool.setuptools.packages.find] where = ["src"] @@ -58,3 +81,68 @@ where = ["src"] requires = ["setuptools>=69.0.0", "wheel"] build-backend = "setuptools.build_meta" +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-asyncio>=1.3.0", + "pytest-cov>=7.1.0", + "pip-licenses>=5.0.0", + "bandit>=1.7.0", + "pip-audit>=2.7.0", + "ruff>=0.3.5", + "mypy>=1.9.0", + "bandit[toml]>=1.7.9", + "pre-commit>=4.0.0", + "pytest-playwright>=0.4.0", +] + +[tool.uv] +default-groups = ["dev"] + +[tool.pytest.ini_options] +pythonpath = ["src", "."] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +addopts = [ + "--cov=src", + "--cov-report=term-missing", + "--cov-report=html:htmlcov", + "--cov-report=xml:coverage.xml", + "--ignore=tests/playground", +] + +[tool.coverage.run] +source = ["src"] +branch = false + +[tool.coverage.report] +omit = [ + "*/__pycache__/*", + "*/node_wire.egg-info/*", +] + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.mypy] +python_version = "3.11" +strict = false +ignore_missing_imports = true +# Default targets when invoking `mypy` with no paths (CI/pre-commit). Avoids repo-root `mypy .` +# pulling in packages/*/setup.py (duplicate module name "setup"). Packaging glue is excluded below. +files = ["src"] +exclude = [ + ".*packages/.*/setup\\.py", +] +[[tool.mypy.overrides]] +module = "playground.*" +ignore_errors = true + +# SAST: scan the editable-install tree under src/ (runtime, bindings, bundled +# connector logic). Publishable connector packages under packages/connectors/* +# are additionally covered by pip-audit in .github/workflows/security-pr.yml. +[tool.bandit] +targets = ["src"] +exclude_dirs = [".venv", "venv", "tests", "playground", "dist", "htmlcov"] +skips = ["B101"] diff --git a/sample.env b/sample.env index 996e670..5f2537b 100644 --- a/sample.env +++ b/sample.env @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 AOT Technologies +# +# SPDX-License-Identifier: Apache-2.0 + # Epic FHIR EPIC_FHIR_BASE_URL=https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4 EPIC_TOKEN_URL=https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token @@ -24,13 +28,47 @@ SMTP_USERNAME=your-email@gmail.com SMTP_PASSWORD=your-gmail-app-password # Stripe (optional / legacy demo) -stripe_api_key=sk_test_your_key_here +STRIPE_API_KEY=sk_test_your_key_here + +# Slack +NW_SLACK_API_BASE_URL =https://slack.com/api +NW_SLACK_SKIP_RESOLVE=true +# Bot Token from https://api.slack.com/apps (Bot Token Scopes: chat:write, files:write, im:write) +SLACK_BOT_TOKEN=xoxb-your-token-here +# Optional: sandboxed directory for filesystem-based file uploads (default /slack_attachments) +# NW_SLACK_ATTACHMENTS_DIR=/slack_attachments +# Optional: per-file upload size cap in MB (default 50, hard max 100) +# NW_SLACK_UPLOAD_LIMIT_MB=50 # ToolHive # Single-server (backward compatible) TOOLHIVE_MCP_URL=http://localhost:PORT/mcp # Multi-server (preferred for per-connector MCP servers) TOOLHIVE_MCP_URLS= +# Optional MCP auth credentials sent by the ToolHive client to MCP server +# TOOLHIVE_MCP_API_KEY=replace-with-your-mcp-api-key +# TOOLHIVE_MCP_BEARER_TOKEN=replace-with-jwt-or-api-key +# When false (recommended for demos), proxy errors are returned to UI directly. +# Set true to allow proxy failure fallback to local stdio MCP. +PLAYGROUND_AGENT_PROXY_FALLBACK_TO_STDIO=false +# Cap MCP tool JSON size sent back to the LLM (Groq on-demand TPM); default 12000 +# TOOLHIVE_MAX_TOOL_RESULT_CHARS=12000 + +# Stream buffering window in milliseconds (default: 0 = no buffering). +# Set to e.g. 2000 for a 2-second batching window on streamed results. +NW_STREAM_BUFFER_MS=0 +# Native MCP Transport (for agents.mcp_entrypoint and per-connector MCP servers) +# ----------------------------------------------------------------------------- +# NW_MCP_TRANSPORT: Selects the communication layer. +# - stdio: (Default) Required for ToolHive proxying and Claude Desktop. +# - streamable-http: Native HTTP/SSE transport for direct web integration. +NW_MCP_TRANSPORT=streamable-http +NW_MCP_HOST=127.0.0.1 +NW_MCP_PATH=/mcp + +# NW_MCP_PORT: The port used only when NW_MCP_TRANSPORT=streamable-http. +# - Default: 8081 in local demos. Ensure it does not conflict with REST API (port 8000). +NW_MCP_PORT=8081 # LLM Provider LLM_PROVIDER=groq @@ -45,8 +83,81 @@ OPENAI_MODEL=gpt-4o-mini # Google Gemini (optional) GEMINI_API_KEY=your-gemini-api-key -GEMINI_MODEL=gemini-2.0-flash +GEMINI_MODEL=gemini-2.5-flash # Anthropic / Claude (optional) ANTHROPIC_API_KEY=your-anthropic-api-key ANTHROPIC_MODEL=claude-3-5-haiku-20241022 + +# MCP auth — set NW_MCP_AUTH_DISABLED=true only for local development (matches NW_REST_AUTH_DISABLED). +# For production, omit it or set false so MCP auth is enforced when NW_MCP_API_KEY / JWT is set. +NW_MCP_AUTH_DISABLED=true +NW_MCP_API_KEY=replace-with-strong-random-value +# API key scopes (JSON array or space/comma-separated). Empty = no scopes; use "*" only for explicit full access. +# Wildcard API keys intentionally bypass per-action scope checks. +# NW_MCP_API_KEY_SCOPES=["mcp:smtp.send_email","mcp:http_generic.request"] +# NW_MCP_API_KEY_SCOPES=mcp:smtp.send_email mcp:http_generic.request +NW_MCP_JWT_SECRET=replace-with-hs256-secret +# Optional per-tool scope map; when unset, scope enforcement uses default mode below. +# NW_MCP_ACTION_SCOPE_MAP_JSON={"smtp.send_email":"mcp:smtp.send_email"} +# Recommended production baseline (closed by default): +NW_MCP_SCOPE_POLICY_DEFAULT=deny +# Optional strict guardrail: fail startup if scope policy would be effectively disabled. +# NW_MCP_SCOPE_POLICY_STRICT=true +# Use allow only for local experimentation: +# NW_MCP_SCOPE_POLICY_DEFAULT=allow +# Example for FHIR + Google Drive policy gating: +# NW_MCP_ACTION_SCOPE_MAP_JSON={"fhir_epic.read_patient":"mcp:fhir.read_patient","fhir_cerner.read_patient":"mcp:fhir.read_patient","google_drive.files.upload":"mcp:gdrive.files.upload"} +# Scope hook is runtime-level. With current bindings, strict scope enforcement applies +# to identity-aware MCP/REST calls. gRPC enforcement is deferred until gRPC identity +# propagation is implemented. +# ToolHive bearer token is sent to MCP as Authorization + X-API-Key + _meta aliases. +# TOOLHIVE_MCP_BEARER_TOKEN= + +# REST auth for Playground demo (disable for local UI testing) +NW_REST_AUTH_DISABLED=true +NW_REST_LOAD_DOTENV=true +# REST API key scopes (same format as NW_MCP_API_KEY_SCOPES). Empty = no scopes unless JWT carries scopes. +# NW_REST_API_KEY_SCOPES=["mcp:smtp.send_email"] +# REST JWTs (NW_REST_JWT_SECRET): claims sub, tenant_id, scopes propagate to connector.run(..., principal, tenant_id, scopes) for ScopePolicyHook + +# MCP contract (optional; Google Drive legacy payload `action: "upload"`) +# NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=warn +# NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=reject + +NW_REST_RATE_LIMIT_ENABLED=true +NW_REST_RATE_LIMIT_MAX_REQUESTS=120 +NW_REST_RATE_LIMIT_WINDOW_SECONDS=60 +# Resilience & Timeout Configurations +# AOT_CONNECTOR_TIMEOUT=30.0 +# AOT_CIRCUIT_BREAKER_FAIL_MAX=5 +# AOT_CIRCUIT_BREAKER_RESET_TIMEOUT=30 +# Plugin allowlist (fail-closed secure default) +# Add connector entry point names here to allow them to be loaded. +NW_ALLOWED_CONNECTORS=fhir_cerner,fhir_epic,google_drive,http_generic,salesforce,slack,smtp,stripe +# Salesforce CRM +SALESFORCE_INSTANCE_URL=https://your-instance.my.salesforce.com +SALESFORCE_TOKEN_URL=https://login.salesforce.com/services/oauth2/token +SALESFORCE_CLIENT_ID=your-client-id +SALESFORCE_CLIENT_SECRET=your-client-secret +SALESFORCE_REFRESH_TOKEN=your-refresh-token + + +# Playwright playground headed execution - set to "true" to view the browser and its activities +PLAYGROUND_HEADED=false + +# ----------------------------------------------------------------------------- +# Playground Integration Tests (pytest tests/playground/) +# ----------------------------------------------------------------------------- + +# Google Drive Playground Tests +GDRIVE_TEST_RECIPIENT_EMAIL=your-gdrive-test-recipient-email + +# Stripe Playground Tests +STRIPE_TEST_CUSTOMER_ID=your-stripe-test-customer-id +STRIPE_TEST_PRICE_ID=your-stripe-test-price-id + +# Slack Playground Tests +SLACK_TEST_CHANNEL=#your-slack-test-channel +SLACK_TEST_USER_ID=your-slack-test-user-id +SLACK_TEST_CHANNEL_ID=your-slack-test-channel-id diff --git a/scripts/add-license-headers.sh b/scripts/add-license-headers.sh new file mode 100644 index 0000000..a3fd32f --- /dev/null +++ b/scripts/add-license-headers.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +# +# Adds Apache 2.0 copyright headers to all applicable files in the repository. + +set -e + +# Run licenseheaders tool from the root of the repository +cd "$(dirname "$0")/.." + +echo "Applying Apache 2.0 copyright headers..." + +uv run licenseheaders \ + -t .copyright.tmpl \ + -d . \ + -E .py .sh .yml .yaml .toml Dockerfile .proto .sample \ + --additional-extensions script=.toml script=Dockerfile script=.sample c=.proto \ + -x ".git/*" ".venv/*" "*/__pycache__/*" "packages/*/dist/*" "packages/*/build/*" "*.egg-info/*" "htmlcov/*" ".pytest_cache/*" ".ruff_cache/*" ".mypy_cache/*" + +echo "Done!" diff --git a/scripts/bandit_report_summary.py b/scripts/bandit_report_summary.py new file mode 100644 index 0000000..2feb2cb --- /dev/null +++ b/scripts/bandit_report_summary.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""Print a concise Bandit JSON report summary for CI logs (always exits 0). + +Bandit exits with a non-zero status when *any* severity finding exists, even if +the separate CI gate only enforces `--severity-level high`. Use `--exit-zero` +when generating JSON, then run this script to surface counts and a short list +without failing the job. +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Any + + +def _load_report(path: Path) -> dict[str, Any]: + if not path.is_file(): + print(f"ERROR: Bandit report not found: {path}", file=sys.stderr) + sys.exit(2) + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (OSError, UnicodeDecodeError, json.JSONDecodeError) as e: + print(f"ERROR: Invalid Bandit JSON at {path}: {e}", file=sys.stderr) + sys.exit(2) + if not isinstance(data, dict): + print("ERROR: Bandit report root must be a JSON object", file=sys.stderr) + sys.exit(2) + return data + + +def main() -> None: + path = Path(sys.argv[1] if len(sys.argv) > 1 else "bandit-report.json") + data = _load_report(path) + + totals = data.get("metrics", {}).get("_totals", {}) + if not isinstance(totals, dict): + totals = {} + + def _int(key: str) -> int: + v = totals.get(key, 0) + return int(v) if isinstance(v, (int, float)) else 0 + + high = _int("SEVERITY.HIGH") + medium = _int("SEVERITY.MEDIUM") + low = _int("SEVERITY.LOW") + loc = _int("loc") + + results = data.get("results", []) + if not isinstance(results, list): + results = [] + + print("=== Bandit report summary ===") + print(f"Report: {path.resolve()}") + print(f"Lines scanned (approx): {loc}") + print(f"Findings by severity — HIGH: {high}, MEDIUM: {medium}, LOW: {low}") + print(f"Total result entries: {len(results)}") + print() + if results: + print("Findings (file:line [severity] test_id — short text):") + for r in results[:50]: + if not isinstance(r, dict): + continue + fn = r.get("filename", "?") + ln = r.get("line_number", "?") + sev = r.get("issue_severity", "?") + tid = r.get("test_id", "?") + text = str(r.get("issue_text", "")).replace("\n", " ")[:120] + print(f" {fn}:{ln} [{sev}] {tid} — {text}") + if len(results) > 50: + print(f" ... and {len(results) - 50} more (see full JSON artifact)") + else: + print("No findings in results[] (clean scan).") + print() + print( + "CI gate: the following step enforces high severity only " + "(`bandit ... --severity-level high`)." + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/build-mcp-images.sh b/scripts/build-mcp-images.sh index a3844c7..d193330 100755 --- a/scripts/build-mcp-images.sh +++ b/scripts/build-mcp-images.sh @@ -1,4 +1,8 @@ #!/usr/bin/env bash +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## set -euo pipefail ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" @@ -16,6 +20,8 @@ Images: - nw-smartonfhir-epic - nw-smartonfhir-cerner - nw-smtp + - nw-stripe + - nw-slack EOF } @@ -65,5 +71,20 @@ docker build -f docker/smtp/Dockerfile \ -t "nw-smtp:${VERSION}" \ . -echo "Done." +docker build -f docker/stripe/Dockerfile \ + -t nw-stripe:latest \ + -t "nw-stripe:${VERSION}" \ + . + +docker build -f docker/salesforce/Dockerfile \ + -t nw-salesforce:latest \ + -t "nw-salesforce:${VERSION}" \ + . + +docker build -f docker/slack/Dockerfile \ + -t nw-slack:latest \ + -t "nw-slack:${VERSION}" \ + . + +echo "Done." diff --git a/scripts/build-packages.sh b/scripts/build-packages.sh new file mode 100755 index 0000000..ed70352 --- /dev/null +++ b/scripts/build-packages.sh @@ -0,0 +1,300 @@ +#!/usr/bin/env bash +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +# build-packages.sh — Build Node Wire packages as binary-only wheels. +# +# Default mode (host + Linux via Docker): +# scripts/build-packages.sh +# scripts/build-packages.sh packages/runtime +# +# All-platform mode (local cibuildwheel; see notes below): +# scripts/build-packages.sh --all +# scripts/build-packages.sh --all packages/runtime +# +# Prerequisites (default mode): +# python3 or python on PATH; pip install build cython wheel +# docker (for Linux wheels) +# +# Prerequisites (--all mode): +# python -m pip install 'cibuildwheel>=2.16.0' +# +# Security guarantee: +# Each wheel is verified to contain zero .py source files before printing "PASS". +# Any leaked .py files trigger an exit 1. + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +ALL_PACKAGES=( + packages/runtime + packages/connectors/google_drive + packages/connectors/fhir_epic + packages/connectors/fhir_cerner + packages/connectors/smtp + packages/connectors/stripe + packages/connectors/salesforce + packages/connectors/http_generic + packages/connectors/slack +) + + +usage() { + cat <<'USAGE' +Usage: + scripts/build-packages.sh [--help] + scripts/build-packages.sh [packages/...] + scripts/build-packages.sh --all [packages/...] + + Default: build each package on the host and again in Docker (Linux wheels). + --all: build with cibuildwheel (targets depend on host; for full OS matrix use CI publish.yml). + +Examples: + scripts/build-packages.sh + scripts/build-packages.sh packages/connectors/smtp + scripts/build-packages.sh --all + scripts/build-packages.sh --all packages/runtime +USAGE +} + +ALL_MODE=0 +PACKAGES=() +while [[ $# -gt 0 ]]; do + case "$1" in + --all) + ALL_MODE=1 + shift + ;; + --help|-h) + usage + exit 0 + ;; + *) + PACKAGES+=("$1") + shift + ;; + esac +done + +if [[ ${#PACKAGES[@]} -eq 0 ]]; then + PACKAGES=("${ALL_PACKAGES[@]}") +fi + +# Verify wheels contain no .py files (binary-only wheels). First arg: python binary. +verify_wheels_no_py() { + local py="$1" + shift + local -a wheels=("$@") + local whl + local py_leak + local pkg_failed=0 + + for whl in "${wheels[@]}"; do + py_leak=$("$py" - "$whl" <<'PYCHECK' +import sys +import zipfile + +wheel_path = sys.argv[1] +with zipfile.ZipFile(wheel_path) as zf: + leaked = [name for name in zf.namelist() if name.endswith(".py")] + +if leaked: + print("\n".join(leaked)) + sys.exit(1) +PYCHECK + 2>&1) || { + echo "SECURITY FAIL: .py files leaked into $whl:" >&2 + echo "$py_leak" >&2 + pkg_failed=1 + break + } + done + return "$pkg_failed" +} + +# ─── All-platform mode (cibuildwheel) ─────────────────────────────────────── +if [[ "$ALL_MODE" -eq 1 ]]; then + export CIBW_BUILD="${CIBW_BUILD:-cp311-* cp312-*}" + export CIBW_SKIP="${CIBW_SKIP:-*-win32 *-manylinux_i686 pp*}" + + echo "=== Node Wire — cibuildwheel build for ${#PACKAGES[@]} package(s) ===" + echo "CIBW_BUILD=$CIBW_BUILD" + echo "CIBW_SKIP=$CIBW_SKIP" + + if command -v python3 >/dev/null 2>&1; then + PYTHON=python3 + elif command -v python >/dev/null 2>&1; then + PYTHON=python + else + echo "ERROR: python or python3 is required but not found in PATH." >&2 + exit 1 + fi + + if ! "$PYTHON" -c "import cibuildwheel" >/dev/null 2>&1; then + echo "ERROR: cibuildwheel is not installed in the current Python environment." >&2 + echo "Install with: $PYTHON -m pip install --upgrade 'cibuildwheel>=2.16.0'" >&2 + exit 1 + fi + + shopt -s nullglob + FAILED=() + + for PKG in "${PACKAGES[@]}"; do + echo "" + echo "--- Building: $PKG ---" + + if [[ ! -d "$PKG" ]]; then + echo "ERROR: Package path not found: $PKG" >&2 + FAILED+=("$PKG (missing path)") + continue + fi + + if [[ ! -f "$PKG/pyproject.toml" ]]; then + echo "ERROR: Missing pyproject.toml in $PKG" >&2 + FAILED+=("$PKG (missing pyproject.toml)") + continue + fi + + mkdir -p "$PKG/dist" + rm -f "$PKG"/dist/*.whl + + if ! ( + cd "$PKG" + "$PYTHON" -m cibuildwheel --output-dir dist + ); then + echo "ERROR: cibuildwheel build failed for $PKG" >&2 + FAILED+=("$PKG (build failed)") + continue + fi + + WHEELS=("$PKG"/dist/*.whl) + if [[ ${#WHEELS[@]} -eq 0 ]]; then + echo "ERROR: No wheels produced for $PKG" >&2 + FAILED+=("$PKG (no wheels)") + continue + fi + + if ! verify_wheels_no_py "$PYTHON" "${WHEELS[@]}"; then + FAILED+=("$PKG (.py leak)") + continue + fi + + echo "PASS: ${#WHEELS[@]} wheel(s) for $PKG — no .py source files" + done + + echo "" + if [[ ${#FAILED[@]} -gt 0 ]]; then + echo "=== FAILED packages ===" + for F in "${FAILED[@]}"; do echo " - $F"; done + exit 1 + fi + + echo "=== All packages built and verified successfully ===" + echo "" + echo "Wheels are in:" + for PKG in "${PACKAGES[@]}"; do + ls "$PKG"/dist/*.whl 2>/dev/null || true + done + exit 0 +fi + +# ─── Default mode (host + Linux Docker) ─────────────────────────────────── +echo "=== Node Wire — building ${#PACKAGES[@]} package(s) (host + linux) ===" + +FAILED=() + +if command -v python3 >/dev/null 2>&1; then + PYTHON_HOST=python3 +elif command -v python >/dev/null 2>&1; then + PYTHON_HOST=python +else + echo "ERROR: python3 or python is required on the host to build wheels but neither was found in PATH." >&2 + exit 1 +fi + +# Validate paths first so typos fail without Docker installed or running. +for PKG in "${PACKAGES[@]}"; do + if [[ ! -d "$PKG" ]]; then + echo "ERROR: Package path not found: $PKG" >&2 + FAILED+=("$PKG (missing path)") + continue + fi + if [[ ! -f "$PKG/pyproject.toml" ]]; then + echo "ERROR: Missing pyproject.toml in $PKG" >&2 + FAILED+=("$PKG (missing pyproject.toml)") + continue + fi +done + +if [[ ${#FAILED[@]} -gt 0 ]]; then + echo "" + echo "=== FAILED packages ===" + for F in "${FAILED[@]}"; do echo " - $F"; done + exit 1 +fi + +if ! command -v docker >/dev/null 2>&1; then + echo "ERROR: docker is required to build Linux wheels but was not found in PATH." >&2 + exit 1 +fi + +if ! docker info >/dev/null 2>&1; then + echo "ERROR: Docker daemon is not running. Start Docker and retry." >&2 + exit 1 +fi + +FAILED=() + +for PKG in "${PACKAGES[@]}"; do + echo "" + echo "--- Building: $PKG ---" + + ( + cd "$PKG" + "$PYTHON_HOST" -m build --wheel --no-isolation + ) + + docker run --rm \ + -v "$ROOT_DIR:/work" \ + -w "/work/$PKG" \ + python:3.12-slim \ + bash -lc "apt-get update && apt-get install -y --no-install-recommends build-essential && rm -rf /var/lib/apt/lists/* && python -m pip install --no-cache-dir setuptools build cython wheel && python -m build --wheel --no-isolation" || { + echo "ERROR: Linux wheel build failed for $PKG" >&2 + FAILED+=("$PKG (linux build failed)") + continue + } + + shopt -s nullglob + WHEELS=("$PKG"/dist/*.whl) + shopt -u nullglob + if [[ ${#WHEELS[@]} -eq 0 ]]; then + echo "ERROR: No wheels produced for $PKG" >&2 + FAILED+=("$PKG (no wheels)") + continue + fi + + if ! verify_wheels_no_py "$PYTHON_HOST" "${WHEELS[@]}"; then + FAILED+=("$PKG (.py leak)") + continue + fi + + echo "PASS: ${#WHEELS[@]} wheel(s) for $PKG — no .py source files" +done + +echo "" +if [[ ${#FAILED[@]} -gt 0 ]]; then + echo "=== FAILED packages ===" + for F in "${FAILED[@]}"; do echo " - $F"; done + exit 1 +fi + +echo "=== All packages built and verified successfully ===" +echo "" +echo "Wheels are in:" +for PKG in "${PACKAGES[@]}"; do + ls "$PKG/dist/"*.whl 2>/dev/null || true +done diff --git a/scripts/run-compliance-checks.sh b/scripts/run-compliance-checks.sh new file mode 100644 index 0000000..eeaf33f --- /dev/null +++ b/scripts/run-compliance-checks.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +# Runs security, dependency, and open-source compliance checks. + +set -e + +cd "$(dirname "$0")/.." + +echo "=====================================" +echo "1. Generating DEPENDENCIES.md..." +echo "=====================================" + +cat << 'EOF' > DEPENDENCIES.md +# Node Wire Open Source Dependencies + +This file is automatically generated and contains an inventory of all third-party dependencies used in the Node Wire project. + +## License Classification Criteria +To maintain open-source compliance, dependencies are evaluated against the following criteria: +* **✅ Safe (Permissive):** MIT, Apache-2.0, BSD, PSF. These licenses are universally safe for our Apache 2.0 open-source release and can be freely used, modified, and distributed. +* **⚠️ Needs Review:** Custom or obscure licenses. These require manual review by the engineering team to ensure they don't impose conflicting obligations. +* **⛔ Risky (Copyleft):** GPLv2, GPLv3, AGPL. These licenses are strictly prohibited in the runtime application as they force derivative works to adopt the same open-source license. They may only be used as isolated, non-distributed Development/Linting tools. + +--- + +EOF + +uv run pip-licenses --format=markdown --with-urls >> DEPENDENCIES.md +echo "DEPENDENCIES.md generated successfully!" + +echo "" +echo "=====================================" +echo "2. Running Bandit (SAST Scanner)..." +echo "=====================================" +# We allow medium/low severity but want to output findings. +uv run bandit -r src/ packages/ playground/ tests/ -ll || true + +echo "" +echo "=====================================" +echo "3. Running pip-audit (Vulnerability Scanner)..." +echo "=====================================" +uv run pip-audit || true + +echo "" +echo "Compliance checks finished!" diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..c530b7d --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,12 @@ +sonar.projectKey=node-wire +sonar.projectName=Node Wire +sonar.sourceEncoding=UTF-8 +sonar.python.version=3.11 + +sonar.sources=src +sonar.tests=tests + +sonar.exclusions=**/__pycache__/**,**/*.pyc,htmlcov/**,dist/**,playground/**,grafana/** +sonar.test.inclusions=tests/**/*.py + +sonar.python.coverage.reportPaths=coverage.xml diff --git a/src/agents/README.md b/src/agents/README.md index 22834ab..90777f7 100644 --- a/src/agents/README.md +++ b/src/agents/README.md @@ -1,8 +1,14 @@ -# 🤖 Node Wire Agents & MCP Orchestration + + +# Node Wire Agents & MCP Orchestration This folder contains the core intelligence and orchestration layer of **Node Wire**, enabling autonomous AI agents to interact with healthcare systems and cloud services via the **Model Context Protocol (MCP)**. -## 🚀 Overview +## Overview The `agents` module transforms static connectors (EHR, Google Drive, SMTP) into dynamic, discoverable tools for Large Language Models (LLMs). By following the MCP standard, we provide a unified interface for "ReAct" style agents to perform end-to-end clinical workflows through natural language instructions. @@ -13,13 +19,13 @@ The `agents` module transforms static connectors (EHR, Google Drive, SMTP) into --- -## 🏗️ Core Architecture +## Core Architecture ### 1. **MCP Server (`mcp_entrypoint.py`)** -A high-performance server built on the [FastMCP](https://github.com/modelcontextprotocol/python-sdk) framework. -- **Dynamic Bindings**: Uses the `ConnectorFactory` to load platform connectors and expose them as MCP tools. -- **Data Protection**: Automatically extracts and summarizes raw FHIR resources to protect patient privacy and reduce LLM token consumption. -- **Flexible Transport**: Defaults to `stdio` transport for seamless integration with ToolHive, Claude Desktop, or custom proxies. +Stdio MCP server using the official [Model Context Protocol Python SDK](https://github.com/modelcontextprotocol/python-sdk). +- **Manifest-driven tools**: `McpServer` builds the tool list from connector metadata (`.`) and dispatches via `connector.run()`. +- **Unified entrypoint**: `python -m agents.mcp_entrypoint` exposes every connector enabled for MCP in `config/connectors.yaml`. +- **Per-connector images**: `fhir_cerner_mcp`, `fhir_epic_mcp`, `google_drive_mcp`, and `smtp_mcp` run the same server with a `connector_ids` filter. ### 2. **ToolHive Agent (`toolhive.py`)** A reference implementation of a ReAct-style agent designed for the **ToolHive** ecosystem. @@ -35,14 +41,18 @@ A modular factory system supporting diverse LLM backends: --- -## 🛠️ Available MCP Tools +## MCP tool naming + +Tools are named **`{connector_id}.{action}`** as defined by each connector’s manifest (see `connectors/manifest.py` and `bindings/mcp_server/server.py`). Examples: + +| Example tool name | Connector | +| :--- | :--- | +| `fhir_cerner.read_patient` | Cerner FHIR | +| `fhir_epic.read_patient` | Epic FHIR | +| `google_drive.files.upload` | Google Drive | +| `smtp.send_email` | SMTP | -| Tool Name | Description | Connector | -| :--- | :--- | :--- | -| `fhir_cerner_read_patient` | Fetches patient demographics (Name, DOB, ID) from Cerner FHIR R4. | `fhir_cerner` | -| `fhir_epic_read_patient` | Fetches patient demographics from Epic FHIR R4. (IDs usually start with 'e'). | `fhir_epic` | -| `google_drive_upload_file` | Securely uploads text summaries or reports to a designated folder. | `google_drive` | -| `smtp_send_email` | Dispatches notifications or clinical summaries via secure SMTP. | `smtp` | +Use **`tools/list`** for the exact names and JSON Schemas your deployment exposes. --- @@ -67,14 +77,14 @@ TOOLHIVE_MCP_URL=http://localhost:8000/mcp # Connector Secrets (Injected into MCP Server) CERNER_CLIENT_ID=... -GOOGLE_DRIVE_SA_JSON=D:\connector-platform\service_account.json +GOOGLE_DRIVE_SA_JSON=/path/to/service_account.json SMTP_USERNAME=... SMTP_PASSWORD=... ``` --- -## 🏃 Usage Guide +## Usage Guide ### **1. Launch the MCP Server (Local)** To verify tool discovery and execution via `stdio`: diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 96bd967..2ea89ee 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -1 +1,5 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """AOT Agents — LLM agent utilities for the Connector Platform.""" diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index 5628bd6..ddb228b 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -1,16 +1,13 @@ -""" -FastMCP Server Entrypoint — SMART on FHIR (Cerner) -================================================= -Standalone MCP server exposing only the Cerner FHIR patient read tool. +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""MCP Server — Cerner FHIR connector only. Usage: python -m agents.fhir_cerner_mcp""" -Usage: - python -m agents.fhir_cerner_mcp -""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,86 +18,18 @@ logger = logging.getLogger("agents.fhir_cerner_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smartonfhir-cerner") +def main() -> None: + from bindings.mcp_server.server import McpServer - @mcp.tool( - name="fhir_cerner_read_patient", - description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." - ), + transport = os.getenv("NW_MCP_TRANSPORT", "stdio").strip().lower() + logger.info( + f"Starting nw-smartonfhir-cerner MCP server (transport={transport}, manifest-driven)" ) - async def fhir_cerner_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - birthdate: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - action = cerner.get_action("read_patient") - - if patient_id: - params = FhirCernerPatientReadInput(resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirCernerPatientReadInput(search_params=search) - else: - raise ValueError("Provide patient_id OR at least family_name/given_name") - - result = await action.internal_execute(params, trace_id=trace_id) - resource = result.resource - - name_parts = resource.get("name", [{}])[0] - full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - - addr = resource.get("address", [{}])[0] - full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" - ).strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - } - - return mcp - - -def main() -> None: - server = _make_server() - logger.info("Starting nw-smartonfhir-cerner MCP server (stdio transport)") - server.run() + McpServer( + server_name="nw-smartonfhir-cerner", + connector_ids=["fhir_cerner"], + ).run(transport=transport) if __name__ == "__main__": main() - diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index d7f6335..ec14e90 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -1,16 +1,13 @@ -""" -FastMCP Server Entrypoint — SMART on FHIR (Epic) -=============================================== -Standalone MCP server exposing only the Epic FHIR patient read tool. +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""MCP Server — Epic FHIR connector only. Usage: python -m agents.fhir_epic_mcp""" -Usage: - python -m agents.fhir_epic_mcp -""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,88 +18,16 @@ logger = logging.getLogger("agents.fhir_epic_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smartonfhir-epic") - - @mcp.tool( - name="fhir_epic_read_patient", - description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - birthdate: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - action = epic.get_action("read_patient") - - if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirEpicPatientReadInput(search_params=search) - else: - raise ValueError("Provide patient_id OR at least family_name/given_name") - - result = await action.internal_execute(params, trace_id=trace_id) - resource = result.resource - - name_parts = resource.get("name", [{}])[0] - full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - - addr = resource.get("address", [{}])[0] - full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" - ).strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - "source": "Epic FHIR", - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smartonfhir-epic MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + transport = os.getenv("NW_MCP_TRANSPORT", "stdio").strip().lower() + logger.info(f"Starting nw-smartonfhir-epic MCP server (transport={transport}, manifest-driven)") + McpServer( + server_name="nw-smartonfhir-epic", + connector_ids=["fhir_epic"], + ).run(transport=transport) if __name__ == "__main__": main() - diff --git a/src/agents/google_drive_mcp.py b/src/agents/google_drive_mcp.py index 050a3ef..0eff85b 100644 --- a/src/agents/google_drive_mcp.py +++ b/src/agents/google_drive_mcp.py @@ -1,16 +1,13 @@ -""" -FastMCP Server Entrypoint — Google Drive -======================================= -Standalone MCP server exposing only the Google Drive tool. +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""MCP Server — Google Drive connector only. Usage: python -m agents.google_drive_mcp""" -Usage: - python -m agents.google_drive_mcp -""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,69 +18,16 @@ logger = logging.getLogger("agents.google_drive_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.google_drive.schema import GoogleDriveOperationInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-google-drive") - - @mcp.tool( - name="google_drive_upload_file", - description=( - "Upload a text file to Google Drive. " - "Returns the file ID and a shareable web view link." - ), - ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), - mime_type: str = "text/plain", - ) -> dict: - trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") - - payload: dict = { - "action": "files.upload", - "name": file_name, - "mime_type": mime_type, - "content": content, - } - if folder_id: - payload["parents"] = [folder_id] - - params = GoogleDriveOperationInput(**payload) - result = await drive.internal_execute(params, trace_id=trace_id) - - raw = result.raw - return { - "file_id": raw.get("id"), - "file_name": raw.get("name"), - "web_view_link": raw.get("webViewLink"), - "description": result.description, - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-google-drive MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + transport = os.getenv("NW_MCP_TRANSPORT", "stdio").strip().lower() + logger.info(f"Starting nw-google-drive MCP server (transport={transport}, manifest-driven)") + McpServer( + server_name="nw-google-drive", + connector_ids=["google_drive"], + ).run(transport=transport) if __name__ == "__main__": main() - diff --git a/src/agents/llm_factory.py b/src/agents/llm_factory.py index 9bf12d3..0d293b2 100644 --- a/src/agents/llm_factory.py +++ b/src/agents/llm_factory.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ LLM Provider Factory ==================== @@ -16,21 +20,24 @@ gemini — gemini-2.0-flash anthropic — claude-3-5-haiku-20241022 """ + from __future__ import annotations import os from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type # --------------------------------------------------------------------------- # Data models (provider-agnostic) # --------------------------------------------------------------------------- + @dataclass class ToolCall: """A single tool-call request returned by the LLM.""" + id: str name: str arguments: Dict[str, Any] @@ -39,19 +46,21 @@ class ToolCall: @dataclass class LLMMessage: """A single message in the conversation thread.""" - role: str # "system" | "user" | "assistant" | "tool" + + role: str # "system" | "user" | "assistant" | "tool" content: Optional[str] = None tool_calls: List[ToolCall] = field(default_factory=list) tool_call_id: Optional[str] = None # required for role="tool" responses - name: Optional[str] = None # tool name for role="tool" + name: Optional[str] = None # tool name for role="tool" @dataclass class LLMResponse: """Raw response from the LLM.""" + content: Optional[str] tool_calls: List[ToolCall] = field(default_factory=list) - stop_reason: str = "stop" # "stop" | "tool_calls" + stop_reason: str = "stop" # "stop" | "tool_calls" @property def wants_tool_call(self) -> bool: @@ -62,6 +71,7 @@ def wants_tool_call(self) -> bool: # Abstract base # --------------------------------------------------------------------------- + class BaseLLMProvider(ABC): """Common interface for all LLM providers.""" @@ -93,18 +103,25 @@ def chat_with_tools( # Factory # --------------------------------------------------------------------------- +# Optional provider classes when [agents] extras are not installed. +GroqProvider: Optional[Type[BaseLLMProvider]] = None +OpenAIProvider: Optional[Type[BaseLLMProvider]] = None +GeminiProvider: Optional[Type[BaseLLMProvider]] = None +AnthropicProvider: Optional[Type[BaseLLMProvider]] = None + try: - from agents.providers.groq_provider import GroqProvider - from agents.providers.openai_provider import OpenAIProvider - from agents.providers.gemini_provider import GeminiProvider - from agents.providers.anthropic_provider import AnthropicProvider + from agents.providers.groq_provider import GroqProvider as _GroqProvider + from agents.providers.openai_provider import OpenAIProvider as _OpenAIProvider + from agents.providers.gemini_provider import GeminiProvider as _GeminiProvider + from agents.providers.anthropic_provider import AnthropicProvider as _AnthropicProvider + + GroqProvider = _GroqProvider + OpenAIProvider = _OpenAIProvider + GeminiProvider = _GeminiProvider + AnthropicProvider = _AnthropicProvider except ImportError: - # These may fail if running in an environment without the full [agents] extras, - # but we handle this during instantiation if needed. - GroqProvider = None - OpenAIProvider = None - GeminiProvider = None - AnthropicProvider = None + # Leave all four as None; create() raises ImportError with a clear message. + pass class LLMProviderFactory: @@ -128,7 +145,7 @@ def create(cls, provider: str, **kwargs: Any) -> BaseLLMProvider: e.g. ``api_key``, ``model``. """ provider = provider.lower().strip() - + if provider == "groq": if GroqProvider is None: raise ImportError("GroqProvider could not be loaded. Check dependencies.") @@ -148,8 +165,7 @@ def create(cls, provider: str, **kwargs: Any) -> BaseLLMProvider: else: supported = ["groq", "openai", "gemini", "anthropic"] raise ValueError( - f"Unknown LLM provider {provider!r}. " - f"Supported: {', '.join(supported)}" + f"Unknown LLM provider {provider!r}. Supported: {', '.join(supported)}" ) @classmethod diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index dee264e..fa3f131 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -1,629 +1,33 @@ -""" -FastMCP Server Entrypoint -========================= -This module is the main entrypoint for the Node Wire MCP server. -When run, it exposes healthcare workflow tools via the MCP stdio transport: +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""MCP Server — all connectors exposed via MCP. Usage: python -m agents.mcp_entrypoint""" - • fhir_cerner_read_patient — fetch a single patient from Cerner FHIR R4 - • fhir_cerner_search_patients — fetch multiple patients from Cerner (multi-ID or name) - • fhir_epic_read_patient — fetch a single patient from Epic FHIR R4 - • fhir_epic_search_patients — fetch multiple patients from Epic (multi-ID or name) - • google_drive_upload_file — write a file to Google Drive - • smtp_send_email — send an email via SMTP - -ToolHive manages the container lifecycle, injects secrets as environment -variables, and proxies the stdio MCP stream to HTTP/SSE for clients. - -Usage (run directly by ToolHive): - python -m agents.mcp_entrypoint - -Environment variables (injected by ToolHive via --secret flags): - CERNER_FHIR_BASE_URL, CERNER_CLIENT_ID, CERNER_KID, - CERNER_PRIVATE_KEY, CERNER_TOKEN_URL, CERNER_SCOPES - GOOGLE_DRIVE_SA_JSON - SMTP_USERNAME, SMTP_PASSWORD, SMTP_HOST, SMTP_PORT -""" from __future__ import annotations -import json import logging import os -import uuid + from dotenv import load_dotenv -# Load .env variables for local stdio transport -# Try both CWD and script's own folder to be safe -load_dotenv() -load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) +# Align with ``bindings.rest_api.app``: respect ``NW_REST_LOAD_DOTENV`` (pytest/CI +# set ``false``) and never override keys already in the environment — ``override=True`` +# here was stomping conftest and breaking ``monkeypatch.delenv`` restores. +if os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() not in ("0", "false", "no"): + load_dotenv(override=False) + load_dotenv(os.path.join(os.path.dirname(__file__), ".env"), override=False) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("agents.mcp_entrypoint") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError( - "mcp SDK not installed. Run: pip install 'node-wire[agents]'" - ) from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_cerner.schema import ( - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - FhirCernerEncounterSearchInput, - ) - from connectors.fhir_epic.schema import ( - FhirPatientReadInput as FhirEpicPatientReadInput, - FhirPatientSearchInput as FhirEpicPatientSearchInput, - FhirEncounterSearchInput as FhirEpicEncounterSearchInput, - ) - from connectors.google_drive.schema import GoogleDriveOperationInput - from connectors.smtp.schema import SmtpSendInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("Node Wire") - - # ------------------------------------------------------------------ - # Tool 1: Fetch patient from Cerner FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_read_patient", - description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." - ), - ) - async def fhir_cerner_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - FHIR Patient resource ID (direct lookup — use this if you have it). - family_name : str - Patient family/last name (used for search when no ID is known). - given_name : str - Patient given/first name. - name : str - Full or partial patient name (convenience — use when you only have a - single combined name string and no split given/family available). - birthdate : str - Patient date of birth in YYYY-MM-DD format. - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - action = cerner.get_action("read_patient") - - if patient_id: - params = FhirCernerPatientReadInput(resource_id=patient_id) - elif family_name or given_name or name: - params = FhirCernerPatientReadInput( - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await action.internal_execute(params, trace_id=trace_id) - resource = result.resource - - # Extract a clean summary for the LLM - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - - # Drastically simplify to keep token count low - ids = ", ".join([f"{i.get('system')}: {i.get('value')}" for i in resource.get("identifier", [])]) - phones = ", ".join([t.get("value") for t in resource.get("telecom", []) if t.get("system") == "phone"]) - emails = ", ".join([t.get("value") for t in resource.get("telecom", []) if t.get("system") == "email"]) - addr = resource.get("address", [{}])[0] - full_addr = f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}".strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - } - - # ------------------------------------------------------------------ - # Tool 2: Fetch patient from Epic FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_read_patient", - description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - FHIR Patient resource ID (Epic specific, usually starts with 'e'). - family_name : str - Patient family/last name. - given_name : str - Patient given/first name. - name : str - Full or partial patient name (convenience — use when you only have a - single combined name string and no split given/family available). - birthdate : str - Patient date of birth in YYYY-MM-DD format. - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - action = epic.get_action("read_patient") - - if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) - elif family_name or given_name or name: - params = FhirEpicPatientReadInput( - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await action.internal_execute(params, trace_id=trace_id) - resource = result.resource - - # Clean extract for LLM - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - - addr = resource.get("address", [{}])[0] - full_addr = f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}".strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - "source": "Epic FHIR", - } - - # ------------------------------------------------------------------ - # Tool 3: Search patients in Cerner (multi-ID or name-based) - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_search_patients", - description=( - "Search for multiple patients in Cerner FHIR R4. " - "Pass a comma-separated list of Patient IDs for concurrent lookup, " - "or supply name/birthdate fields for a name-based search returning all matches." - ), - ) - async def fhir_cerner_search_patients( - patient_ids: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_ids : str - Comma-separated Patient IDs for concurrent multi-ID lookup - (e.g. '12345678,87654321'). Takes priority over name fields. - family_name : str - Patient family/last name (name-search mode). - given_name : str - Patient given/first name (name-search mode). - name : str - Full or partial name string — FHIR 'name' token search. - birthdate : str - Date of birth in YYYY-MM-DD format (name-search mode). - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - action = cerner.get_action("search_patients") - - if patient_ids.strip(): - ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirCernerPatientSearchInput(resource_ids=ids) - elif family_name or given_name or name or birthdate: - params = FhirCernerPatientSearchInput( - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError( - "Provide patient_ids (comma-separated) OR at least one of " - "family_name / given_name / name / birthdate" - ) - - result = await action.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - summaries.append({ - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - }) - - return { - "patients": summaries, - "total": result.total, - "errors": result.errors, - } - - # ------------------------------------------------------------------ - # Tool 4: Search patients in Epic (multi-ID or name-based) - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_search_patients", - description=( - "Search for multiple patients in Epic FHIR R4. " - "Pass a comma-separated list of Patient IDs for concurrent lookup, " - "or supply name/birthdate fields for a name-based search returning all matches. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_search_patients( - patient_ids: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_ids : str - Comma-separated Patient IDs for concurrent multi-ID lookup - (e.g. 'eABC,eDEF'). Takes priority over name fields. - family_name : str - Patient family/last name (name-search mode). - given_name : str - Patient given/first name (name-search mode). - name : str - Full or partial name string — FHIR 'name' token search. - birthdate : str - Date of birth in YYYY-MM-DD format (name-search mode). - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - action = epic.get_action("search_patients") - - if patient_ids.strip(): - ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirEpicPatientSearchInput(resource_ids=ids) - elif family_name or given_name or name or birthdate: - params = FhirEpicPatientSearchInput( - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError( - "Provide patient_ids (comma-separated) OR at least one of " - "family_name / given_name / name / birthdate" - ) - - result = await action.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - summaries.append({ - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "source": "Epic FHIR", - }) - - return { - "patients": summaries, - "total": result.total, - "errors": result.errors, - } - - # ------------------------------------------------------------------ - # Tool 5: Search encounters in Cerner FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_search_encounters", - description=( - "Search for encounters in Cerner FHIR R4. " - "Returns a list of encounter summaries for a given patient or filter." - ), - ) - async def fhir_cerner_search_encounters( - patient_id: str = "", - status: str = "", - date: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - Cerner Patient ID to find encounters for. - status : str - Filter by encounter status (e.g. 'finished', 'in-progress'). - date : str - Filter by date or date range (e.g. '2024', 'ge2023-01-01'). - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - action = cerner.get_action("search_encounter") - - params = FhirCernerEncounterSearchInput( - patient_id=patient_id or None, - status=status or None, - date=date or None, - ) - - result = await action.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - summaries.append({ - "encounter_id": resource.get("id"), - "status": resource.get("status"), - "class": resource.get("class", {}).get("display"), - "period_start": resource.get("period", {}).get("start"), - "period_end": resource.get("period", {}).get("end"), - "type": resource.get("type", [{}])[0].get("text"), - }) - - return { - "encounters": summaries, - "total": result.total, - } - - # ------------------------------------------------------------------ - # Tool 6: Search encounters in Epic FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_search_encounters", - description=( - "Search for encounters in Epic FHIR R4. " - "Returns a list of encounter summaries for a given patient or filter. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_search_encounters( - patient_id: str = "", - status: str = "", - date: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - Epic Patient ID to find encounters for. - status : str - Filter by encounter status (e.g. 'finished'). - date : str - Filter by date or date range. - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - action = epic.get_action("search_encounter") - - params = FhirEpicEncounterSearchInput( - patient_id=patient_id or None, - status=status or None, - date=date or None, - ) - - result = await action.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - summaries.append({ - "encounter_id": resource.get("id"), - "status": resource.get("status"), - "class": resource.get("class", {}).get("display"), - "period_start": resource.get("period", {}).get("start"), - "period_end": resource.get("period", {}).get("end"), - "type": resource.get("type", [{}])[0].get("text"), - }) - - return { - "encounters": summaries, - "total": result.total, - "source": "Epic FHIR", - } - - # ------------------------------------------------------------------ - # Tool 7: Upload a file to Google Drive - # ------------------------------------------------------------------ - - @mcp.tool( - name="google_drive_upload_file", - description=( - "Upload a text file to Google Drive. " - "Returns the file ID and a shareable web view link." - ), - ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), - mime_type: str = "text/plain", - ) -> dict: - """ - Parameters - ---------- - file_name : str - Name for the file in Google Drive (e.g. 'patient_summary_12345.txt'). - content : str - UTF-8 text content to write into the file. - folder_id : str - Optional Google Drive folder ID to place the file in. - mime_type : str - MIME type (default: text/plain). - """ - trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") - - payload: dict = { - "action": "files.upload", - "name": file_name, - "mime_type": mime_type, - "content": content, - } - if folder_id: - payload["parents"] = [folder_id] - - params = GoogleDriveOperationInput(**payload) - result = await drive.internal_execute(params, trace_id=trace_id) - - raw = result.raw - return { - "file_id": raw.get("id"), - "file_name": raw.get("name"), - "web_view_link": raw.get("webViewLink"), - "description": result.description, - } - - # ------------------------------------------------------------------ - # Tool 4: Send email via SMTP - # ------------------------------------------------------------------ - - @mcp.tool( - name="smtp_send_email", - description=( - "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." - ), - ) - async def smtp_send_email( - to_email: str, - subject: str, - body: str, - from_email: str = "", - ) -> dict: - """ - Parameters - ---------- - to_email : str - Recipient email address. - subject : str - Email subject line. - body : str - Plain-text email body. - from_email : str - Sender address — defaults to SMTP_USERNAME env var if empty. - """ - trace_id = str(uuid.uuid4()) - smtp = factory._connectors.get("smtp") - if not smtp: - raise RuntimeError("smtp connector not configured") - - smtp_host = os.environ.get("SMTP_HOST", "smtp.gmail.com").strip(" '\"") - smtp_port_raw = os.environ.get("SMTP_PORT", "587").strip(" '\"") - smtp_port = int(smtp_port_raw) - smtp_use_tls = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" - - # Guardrail: Handle placeholder strings from LLM or empty input - sender = from_email.strip(" '\"") - if not sender or "@" not in sender or "system_default" in sender: - sender = (os.environ.get("FROM_EMAIL") or os.environ.get("SMTP_USERNAME") or "noreply@node-wire.local").strip(" '\"") - - # Pydantic EmailStr does not like "Name " - # Extract just the email part if needed - import re - def _extract_email(s: str) -> str: - match = re.search(r"<(.+?)>", s) - return match.group(1) if match else s.strip() - - sender = _extract_email(sender) - recipient = _extract_email(to_email) - - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) - - params = SmtpSendInput( - host=smtp_host, - port=smtp_port, - use_tls=smtp_use_tls, - username_secret_key="SMTP_USERNAME", - password_secret_key="SMTP_PASSWORD", - from_email=sender, - to=[recipient], - subject=subject, - body=body, - ) - result = await smtp.internal_execute(params, trace_id=trace_id) - return {"sent": result.sent, "message_id": result.message_id} - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting Node Wire MCP server (stdio transport)") - server.run() # stdio — ToolHive proxies this to HTTP/SSE + from bindings.mcp_server.server import McpServer + + transport = os.getenv("NW_MCP_TRANSPORT", "stdio").strip().lower() + logger.info(f"Starting Node Wire MCP server (transport={transport}, manifest-driven)") + McpServer(server_name="node-wire").run(transport=transport) if __name__ == "__main__": diff --git a/src/agents/providers/__init__.py b/src/agents/providers/__init__.py index 1170f25..55eae88 100644 --- a/src/agents/providers/__init__.py +++ b/src/agents/providers/__init__.py @@ -1 +1,5 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """Providers package for agents.""" diff --git a/src/agents/providers/anthropic_provider.py b/src/agents/providers/anthropic_provider.py index f48b875..8515a53 100644 --- a/src/agents/providers/anthropic_provider.py +++ b/src/agents/providers/anthropic_provider.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ Anthropic (Claude) LLM Provider ================================ @@ -7,11 +11,10 @@ Required env var: ANTHROPIC_API_KEY Optional env var: ANTHROPIC_MODEL (default: claude-3-5-haiku-20241022) """ + from __future__ import annotations -import json import logging -import uuid from typing import Any, Dict, List, Optional from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall @@ -45,31 +48,40 @@ def _messages_to_claude( if m.content: content.append({"type": "text", "text": m.content}) for tc in m.tool_calls: - content.append({ - "type": "tool_use", - "id": tc.id, - "name": tc.name, - "input": tc.arguments, - }) + content.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.arguments, + } + ) result.append({"role": "assistant", "content": content}) elif m.role == "tool": # Claude expects tool results as user messages with tool_result blocks - result.append({ - "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": m.tool_call_id, - "content": m.content or "", - }], - }) + result.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": m.tool_call_id, + "content": m.content or "", + } + ], + } + ) return result, system_prompt +anthropic: Any = None try: - import anthropic + import anthropic as _anthropic + + anthropic = _anthropic except ImportError: - anthropic = None + pass class AnthropicProvider(BaseLLMProvider): @@ -77,9 +89,7 @@ class AnthropicProvider(BaseLLMProvider): def __init__(self, api_key: str, model: str = "claude-3-5-haiku-20241022") -> None: if anthropic is None: - raise ImportError( - "anthropic SDK not installed. Run: pip install 'node-wire[agents]'" - ) + raise ImportError("anthropic SDK not installed. Run: pip install 'node-wire[agents]'") self._anthropic = anthropic self._client = anthropic.Anthropic(api_key=api_key) self._model = model @@ -103,7 +113,9 @@ def chat_with_tools( if claude_tools: kwargs["tools"] = claude_tools - logger.debug("Anthropic request | model=%s | messages=%d", self._model, len(claude_messages)) + logger.debug( + "Anthropic request | model=%s | messages=%d", self._model, len(claude_messages) + ) response = self._client.messages.create(**kwargs) tool_calls: List[ToolCall] = [] @@ -111,11 +123,13 @@ def chat_with_tools( for block in response.content: if block.type == "tool_use": - tool_calls.append(ToolCall( - id=block.id, - name=block.name, - arguments=block.input if isinstance(block.input, dict) else {}, - )) + tool_calls.append( + ToolCall( + id=block.id, + name=block.name, + arguments=block.input if isinstance(block.input, dict) else {}, + ) + ) elif block.type == "text": text_parts.append(block.text) diff --git a/src/agents/providers/gemini_provider.py b/src/agents/providers/gemini_provider.py index 1184528..379dd8d 100644 --- a/src/agents/providers/gemini_provider.py +++ b/src/agents/providers/gemini_provider.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ Gemini LLM Provider =================== @@ -8,9 +12,9 @@ Required env var: GEMINI_API_KEY Optional env var: GEMINI_MODEL (default: gemini-2.0-flash) """ + from __future__ import annotations -import json import logging import uuid from typing import Any, Dict, List, Optional @@ -32,10 +36,13 @@ def _mcp_schema_to_gemini(schema: Dict[str, Any]) -> Dict[str, Any]: return cleaned +genai: Any = None try: - import google.generativeai as genai + import google.generativeai as _genai + + genai = _genai except ImportError: - genai = None + pass class GeminiProvider(BaseLLMProvider): @@ -68,11 +75,13 @@ def chat_with_tools( schema = _mcp_schema_to_gemini( t.get("input_schema", {"type": "object", "properties": {}}) ) - decls.append(FunctionDeclaration( - name=t["name"], - description=t.get("description", ""), - parameters=schema, - )) + decls.append( + FunctionDeclaration( + name=t["name"], + description=t.get("description", ""), + parameters=schema, + ) + ) gemini_tools = [Tool(function_declarations=decls)] # Translate conversation to Gemini Contents format @@ -93,22 +102,28 @@ def chat_with_tools( parts.append(m.content) if m.tool_calls: for tc in m.tool_calls: - parts.append(genai.protos.Part( - function_call=genai.protos.FunctionCall( - name=tc.name, args=tc.arguments + parts.append( + genai.protos.Part( + function_call=genai.protos.FunctionCall( + name=tc.name, args=tc.arguments + ) ) - )) + ) chat_history.append({"role": "model", "parts": parts}) elif m.role == "tool": - chat_history.append({ - "role": "function", - "parts": [genai.protos.Part( - function_response=genai.protos.FunctionResponse( - name=m.name or "tool", - response={"result": m.content or ""}, - ) - )], - }) + chat_history.append( + { + "role": "function", + "parts": [ + genai.protos.Part( + function_response=genai.protos.FunctionResponse( + name=m.name or "tool", + response={"result": m.content or ""}, + ) + ) + ], + } + ) model = genai.GenerativeModel( model_name=self._model_name, @@ -124,11 +139,13 @@ def chat_with_tools( for part in response.parts: if hasattr(part, "function_call") and part.function_call.name: fc = part.function_call - tool_calls.append(ToolCall( - id=str(uuid.uuid4()), - name=fc.name, - arguments=dict(fc.args), - )) + tool_calls.append( + ToolCall( + id=str(uuid.uuid4()), + name=fc.name, + arguments=dict(fc.args), + ) + ) elif hasattr(part, "text") and part.text: text_parts.append(part.text) diff --git a/src/agents/providers/groq_provider.py b/src/agents/providers/groq_provider.py index 034ad85..8105efb 100644 --- a/src/agents/providers/groq_provider.py +++ b/src/agents/providers/groq_provider.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ Groq LLM Provider ================= @@ -7,11 +11,12 @@ Required env var: GROQ_API_KEY Optional env var: GROQ_MODEL (default: llama-3.3-70b-versatile) """ + from __future__ import annotations import json import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, cast from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall @@ -34,33 +39,42 @@ def _messages_to_groq(messages: List[LLMMessage]) -> List[Dict[str, Any]]: result = [] for m in messages: if m.role == "tool": - result.append({ - "role": "tool", - "tool_call_id": m.tool_call_id, - "content": m.content or "", - }) + result.append( + { + "role": "tool", + "tool_call_id": m.tool_call_id, + "content": m.content or "", + } + ) elif m.tool_calls: - result.append({ + assistant_msg: Dict[str, Any] = { "role": "assistant", - "content": m.content, + "content": m.content if m.content is not None else "", "tool_calls": [ { "id": tc.id, "type": "function", - "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}, + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments), + }, } for tc in m.tool_calls ], - }) + } + result.append(cast(Dict[str, Any], assistant_msg)) else: result.append({"role": m.role, "content": m.content or ""}) return result +Groq: Any = None try: - from groq import Groq + from groq import Groq as _Groq + + Groq = _Groq except ImportError: - Groq = None + pass class GroqProvider(BaseLLMProvider): @@ -68,9 +82,7 @@ class GroqProvider(BaseLLMProvider): def __init__(self, api_key: str, model: str = "llama-3.3-70b-versatile") -> None: if Groq is None: - raise ImportError( - "groq SDK not installed. Run: pip install 'node-wire[agents]'" - ) + raise ImportError("groq SDK not installed. Run: pip install 'node-wire[agents]'") self._client = Groq(api_key=api_key) self._model = model logger.info("GroqProvider initialised | model=%s", model) @@ -88,8 +100,12 @@ def chat_with_tools( kwargs["tools"] = groq_tools kwargs["tool_choice"] = "auto" - logger.debug("Groq request | model=%s | messages=%d | tools=%d", - self._model, len(groq_messages), len(groq_tools)) + logger.debug( + "Groq request | model=%s | messages=%d | tools=%d", + self._model, + len(groq_messages), + len(groq_tools), + ) response = self._client.chat.completions.create(**kwargs) choice = response.choices[0] diff --git a/src/agents/providers/openai_provider.py b/src/agents/providers/openai_provider.py index 3c66d4b..29d1b50 100644 --- a/src/agents/providers/openai_provider.py +++ b/src/agents/providers/openai_provider.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ OpenAI LLM Provider =================== @@ -6,11 +10,12 @@ Required env var: OPENAI_API_KEY Optional env var: OPENAI_MODEL (default: gpt-4o-mini) """ + from __future__ import annotations import json import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, cast from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall @@ -32,33 +37,42 @@ def _messages_to_openai(messages: List[LLMMessage]) -> List[Dict[str, Any]]: result = [] for m in messages: if m.role == "tool": - result.append({ - "role": "tool", - "tool_call_id": m.tool_call_id, - "content": m.content or "", - }) + result.append( + { + "role": "tool", + "tool_call_id": m.tool_call_id, + "content": m.content or "", + } + ) elif m.tool_calls: - result.append({ + assistant_msg: Dict[str, Any] = { "role": "assistant", - "content": m.content, + "content": m.content if m.content is not None else "", "tool_calls": [ { "id": tc.id, "type": "function", - "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}, + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments), + }, } for tc in m.tool_calls ], - }) + } + result.append(cast(Dict[str, Any], assistant_msg)) else: result.append({"role": m.role, "content": m.content or ""}) return result +OpenAI: Any = None try: - from openai import OpenAI + from openai import OpenAI as _OpenAI + + OpenAI = _OpenAI except ImportError: - OpenAI = None + pass class OpenAIProvider(BaseLLMProvider): @@ -66,9 +80,7 @@ class OpenAIProvider(BaseLLMProvider): def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None: if OpenAI is None: - raise ImportError( - "openai SDK not installed. Run: pip install 'node-wire[agents]'" - ) + raise ImportError("openai SDK not installed. Run: pip install 'node-wire[agents]'") self._client = OpenAI(api_key=api_key) self._model = model logger.info("OpenAIProvider initialised | model=%s", model) diff --git a/src/agents/salesforce_mcp.py b/src/agents/salesforce_mcp.py new file mode 100644 index 0000000..31f8669 --- /dev/null +++ b/src/agents/salesforce_mcp.py @@ -0,0 +1,28 @@ +"""MCP Server — Salesforce connector only. Usage: python -m agents.salesforce_mcp""" + +from __future__ import annotations + +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() +load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("agents.salesforce_mcp") + + +def main() -> None: + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-salesforce MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-salesforce", + connector_ids=["salesforce"], + ).run_stdio() + + +if __name__ == "__main__": + main() diff --git a/src/agents/slack_mcp.py b/src/agents/slack_mcp.py new file mode 100644 index 0000000..4851521 --- /dev/null +++ b/src/agents/slack_mcp.py @@ -0,0 +1,28 @@ +"""MCP Server — Slack connector only. Usage: python -m agents.slack_mcp""" + +from __future__ import annotations + +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() +load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("agents.slack_mcp") + + +def main() -> None: + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-slack MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-slack", + connector_ids=["slack"], + ).run_stdio() + + +if __name__ == "__main__": + main() diff --git a/src/agents/smtp_mcp.py b/src/agents/smtp_mcp.py index 80c147c..456111f 100644 --- a/src/agents/smtp_mcp.py +++ b/src/agents/smtp_mcp.py @@ -1,25 +1,13 @@ -""" -FastMCP Server Entrypoint — SMTP -================================ -Standalone MCP server exposing only the SMTP email tool. +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""MCP Server — SMTP connector only. Usage: python -m agents.smtp_mcp""" -Usage: - python -m agents.smtp_mcp - -Environment variables: - SMTP_HOST (default: smtp.gmail.com) - SMTP_PORT (default: 587) - SMTP_USE_TLS (default: true) - SMTP_USERNAME - SMTP_PASSWORD - FROM_EMAIL (optional; fallback sender address) -""" from __future__ import annotations import logging import os -import re -import uuid from dotenv import load_dotenv @@ -30,87 +18,16 @@ logger = logging.getLogger("agents.smtp_mcp") -def _extract_email(value: str) -> str: - # Pydantic EmailStr does not like "Name " - match = re.search(r"<(.+?)>", value) - return (match.group(1) if match else value).strip() - - -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.smtp.schema import SmtpSendInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smtp") - - @mcp.tool( - name="smtp_send_email", - description=( - "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." - ), - ) - async def smtp_send_email( - to_email: str, - subject: str, - body: str, - from_email: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - smtp = factory._connectors.get("smtp") - if not smtp: - raise RuntimeError("smtp connector not configured") - - smtp_host = os.environ.get("SMTP_HOST", "smtp.gmail.com").strip(" '\"") - smtp_port_raw = os.environ.get("SMTP_PORT", "587").strip(" '\"") - smtp_port = int(smtp_port_raw) - smtp_use_tls = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" - - sender = from_email.strip(" '\"") - if not sender or "@" not in sender or "system_default" in sender: - sender = ( - os.environ.get("FROM_EMAIL") - or os.environ.get("SMTP_USERNAME") - or "noreply@node-wire.local" - ).strip(" '\"") - - sender = _extract_email(sender) - recipient = _extract_email(to_email) - - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) - - params = SmtpSendInput( - host=smtp_host, - port=smtp_port, - use_tls=smtp_use_tls, - username_secret_key="SMTP_USERNAME", - password_secret_key="SMTP_PASSWORD", - from_email=sender, - to=[recipient], - subject=subject, - body=body, - ) - result = await smtp.internal_execute(params, trace_id=trace_id) - return {"sent": result.sent, "message_id": getattr(result, "message_id", None)} - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smtp MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + transport = os.getenv("NW_MCP_TRANSPORT", "stdio").strip().lower() + logger.info(f"Starting nw-smtp MCP server (transport={transport}, manifest-driven)") + McpServer( + server_name="nw-smtp", + connector_ids=["smtp"], + ).run(transport=transport) if __name__ == "__main__": main() - diff --git a/src/agents/stripe_mcp.py b/src/agents/stripe_mcp.py new file mode 100644 index 0000000..574c00f --- /dev/null +++ b/src/agents/stripe_mcp.py @@ -0,0 +1,28 @@ +"""MCP Server — Stripe connector only. Usage: python -m agents.stripe_mcp""" + +from __future__ import annotations + +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() +load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("agents.stripe_mcp") + + +def main() -> None: + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-stripe MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-stripe", + connector_ids=["stripe"], + ).run_stdio() + + +if __name__ == "__main__": + main() diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index 884f949..104b851 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -1,12 +1,16 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ ToolHive Agent ============== A ReAct-style AI agent that connects to an MCP server running in ToolHive, discovers its tools, and orchestrates a healthcare workflow: - 1. Fetch patient details via fhir_cerner_read_patient or fhir_epic_read_patient - 2. Write a patient summary file via google_drive_upload_file - 3. Email the summary via smtp_send_email + 1. Fetch patient details via fhir_cerner.read_patient / fhir_epic.read_patient (or search_* tools) + 2. Write a patient summary file via google_drive.files.upload + 3. Email the summary via smtp.send_email The LLM backend is fully configurable via the LLM_PROVIDER env var. @@ -23,12 +27,16 @@ Environment variables: TOOLHIVE_MCP_URL : MCP proxy URL from ToolHive UI (e.g. http://localhost:PORT/mcp) TOOLHIVE_MCP_URLS: Comma-separated MCP proxy URLs (multi-server) + TOOLHIVE_MCP_API_KEY: Optional inbound MCP auth key (sent as Bearer + X-API-Key) + TOOLHIVE_MCP_BEARER_TOKEN: Optional inbound MCP bearer token (JWT/API key) + TOOLHIVE_MAX_TOOL_FAILURES: Stop after this many failed invocations per tool name (default: 2) LLM_PROVIDER : groq | openai | gemini | anthropic (default: groq) GROQ_API_KEY : (when using groq) OPENAI_API_KEY : (when using openai) GEMINI_API_KEY : (when using gemini) ANTHROPIC_API_KEY: (when using anthropic) """ + from __future__ import annotations import argparse @@ -40,7 +48,8 @@ import uuid from contextlib import AsyncExitStack from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, AsyncIterator, Dict, List, Optional, Protocol, Union, runtime_checkable +import re from dotenv import load_dotenv @@ -52,10 +61,141 @@ logger = logging.getLogger("agents.toolhive") +_EMAIL_RE = re.compile(r"[^@\s]+@[^@\s]+\.[^@\s]+") +_SMTP_EMAIL_FIELDS = {"from_email", "to", "cc", "bcc", "reply_to", "sender"} + + +def _redact_tool_args_for_log(tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]: + """ + Return a copy of *args* safe for logging. + + For SMTP tools only: replace any email address value with '[REDACTED]' + so that recipient and sender identifiers are never written to logs. + All other tool args pass through unchanged. + """ + if not tool_name.startswith("smtp."): + return args + + scrubbed: Dict[str, Any] = {} + for key, value in args.items(): + if key in _SMTP_EMAIL_FIELDS: + if isinstance(value, list): + scrubbed[key] = ["[REDACTED]"] * len(value) + elif isinstance(value, str) and _EMAIL_RE.search(value): + scrubbed[key] = "[REDACTED]" + else: + scrubbed[key] = value + else: + scrubbed[key] = value + return scrubbed + + +def truncate_tool_result_for_llm(text: str) -> str: + """ + Cap tool output size sent to the LLM so providers with strict limits (e.g. Groq + on-demand TPM) do not fail with 413 / oversized requests after large FHIR payloads. + + Full raw output remains in AgentStep.tool_result for logging; only the message + passed back into the chat is truncated. + + Override with env TOOLHIVE_MAX_TOOL_RESULT_CHARS (default 12000). Use 0 to disable. + """ + raw = (os.environ.get("TOOLHIVE_MAX_TOOL_RESULT_CHARS") or "12000").strip() + try: + max_chars = int(raw) + except ValueError: + max_chars = 12000 + if max_chars <= 0 or len(text) <= max_chars: + return text + omitted = len(text) - max_chars + return ( + text[:max_chars] + + "\n\n[... truncated " + + str(omitted) + + " characters for LLM context limits; use visible fields for next steps.]" + ) + + +def resolve_max_tool_failures(override: Optional[int] = None) -> int: + """ + Max failed tool invocations per tool name before aborting the agent run. + ``override`` wins; otherwise ``TOOLHIVE_MAX_TOOL_FAILURES`` (default 2). Minimum 1. + """ + if override is not None: + return max(1, int(override)) + raw = (os.environ.get("TOOLHIVE_MAX_TOOL_FAILURES") or "2").strip() + try: + n = int(raw) + except ValueError: + n = 2 + return max(1, n) + + +def _is_tool_failure(tool_result: str) -> bool: + """True if MCP/connector reported a failed tool outcome (not empty success).""" + if not tool_result or not tool_result.strip(): + return False + t = tool_result.strip() + if t.startswith("ERROR:"): + return True + low = t.lower() + if "input validation error" in low: + return True + if "validation error" in low and "input" in low: + return True + if t.startswith("{"): + try: + data = json.loads(t) + if isinstance(data, dict) and data.get("success") is False: + return True + except json.JSONDecodeError: + pass + return False + + +def _tool_failure_abort_message(tool_name: str, max_failures: int) -> str: + return ( + f'The tool "{tool_name}" failed {max_failures} times in a row. ' + "Please check the parameters against the schema from tools/list, " + "or tell me if I should use a different tool or approach." + ) + + +def _chunk_agent_text(text: str, chunk_size: int = 180) -> List[str]: + """Split final assistant text into UI-friendly chunks for stream consumers.""" + if not text: + return [""] + chunks: List[str] = [] + current = "" + for part in text.split(" "): + candidate = f"{current} {part}".strip() + if current and len(candidate) > chunk_size: + chunks.append(current + " ") + current = part + else: + current = candidate + if current: + chunks.append(current) + return chunks + + +def _stream_done_event(trace_id: str, *, success: bool) -> Dict[str, Any]: + from node_wire_runtime.streaming import stream_completion_log + + stream_completion_log(trace_id, success, connector_id="agent", action="run_events") + return { + "type": "done", + "trace_id": trace_id, + "success": success, + "message": f"Streaming completed. trace_id={trace_id}", + } + + # --------------------------------------------------------------------------- # Result model # --------------------------------------------------------------------------- + @dataclass class AgentStep: step: int @@ -78,6 +218,8 @@ class AgentRunResult: # Lightweight async MCP client (SSE / streamable-HTTP transport) # --------------------------------------------------------------------------- + +@runtime_checkable class McpClient(Protocol): async def list_tools(self) -> List[Dict[str, Any]]: ... @@ -103,6 +245,42 @@ def __init__(self, base_url: str) -> None: self._base_url = base_url.rstrip("/") self._session_id: Optional[str] = None self._initialized: bool = False + self._auth_token: Optional[str] = os.environ.get( + "TOOLHIVE_MCP_BEARER_TOKEN" + ) or os.environ.get("TOOLHIVE_MCP_API_KEY") + + def _build_request_headers(self) -> Dict[str, str]: + headers: Dict[str, str] = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self._session_id: + headers["Mcp-Session-Id"] = self._session_id + # For MCP auth-gated servers, send both forms for compatibility. + if self._auth_token: + headers["Authorization"] = f"Bearer {self._auth_token}" + headers["X-API-Key"] = self._auth_token + return headers + + def _inject_auth_meta(self, params: Dict[str, Any]) -> Dict[str, Any]: + if not self._auth_token: + return dict(params) + out = dict(params) + meta = out.get("_meta") + if isinstance(meta, dict): + merged_meta = dict(meta) + else: + merged_meta = {} + # Include common aliases to maximize compatibility with MCP servers. + merged_meta.setdefault("authorization", f"Bearer {self._auth_token}") + merged_meta.setdefault("Authorization", f"Bearer {self._auth_token}") + merged_meta.setdefault("x-api-key", self._auth_token) + merged_meta.setdefault("X-API-Key", self._auth_token) + merged_meta.setdefault("token", self._auth_token) + merged_meta.setdefault("api_key", self._auth_token) + merged_meta.setdefault("apiKey", self._auth_token) + out["_meta"] = merged_meta + return out async def _initialize(self) -> None: """Send MCP initialize + initialized handshake; store session ID.""" @@ -122,7 +300,7 @@ async def _initialize(self) -> None: resp = await client.post( self._base_url, json=init_payload, - headers={"Content-Type": "application/json"}, + headers=self._build_request_headers(), ) resp.raise_for_status() session_id = resp.headers.get("Mcp-Session-Id") @@ -134,11 +312,8 @@ async def _initialize(self) -> None: # Send the initialized notification (fire-and-forget; no id = notification) notif = {"jsonrpc": "2.0", "method": "notifications/initialized"} - headers: Dict[str, str] = {"Content-Type": "application/json"} - if self._session_id: - headers["Mcp-Session-Id"] = self._session_id try: - await client.post(self._base_url, json=notif, headers=headers) + await client.post(self._base_url, json=notif, headers=self._build_request_headers()) except Exception: pass # Notifications have no response; ignore transport errors @@ -155,16 +330,14 @@ async def _rpc(self, method: str, params: Dict[str, Any]) -> Any: "id": str(uuid.uuid4()), "method": method, } - if params: - payload["params"] = params - - headers: Dict[str, str] = {"Content-Type": "application/json"} - if self._session_id: - headers["Mcp-Session-Id"] = self._session_id + # Always include params when auth token is present so _meta is sent even + # for methods like tools/list that otherwise pass {}. + if params or self._auth_token: + payload["params"] = self._inject_auth_meta(params) url = self._base_url async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(url, json=payload, headers=headers) + resp = await client.post(url, json=payload, headers=self._build_request_headers()) resp.raise_for_status() data = resp.json() if "error" in data: @@ -223,7 +396,9 @@ async def list_tools(self) -> List[Dict[str, Any]]: logger.info( "MultiMcpClient: %d/%d clients reachable, %d tools discovered", - success_count, len(self._clients), len(merged), + success_count, + len(self._clients), + len(merged), ) self._tool_to_client_idx = tool_to_idx return merged @@ -253,7 +428,7 @@ class StdioMcpClient: def __init__(self, command: List[str]) -> None: self._command = command self._exit_stack = AsyncExitStack() - self._session = None + self._session: Any = None async def __aenter__(self) -> StdioMcpClient: try: @@ -283,7 +458,10 @@ async def list_tools(self) -> List[Dict[str, Any]]: raise RuntimeError("Client not initialised. Use 'async with'") resp = await self._session.list_tools() # Convert to simple tool list - return [{"name": t.name, "description": t.description, "input_schema": t.inputSchema} for t in resp.tools] + return [ + {"name": t.name, "description": t.description, "input_schema": t.inputSchema} + for t in resp.tools + ] async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: if not self._session: @@ -297,6 +475,7 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: # The Agent # --------------------------------------------------------------------------- + class ToolHiveAgent: """ ReAct-style agent that uses an LLM + MCP tools from ToolHive. @@ -306,7 +485,7 @@ class ToolHiveAgent: 2. Enters a ReAct loop: send task + tools to LLM → if tool call → invoke tool → append result → repeat. 3. Stops when the LLM returns a final answer (no tool calls) or - ``max_steps`` is reached. + ``max_steps`` is reached, or the same tool fails ``max_tool_failures`` times. """ def __init__( @@ -314,20 +493,30 @@ def __init__( mcp_client: McpClient, llm_provider: Any, # BaseLLMProvider max_steps: int = 10, + max_tool_failures: Optional[int] = None, ) -> None: self._mcp = mcp_client self._llm = llm_provider self._max_steps = max_steps + self._max_tool_failures = resolve_max_tool_failures(max_tool_failures) self._system_prompt: str = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " - "emails via SMTP.\n\n" + "emails via SMTP.\n" + "Tool names are `.` (e.g. `fhir_cerner.read_patient`, " + "`fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`). " + "Use exactly the names and JSON-schema arguments from tools/list.\n\n" "WORKFLOW (MUST EXECUTE SEQUENTIALLY, ONE STRICT STEP AT A TIME):\n" "When asked to 'Send patient summaries via email' or similar tasks, you MUST follow this exact flow in order. DO NOT parallelize these steps:\n" - " 1. First turn: Search for the patient. (If you have a Patient ID, you DO NOT need their name or birthdate).\n" - " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call the search tool with a guessed or hallucinated ID like '12345'.\n" + " 1. First turn: Obtain patient demographics from the EHR.\n" + ' - If the user gave a Patient ID: call `fhir_cerner.read_patient` or `fhir_epic.read_patient` with JSON `{"resource_id": ""}` (use Epic when the ID starts with \'e\'). Do NOT use search_patients for a known ID.\n' + " - If there is NO Patient ID but there IS a name: use name fields or `search_patients` per tools/list schema (e.g. `given_name`, `family_name`, `birthdate`, or valid `search_params`).\n" + " - Use `search_patients` only when you have no ID, or after `read_patient` failed and you need a fallback.\n" + " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call tools with a guessed or hallucinated ID like '12345'.\n" " 2. Second turn: Once you have the patient data from step 1, create a file on Google Drive containing the masked patient summary. Do NOT use placeholder content.\n" - " 3. Third turn: Once step 2 returns a 'web_view_link', send an email with that exact link. Do NOT call the email tool until you have the link.\n" + " For `google_drive.files.upload`, pass a flat JSON object: `name`, `mime_type` (snake_case — not `mimeType`), `parents`, and `content` (or `content_base64`). " + "If you include `action`, it must be exactly `files.upload`. Do not nest fields under a `file` object. Do NOT pass `media` / `media_body`.\n" + " 3. Third turn: Once step 2 returns a shareable Drive URL (see `data.raw.webViewLink` from tool `google_drive.files.upload`), send an email with that exact link. Do NOT call the email tool until you have the link.\n" " CRITICAL: You MUST ask the user for the recipient email address if they haven't provided it. DO NOT guess email addresses like 'recipient_email@example.com'.\n" " CRITICAL: In the email body, you MUST insert the actual URL string returned from step 2 (e.g. 'https://drive.google.com/...'). Do NOT literally write the text ''.\n\n" "DATA PRIVACY & MASKING — follow these strictly:\n" @@ -337,9 +526,14 @@ def __init__( " - NEVER use the placeholder values ('1990-05-12', '12724066', or 'Name') in your reports - always use the real patient data masked accordingly.\n" "- EMAIL WORKFLOW: When sending patient details to an email recipient:\n" " 1. ALWAYS upload the masked patient summary to Google Drive first.\n" - " 2. Use the 'web_view_link' returned by the google_drive_upload_file tool.\n" + " 2. Use `data.raw.webViewLink` from the `google_drive.files.upload` tool result.\n" " 3. In the email body, provide that link instead of the actual data.\n" " 4. The email body should be professional: 'Patient data summary from the EHR is available at the following secure link: [Link]'\n\n" + "PAGINATION HANDLING — IMPORTANT:\n" + "- When tools return pagination metadata with 'next_page_token', you MUST call the same tool again with 'page_token' set to that value to get the next page.\n" + "- Always check for pagination info in tool results and continue fetching pages until there's no 'next_page_token'.\n" + "- For Google Drive tools: Use 'page_token' from previous result to get next page of files.\n" + "- For FHIR search tools: Use pagination tokens to get complete result sets.\n\n" "GUARDRAILS:\n" "- NEVER hallucinate or make up patient details. DO NOT guess IDs like '12345'. If missing, ask the user.\n" "- NEVER use placeholders like 'to be updated later' or ''.\n" @@ -352,8 +546,6 @@ def __init__( "- Keep responses concise and professional.\n" ) - - async def run(self, task: str) -> AgentRunResult: trace_id = str(uuid.uuid4()) logger.info("Agent run started | trace_id=%s", trace_id) @@ -383,6 +575,9 @@ async def run(self, task: str) -> AgentRunResult: ] # 3. ReAct loop + tool_failures: Dict[str, int] = {} + abort_after_tool_failures = False + for step_num in range(1, self._max_steps + 1): logger.info("Agent step %d / %d", step_num, self._max_steps) @@ -394,11 +589,13 @@ async def run(self, task: str) -> AgentRunResult: return result # Track the assistant turn - messages.append(LLMMessage( - role="assistant", - content=llm_resp.content, - tool_calls=llm_resp.tool_calls, - )) + messages.append( + LLMMessage( + role="assistant", + content=llm_resp.content, + tool_calls=llm_resp.tool_calls, + ) + ) if not llm_resp.wants_tool_call: # LLM finished @@ -409,7 +606,8 @@ async def run(self, task: str) -> AgentRunResult: # Execute each tool call for tc in llm_resp.tool_calls: - logger.info("Calling tool: %s | args=%s", tc.name, tc.arguments) + scrubbed_args = _redact_tool_args_for_log(tc.name, tc.arguments) + logger.info("Calling tool: %s | args=%s", tc.name, scrubbed_args) agent_step = AgentStep( step=step_num, tool_called=tc.name, @@ -420,7 +618,44 @@ async def run(self, task: str) -> AgentRunResult: try: tool_result_str = await self._mcp.call_tool(tc.name, tc.arguments) - logger.info("Tool %s returned: %.200s", tc.name, tool_result_str) + logger.info( + "Tool %s returned response of length: %d chars", + tc.name, + len(tool_result_str), + ) + + # --- AUTOMATIC PAGINATION TOKEN HANDLING --- + try: + result_data = json.loads(tool_result_str) + pagination_meta = result_data.get("data", {}).get( + "_server_pagination_metadata", {} + ) + next_token = pagination_meta.get("next_page_token") + + if next_token: + print("\n=== PAGINATION TOKEN DETECTED ===", file=sys.stderr) + + # Add pagination info to tool result for LLM to see + pagination_info = ( + f"\n\n[PAGINATION INFO]\n" + f"Items returned: {pagination_meta.get('items_returned')}\n" + f"Was truncated: {pagination_meta.get('was_truncated_by_server', False)}\n" + f"Next page token available: {next_token}\n" + f"To get next page, call the same tool with page_token='{next_token}'" + ) + tool_result_str += pagination_info + print("=== ADDED PAGINATION INFO TO RESULT ===", file=sys.stderr) + else: + print("\n=== NO PAGINATION TOKEN FOUND ===", file=sys.stderr) + except (json.JSONDecodeError, KeyError) as e: + print(f"Error parsing pagination metadata: {e}", file=sys.stderr) + + print( + "=================================================\n", + file=sys.stderr, + flush=True, + ) + except Exception as exc: tool_result_str = f"ERROR: {exc}" logger.error("Tool %s failed: %s", tc.name, exc) @@ -428,24 +663,189 @@ async def run(self, task: str) -> AgentRunResult: agent_step.tool_result = tool_result_str result.steps.append(agent_step) - messages.append(LLMMessage( - role="tool", - content=tool_result_str, - tool_call_id=tc.id, - name=tc.name, - )) + llm_tool_content = truncate_tool_result_for_llm(tool_result_str) + if len(llm_tool_content) < len(tool_result_str): + logger.info( + "Tool %s result truncated for LLM: %d -> %d chars", + tc.name, + len(tool_result_str), + len(llm_tool_content), + ) + + messages.append( + LLMMessage( + role="tool", + content=llm_tool_content, + tool_call_id=tc.id, + name=tc.name, + ) + ) + + if _is_tool_failure(tool_result_str): + tool_failures[tc.name] = tool_failures.get(tc.name, 0) + 1 + if tool_failures[tc.name] >= self._max_tool_failures: + msg = _tool_failure_abort_message(tc.name, self._max_tool_failures) + result.error = msg + result.final_answer = msg + logger.warning("Stopping agent: %s", msg) + abort_after_tool_failures = True + break + + if abort_after_tool_failures: + break else: # Hit max_steps without a final answer - result.error = f"Agent reached max_steps ({self._max_steps}) without completing the task." + result.error = ( + f"Agent reached max_steps ({self._max_steps}) without completing the task." + ) logger.warning(result.error) + from node_wire_runtime.streaming import stream_completion_log + + stream_completion_log(trace_id, result.success, connector_id="agent", action="run") return result + async def run_events(self, task: str) -> AsyncIterator[Dict[str, Any]]: + trace_id = str(uuid.uuid4()) + from node_wire_runtime.streaming import resolve_stream_buffer_ms, BufferedStreamIterator + + buffer_ms = resolve_stream_buffer_ms() + iterator = self._run_events_inner(task, trace_id) + + if buffer_ms > 0: + async for item in BufferedStreamIterator( + iterator, buffer_ms, trace_id, connector_id="agent", action="run_events" + ): + yield item + else: + async for item in iterator: + yield item + + async def _run_events_inner(self, task: str, trace_id: str) -> AsyncIterator[Dict[str, Any]]: + """ + Stream agent progress events for web clients. + + Contract: + - ``meta``: emitted once with ``trace_id``. + - ``status``: informational progress text. + - ``step``: emitted after each MCP tool call completes. + - ``final_chunk``: chunks of the final assistant answer. + - ``error``: recoverable terminal error text. + - ``done``: always emitted at terminal completion; clients should stop + loaders when this event arrives. + Stream agent progress events as the ReAct loop runs. + + The LLM providers currently return complete assistant messages, so final + answer chunks begin after the final LLM call completes. Tool-step events + are emitted immediately after each MCP tool call completes. + """ + logger.info("Streaming agent run started | trace_id=%s", trace_id) + logger.info("Task: %s", task) + + from agents.llm_factory import LLMMessage + + yield {"type": "meta", "trace_id": trace_id} + + try: + tools = await self._mcp.list_tools() + logger.info("Discovered %d MCP tools", len(tools)) + yield {"type": "status", "message": f"Discovered {len(tools)} MCP tools"} + except Exception as exc: + error = f"Failed to list MCP tools: {exc}" + logger.error(error) + yield {"type": "error", "trace_id": trace_id, "message": error} + yield _stream_done_event(trace_id, success=False) + return + + messages: List[LLMMessage] = [ + LLMMessage(role="system", content=self._system_prompt), + LLMMessage(role="user", content=task), + ] + tool_failures: Dict[str, int] = {} + + for step_num in range(1, self._max_steps + 1): + logger.info("Streaming agent step %d / %d", step_num, self._max_steps) + yield {"type": "status", "message": f"Agent reasoning step {step_num}"} + + try: + llm_resp = self._llm.chat_with_tools(messages, tools) + except Exception as exc: + error = f"LLM error at step {step_num}: {exc}" + logger.error(error) + yield {"type": "error", "trace_id": trace_id, "message": error} + yield _stream_done_event(trace_id, success=False) + return + + messages.append( + LLMMessage( + role="assistant", + content=llm_resp.content, + tool_calls=llm_resp.tool_calls, + ) + ) + + if not llm_resp.wants_tool_call: + for chunk in _chunk_agent_text(llm_resp.content or ""): + yield {"type": "final_chunk", "content": chunk} + yield _stream_done_event(trace_id, success=True) + return + + abort_message: Optional[str] = None + for tc in llm_resp.tool_calls: + scrubbed_args = _redact_tool_args_for_log(tc.name, tc.arguments) + logger.info("Calling tool: %s | args=%s", tc.name, scrubbed_args) + + try: + tool_result_str = await self._mcp.call_tool(tc.name, tc.arguments) + logger.info("Tool %s returned: %.200s", tc.name, tool_result_str) + except Exception as exc: + tool_result_str = f"ERROR: {exc}" + logger.error("Tool %s failed: %s", tc.name, exc) + + yield { + "type": "step", + "step": step_num, + "tool": tc.name, + "args": tc.arguments, + "result": tool_result_str, + } + + messages.append( + LLMMessage( + role="tool", + content=truncate_tool_result_for_llm(tool_result_str), + tool_call_id=tc.id, + name=tc.name, + ) + ) + + if _is_tool_failure(tool_result_str): + tool_failures[tc.name] = tool_failures.get(tc.name, 0) + 1 + if tool_failures[tc.name] >= self._max_tool_failures: + abort_message = _tool_failure_abort_message( + tc.name, self._max_tool_failures + ) + logger.warning("Stopping streaming agent: %s", abort_message) + break + + if abort_message: + for chunk in _chunk_agent_text(abort_message): + yield {"type": "final_chunk", "content": chunk} + yield _stream_done_event(trace_id, success=False) + return + + error = f"Agent reached max_steps ({self._max_steps}) without completing the task." + logger.warning(error) + for chunk in _chunk_agent_text(error): + yield {"type": "final_chunk", "content": chunk} + yield _stream_done_event(trace_id, success=False) + # --------------------------------------------------------------------------- # CLI entrypoint # --------------------------------------------------------------------------- + async def _run_agent(args: argparse.Namespace) -> None: from agents.llm_factory import LLMProviderFactory @@ -453,6 +853,7 @@ async def _run_agent(args: argparse.Namespace) -> None: logger.info("Creating LLM provider: %s", llm_provider_name) provider = LLMProviderFactory.create_from_env() + mcp_client_context: Union[StdioMcpClient, ToolHiveMcpClient, MultiMcpClient] if args.local: logger.info("Using local stdio transport (launching server as subprocess)") # Launch the mcp_entrypoint.py as a subprocess @@ -474,26 +875,40 @@ async def _run_agent(args: argparse.Namespace) -> None: # Use the client (handle async context for stdio) if isinstance(mcp_client_context, StdioMcpClient): async with mcp_client_context as mcp_client: - agent = ToolHiveAgent(mcp_client, provider, max_steps=args.max_steps) + agent = ToolHiveAgent( + mcp_client, + provider, + max_steps=args.max_steps, + max_tool_failures=args.max_tool_failures, + ) await _execute_task(agent, args, llm_provider_name, "local-stdio") else: - agent = ToolHiveAgent(mcp_client_context, provider, max_steps=args.max_steps) + agent = ToolHiveAgent( + mcp_client_context, + provider, + max_steps=args.max_steps, + max_tool_failures=args.max_tool_failures, + ) await _execute_task(agent, args, llm_provider_name, ",".join(urls)) -async def _execute_task(agent: ToolHiveAgent, args: argparse.Namespace, provider_name: str, mcp_info: str) -> None: - +async def _execute_task( + agent: ToolHiveAgent, args: argparse.Namespace, provider_name: str, mcp_info: str +) -> None: # Build the task prompt task_parts = [ f"Patient ID: {args.patient_id}" if args.patient_id else "", - f"Patient name — family: {args.patient_family}, given: {args.patient_given}" if args.patient_family else "", - f"Please:", - f"1. Fetch the patient's details from Cerner FHIR or Epic FHIR (if the ID starts with 'e').", + f"Patient name — family: {args.patient_family}, given: {args.patient_given}" + if args.patient_family + else "", + "Please:", + "1. Fetch the patient's details from Cerner FHIR or Epic FHIR (if the ID starts with 'e').", f"2. Create a text file named 'patient_summary_{args.patient_id or args.patient_family}.txt' in Google Drive" - + (f" in folder {args.drive_folder_id}" if args.drive_folder_id else "") + ".", + + (f" in folder {args.drive_folder_id}" if args.drive_folder_id else "") + + ".", f"3. Send an email to {args.recipient_email} with the subject " f"'Patient Summary' and the patient details in the body.", - f"After completing all steps, confirm what was done.", + "After completing all steps, confirm what was done.", ] task = "\n".join(p for p in task_parts if p) @@ -532,16 +947,31 @@ def main() -> None: parser.add_argument("--patient-id", default="", help="Cerner or Epic FHIR Patient ID") parser.add_argument("--patient-family", default="", help="Patient family name (for search)") parser.add_argument("--patient-given", default="", help="Patient given name (for search)") - parser.add_argument("--recipient-email", required=True, help="Email address to send the summary to") - parser.add_argument("--drive-folder-id", default=os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), help="Google Drive folder ID (optional)") - parser.add_argument("--max-steps", type=int, default=10, help="Maximum agent steps (default: 10)") - parser.add_argument("--local", action="store_true", help="Run against local server via stdio (no proxy)") + parser.add_argument( + "--recipient-email", required=True, help="Email address to send the summary to" + ) + parser.add_argument( + "--drive-folder-id", + default=os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), + help="Google Drive folder ID (optional)", + ) + parser.add_argument( + "--max-steps", type=int, default=10, help="Maximum agent steps (default: 10)" + ) + parser.add_argument( + "--max-tool-failures", + type=int, + default=None, + help="Stop after this many failed calls per tool name (default: env TOOLHIVE_MAX_TOOL_FAILURES or 2)", + ) + parser.add_argument( + "--local", action="store_true", help="Run against local server via stdio (no proxy)" + ) args = parser.parse_args() if not args.patient_id and not args.patient_family: parser.error("Provide either --patient-id or --patient-family") - import sys asyncio.run(_run_agent(args)) diff --git a/src/bindings/__init__.py b/src/bindings/__init__.py index a37e4eb..c11239c 100644 --- a/src/bindings/__init__.py +++ b/src/bindings/__init__.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations """ @@ -9,4 +13,3 @@ - gRPC server binding - MCP server binding for AI agents """ - diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 8a28256..214425d 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -1,83 +1,160 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import logging +import os +import re from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional import yaml -from connectors.fhir_epic.logic import FhirEpicConnector -from connectors.fhir_cerner.logic import FhirCernerConnector -from connectors.http_generic.logic import HttpGenericConnector -from connectors.http_generic.schema import HttpRequestInput, HttpResponseOutput -from connectors.google_drive.logic import GoogleDriveConnector -from connectors.google_drive.schema import ( - GoogleDriveOperationInput, - GoogleDriveOperationOutput, +from node_wire_runtime import BaseConnector, SecretProvider +from node_wire_runtime.base_connector import _CONNECTOR_REGISTRY +from node_wire_runtime.policy import PolicyHook +from node_wire_runtime.policies.mcp_scope_policy import ( + DEFAULT_SCOPE_MODE_DENY, + ScopePolicyHook, + load_scope_map_from_env, + load_scope_policy_default_from_env, ) -from connectors.smtp.logic import SmtpConnector -from connectors.smtp.schema import SmtpSendInput, SmtpSendOutput -from connectors.stripe.logic import StripeChargeConnector -from connectors.stripe.schema import ChargeInput, ChargeOutput -from runtime import BaseConnector, SecretProvider +from node_wire_runtime.secrets import ChainedSecretProvider, EnvSecretProvider logger = logging.getLogger("bindings.factory") -# Resolve default config relative to platform root so it works from any cwd. _PLATFORM_ROOT = Path(__file__).resolve().parent.parent.parent _DEFAULT_CONFIG_PATH = _PLATFORM_ROOT / "config" / "connectors.yaml" -@dataclass -class ConnectorConfig: - id: str - enabled: bool - exposed_via: List[str] - raw: Dict[str, Any] +def _resolve_env_vars(data: Any) -> Any: + if isinstance(data, dict): + return {k: _resolve_env_vars(v) for k, v in data.items()} + elif isinstance(data, list): + return [_resolve_env_vars(item) for item in data] + elif isinstance(data, str): + def replacer(match: Any) -> str: + var_name = match.group(1) + default = match.group(3) + if var_name in os.environ: + return os.environ[var_name] + elif default is not None: + return default + return match.group(0) -class EnvSecretProvider(SecretProvider): + return re.sub(r"\$\{([A-Za-z0-9_]+)(:(.*?))?\}", replacer, data) + return data + + +def _resolve_config_path(explicit: str | Path | None) -> str: + """Resolve connector config path with NW_CONFIG_PATH env var support. + + Priority order (first match wins): + 1. Explicit argument passed to ConnectorFactory() + 2. NW_CONFIG_PATH environment variable + 3. /config/connectors.yaml (existing default — no breakage) + 4. /config/connectors.yaml (existing fallback — no breakage) """ - Simple SecretProvider implementation backed by environment variables. + if explicit is not None: + return str(explicit) + env_path = os.getenv("NW_CONFIG_PATH") + if env_path: + return env_path + if _DEFAULT_CONFIG_PATH.is_file(): + return str(_DEFAULT_CONFIG_PATH) + return str(Path.cwd() / "config" / "connectors.yaml") + + +def _build_secret_provider() -> SecretProvider: + """Compose secret providers from ``NW_SECRET_BACKEND`` (default: ``env``). + + - ``env`` — :class:`EnvSecretProvider` only (fail-closed unless ``NW_ENV_SECRET_LEGACY_EMPTY``). + - ``aws_env`` — :class:`ChainedSecretProvider`( + :class:`~node_wire_runtime.secrets.aws.AwsSecretsManagerProvider`, + :class:`EnvSecretProvider`) for JSON bundle in AWS SM then env fallback. - Keys are looked up directly from os.environ for the POC. + Environment for ``aws_env``: + ``NW_AWS_SECRETS_MANAGER_SECRET_ID`` — Secrets Manager secret id or ARN + ``AWS_REGION`` — optional, default ``us-east-1`` """ + mode = os.environ.get("NW_SECRET_BACKEND", "env").strip().lower() + if mode in ("", "env"): + return EnvSecretProvider() + if mode == "aws_env": + secret_id = os.environ.get("NW_AWS_SECRETS_MANAGER_SECRET_ID") + if not secret_id: + raise ValueError( + "NW_SECRET_BACKEND=aws_env requires NW_AWS_SECRETS_MANAGER_SECRET_ID to be set" + ) + from node_wire_runtime.secrets.aws import AwsSecretsManagerProvider - def __init__(self) -> None: - import os + region = os.environ.get("AWS_REGION", "us-east-1") + return ChainedSecretProvider( + AwsSecretsManagerProvider(secret_name=secret_id, region=region), + EnvSecretProvider(), + ) + raise ValueError(f"Unknown NW_SECRET_BACKEND {mode!r}. Supported: env, aws_env.") - self._env = os.environ - def get_secret(self, key: str) -> str: - val = self._env.get(key) - if val is not None: - return val.strip(" '\"") - val = self._env.get(key.upper()) - if val is not None: - return val.strip(" '\"") - # Return empty string instead of raising RuntimeError for zero-config/local testing. - return "" +def _build_policy_hook() -> PolicyHook | None: + action_scope_map = load_scope_map_from_env() + default_mode = load_scope_policy_default_from_env() + strict_mode = os.environ.get("NW_MCP_SCOPE_POLICY_STRICT", "").strip().lower() in ( + "1", + "true", + "yes", + "on", + ) + logger.info( + "Evaluated MCP scope policy configuration", + extra={ + "scope_map_entries": len(action_scope_map), + "default_mode": default_mode, + "strict_mode": strict_mode, + }, + ) + if not action_scope_map and default_mode != DEFAULT_SCOPE_MODE_DENY: + msg = ( + "MCP scope policy is effectively disabled " + "(NW_MCP_ACTION_SCOPE_MAP_JSON empty and NW_MCP_SCOPE_POLICY_DEFAULT=allow). " + "Set NW_MCP_SCOPE_POLICY_DEFAULT=deny for production." + ) + if strict_mode: + raise ValueError(msg + " Strict mode is enabled via NW_MCP_SCOPE_POLICY_STRICT=true.") + logger.warning(msg) + logger.info("Policy hook disabled (no action scope map; default is allow)") + return None + logger.info( + "Policy hook enabled", + extra={ + "scope_map_entries": len(action_scope_map), + "default_mode": default_mode, + }, + ) + return ScopePolicyHook(action_scope_map, default_mode=default_mode) + + +@dataclass +class ConnectorConfig: + id: str + enabled: bool + exposed_via: List[str] + raw: Dict[str, Any] class ConnectorFactory: """ - Factory responsible for: - - Loading connector configuration from config/connectors.yaml - - Instantiating connector adapters - - Enforcing exposed_via rules per protocol + Loads connectors.yaml and instantiates connectors from the connector registry. """ def __init__(self, config_path: str | Path | None = None) -> None: - if config_path is not None: - self._config_path = str(config_path) - elif _DEFAULT_CONFIG_PATH.is_file(): - self._config_path = str(_DEFAULT_CONFIG_PATH) - else: - # Fallback when run from platform dir (e.g. package installed from wheel) - cwd_config = Path.cwd() / "config" / "connectors.yaml" - self._config_path = str(cwd_config) - self._secret_provider: SecretProvider = EnvSecretProvider() + self._config_path = _resolve_config_path(config_path) + self._secret_provider: SecretProvider = _build_secret_provider() + self._policy_hook: PolicyHook | None = _build_policy_hook() self._connectors: Dict[str, Any] = {} self._configs: Dict[str, ConnectorConfig] = {} @@ -91,6 +168,8 @@ def load(self) -> None: with open(path, "r", encoding="utf-8") as f: raw = yaml.safe_load(f) or {} + raw = _resolve_env_vars(raw) + connectors_cfg: Dict[str, Any] = raw.get("connectors", {}) for connector_id, cfg in connectors_cfg.items(): @@ -111,29 +190,115 @@ def load(self) -> None: ) continue - self._connectors[connector_id] = self._instantiate(connector_id) - - def _instantiate(self, connector_id: str) -> Any: - if connector_id == "http_generic": - return HttpGenericConnector(HttpRequestInput, HttpResponseOutput, secret_provider=self._secret_provider) - if connector_id == "smtp": - return SmtpConnector(SmtpSendInput, SmtpSendOutput, secret_provider=self._secret_provider) - if connector_id == "stripe": - return StripeChargeConnector(ChargeInput, ChargeOutput, secret_provider=self._secret_provider) - if connector_id == "google_drive": - return GoogleDriveConnector( - GoogleDriveOperationInput, - GoogleDriveOperationOutput, + if connector_id not in _CONNECTOR_REGISTRY: + logger.warning( + "Connector enabled in configuration but not registered; skipping instantiation", + extra={ + "connector_id": connector_id, + "reason": "Filtered by NW_ALLOWED_CONNECTORS or not installed", + }, + ) + continue + + instance = self._instantiate(connector_id) + self._connectors[connector_id] = instance + + def _build_auth_provider(self, connector_id: str, cfg: dict) -> Any: + """Construct the appropriate AuthProvider from the connector's YAML ``auth:`` block. + + Falls back to :class:`NoAuthProvider` when the block is absent. + """ + from node_wire_runtime.auth import ( + NoAuthProvider, + OAuth2AuthProvider, + ServiceAccountAuthProvider, + StaticTokenAuthProvider, + ) + + auth_cfg = cfg.get("auth") or {} + provider_type = auth_cfg.get("provider", "none") + + if provider_type in ("none", ""): + return NoAuthProvider() + + if provider_type == "static_token": + return StaticTokenAuthProvider( + secret_provider=self._secret_provider, + secret_key=auth_cfg["secret_key"], + header_name=auth_cfg.get("header_name", "Authorization"), + prefix=auth_cfg.get("prefix", "Bearer"), + encoding=auth_cfg.get("encoding"), + ) + + if provider_type == "oauth2": + return OAuth2AuthProvider( secret_provider=self._secret_provider, + grant_method=auth_cfg.get("grant_method", "private_key_jwt"), + token_url_secret=auth_cfg["token_url_secret"], + client_id_secret=auth_cfg["client_id_secret"], + algorithm=auth_cfg.get("algorithm", "RS384"), + private_key_secret=auth_cfg.get("private_key_secret"), + kid_secret=auth_cfg.get("kid_secret"), + client_secret_secret=auth_cfg.get("client_secret_secret"), + refresh_token_secret=auth_cfg.get("refresh_token_secret"), + scopes=auth_cfg.get("scopes"), + scopes_secret=auth_cfg.get("scopes_secret"), + extra_content_type_headers=auth_cfg.get("extra_headers"), + buffer_secs=int(auth_cfg.get("buffer_secs", 60)), + jwt_ttl_secs=int(auth_cfg.get("jwt_ttl_secs", 300)), ) - if connector_id == "fhir_epic": - return FhirEpicConnector(secret_provider=self._secret_provider) - if connector_id == "fhir_cerner": - return FhirCernerConnector(secret_provider=self._secret_provider) - raise ValueError(f"Unknown connector id {connector_id!r}") + if provider_type == "service_account": + return ServiceAccountAuthProvider( + secret_provider=self._secret_provider, + sa_json_secret=auth_cfg["sa_json_secret"], + scopes=auth_cfg.get("scopes"), + ) - def get_for_protocol(self, connector_id: str, protocol: str, action: Optional[str] = None) -> Optional[BaseConnector[Any, Any]]: + if provider_type == "static_credentials": + # SMTP-style: returns (username, password) tuple via get_client_credentials(). + # We use a lightweight wrapper around StaticTokenAuthProvider pair. + username_secret = auth_cfg.get("username_secret", "SMTP_USERNAME") + password_secret = auth_cfg.get("password_secret", "SMTP_PASSWORD") + from node_wire_runtime.auth.base import AuthProvider + + sp = self._secret_provider + + class _SmtpCredentialsProvider(AuthProvider): # type: ignore[misc] + async def get_headers(self) -> dict: + return {} + + async def get_client_credentials(self): # type: ignore[override] + return (sp.get_secret(username_secret), sp.get_secret(password_secret)) + + return _SmtpCredentialsProvider() + + logger.warning( + "Unknown auth provider type %r for connector %r — defaulting to NoAuthProvider", + provider_type, + connector_id, + ) + return NoAuthProvider() + + def _instantiate(self, connector_id: str) -> "BaseConnector | None": + connector_cls = _CONNECTOR_REGISTRY.get(connector_id) + if connector_cls is not None: + cfg = self._configs[connector_id] + auth_provider = self._build_auth_provider(connector_id, cfg.raw) + return connector_cls( + secret_provider=self._secret_provider, + auth_provider=auth_provider, + policy_hook=self._policy_hook, + ) + + raise RuntimeError( + f"Connector {connector_id!r} is enabled in config but not registered " + "(filtered by NW_ALLOWED_CONNECTORS or not installed)" + ) + + def get_for_protocol( + self, connector_id: str, protocol: str, action: Optional[str] = None + ) -> Optional[BaseConnector]: cfg = self._configs.get(connector_id) if cfg is None: logger.warning( @@ -160,19 +325,17 @@ def get_for_protocol(self, connector_id: str, protocol: str, action: Optional[st if connector is None: return None - # Multi-action connectors (e.g. fhir_epic) expose a get_action() helper. - if action and hasattr(connector, "get_action"): - return connector.get_action(action) + if action: + logger.debug( + "get_for_protocol resolved connector", + extra={"connector_id": connector_id, "protocol": protocol, "action": action}, + ) return connector # type: ignore[return-value] - def list_for_protocol(self, protocol: str) -> List[BaseConnector[Any, Any]]: - result: List[BaseConnector[Any, Any]] = [] + def list_for_protocol(self, protocol: str) -> List[BaseConnector]: + result: List[BaseConnector] = [] for connector_id, connector in self._connectors.items(): if protocol in self._configs[connector_id].exposed_via: - # Multi-action connectors expose all their actions via list_actions(). - if hasattr(connector, "list_actions"): - result.extend(connector.list_actions()) - else: - result.append(connector) # type: ignore[arg-type] + result.append(connector) # type: ignore[arg-type] return result diff --git a/src/bindings/grpc_server/auth.py b/src/bindings/grpc_server/auth.py new file mode 100644 index 0000000..12067c3 --- /dev/null +++ b/src/bindings/grpc_server/auth.py @@ -0,0 +1,82 @@ +""" +gRPC API authentication (enterprise default: required API key or JWT). + +Environment: + NW_GRPC_API_KEY — shared secret; passed via metadata key 'authorization' or 'x-api-key'. + NW_GRPC_JWT_SECRET — optional HS256 secret; if set, tokens with three segments are verified as JWTs. + NW_GRPC_AUTH_DISABLED — if ``true``/``1``/``yes``, skip auth (local dev only; do not use in production). +""" + +from __future__ import annotations + +import os +from typing import Any, Callable + +import grpc +import jwt + + +def _truthy(val: str | None) -> bool: + if val is None: + return False + return val.strip().lower() in ("1", "true", "yes", "on") + + +def _extract_token(metadata: tuple[tuple[str, str], ...]) -> str | None: + for key, value in metadata: + k = key.lower() + if k == "authorization": + if value.lower().startswith("bearer "): + return value[7:].strip() + return value.strip() + if k == "x-api-key": + return value.strip() + return None + + +def _verify_token(token: str, *, api_key: str | None, jwt_secret: str | None) -> bool: + if api_key and token == api_key: + return True + if jwt_secret and token.count(".") == 2: + try: + jwt.decode(token, jwt_secret, algorithms=["HS256"]) + return True + except jwt.PyJWTError: + return False + return False + + +class GrpcAuthInterceptor(grpc.ServerInterceptor): + def intercept_service( + self, + continuation: Callable[[grpc.HandlerCallDetails], Any], + handler_call_details: grpc.HandlerCallDetails, + ) -> Any: + if _truthy(os.environ.get("NW_GRPC_AUTH_DISABLED")): + return continuation(handler_call_details) + + api_key = os.environ.get("NW_GRPC_API_KEY") + jwt_secret = os.environ.get("NW_GRPC_JWT_SECRET") + + def _abort_with_status(code: grpc.StatusCode, details: str) -> Any: + def abort(request: Any, context: grpc.ServicerContext) -> None: + context.abort(code, details) + + return grpc.unary_unary_rpc_method_handler(abort) + + if not api_key and not jwt_secret: + return _abort_with_status( + grpc.StatusCode.UNAVAILABLE, + "gRPC API authentication is not configured. Set NW_GRPC_API_KEY " + "(and optionally NW_GRPC_JWT_SECRET), or set NW_GRPC_AUTH_DISABLED=true " + "for local development only.", + ) + + token = _extract_token(handler_call_details.invocation_metadata or ()) + if not token: + return _abort_with_status(grpc.StatusCode.UNAUTHENTICATED, "Authentication required") + + if not _verify_token(token, api_key=api_key, jwt_secret=jwt_secret): + return _abort_with_status(grpc.StatusCode.UNAUTHENTICATED, "Invalid API key or token") + + return continuation(handler_call_details) diff --git a/src/bindings/grpc_server/connector.proto b/src/bindings/grpc_server/connector.proto index c88df06..3ab0bfa 100644 --- a/src/bindings/grpc_server/connector.proto +++ b/src/bindings/grpc_server/connector.proto @@ -1,3 +1,5 @@ +// SPDX-FileCopyrightText: 2026 AOT Technologies +// SPDX-License-Identifier: Apache-2.0 syntax = "proto3"; package aot.connectors; @@ -20,4 +22,3 @@ message InvokeResponse { string message = 5; string trace_id = 6; } - diff --git a/src/bindings/grpc_server/server.py b/src/bindings/grpc_server/server.py index c411cb1..1342930 100644 --- a/src/bindings/grpc_server/server.py +++ b/src/bindings/grpc_server/server.py @@ -1,18 +1,26 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import asyncio import json import logging +import os from concurrent import futures from typing import Any import grpc from bindings.factory import ConnectorFactory -from connectors import auto_register -from runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.connector_registry import auto_register +from node_wire_runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.ingress import normalize_mcp_tool_arguments +from node_wire_runtime.rate_limit import global_rate_limiter, RateLimitExceeded from . import connector_pb2, connector_pb2_grpc # type: ignore[attr-defined] +from .auth import GrpcAuthInterceptor logger = logging.getLogger("bindings.grpc_server") @@ -23,7 +31,20 @@ def __init__(self) -> None: self._factory = ConnectorFactory() self._factory.load() - async def _invoke_async(self, request: connector_pb2.InvokeRequest) -> connector_pb2.InvokeResponse: # type: ignore[name-defined] + async def _invoke_async( + self, request: connector_pb2.InvokeRequest + ) -> connector_pb2.InvokeResponse: # type: ignore[name-defined] + try: + await global_rate_limiter.acquire() + except RateLimitExceeded as e: + return connector_pb2.InvokeResponse( # type: ignore[name-defined] + success=False, + error_code="RATE_LIMIT_EXCEEDED", + error_category=ErrorCategory.RETRYABLE.value, + message=str(e), + trace_id="", + ) + connector = self._factory.get_for_protocol(request.connector_id, "grpc") if connector is None: return connector_pb2.InvokeResponse( # type: ignore[name-defined] @@ -36,12 +57,31 @@ async def _invoke_async(self, request: connector_pb2.InvokeRequest) -> connector payload: Any = {} if request.payload_json: - payload = json.loads(request.payload_json) + try: + payload = json.loads(request.payload_json) + except json.JSONDecodeError as e: + return connector_pb2.InvokeResponse( # type: ignore[name-defined] + success=False, + error_code="INVALID_JSON", + error_category=ErrorCategory.BUSINESS.value, + message=f"Failed to parse payload_json: {e}", + trace_id="", + ) + + if isinstance(payload, dict): + # The payload MUST include the action for Pydantic discriminated union validation to succeed + if request.action: + payload["action"] = request.action + + if payload.get("action"): + normalize_mcp_tool_arguments(connector, str(payload["action"]), payload) response: ConnectorResponse = await connector.run(payload) data_json = json.dumps(response.data) if response.data is not None else "" - error_category = response.error_category.value if response.error_category is not None else "" + error_category = ( + response.error_category.value if response.error_category is not None else "" + ) return connector_pb2.InvokeResponse( # type: ignore[name-defined] success=response.success, @@ -58,14 +98,34 @@ def Invoke(self, request, context): # type: ignore[override] def serve(port: int = 50051) -> None: - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + interceptor = GrpcAuthInterceptor() + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), interceptors=(interceptor,)) connector_pb2_grpc.add_ConnectorServiceServicer_to_server(ConnectorServiceServicer(), server) # type: ignore[attr-defined] - server.add_insecure_port(f"[::]:{port}") - logger.info("Starting gRPC server", extra={"port": port}) + + cert_path = os.environ.get("NW_GRPC_TLS_CERT_PATH") + key_path = os.environ.get("NW_GRPC_TLS_KEY_PATH") + + if cert_path and key_path: + # Load the TLS certificate and key + with open(key_path, "rb") as f: + private_key = f.read() + with open(cert_path, "rb") as f: + certificate_chain = f.read() + + server_credentials = grpc.ssl_server_credentials(((private_key, certificate_chain),)) + server.add_secure_port(f"[::]:{port}", server_credentials) + logger.info("Starting secure gRPC server (TLS enabled)", extra={"port": port}) + else: + server.add_insecure_port(f"[::]:{port}") + logger.warning( + "Starting insecure gRPC server (no TLS credentials found). " + "Traffic will be unencrypted.", + extra={"port": port}, + ) + server.start() server.wait_for_termination() if __name__ == "__main__": serve() - diff --git a/src/bindings/mcp_server/auth.py b/src/bindings/mcp_server/auth.py new file mode 100644 index 0000000..9a69ec9 --- /dev/null +++ b/src/bindings/mcp_server/auth.py @@ -0,0 +1,246 @@ +""" +MCP authentication (enterprise default: required API key or JWT). + +Environment: + NW_MCP_API_KEY — shared secret; send as ``Authorization: Bearer `` or ``X-API-Key: ``. + NW_MCP_JWT_SECRET — optional HS256 secret; if set, Bearer tokens with three segments are verified as JWTs. + NW_MCP_AUTH_DISABLED — if ``true``/``1``/``yes``, skip auth (local dev only; do not use in production). +""" + +from __future__ import annotations + +import os +import logging +from pathlib import Path +from typing import Any, Mapping + +import jwt +from dotenv import load_dotenv + +from node_wire_runtime.caller_identity import ( + CallerIdentity, + build_caller_identity, + parse_api_key_scopes_from_env, +) + +logger = logging.getLogger("bindings.mcp_server.auth") + +# Back-compat: callers may still import ``McpIdentity`` / ``build_identity`` from MCP auth. +McpIdentity = CallerIdentity + + +def _truthy(val: str | None) -> bool: + if val is None: + return False + return val.strip().lower() in ("1", "true", "yes", "on") + + +class McpAuthError(PermissionError): + def __init__( + self, + detail: str, + *, + status_code: int, + error_code: str, + www_authenticate: str | None = None, + ) -> None: + super().__init__(detail) + self.detail = detail + self.status_code = status_code + self.error_code = error_code + self.www_authenticate = www_authenticate + + def to_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "detail": self.detail, + "error_code": self.error_code, + "status_code": self.status_code, + } + if self.www_authenticate: + payload["www_authenticate"] = self.www_authenticate + return payload + + +class McpAuthRequiredError(McpAuthError): + def __init__(self) -> None: + super().__init__( + "Authentication required", + status_code=401, + error_code="MCP_AUTH_REQUIRED", + www_authenticate='Bearer realm="node-wire"', + ) + + +class McpAuthInvalidError(McpAuthError): + def __init__(self) -> None: + super().__init__( + "Invalid API key or token", + status_code=403, + error_code="MCP_AUTH_INVALID", + www_authenticate='Bearer realm="node-wire"', + ) + + +class McpAuthNotConfiguredError(McpAuthError): + def __init__(self) -> None: + super().__init__( + ( + "MCP authentication is not configured. Set NW_MCP_API_KEY " + "(and optionally NW_MCP_JWT_SECRET), or set NW_MCP_AUTH_DISABLED=true " + "for local development only." + ), + status_code=503, + error_code="MCP_AUTH_NOT_CONFIGURED", + ) + + +_mcp_auth_env_bootstrapped = False + + +def _bootstrap_mcp_auth_env() -> None: + global _mcp_auth_env_bootstrapped + if _mcp_auth_env_bootstrapped: + return + + # Some launch paths on Windows can miss .env loading for the MCP worker. + # If MCP auth vars are missing/empty, try loading project .env once. + if os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET"): + _mcp_auth_env_bootstrapped = True + return + + # Align with REST/bindings: when dotenv merge is disabled (pytest, CI, prod), + # never load repo `.env` with override=True — that stomps conftest env and + # monkeypatched values (e.g. NW_ALLOWED_CONNECTORS, NW_MCP_AUTH_DISABLED). + rest_dotenv = os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() + if rest_dotenv in ("0", "false", "no"): + # Keys may be injected later (tests); do not mark bootstrapped so we recheck. + return + + repo_root_env = Path(__file__).resolve().parents[3] / ".env" + load_dotenv(override=False) + load_dotenv(repo_root_env, override=False) + _mcp_auth_env_bootstrapped = True + + +def mcp_auth_disabled() -> bool: + disabled = os.environ.get("NW_MCP_AUTH_DISABLED") + if disabled is not None: + return _truthy(disabled) + + legacy_enabled = os.environ.get("NW_MCP_AUTH_ENABLED") + if legacy_enabled is not None: + logger.warning( + "NW_MCP_AUTH_ENABLED is deprecated; use NW_MCP_AUTH_DISABLED instead " + "(true disables auth). NW_MCP_AUTH_ENABLED will be removed in a future release." + ) + return _truthy(legacy_enabled) + + return False + + +def mcp_auth_configured() -> bool: + _bootstrap_mcp_auth_env() + return bool(os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET")) + + +def log_mcp_auth_startup_state() -> None: + """Log effective MCP auth posture once at server startup.""" + _bootstrap_mcp_auth_env() + disabled = mcp_auth_disabled() + configured = mcp_auth_configured() + state = "disabled" if disabled else "enabled" + logger.info("MCP authentication %s (configured=%s)", state, configured) + if disabled: + logger.warning("NW_MCP_AUTH_DISABLED is set — MCP auth is OFF; do not use in production") + + +def _get_meta_value(meta: Mapping[str, Any] | None, keys: tuple[str, ...]) -> str | None: + if not meta: + return None + for key in keys: + val = meta.get(key) + if isinstance(val, str) and val.strip(): + return val.strip() + return None + + +def extract_token( + *, + headers: Mapping[str, Any] | None = None, + meta: Mapping[str, Any] | None = None, +) -> str | None: + if headers: + auth = headers.get("authorization") or headers.get("Authorization") + if isinstance(auth, str) and auth.lower().startswith("bearer "): + return auth[7:].strip() + x_api_key = headers.get("x-api-key") or headers.get("X-API-Key") + if isinstance(x_api_key, str) and x_api_key.strip(): + return x_api_key.strip() + + auth_meta = _get_meta_value(meta, ("authorization", "Authorization")) + if auth_meta and auth_meta.lower().startswith("bearer "): + return auth_meta[7:].strip() + + return _get_meta_value(meta, ("x-api-key", "X-API-Key", "api_key", "apiKey", "token")) + + +def verify_mcp_token(token: str) -> tuple[dict[str, Any], str]: + api_key = os.getenv("NW_MCP_API_KEY") + jwt_secret = os.getenv("NW_MCP_JWT_SECRET") + + if api_key and token == api_key: + scopes = list(parse_api_key_scopes_from_env("NW_MCP_API_KEY_SCOPES")) + return ({"sub": "api-key-user", "tenant_id": None, "scopes": scopes}, "api_key") + + if jwt_secret and token.count(".") == 2: + try: + claims = jwt.decode(token, jwt_secret, algorithms=["HS256"]) + logger.info("MCP token verified as JWT") + return (claims, "jwt") + except jwt.PyJWTError as exc: + raise McpAuthInvalidError() from exc + + raise McpAuthInvalidError() + + +def build_identity(claims: Mapping[str, Any], auth_type: str) -> CallerIdentity: + """Deprecated alias for :func:`build_caller_identity`; prefer that name in new code.""" + return build_caller_identity(claims, auth_type) + + +def authenticate_mcp_request( + *, + headers: Mapping[str, Any] | None = None, + meta: Mapping[str, Any] | None = None, +) -> CallerIdentity | None: + logger.info( + "MCP auth gate status", + extra={ + "auth_disabled": mcp_auth_disabled(), + "auth_configured": mcp_auth_configured(), + "has_api_key": bool(os.environ.get("NW_MCP_API_KEY")), + "has_jwt_secret": bool(os.environ.get("NW_MCP_JWT_SECRET")), + }, + ) + if mcp_auth_disabled(): + return None + + if not mcp_auth_configured(): + raise McpAuthNotConfiguredError() + + token = extract_token(headers=headers, meta=meta) + if not token: + raise McpAuthRequiredError() + + claims, auth_type = verify_mcp_token(token) + identity = build_caller_identity(claims, auth_type) + logger.info( + "MCP auth accepted", + extra={ + "auth_type": identity.auth_type, + "principal": identity.principal, + "tenant_id": identity.tenant_id or "", + "scopes": list(identity.scopes), + }, + ) + return identity diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index ce98707..2ecc4ed 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -1,61 +1,577 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations +import contextvars import json import logging -from typing import Any, Dict, List +import os +import uuid +from contextvars import ContextVar +from typing import Any, Dict, List, Mapping, Optional, Tuple from bindings.factory import ConnectorFactory -from connectors import auto_register -from connectors.manifest import build_manifest +from bindings.mcp_server.auth import ( + McpAuthError, + authenticate_mcp_request, + log_mcp_auth_startup_state, +) +from node_wire_runtime.caller_identity import CallerIdentity +from node_wire_runtime.policies.mcp_scope_policy import ( + action_allowed_for_identity_scopes, + load_scope_map_from_env, + load_scope_policy_default_from_env, +) +from node_wire_runtime.connector_registry import auto_register +from node_wire_runtime.manifest import MCP_MANIFEST_CONTRACT_VERSION, build_manifest +from node_wire_runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.ingress import enforce_authoritative_action, normalize_mcp_tool_arguments +from node_wire_runtime.rate_limit import global_rate_limiter, RateLimitExceeded +from node_wire_runtime.streaming import stream_completion_log logger = logging.getLogger("bindings.mcp_server") +_streamable_http_identity_ctx: contextvars.ContextVar[CallerIdentity | None] = ( + contextvars.ContextVar( + "nw_streamable_http_identity", + default=None, + ) +) + +_http_request_headers: ContextVar[Mapping[str, str] | None] = ContextVar( + "mcp_http_request_headers", + default=None, +) + + +def _process_response_payload(data: Any, max_items: int) -> Tuple[Any, bool, int, Optional[str]]: + """ + Recursively search for large lists and truncate them. + Also tracks the maximum list size found and searches for pagination tokens. + Returns: (processed_data, was_truncated, max_list_size, next_page_token) + """ + next_page_token = None + max_list_size = 0 + was_truncated = False + + if isinstance(data, list): + current_len = len(data) + max_list_size = max(max_list_size, current_len) + + working_list = data + if current_len > max_items: + working_list = data[:max_items] + was_truncated = True + + out_list = [] + for item in working_list: + new_item, t, mls, npt = _process_response_payload(item, max_items) + out_list.append(new_item) + was_truncated = was_truncated or t + max_list_size = max(max_list_size, mls) + if npt and not next_page_token: + next_page_token = npt + + return out_list, was_truncated, max_list_size, next_page_token + + if isinstance(data, dict): + out_dict = {} + for k, v in data.items(): + if k in ( + "nextPageToken", + "pageToken", + "next_cursor", + "cursor", + "next_page_token", + ) and isinstance(v, str): + if not next_page_token: + next_page_token = v + + new_v, t, mls, npt = _process_response_payload(v, max_items) + out_dict[k] = new_v + was_truncated = was_truncated or t + max_list_size = max(max_list_size, mls) + if npt and not next_page_token: + next_page_token = npt + + return out_dict, was_truncated, max_list_size, next_page_token + + return data, False, 0, next_page_token class McpServer: """ - Minimal MCP-style server abstraction for the POC. + Manifest-driven MCP server: tools come from connector metadata; execution + dispatches through ConnectorFactory and connector.run(). - This does not implement the full Model Context Protocol over JSON-RPC, - but exposes two conceptual operations: - - list_tools(): returns connector/actions manifest - - invoke_tool(name, arguments): executes the corresponding connector + Use list_tools() / invoke_tool() for programmatic access, or run_stdio() + for a full MCP stdio transport. """ - def __init__(self) -> None: + def __init__( + self, + *, + server_name: str = "node-wire", + connector_ids: Optional[List[str]] = None, + ) -> None: + self._server_name = server_name + self._connector_ids: Optional[frozenset[str]] = ( + None if connector_ids is None else frozenset(connector_ids) + ) auto_register() self._factory = ConnectorFactory() self._factory.load() + try: + from importlib.metadata import version as pkg_version + + _pkg_ver = pkg_version("node-wire") + except Exception: # pragma: no cover + _pkg_ver = "unknown" + logger.info( + "MCP server initialized | server_name=%s | manifest_contract=%s | package=%s", + server_name, + MCP_MANIFEST_CONTRACT_VERSION, + _pkg_ver, + ) + log_mcp_auth_startup_state() - def list_tools(self) -> List[Dict[str, Any]]: + def list_tools(self, *, identity: CallerIdentity | None = None) -> List[Dict[str, Any]]: + identity = self._ensure_identity(identity=identity) + return self._list_tools_impl(identity=identity) + + def _list_tools_impl(self, *, identity: CallerIdentity | None = None) -> List[Dict[str, Any]]: + scope_map = load_scope_map_from_env() + default_mode = load_scope_policy_default_from_env() connectors = self._factory.list_for_protocol("mcp") manifest = build_manifest(connectors) tools: List[Dict[str, Any]] = [] for entry in manifest: + cid = entry["connector_id"] + if self._connector_ids is not None and cid not in self._connector_ids: + continue + if identity is not None: + if not action_allowed_for_identity_scopes( + connector_id=cid, + action=str(entry["action"]), + principal=identity.principal, + tenant_id=identity.tenant_id, + scopes=identity.scopes, + action_scope_map=scope_map, + default_mode=default_mode, + ): + continue + schema_desc = entry["input_schema"].get("description", "") + + security_lines = [] + if entry.get("requires_auth"): + security_lines.append("- Requires Auth: Yes") + scopes = entry.get("scopes") + if scopes: + security_lines.append(f"- Scopes: {', '.join(scopes)}") + rate_limit = entry.get("rate_limit") + if rate_limit: + security_lines.append(f"- Rate Limit: {rate_limit}") + if entry.get("deprecated"): + security_lines.append("- DEPRECATED: True") + + sec_block = "\n".join(security_lines) + if sec_block: + sec_block = f"\n\nSecurity & Limits:\n{sec_block}\n\n" + + tool_desc = ( + (f"{schema_desc}\n" if schema_desc else "") + + sec_block + + ( + f"Pass fields from inputSchema only; do not include an action field " + f"(it is injected from the tool name). " + f"Manifest contract v{MCP_MANIFEST_CONTRACT_VERSION}." + ) + ) tools.append( { - "name": f"{entry['connector_id']}.{entry['action']}", - "description": f"{entry['connector_id']} {entry['action']} connector action", + "name": f"{cid}.{entry['action']}", + "description": tool_desc, "input_schema": entry["input_schema"], + "output_schema": entry["output_schema"], } ) return tools - async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + def _ensure_identity( + self, + *, + identity: CallerIdentity | None, + meta: Mapping[str, Any] | None = None, + ) -> CallerIdentity | None: + if identity is not None: + return identity + request_identity = _streamable_http_identity_ctx.get() + if request_identity is not None: + return request_identity + return authenticate_mcp_request( + headers=_http_request_headers.get(), + meta=meta, + ) + + def _request_meta_from_context(self) -> Mapping[str, Any] | None: + try: + from mcp.server.lowlevel.server import request_ctx + + ctx = request_ctx.get() + except Exception: + return None + if ctx is None or ctx.meta is None: + return None + if hasattr(ctx.meta, "model_dump"): + dumped = ctx.meta.model_dump() # type: ignore[attr-defined] + if isinstance(dumped, dict): + return dumped + return None + if isinstance(ctx.meta, dict): + return ctx.meta + return None + + async def invoke_tool( + self, + name: str, + arguments: Dict[str, Any], + *, + identity: CallerIdentity | None = None, + ) -> Dict[str, Any]: + identity = self._ensure_identity(identity=identity) + try: + # Skip rate limiting if disabled + if os.environ.get("NW_RATE_LIMIT_DISABLED", "false").lower() not in ( + "true", + "1", + "yes", + ): + await global_rate_limiter.acquire() + except RateLimitExceeded as e: + raise ValueError(str(e)) + try: connector_id, action = name.split(".", 1) except ValueError: raise ValueError("Tool name must be in the form '.'") + if self._connector_ids is not None and connector_id not in self._connector_ids: + raise ValueError(f"Connector {connector_id!r} is not allowed on this MCP server.") + connector = self._factory.get_for_protocol(connector_id, "mcp") if connector is None: raise ValueError(f"Connector {connector_id!r} is not available via MCP.") - response = await connector.run(arguments) - return response.model_dump() + run_args = normalize_mcp_tool_arguments(connector, action, arguments) + enforce_authoritative_action(run_args, action) + run_args["action"] = action + + trace_id = run_args.get("trace_id") or str(uuid.uuid4()) + + # Proactively inject/clamp pagination parameters to prevent native token desync + # caused by the post-execution truncation guardrail + max_items = int(os.environ.get("NW_MCP_MAX_LIST_ITEMS", "50")) + meta = connector.sdk_action_metas().get(action) + clamped_params = {} + if meta and hasattr(meta.input_model, "model_fields"): + for page_param in ["page_size", "limit", "_count"]: + if page_param in meta.input_model.model_fields: + current_val = run_args.get(page_param) + if current_val is None: + run_args[page_param] = max_items + clamped_params[page_param] = max_items + else: + try: + val = int(current_val) + run_args[page_param] = min(val, max_items) + clamped_params[page_param] = run_args[page_param] + except (ValueError, TypeError): + pass + + try: + response = await connector.run( + run_args, + principal=identity.principal if identity else None, + tenant_id=identity.tenant_id if identity else None, + scopes=identity.scopes if identity else None, + ) + stream_completion_log(trace_id, True, connector_id=connector_id, action=action) + except Exception: + stream_completion_log(trace_id, False, connector_id=connector_id, action=action) + raise + + raw_response = response.model_dump() + + # Enforce MCP sampling guardrail + processed_payload, was_truncated, item_count, next_token = _process_response_payload( + raw_response, max_items + ) + + # Overwrite raw_response in place + raw_response.clear() + raw_response.update(processed_payload) + + # Add _system_pagination_used metadata (keeps old clients/MCP inspector working) + if clamped_params: + raw_response["_system_pagination_used"] = clamped_params + + # IMPORTANT: Inject metadata IN-BAND inside the "data" dictionary so client UIs + # (like Toolhive / Agent chat) that only render the `data` block will explicitly see it. + if "data" in raw_response and isinstance(raw_response["data"], dict): + pagination_meta: dict[str, Any] = {} + if clamped_params: + pagination_meta["coerced_parameters"] = clamped_params + pagination_meta["items_returned"] = item_count + if was_truncated: + pagination_meta["was_truncated_by_server"] = True + if next_token: + pagination_meta["next_page_token"] = next_token + # Prepend it visually for the LLM + raw_response["data"] = { + "_server_pagination_metadata": pagination_meta, + **raw_response["data"], + } + + # We also inject explicitly into the root if it doesn't have a data block + elif not isinstance(raw_response.get("data"), dict): + raw_response["_server_pagination_metadata"] = { + "coerced_parameters": clamped_params, + "items_returned": item_count, + "next_page_token": next_token, + } + + # Build dynamic system message + sys_msgs = [] + if clamped_params: + sys_msgs.append( + f"[System Pagination] Arguments coerced to safeguard limits: {json.dumps(clamped_params)}" + ) + + if item_count > 0: + count_msg = f"[System Guardrail] The connector returned {item_count} items." + if was_truncated: + count_msg += f" (truncated to {max_items} to preserve context)" + sys_msgs.append(count_msg) + + if next_token: + sys_msgs.append( + f"[System Pagination] nextPageToken available for next query: '{next_token}'" + ) + + if was_truncated and not next_token: + sys_msgs.append( + f"[System Guardrail WARNING] Data exceeded {max_items} items and was hard-truncated. " + "No native next page token was found! You MUST retry this query with an explicit " + f"`page_size` or limit parameter set to {max_items} to force the API to generate valid cursors." + ) + + if sys_msgs: + combined_sys_msgs = "\n".join(sys_msgs) + if raw_response.get("message"): + raw_response["message"] = f"{raw_response['message']}\n\n{combined_sys_msgs}" + else: + raw_response["message"] = combined_sys_msgs + + return raw_response + + def _setup_lowlevel_server(self) -> Any: + from mcp.server import Server as LowLevelServer + from mcp.types import Tool + + low = LowLevelServer(self._server_name) + + @low.list_tools() + async def handle_list_tools() -> list[Tool]: + meta = self._request_meta_from_context() + try: + identity = self._ensure_identity(identity=None, meta=meta) + except McpAuthError as exc: + logger.warning( + "MCP tools/list denied by authentication", + extra={ + "status_code": exc.status_code, + "error_code": exc.error_code, + }, + ) + raise RuntimeError(json.dumps(exc.to_payload())) from exc + if identity: + logger.info( + "MCP tools/list authorized", + extra={ + "principal": identity.principal, + "tenant_id": identity.tenant_id or "", + "auth_type": identity.auth_type, + }, + ) + out: list[Tool] = [] + for t in self._list_tools_impl(identity=identity): + kwargs: Dict[str, Any] = { + "name": t["name"], + "description": t["description"], + "inputSchema": t["input_schema"], + "outputSchema": t["output_schema"], + } + out.append(Tool(**kwargs)) + return out + + @low.call_tool() + async def handle_call_tool(tool_name: str, arguments: dict) -> dict: + meta = self._request_meta_from_context() + try: + identity = self._ensure_identity(identity=None, meta=meta) + except McpAuthError as exc: + logger.warning( + "MCP tools/call denied by authentication", + extra={ + "tool_name": tool_name, + "status_code": exc.status_code, + "error_code": exc.error_code, + }, + ) + return ConnectorResponse( + success=False, + data=None, + error_code=exc.error_code, + error_category=ErrorCategory.AUTH, + message=exc.detail, + trace_id=f"mcp-auth-{uuid.uuid4()}", + details=exc.to_payload(), + ).model_dump() + + if identity: + logger.info( + "MCP tools/call authorized", + extra={ + "tool_name": tool_name, + "principal": identity.principal, + "tenant_id": identity.tenant_id or "", + "auth_type": identity.auth_type, + }, + ) + return await self.invoke_tool(tool_name, arguments or {}, identity=identity) + + return low + + async def _run_stdio_async(self) -> None: + from mcp.server.stdio import stdio_server + from mcp.server import NotificationOptions + + low = self._setup_lowlevel_server() + + async with stdio_server() as (read_stream, write_stream): + await low.run( + read_stream, + write_stream, + low.create_initialization_options(notification_options=NotificationOptions()), + ) + + def run_stdio(self) -> None: + import anyio + + anyio.run(self._run_stdio_async) + + def _build_streamable_http_app(self, *, session_manager: Any, path: str) -> Any: + from contextlib import asynccontextmanager + + from starlette.applications import Starlette + from starlette.middleware.base import BaseHTTPMiddleware + from starlette.requests import Request + from starlette.responses import JSONResponse + from starlette.routing import Route + + @asynccontextmanager + async def lifespan(app: Starlette): + async with session_manager.run(): + yield + + class StreamableHttpAuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): # type: ignore[override] + if request.url.path != path: + return await call_next(request) + try: + identity = authenticate_mcp_request(headers=request.headers) + except McpAuthError as exc: + headers: Dict[str, str] = {} + if exc.www_authenticate: + headers["WWW-Authenticate"] = exc.www_authenticate + return JSONResponse( + status_code=exc.status_code, + content=exc.to_payload(), + headers=headers, + ) + + setattr(request.state, "nw_mcp_identity", identity) + token = _streamable_http_identity_ctx.set(identity) + try: + return await call_next(request) + finally: + _streamable_http_identity_ctx.reset(token) + + # Use a wrapper class to ensure Starlette treats this as an ASGI app + # without the automatic redirection logic of Mount(). + class _ASGIApp: + def __init__(self, handler): + self.handler = handler + + async def __call__(self, scope, receive, send): + headers = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in scope.get("headers", []) + } + token = _http_request_headers.set(headers) + try: + await self.handler(scope, receive, send) + finally: + _http_request_headers.reset(token) + + starlette_app = Starlette( + lifespan=lifespan, + routes=[ + Route( + path, + endpoint=_ASGIApp(session_manager.handle_request), + methods=["GET", "POST"], + ) + ], + ) + starlette_app.add_middleware(StreamableHttpAuthMiddleware) + return starlette_app + + async def _run_streamable_http_async(self) -> None: + import os + from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + import uvicorn + + host = os.getenv("NW_MCP_HOST", "0.0.0.0") + port = int(os.getenv("NW_MCP_PORT", "8081")) + path = os.getenv("NW_MCP_PATH", "/mcp") + + low = self._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + starlette_app = self._build_streamable_http_app(session_manager=session_manager, path=path) + + logger.info(f"Starting MCP streamable-http server on {host}:{port}{path}") + config = uvicorn.Config(starlette_app, host=host, port=port, log_level="info") + server = uvicorn.Server(config) + await server.serve() + + def run_streamable_http(self) -> None: + import anyio + + anyio.run(self._run_streamable_http_async) + + def run(self, transport: str = "stdio") -> None: + transport = transport.strip().lower() + if transport == "stdio": + self.run_stdio() + elif transport == "streamable-http": + self.run_streamable_http() + else: + raise ValueError(f"Unsupported MCP transport: {transport}") if __name__ == "__main__": # Simple demo runner that prints tool list and exits. server = McpServer() print(json.dumps(server.list_tools(), indent=2)) - diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index 7d27dcc..86e288e 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -1,43 +1,59 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import logging from typing import Any, Dict -from fastapi import Depends, FastAPI, HTTPException +import os +import sys +from pathlib import Path + +from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel, create_model from dotenv import load_dotenv -load_dotenv() # Load environmental variables from .env +# Production: set NW_REST_LOAD_DOTENV=false to rely on injected env only (no .env file). +if os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() not in ("0", "false", "no"): + # Do not override existing os.environ keys (pytest/conftest injects values first). + load_dotenv(override=False) from bindings.factory import ConnectorFactory -from connectors import auto_register -from connectors.manifest import build_manifest -from runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.connector_registry import auto_register +from node_wire_runtime.manifest import build_manifest +from node_wire_runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.ingress import enforce_authoritative_action, normalize_mcp_tool_arguments from opentelemetry import trace from opentelemetry.trace import Status, StatusCode from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor -import os -import sys -from pathlib import Path +from node_wire_runtime.rate_limit import global_rate_limiter, RateLimitExceeded + +from bindings.rest_api.rate_limit import InMemoryRateLimiter +from bindings.rest_api.auth import ( + RestAuthMiddleware, + get_request_identity_key, + get_rest_caller_identity, +) # Add project root to sys.path to allow importing from 'playground' package PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent if str(PROJECT_ROOT) not in sys.path: sys.path.append(str(PROJECT_ROOT)) -from playground.scenarios import router as scenarios_router +from playground.scenarios import router as scenarios_router # noqa: E402 + logger = logging.getLogger("bindings.rest_api") tracer = trace.get_tracer("bindings.rest_api") app = FastAPI(title="Node Wire - REST API") FastAPIInstrumentor.instrument_app(app) - -import os -from pathlib import Path +# Auth runs outermost (added last): protects /connectors/*, /playground/*, /scenarios/*; /health is public. +app.add_middleware(RestAuthMiddleware) # Include the professional scenarios orchestrator app.include_router(scenarios_router) @@ -48,6 +64,9 @@ app.mount("/playground", StaticFiles(directory=str(DEMO_DIR), html=True), name="playground") _factory: ConnectorFactory | None = None +_rate_limiter: InMemoryRateLimiter | None = None +_rate_limiter_cfg: tuple[int, int] | None = None + def get_factory() -> ConnectorFactory: global _factory @@ -58,10 +77,20 @@ def get_factory() -> ConnectorFactory: return _factory +async def check_rate_limit() -> None: + try: + # Skip rate limiting if disabled + if os.environ.get("NW_RATE_LIMIT_DISABLED", "false").lower() not in ("true", "1", "yes"): + await global_rate_limiter.acquire() + except RateLimitExceeded as exc: + raise HTTPException(status_code=429, detail=str(exc)) + + @app.get("/health", tags=["system"]) async def health() -> Dict[str, str]: return {"status": "ok"} + def _http_status_for_category(category: ErrorCategory | None) -> int: if category is None: return 200 @@ -73,10 +102,37 @@ def _http_status_for_category(category: ErrorCategory | None) -> int: return 503 return 500 -def _make_endpoint(cid: str, act: str) -> Any: + +def _truthy(value: str | None) -> bool: + if value is None: + return False + return value.strip().lower() in ("1", "true", "yes", "on") + + +def _rate_limit_enabled() -> bool: + return _truthy(os.environ.get("NW_REST_RATE_LIMIT_ENABLED")) + + +def _get_rate_limiter() -> InMemoryRateLimiter: + global _rate_limiter, _rate_limiter_cfg + max_requests = int(os.environ.get("NW_REST_RATE_LIMIT_MAX_REQUESTS", "120")) + window_seconds = int(os.environ.get("NW_REST_RATE_LIMIT_WINDOW_SECONDS", "60")) + cfg = (max_requests, window_seconds) + if _rate_limiter is None or _rate_limiter_cfg != cfg: + _rate_limiter = InMemoryRateLimiter( + max_requests=max_requests, + window_seconds=window_seconds, + ) + _rate_limiter_cfg = cfg + return _rate_limiter + + +def _make_endpoint(cid: str, act: str) -> Any: async def endpoint( + request: Request, payload: Dict[str, Any], factory_dep: ConnectorFactory = Depends(get_factory), + _: None = Depends(check_rate_limit), ) -> JSONResponse: """ Concrete endpoint for a specific connector/action, e.g. @@ -85,13 +141,37 @@ async def endpoint( span = trace.get_current_span() span.set_attribute("connector.id", cid) span.set_attribute("connector.action", act) + if _rate_limit_enabled(): + limiter = _get_rate_limiter() + identity_key = get_request_identity_key(request) + rate_key = f"{cid}:{act}:{identity_key}" + result = limiter.consume(rate_key) + if not result.allowed: + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded"}, + headers={"Retry-After": str(result.retry_after_seconds)}, + ) connector = factory_dep.get_for_protocol(cid, "rest", action=act) if connector is None: raise HTTPException(status_code=404, detail="Connector not available for REST") + run_payload = dict(payload) + run_payload = normalize_mcp_tool_arguments(connector, act, run_payload) + try: + enforce_authoritative_action(run_payload, act) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + run_payload["action"] = act # Let the runtime (Layer A) perform full schema validation. # Any validation errors will be mapped into ConnectorResponse. - response: ConnectorResponse = await connector.run(payload) + rest_id = get_rest_caller_identity(request) + response: ConnectorResponse = await connector.run( + run_payload, + principal=rest_id.principal if rest_id else None, + tenant_id=rest_id.tenant_id if rest_id else None, + scopes=rest_id.scopes if rest_id else None, + ) status = _http_status_for_category(response.error_category) if not response.success: @@ -105,10 +185,12 @@ async def endpoint( status_code=status, content=response.model_dump(), ) + return endpoint + def _build_dynamic_routes() -> None: - factory = get_factory() + factory = get_factory() connectors = factory.list_for_protocol("rest") manifest = build_manifest(connectors) @@ -120,7 +202,7 @@ def _build_dynamic_routes() -> None: # For REST, let the runtime perform full Pydantic validation. # We accept an arbitrary JSON object as the payload and forward it # directly to connector.run(...). - route_path = f"/connectors/{connector_id}/{action}" + route_path = f"/connectors/{connector_id}/{action}" app.post(route_path, name=f"{connector_id}_{action}")(_make_endpoint(connector_id, action)) diff --git a/src/bindings/rest_api/auth.py b/src/bindings/rest_api/auth.py new file mode 100644 index 0000000..84f9362 --- /dev/null +++ b/src/bindings/rest_api/auth.py @@ -0,0 +1,165 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +REST API authentication (enterprise default: required API key or JWT). + +Environment: + NW_REST_API_KEY — shared secret; send as ``Authorization: Bearer `` or ``X-API-Key: ``. + NW_REST_JWT_SECRET — optional HS256 secret; if set, Bearer tokens with three segments are verified as JWTs. + NW_REST_AUTH_DISABLED — if ``true``/``1``/``yes``, skip auth (local dev only; do not use in production). + +Public (unauthenticated): ``GET /health`` only. OpenAPI UI requires auth. + +After successful auth, normalized caller identity (principal / tenant_id / scopes) is stored on +``request.state.nw_rest_caller_identity`` and forwarded to ``connector.run`` for policy hooks. +""" + +from __future__ import annotations + +import hashlib +import os +from typing import Callable + +import jwt +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from node_wire_runtime.caller_identity import ( + CallerIdentity, + build_caller_identity, + parse_api_key_scopes_from_env, +) + + +REST_CALLER_STATE_KEY = "nw_rest_caller_identity" + + +def get_rest_caller_identity(request: Request) -> CallerIdentity | None: + """Return JWT/API-key caller identity attached by middleware, if any.""" + return getattr(request.state, REST_CALLER_STATE_KEY, None) + + +def _truthy(val: str | None) -> bool: + if val is None: + return False + return val.strip().lower() in ("1", "true", "yes", "on") + + +def _is_public_path(path: str) -> bool: + p = path.rstrip("/") or "/" + return p == "/health" + + +def _extract_bearer_or_api_key(request: Request) -> str | None: + auth = request.headers.get("authorization") + if auth: + auth_val = auth.strip() + if auth_val.lower().startswith("bearer "): + return auth_val[7:].strip() + x = request.headers.get("x-api-key") + if x and x.strip(): + return x.strip() + return None + + +def get_request_identity_key(request: Request) -> str: + """ + Return a stable, non-sensitive identity key for request-level controls. + + Preference order: + 1) Auth token/API key (fingerprinted, never returned raw) + 2) X-Forwarded-For first hop + 3) request.client.host + """ + token = _extract_bearer_or_api_key(request) + if token: + digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16] + return f"token:{digest}" + forwarded = (request.headers.get("x-forwarded-for") or "").split(",", maxsplit=1)[0].strip() + if forwarded: + return f"ip:{forwarded}" + client_host = request.client.host if request.client else "unknown" + return f"ip:{client_host}" + + +def verify_rest_token_and_identity( + token: str, + *, + api_key: str | None, + jwt_secret: str | None, +) -> tuple[bool, CallerIdentity | None]: + """ + Validate REST bearer/API-key token and build caller identity (same shape as MCP). + + Shared API key scopes come from ``NW_REST_API_KEY_SCOPES`` (JSON array or + comma/space-separated). Empty means no scopes; use explicit ``*`` only when + intended (JWT-style superuser for the policy hook). + """ + if api_key and token == api_key: + scopes = list(parse_api_key_scopes_from_env("NW_REST_API_KEY_SCOPES")) + ident = build_caller_identity( + {"sub": "api-key-user", "tenant_id": None, "scopes": scopes}, + auth_type="rest_api_key", + ) + return True, ident + + if jwt_secret and token.count(".") == 2: + try: + claims = jwt.decode(token, jwt_secret, algorithms=["HS256"]) + except jwt.PyJWTError: + return False, None + return True, build_caller_identity(claims, auth_type="jwt") + + return False, None + + +class RestAuthMiddleware(BaseHTTPMiddleware): + """Require API key or valid JWT for all routes except public paths.""" + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + if request.method == "OPTIONS": + return await call_next(request) + + path = request.url.path + if _is_public_path(path): + return await call_next(request) + + if _truthy(os.environ.get("NW_REST_AUTH_DISABLED")): + return await call_next(request) + + api_key = os.environ.get("NW_REST_API_KEY") + jwt_secret = os.environ.get("NW_REST_JWT_SECRET") + + if not api_key and not jwt_secret: + return JSONResponse( + status_code=503, + content={ + "detail": ( + "REST API authentication is not configured. Set NW_REST_API_KEY " + "(and optionally NW_REST_JWT_SECRET), or set NW_REST_AUTH_DISABLED=true " + "for local development only." + ) + }, + ) + + token = _extract_bearer_or_api_key(request) + if not token: + return JSONResponse( + status_code=401, + content={"detail": "Authentication required"}, + headers={"WWW-Authenticate": 'Bearer realm="node-wire"'}, + ) + + ok, identity = verify_rest_token_and_identity(token, api_key=api_key, jwt_secret=jwt_secret) + if not ok or identity is None: + return JSONResponse( + status_code=403, + content={"detail": "Invalid API key or token"}, + headers={"WWW-Authenticate": 'Bearer realm="node-wire"'}, + ) + + setattr(request.state, REST_CALLER_STATE_KEY, identity) + return await call_next(request) diff --git a/src/bindings/rest_api/rate_limit.py b/src/bindings/rest_api/rate_limit.py new file mode 100644 index 0000000..eee39a0 --- /dev/null +++ b/src/bindings/rest_api/rate_limit.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import math +import threading +from collections import defaultdict, deque +from dataclasses import dataclass +from time import monotonic + + +@dataclass(frozen=True) +class RateLimitResult: + allowed: bool + retry_after_seconds: int = 0 + + +class InMemoryRateLimiter: + """ + Sliding-window in-memory limiter. + + This is intentionally simple for single-process REST deployments. + """ + + def __init__(self, *, max_requests: int, window_seconds: int) -> None: + if max_requests <= 0: + raise ValueError("max_requests must be > 0") + if window_seconds <= 0: + raise ValueError("window_seconds must be > 0") + self._max_requests = max_requests + self._window_seconds = float(window_seconds) + self._buckets: dict[str, deque[float]] = defaultdict(deque) + self._lock = threading.Lock() + + def consume(self, key: str) -> RateLimitResult: + now = monotonic() + cutoff = now - self._window_seconds + with self._lock: + bucket = self._buckets[key] + while bucket and bucket[0] <= cutoff: + bucket.popleft() + if len(bucket) >= self._max_requests: + retry_after = max( + 1, + int(math.ceil((bucket[0] + self._window_seconds) - now)), + ) + return RateLimitResult(allowed=False, retry_after_seconds=retry_after) + bucket.append(now) + return RateLimitResult(allowed=True) diff --git a/src/bindings_entrypoint.py b/src/bindings_entrypoint.py index a7c5e17..29680ae 100644 --- a/src/bindings_entrypoint.py +++ b/src/bindings_entrypoint.py @@ -1,13 +1,23 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import logging import os import uvicorn +from dotenv import load_dotenv from bindings.rest_api.app import app as rest_app from bindings.mcp_server.server import McpServer -from runtime.observability import init_observability +from node_wire_runtime.observability import init_observability + +# Match ``bindings.rest_api.app``: honor ``NW_REST_LOAD_DOTENV`` and never override +# keys already set (pytest/conftest sets ``NW_REST_LOAD_DOTENV=false`` before imports). +if os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() not in ("0", "false", "no"): + load_dotenv(override=False) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("bindings.entrypoint") @@ -34,7 +44,10 @@ def main() -> None: # For the POC we just start a simple process that can be interacted # with manually or via a thin wrapper; a full JSON-RPC loop is out of scope. server = McpServer() - logger.info("MCP server ready (list_tools available)", extra={"tool_count": len(server.list_tools())}) + logger.info( + "MCP server ready (list_tools available)", + extra={"tool_count": len(server.list_tools())}, + ) import time while True: @@ -45,4 +58,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/connectors/__init__.py b/src/connectors/__init__.py deleted file mode 100644 index f9c7b0f..0000000 --- a/src/connectors/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -""" -Node Wire - Layer B: System Adapters. - -Each connector lives in its own subpackage and follows the three-file pattern: - - connector_name/ - schema.py - logic.py - registration.py - -Registration modules are auto-discovered so they can register system-specific -exceptions with the global ErrorMapper in Layer A. -""" - -from importlib import import_module -from pkgutil import iter_modules -from typing import Iterable, List - - -def auto_register() -> List[str]: - """ - Import all `registration` modules in connector subpackages. - - Returns the list of successfully imported module names. This should be - called once at process startup (e.g. by Layer C bindings) to ensure all - connector-specific error mappings are registered. - """ - imported: List[str] = [] - package_name = __name__ - - for module_info in iter_modules(__path__, prefix=f"{package_name}."): - # We only care about subpackages; each is expected to expose registration.py - if not module_info.ispkg: - continue - - registration_module = f"{module_info.name}.registration" - try: - import_module(registration_module) - imported.append(registration_module) - except ModuleNotFoundError: - # Connector without a registration module; skip silently. - continue - - return imported - - -__all__ = ["auto_register"] - diff --git a/src/connectors/fhir_epic/logic.py b/src/connectors/fhir_epic/logic.py deleted file mode 100644 index e9cc615..0000000 --- a/src/connectors/fhir_epic/logic.py +++ /dev/null @@ -1,567 +0,0 @@ -from __future__ import annotations - -import asyncio -import codecs -import json -import logging -import uuid -from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional - -import httpx -import jwt - -from runtime import BaseConnector, SecretProvider - -from .schema import ( - FhirDocumentReferenceCreateInput, - FhirDocumentReferenceCreateOutput, - FhirDocumentReferenceSearchInput, - FhirDocumentReferenceSearchOutput, - FhirEncounterSearchInput, - FhirEncounterSearchOutput, - FhirPatientReadInput, - FhirPatientReadOutput, - FhirPatientSearchInput, - FhirPatientSearchOutput, -) - -logger = logging.getLogger("connectors.fhir_epic") - - -class _FhirAction(BaseConnector[Any, Any]): - """ - Lightweight BaseConnector that delegates execution to a FhirEpicConnector - instance method. One of these is created per action so that the manifest - and REST router can discover each action's schema and route automatically. - """ - - connector_id = "fhir_epic" - - def __init__( - self, - action: str, - input_model: type, - output_model: type, - handler: Callable, - *, - secret_provider: Optional[SecretProvider] = None, - ) -> None: - super().__init__(input_model, output_model, secret_provider=secret_provider) - self.action = action # instance attribute, overrides absent class-level action - self._handler = handler - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - return await self._handler(params, trace_id=trace_id) - - -class FhirEpicConnector: - """ - Single FHIR/Epic connector. - - ``connector_id = "fhir_epic"``. All authentication helpers and action - implementations live here. The factory registers ONE instance of this - class; ``list_actions()`` and ``get_action()`` are used by the factory to - expose each action to the manifest and REST router. - - Supported actions: - • read_patient — fetch a single Patient by ID or name search - • search_patients — fetch multiple Patients by list of IDs or name search - • search_encounter - • create_document_reference - • search_document_reference - - Name-based search parameters (``given_name``, ``family_name``, ``name``, - ``birthdate``) are prioritised over the raw ``search_params`` dict and are - normalised (stripped, lowercased for ``name`` token search). - """ - - connector_id = "fhir_epic" - - def __init__(self, *, secret_provider: SecretProvider) -> None: - self._secret_provider = secret_provider - - self._actions: Dict[str, _FhirAction] = { - "read_patient": _FhirAction( - "read_patient", FhirPatientReadInput, FhirPatientReadOutput, - self._read_patient, secret_provider=secret_provider, - ), - "search_patients": _FhirAction( - "search_patients", FhirPatientSearchInput, FhirPatientSearchOutput, - self._search_patients, secret_provider=secret_provider, - ), - "search_encounter": _FhirAction( - "search_encounter", FhirEncounterSearchInput, FhirEncounterSearchOutput, - self._search_encounter, secret_provider=secret_provider, - ), - "create_document_reference": _FhirAction( - "create_document_reference", FhirDocumentReferenceCreateInput, FhirDocumentReferenceCreateOutput, - self._create_document_reference, secret_provider=secret_provider, - ), - "search_document_reference": _FhirAction( - "search_document_reference", FhirDocumentReferenceSearchInput, FhirDocumentReferenceSearchOutput, - self._search_document_reference, secret_provider=secret_provider, - ), - } - - # ------------------------------------------------------------------ - # Action discovery — consumed by ConnectorFactory - # ------------------------------------------------------------------ - - def list_actions(self) -> List[_FhirAction]: - """Return all registered action connectors (used by list_for_protocol).""" - return list(self._actions.values()) - - def get_action(self, name: str) -> Optional[_FhirAction]: - """Return the action connector for the given action name.""" - return self._actions.get(name) - - # ------------------------------------------------------------------ - # Shared authentication helpers - # ------------------------------------------------------------------ - - def _get_base_url(self) -> str: - return self._secret_provider.get_secret("epic_fhir_base_url").rstrip("/") - - async def _get_auth_header(self) -> Dict[str, str]: - """ - Obtain an access token via Epic's SMART Backend Services (private_key_jwt) - and return ready-to-use request headers. - - Algorithm: RS384. Token lifetime: 5 minutes (Epic maximum). - Reference: https://fhir.epic.com/Documentation?docId=oauth2tutorial§ion=cloud-based-app - """ - headers = { - "Content-Type": "application/fhir+json", - "Accept": "application/fhir+json", - } - - private_key_str = self._secret_provider.get_secret("epic_private_key") - kid = self._secret_provider.get_secret("epic_kid") - client_id = self._secret_provider.get_secret("epic_client_id") - token_url = self._secret_provider.get_secret("epic_token_url") - - # Environment variables sometimes store newlines as escape sequences. - private_key_pem = codecs.decode(private_key_str, "unicode_escape") - - now = int(datetime.now(tz=timezone.utc).timestamp()) - jwt_token = jwt.encode( - { - "iss": client_id, "sub": client_id, "aud": token_url, - "jti": str(uuid.uuid4()), "iat": now, "nbf": now, "exp": now + 300, - }, - private_key_pem, - algorithm="RS384", - headers={"alg": "RS384", "typ": "JWT", "kid": kid}, - ) - - logger.debug("Exchanging JWT for Epic access token", extra={"token_url": token_url}) - - async with httpx.AsyncClient() as client: - token_response = await client.post( - token_url, - data={ - "grant_type": "client_credentials", - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": jwt_token, - }, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ) - if token_response.status_code != 200: - logger.error( - "OAuth token exchange failed | status=%s | body=%s", - token_response.status_code, token_response.text, - ) - token_response.raise_for_status() - token_data = token_response.json() - - access_token = token_data.get("access_token") - if not access_token: - raise ValueError("Epic token response did not contain an access_token") - - headers["Authorization"] = f"Bearer {access_token}" - return headers - - # ------------------------------------------------------------------ - # Internal name-field helpers - # ------------------------------------------------------------------ - - @staticmethod - def _build_name_search_params( - given_name: Optional[str], - family_name: Optional[str], - name: Optional[str], - birthdate: Optional[str], - extra: Optional[Dict[str, str]] = None, - ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields. - - Priority: given_name/family_name > name > (nothing). - The ``extra`` dict (raw search_params) is merged at lowest priority so - callers can pass additional filters without overriding name fields. - """ - params: Dict[str, str] = dict(extra or {}) - - # Normalize: strip whitespace; FHIR name search is typically case-insensitive - # on compliant servers but we preserve original case per FHIR spec. - if given_name and given_name.strip(): - params["given"] = given_name.strip() - if family_name and family_name.strip(): - params["family"] = family_name.strip() - if name and name.strip() and "given" not in params and "family" not in params: - # Only fall back to the combined 'name' token when no split fields given - params["name"] = name.strip() - if birthdate and birthdate.strip(): - params["birthdate"] = birthdate.strip() - - return params - - @staticmethod - def _build_encounter_search_params( - patient_id: Optional[str], - status: Optional[str], - date: Optional[str], - extra: Optional[Dict[str, str]] = None, - ) -> Dict[str, str]: - """Build a FHIR search params dict for Encounter from explicit fields.""" - params: Dict[str, str] = dict(extra or {}) - - if patient_id and patient_id.strip(): - params["patient"] = patient_id.strip() - if status and status.strip(): - params["status"] = status.strip() - if date and date.strip(): - params["date"] = date.strip() - - return params - - # ------------------------------------------------------------------ - # Action: read_patient - # ------------------------------------------------------------------ - - async def _read_patient( - self, params: FhirPatientReadInput, *, trace_id: str - ) -> FhirPatientReadOutput: - base_url = self._get_base_url() - auth_header = await self._get_auth_header() - - if params.resource_id: - url = f"{base_url}/Patient/{params.resource_id}" - query_params: Optional[Dict[str, str]] = None - logger.info("FHIR Patient read by ID", extra={"trace_id": trace_id, "resource_id": params.resource_id}) - elif params.given_name or params.family_name or params.name: - url = f"{base_url}/Patient" - query_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, - ) - logger.info("FHIR Patient read by name fields", extra={"trace_id": trace_id, "query_params": query_params}) - elif params.search_params: - url = f"{base_url}/Patient" - query_params = params.search_params - logger.info("FHIR Patient read by search", extra={"trace_id": trace_id, "search_params": params.search_params}) - else: - raise ValueError( - "Provide resource_id, or name fields (given_name/family_name/name), " - "or search_params" - ) - - try: - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=auth_header, params=query_params, timeout=30.0) - response.raise_for_status() - except Exception as exc: - logger.error("FHIR Patient read failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) - raise - - data = response.json() - if data.get("resourceType") == "Bundle": - if data.get("entry"): - resource = data["entry"][0].get("resource", {}) - else: - raise ValueError("No patients found in search results") - else: - resource = data - - logger.info("FHIR Patient read completed", extra={"trace_id": trace_id, "status_code": response.status_code}) - return FhirPatientReadOutput(resource=resource) - - # ------------------------------------------------------------------ - # Action: search_patients (multi-ID fan-out OR name search) - # ------------------------------------------------------------------ - - async def _search_patients( - self, params: FhirPatientSearchInput, *, trace_id: str - ) -> FhirPatientSearchOutput: - base_url = self._get_base_url() - auth_header = await self._get_auth_header() - - # ---- Mode 1: Multi-ID fan-out ---- - if params.resource_ids: - ids = [rid.strip() for rid in params.resource_ids if rid.strip()] - if not ids: - raise ValueError("resource_ids list is empty") - - logger.info( - "FHIR Patient multi-ID lookup | count=%s", - len(ids), - extra={"trace_id": trace_id, "resource_ids": ids}, - ) - - async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: - """Return (rid, resource_or_None, error_or_None).""" - try: - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{base_url}/Patient/{rid}", - headers=auth_header, - timeout=30.0, - ) - resp.raise_for_status() - return rid, resp.json(), None - except Exception as exc: - logger.warning( - "FHIR Patient fetch failed | resource_id=%s | error=%s", - rid, str(exc), - extra={"trace_id": trace_id}, - ) - return rid, None, str(exc) - - results = await asyncio.gather(*[_fetch_one(rid) for rid in ids]) - - resources: List[Dict[str, Any]] = [] - errors: List[Dict[str, Any]] = [] - for rid, resource, error in results: - if resource is not None: - resources.append(resource) - else: - errors.append({"resource_id": rid, "error": error or "Unknown error"}) - - logger.info( - "FHIR Patient multi-ID lookup completed | found=%s | errors=%s", - len(resources), len(errors), - extra={"trace_id": trace_id}, - ) - return FhirPatientSearchOutput(resources=resources, total=len(resources), errors=errors) - - # ---- Mode 2: Name-based search (returns Bundle) ---- - name_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, - ) - if not name_params: - raise ValueError( - "Provide resource_ids for multi-ID lookup, or at least one of " - "given_name / family_name / name / birthdate / search_params for name-based search" - ) - - logger.info( - "FHIR Patient name search | params=%s", - name_params, - extra={"trace_id": trace_id}, - ) - - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{base_url}/Patient", - headers=auth_header, - params=name_params, - timeout=30.0, - ) - response.raise_for_status() - except httpx.HTTPStatusError as exc: - logger.error( - "FHIR Patient name search failed | status=%s | body=%s", - exc.response.status_code, exc.response.text, - extra={"trace_id": trace_id}, - ) - raise - except Exception as exc: - logger.error( - "FHIR Patient name search failed | error=%s: %s", - type(exc).__name__, str(exc), - extra={"trace_id": trace_id}, - ) - raise - - data = response.json() - resources = [] - total = data.get("total") - if data.get("resourceType") == "Bundle" and data.get("entry"): - resources = [e["resource"] for e in data["entry"] if "resource" in e] - - logger.info( - "FHIR Patient name search completed | found=%s | total=%s", - len(resources), total, - extra={"trace_id": trace_id}, - ) - return FhirPatientSearchOutput(resources=resources, total=total) - - # ------------------------------------------------------------------ - # Action: search_encounter - # ------------------------------------------------------------------ - - async def _search_encounter( - self, params: FhirEncounterSearchInput, *, trace_id: str - ) -> FhirEncounterSearchOutput: - base_url = self._get_base_url() - auth_header = await self._get_auth_header() - - if params.patient_id or params.status or params.date: - query_params = self._build_encounter_search_params( - params.patient_id, params.status, params.date, params.search_params - ) - logger.info("FHIR Encounter search by explicit fields", extra={"trace_id": trace_id, "query_params": query_params}) - elif params.search_params: - query_params = params.search_params - logger.info("FHIR Encounter search by raw params", extra={"trace_id": trace_id, "search_params": params.search_params}) - else: - raise ValueError("Provide at least patient_id, status, date OR search_params") - - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{base_url}/Encounter", headers=auth_header, params=query_params, timeout=30.0, - ) - response.raise_for_status() - except httpx.HTTPStatusError as exc: - logger.error("FHIR Encounter search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) - raise - except Exception as exc: - logger.error("FHIR Encounter search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) - raise - - data = response.json() - resources: list[Dict[str, Any]] = [] - total = data.get("total") - if data.get("resourceType") == "Bundle" and data.get("entry"): - resources = [e["resource"] for e in data["entry"] if "resource" in e] - - logger.info("FHIR Encounter search completed | found=%s", len(resources), extra={"trace_id": trace_id}) - return FhirEncounterSearchOutput(resources=resources, total=total) - - # ------------------------------------------------------------------ - # Action: create_document_reference - # ------------------------------------------------------------------ - - async def _create_document_reference( - self, params: FhirDocumentReferenceCreateInput, *, trace_id: str - ) -> FhirDocumentReferenceCreateOutput: - base_url = self._get_base_url() - auth_header = await self._get_auth_header() - - doc_ref: Dict[str, Any] = { - "resourceType": "DocumentReference", - "identifier": params.identifier, - "status": params.status, - "type": params.type, - "subject": {"reference": params.subject}, - "date": datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), - "content": [{"attachment": {"contentType": params.content_type or "text/plain", "data": params.data}}], - } - if params.category: - doc_ref["category"] = params.category - if params.author: - doc_ref["author"] = params.author - if params.description: - doc_ref["description"] = params.description - if params.context: - doc_ref["context"] = params.context - if params.additional_fields: - doc_ref.update(params.additional_fields) - - logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, timeout=30.0, - ) - response.raise_for_status() - except httpx.HTTPStatusError as exc: - try: - resp_json = exc.response.json() - diagnostics = [] - if resp_json.get("resourceType") == "OperationOutcome": - for issue in resp_json.get("issue", []): - if "diagnostics" in issue: - diagnostics.append(issue["diagnostics"]) - error_detail = " | ".join(diagnostics) if diagnostics else exc.response.text - except Exception: - error_detail = exc.response.text - - logger.error( - "FHIR DocumentReference create failed | status=%s | epic_error=%s | sent_payload=%s", - exc.response.status_code, error_detail, json.dumps(doc_ref), - extra={"trace_id": trace_id}, - ) - # Raise a more descriptive error for the API to catch - raise ValueError(f"Epic Error: {error_detail}") from exc - except Exception as exc: - logger.error("FHIR DocumentReference create failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) - raise - - resource_id: Optional[str] = None - body: Dict[str, Any] = {} - - location = response.headers.get("Location", "") - if location: - history_marker = location.find("/_history/") - resource_id = location[:history_marker].split("/")[-1] if history_marker != -1 else location.split("/")[-1] - - if not resource_id: - content_length = response.headers.get("content-length", "0") - if content_length != "0" and response.content: - try: - body = response.json() - resource_id = body.get("id") - except Exception: - pass - - if not resource_id: - raise ValueError( - f"Could not extract resource ID from DocumentReference create response. " - f"Status: {response.status_code}, Location: {location!r}, Body: {response.text[:200]!r}" - ) - - logger.info("FHIR DocumentReference create completed | resource_id=%s", resource_id, extra={"trace_id": trace_id}) - return FhirDocumentReferenceCreateOutput(resource_id=resource_id, resource=body if body else None) - - # ------------------------------------------------------------------ - # Action: search_document_reference - # ------------------------------------------------------------------ - - async def _search_document_reference( - self, params: FhirDocumentReferenceSearchInput, *, trace_id: str - ) -> FhirDocumentReferenceSearchOutput: - base_url = self._get_base_url() - auth_header = await self._get_auth_header() - - logger.info("FHIR DocumentReference search", extra={"trace_id": trace_id, "search_params": params.search_params}) - - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, timeout=30.0, - ) - response.raise_for_status() - except httpx.HTTPStatusError as exc: - logger.error("FHIR DocumentReference search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) - raise - except Exception as exc: - logger.error("FHIR DocumentReference search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) - raise - - data = response.json() - resources: list[Dict[str, Any]] = [] - total = data.get("total") - if data.get("resourceType") == "Bundle" and data.get("entry"): - resources = [e["resource"] for e in data["entry"] if "resource" in e] - - logger.info( - "FHIR DocumentReference search completed | found=%s", - len(resources), - extra={"trace_id": trace_id}, - ) - return FhirDocumentReferenceSearchOutput(resources=resources, total=total) \ No newline at end of file diff --git a/src/connectors/google_drive/__init__.py b/src/connectors/google_drive/__init__.py deleted file mode 100644 index 4ab7ba3..0000000 --- a/src/connectors/google_drive/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Connector subpackage: google_drive diff --git a/src/connectors/google_drive/logic.py b/src/connectors/google_drive/logic.py deleted file mode 100644 index a4b2b3d..0000000 --- a/src/connectors/google_drive/logic.py +++ /dev/null @@ -1,256 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import base64 -import logging -from typing import Any, Union - -from google.oauth2 import service_account -from googleapiclient.discovery import build -from googleapiclient.errors import HttpError -from googleapiclient.http import MediaInMemoryUpload - -from runtime import BaseConnector - -from .exceptions import ( - GoogleDriveAuthError, - GoogleDriveBusinessError, - GoogleDriveFatalError, - GoogleDriveRateLimitError, -) -from .schema import ( - FilesCreateOperation, - FilesDeleteOperation, - FilesGetOperation, - FilesListOperation, - FilesUpdateOperation, - FilesUploadOperation, - GoogleDriveOperationInput, - GoogleDriveOperationOutput, - PermissionsCreateOperation, -) - -logger = logging.getLogger("connectors.google_drive") - -# Performant default for files.list so the API returns only needed metadata. -DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" - -_OperationUnion = Union[ - FilesCreateOperation, - FilesListOperation, - PermissionsCreateOperation, - FilesGetOperation, - FilesUpdateOperation, - FilesUploadOperation, - FilesDeleteOperation, -] - - -class GoogleDriveConnector( - BaseConnector[GoogleDriveOperationInput, GoogleDriveOperationOutput], -): - """ - Google Drive connector for files and permissions operations. - """ - - connector_id = "google_drive" - action = "execute" - - async def internal_execute( - self, params: GoogleDriveOperationInput, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Executing Google Drive operation", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "action_type": params.root.action, - }, - ) - - drive = self._build_client() - - try: - response = await asyncio.to_thread( - self._dispatch_to_sdk, drive, params.root - ) - return GoogleDriveOperationOutput( - raw=response, - description=f"Successfully executed {params.root.action}", - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - except Exception as exc: # noqa: BLE001 - logger.error( - "Unexpected SDK failure", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - raise GoogleDriveFatalError(str(exc)) from exc - - def _dispatch_to_sdk( - self, drive: Any, params: _OperationUnion - ) -> dict[str, Any]: - """Routes the strictly validated model to the correct SDK method.""" - if params.action == "files.create": - body = { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - } - body = {k: v for k, v in body.items() if v is not None} - return drive.files().create(body=body, fields='id, name, webViewLink', - supportsAllDrives=True, - ).execute() - - if params.action == "files.list": - fields = params.fields or DEFAULT_LIST_FIELDS - result = drive.files().list( - pageSize=params.page_size, - q=params.query, - fields=fields, - supportsAllDrives=True, - includeItemsFromAllDrives=True, - ).execute() - return result - - if params.action == "permissions.create": - body = { - "role": params.role, - "type": params.type, - "emailAddress": params.email_address, - } - return drive.permissions().create( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - ).execute() - - if params.action == "files.get": - fields = params.fields or "id,name,mimeType,parents" - return ( - drive.files() - .get( - fileId=params.file_id, - fields=fields, - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.update": - body: dict[str, Any] = {} - if params.name is not None: - body["name"] = params.name - if params.mime_type is not None: - body["mimeType"] = params.mime_type - - kwargs: dict[str, Any] = {} - if params.add_parents: - kwargs["addParents"] = ",".join(params.add_parents) - if params.remove_parents: - kwargs["removeParents"] = ",".join(params.remove_parents) - - return ( - drive.files() - .update( - fileId=params.file_id, - body=body, - **kwargs, - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.upload": - body = { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - } - body = {k: v for k, v in body.items() if v is not None} - - if params.content_base64 is not None: - media_bytes = base64.b64decode(params.content_base64) - elif params.content is not None: - media_bytes = params.content.encode("utf-8") - else: - raise ValueError("Either content or content_base64 must be provided for files.upload") - - media = MediaInMemoryUpload( - media_bytes, - mimetype=params.mime_type, - resumable=False, - ) - - return ( - drive.files() - .create( - body=body, - media_body=media, - fields='id, name, webViewLink', - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.delete": - drive.files().update(fileId=params.file_id, - body={'trashed': True}, - supportsAllDrives=True, - ).execute() - return {"file_id": params.file_id, "status": "deleted"} - - raise ValueError(f"Unmapped action router: {params.action}") - - def _translate_and_raise_http_error(self, exc: HttpError) -> None: - """Translates Google's dynamic HTTP errors into static taxonomy classes.""" - status = exc.resp.status - content_str = str(getattr(exc, "content", "") or "") - - if status in (401, 403): - if "quotaExceeded" in content_str or "rateLimitExceeded" in content_str: - raise GoogleDriveRateLimitError( - "Google Drive quota/rate limit exceeded" - ) from exc - raise GoogleDriveAuthError( - "Authentication or permissions failure" - ) from exc - - if status == 429 or status >= 500: - raise GoogleDriveRateLimitError( - "Upstream service unavailable or rate limited" - ) from exc - - if status in (400, 404, 409): - reason = getattr(exc, "reason", str(exc)) - raise GoogleDriveBusinessError( - f"Business logic failure: {reason}" - ) from exc - - raise GoogleDriveFatalError( - f"Unhandled HttpError status {status}" - ) from exc - - def _build_client(self) -> Any: - raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") - try: - info = json.loads(raw_sa) - creds = service_account.Credentials.from_service_account_info( - info, - scopes=["https://www.googleapis.com/auth/drive"], - ) - except json.JSONDecodeError: - # Fallback: treat the secret as a file path - creds = service_account.Credentials.from_service_account_file( - raw_sa.strip(), - scopes=["https://www.googleapis.com/auth/drive"], - ) - return build("drive", "v3", credentials=creds) diff --git a/src/connectors/http_generic/__init__.py b/src/connectors/http_generic/__init__.py deleted file mode 100644 index f65872d..0000000 --- a/src/connectors/http_generic/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Connector subpackage: http_generic diff --git a/src/connectors/http_generic/schema.py b/src/connectors/http_generic/schema.py deleted file mode 100644 index a9df220..0000000 --- a/src/connectors/http_generic/schema.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, Optional - -from pydantic import BaseModel, HttpUrl - - -class HttpRequestInput(BaseModel): - url: HttpUrl - method: str - headers: Optional[Dict[str, str]] = None - params: Optional[Dict[str, str]] = None - body: Optional[Any] = None - - -class HttpResponseOutput(BaseModel): - status_code: int - headers: Dict[str, str] - body: Any - diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py deleted file mode 100644 index a13f9f5..0000000 --- a/src/connectors/manifest.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Type - -from pydantic import BaseModel - -from runtime import BaseConnector - - -def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: - return model.model_json_schema() - - -def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, Any]]: - """ - Build a simple manifest for discovery. - - Each entry describes a connector/action pair and includes JSON Schemas - for the input and output models. This is consumed by Layer C for - REST route generation and MCP tool manifests. - """ - manifest: List[Dict[str, Any]] = [] - for connector in connectors: - input_model = connector._input_model_cls # type: ignore[attr-defined] - output_model = connector._output_model_cls # type: ignore[attr-defined] - manifest.append( - { - "connector_id": connector.connector_id, - "action": connector.action, - "input_schema": _schema_for(input_model), - "output_schema": _schema_for(output_model), - } - ) - return manifest - diff --git a/src/connectors/smtp/__init__.py b/src/connectors/smtp/__init__.py deleted file mode 100644 index 817f413..0000000 --- a/src/connectors/smtp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Connector subpackage: smtp diff --git a/src/connectors/smtp/registration.py b/src/connectors/smtp/registration.py deleted file mode 100644 index 65e76ea..0000000 --- a/src/connectors/smtp/registration.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import aiosmtplib - -from runtime import ErrorCategory, ErrorMapper - - -# Connection / timeout issues are retryable. -ErrorMapper.register(aiosmtplib.errors.SMTPConnectError, ErrorCategory.RETRYABLE, code="SMTP_CONNECT_ERROR") -ErrorMapper.register(aiosmtplib.errors.SMTPTimeoutError, ErrorCategory.RETRYABLE, code="SMTP_TIMEOUT") - -# Authentication failures map to AUTH. -ErrorMapper.register(aiosmtplib.errors.SMTPAuthenticationError, ErrorCategory.AUTH, code="SMTP_AUTH_ERROR") - -# Generic SMTP protocol problems are treated as BUSINESS by default. -ErrorMapper.register(aiosmtplib.errors.SMTPException, ErrorCategory.BUSINESS, code="SMTP_ERROR") - diff --git a/src/connectors/smtp/schema.py b/src/connectors/smtp/schema.py deleted file mode 100644 index 1698024..0000000 --- a/src/connectors/smtp/schema.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from typing import List, Optional - -from pydantic import BaseModel, EmailStr - - -class SmtpSendInput(BaseModel): - host: str - port: int - use_tls: bool = True - username_secret_key: str - password_secret_key: str - from_email: EmailStr - to: List[EmailStr] - subject: str - body: str - - -class SmtpSendOutput(BaseModel): - sent: bool - message_id: Optional[str] = None - diff --git a/src/connectors/stripe/__init__.py b/src/connectors/stripe/__init__.py deleted file mode 100644 index d7426ce..0000000 --- a/src/connectors/stripe/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Connector subpackage: stripe diff --git a/src/connectors/stripe/logic.py b/src/connectors/stripe/logic.py deleted file mode 100644 index 14e973f..0000000 --- a/src/connectors/stripe/logic.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import logging - -import stripe - -from runtime import BaseConnector - -from .schema import ChargeInput, ChargeOutput - -logger = logging.getLogger("connectors.stripe") - - -class StripeChargeConnector(BaseConnector[ChargeInput, ChargeOutput]): - """ - Stripe connector for creating charges using the official Stripe SDK. - """ - - connector_id = "stripe" - action = "charge" - - async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: - # API key is expected to be provided by SecretProvider. - api_key = self.secret_provider.get_secret("stripe_api_key") - stripe.api_key = api_key - - logger.info( - "Creating Stripe charge", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "amount": params.amount, - "currency": params.currency, - }, - ) - - try: - charge = await stripe.Charge.create( # type: ignore[attr-defined] - amount=params.amount, - currency=params.currency, - source=params.source, - description=params.description, - ) - except Exception as exc: # noqa: BLE001 - logger.error( - "Stripe charge creation failed", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "amount": params.amount, - "currency": params.currency, - "error_type": type(exc).__name__, - "message": str(exc), - }, - ) - raise - - logger.info( - "Stripe charge created successfully", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "charge_id": charge.get("id"), - }, - ) - - return ChargeOutput( - charge_id=charge.get("id"), - receipt_url=charge.get("receipt_url"), - ) - diff --git a/src/connectors/stripe/schema.py b/src/connectors/stripe/schema.py deleted file mode 100644 index bf7e6f6..0000000 --- a/src/connectors/stripe/schema.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from pydantic import BaseModel - - -class ChargeInput(BaseModel): - amount: int - currency: str - source: str - description: str | None = None - - -class ChargeOutput(BaseModel): - charge_id: str - receipt_url: str | None = None - diff --git a/src/connectors/fhir_cerner/README.md b/src/node_wire_fhir_cerner/README.md similarity index 92% rename from src/connectors/fhir_cerner/README.md rename to src/node_wire_fhir_cerner/README.md index 47058e7..95cc58c 100644 --- a/src/connectors/fhir_cerner/README.md +++ b/src/node_wire_fhir_cerner/README.md @@ -1,10 +1,16 @@ + + # FHIR Cerner Connector — Technical Documentation > **Platform:** Node Wire > **Standard:** FHIR R4 > **Auth Method:** SMART Backend Services — `private_key_jwt` (RS384) > **Actions:** `read_patient` · `search_encounter` · `create_document_reference` · `search_document_reference` -> **Source:** `src/connectors/fhir_cerner/` +> **Source:** `src/node_wire_fhir_cerner/` > **Test Collection:** `postman_fhir_cerner_collection.json` --- @@ -109,9 +115,9 @@ Cerner's FHIR implementation (especially in the sandbox) has several unique requ | File / Path | Purpose | |---|---| -| `src/connectors/fhir_cerner/logic.py` | Core logic, authentication, and action dispatch | -| `src/connectors/fhir_cerner/schema.py` | Pydantic input/output models and field-level documentation | -| `src/connectors/fhir_cerner/registration.py` | Error mapping and exception handling specifically for Cerner API errors | +| `src/node_wire_fhir_cerner/logic.py` | Core logic, authentication, and action dispatch | +| `src/node_wire_fhir_cerner/schema.py` | Pydantic input/output models and field-level documentation | +| `src/node_wire_fhir_cerner/registration.py` | Error mapping and exception handling specifically for Cerner API errors | | `postman_fhir_cerner_collection.json` | Pre-configured requests to test endpoints end-to-end (at repo root) | --- diff --git a/src/node_wire_fhir_cerner/__init__.py b/src/node_wire_fhir_cerner/__init__.py new file mode 100644 index 0000000..a92bcb3 --- /dev/null +++ b/src/node_wire_fhir_cerner/__init__.py @@ -0,0 +1,5 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""FHIR Cerner connector package.""" diff --git a/src/connectors/fhir_cerner/logic.py b/src/node_wire_fhir_cerner/logic.py similarity index 54% rename from src/connectors/fhir_cerner/logic.py rename to src/node_wire_fhir_cerner/logic.py index 03cc6b0..022573d 100644 --- a/src/connectors/fhir_cerner/logic.py +++ b/src/node_wire_fhir_cerner/logic.py @@ -1,19 +1,27 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import asyncio import base64 import json import logging -import uuid +import os from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import httpx -import jwt -from runtime import BaseConnector, SecretProvider +from node_wire_runtime import BaseConnector, nw_action, sdk_action +from node_wire_runtime.fhir_encounter import assert_encounter_query_has_patient +from node_wire_runtime.mcp_normalizers import ( + normalize_fhir_read_patient, + normalize_fhir_search_encounter, + normalize_fhir_search_patients, +) -from . import registration from .schema import ( FhirCernerDocumentReferenceCreateInput, FhirCernerDocumentReferenceCreateOutput, @@ -21,6 +29,7 @@ FhirCernerDocumentReferenceSearchOutput, FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, + FhirCernerOperationOutput, FhirCernerPatientReadInput, FhirCernerPatientReadOutput, FhirCernerPatientSearchInput, @@ -30,213 +39,123 @@ logger = logging.getLogger("connectors.fhir_cerner") -class _FhirCernerAction(BaseConnector[Any, Any]): - """ - Lightweight BaseConnector that delegates execution to a FhirCernerConnector - instance method. One of these is created per action so that the manifest - and REST router can discover each action's schema and route automatically. - """ - - connector_id = "fhir_cerner" - - def __init__( - self, - action: str, - input_model: type, - output_model: type, - handler: Callable, - *, - secret_provider: Optional[SecretProvider] = None, - ) -> None: - super().__init__(input_model, output_model, secret_provider=secret_provider) - self.action = action - self._handler = handler - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - return await self._handler(params, trace_id=trace_id) +def _safe_doc_ref_log_summary(doc_ref: Dict[str, Any]) -> Dict[str, Any]: + attachment: Dict[str, Any] = {} + content_items = doc_ref.get("content") + if isinstance(content_items, list) and content_items: + first = content_items[0] + if isinstance(first, dict): + attachment = ( + first.get("attachment", {}) if isinstance(first.get("attachment"), dict) else {} + ) + data_value = attachment.get("data") + data_len = len(data_value) if isinstance(data_value, str) else 0 + return { + "keys": sorted(doc_ref.keys()), + "content_items": len(content_items) if isinstance(content_items, list) else 0, + "attachment_content_type": attachment.get("contentType"), + "attachment_data_length": data_len, + } -class FhirCernerConnector: +class FhirCernerConnector(BaseConnector): """ - Single FHIR/Cerner connector. - - ``connector_id = "fhir_cerner"``. All authentication helpers and action - implementations live here. The factory registers ONE instance of this - class; ``list_actions()`` and ``get_action()`` are used by the factory to - expose each action to the manifest and REST router. - - Authentication uses Cerner's SMART Backend Services (private_key_jwt) flow, - identical to Epic's implementation — RS384-signed JWT exchanged for an - OAuth2 access token at the configured token endpoint. - - Supported actions: - • read_patient — fetch a single Patient by ID or name search - • search_patients — fetch multiple Patients by list of IDs or name search - • search_encounter - • create_document_reference - • search_document_reference - - Name-based search parameters (``given_name``, ``family_name``, ``name``, - ``birthdate``) are prioritised over the raw ``search_params`` dict. - - .. note:: - Cerner's sandbox name search is case-sensitive. Supply names exactly - as stored in the system. Special characters in search values should be - URL-encoded (httpx handles this automatically). - - Required secrets (configured via SecretProvider): - - cerner_fhir_base_url : Cerner FHIR R4 base URL - - cerner_private_key : RSA private key PEM (newlines may be escaped) - - cerner_kid : Key ID registered in the Cerner code console - - cerner_client_id : Client ID from Cerner app registration - - cerner_token_url : OAuth2 token endpoint URL (from .well-known/smart-configuration - or the Cerner code console) + FHIR/Cerner connector: SMART Backend Services (private_key_jwt), RS384. + + Required secrets: cerner_fhir_base_url, cerner_private_key, cerner_kid, + cerner_client_id, cerner_token_url (optional cerner_scopes). """ connector_id = "fhir_cerner" + action = "execute" + output_model = FhirCernerOperationOutput + + @sdk_action( + "read_patient", + alias_tolerant=True, + mcp_normalize=normalize_fhir_read_patient, + ) + async def read_patient( + self, params: FhirCernerPatientReadInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._read_patient(params, trace_id=trace_id) + return FhirCernerOperationOutput(resource=out.resource) + + @sdk_action( + "search_patients", + alias_tolerant=True, + mcp_normalize=normalize_fhir_search_patients, + ) + async def search_patients( + self, params: FhirCernerPatientSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_patients(params, trace_id=trace_id) + return FhirCernerOperationOutput( + resources=out.resources, + total=out.total, + errors=out.errors, + ) - def __init__(self, *, secret_provider: SecretProvider) -> None: - self._secret_provider = secret_provider - - self._actions: Dict[str, _FhirCernerAction] = { - "read_patient": _FhirCernerAction( - "read_patient", FhirCernerPatientReadInput, FhirCernerPatientReadOutput, - self._read_patient, secret_provider=secret_provider, - ), - "search_patients": _FhirCernerAction( - "search_patients", FhirCernerPatientSearchInput, FhirCernerPatientSearchOutput, - self._search_patients, secret_provider=secret_provider, - ), - "search_encounter": _FhirCernerAction( - "search_encounter", FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, - self._search_encounter, secret_provider=secret_provider, - ), - "create_document_reference": _FhirCernerAction( - "create_document_reference", FhirCernerDocumentReferenceCreateInput, FhirCernerDocumentReferenceCreateOutput, - self._create_document_reference, secret_provider=secret_provider, - ), - "search_document_reference": _FhirCernerAction( - "search_document_reference", FhirCernerDocumentReferenceSearchInput, FhirCernerDocumentReferenceSearchOutput, - self._search_document_reference, secret_provider=secret_provider, - ), - } - - # ------------------------------------------------------------------ - # Action discovery — consumed by ConnectorFactory - # ------------------------------------------------------------------ + @sdk_action( + "search_encounter", + alias_tolerant=True, + mcp_normalize=normalize_fhir_search_encounter, + ) + async def search_encounter( + self, params: FhirCernerEncounterSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_encounter(params, trace_id=trace_id) + return FhirCernerOperationOutput(resources=out.resources, total=out.total) - def list_actions(self) -> List[_FhirCernerAction]: - """Return all registered action connectors (used by list_for_protocol).""" - return list(self._actions.values()) + @nw_action("create_document_reference") + async def create_document_reference( + self, params: FhirCernerDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._create_document_reference(params, trace_id=trace_id) + return FhirCernerOperationOutput(resource_id=out.resource_id, resource=out.resource) - def get_action(self, name: str) -> Optional[_FhirCernerAction]: - """Return the action connector for the given action name.""" - return self._actions.get(name) + @nw_action("search_document_reference") + async def search_document_reference( + self, params: FhirCernerDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_document_reference(params, trace_id=trace_id) + return FhirCernerOperationOutput(resources=out.resources, total=out.total) # ------------------------------------------------------------------ - # Shared authentication helpers + # Shared helpers — base URL + auth headers via AuthProvider # ------------------------------------------------------------------ def _get_base_url(self) -> str: - return self._secret_provider.get_secret("cerner_fhir_base_url").rstrip("/") + return self.secret_provider.get_secret("cerner_fhir_base_url").rstrip("/") async def _get_auth_header(self) -> Dict[str, str]: - """ - Obtain an access token via Cerner's SMART Backend Services (private_key_jwt) - and return ready-to-use request headers. + """Delegate to the runtime AuthProvider injected by the factory. - Algorithm: RS384. Token lifetime: 5 minutes. - Reference: https://code-console.cerner.com/ + Returns ready-to-use FHIR request headers including the Bearer token. + Token acquisition, JWT construction, scope resolution and caching are + all handled by the provider. """ - headers = { - "Content-Type": "application/fhir+json", - "Accept": "application/fhir+json", - } - - private_key_str = self._secret_provider.get_secret("cerner_private_key") - kid = self._secret_provider.get_secret("cerner_kid") - client_id = self._secret_provider.get_secret("cerner_client_id") - token_url = self._secret_provider.get_secret("cerner_token_url") - - # Validate required secrets are present and non-empty. - missing = [name for name, val in [ - ("cerner_private_key", private_key_str), - ("cerner_kid", kid), - ("cerner_client_id", client_id), - ("cerner_token_url", token_url), - ] if not (val or "").strip()] - if missing: - raise ValueError(f"Missing or empty required Cerner secrets: {', '.join(missing)}") - - # Guard against the malformed URL pattern that embeds the FHIR host inside the auth URL. - # Correct: .../tenants/{tenant}/protocols/oauth2/profiles/smart-v1/token - # Wrong: .../tenants/{tenant}/hosts/fhir-ehr-code.cerner.com/protocols/... - if "/hosts/" in token_url: - raise ValueError( - "cerner_token_url appears malformed — it contains a '/hosts/' segment which is not " - "valid for the Cerner authorization server. " - "Correct format: https://authorization.cerner.com/tenants/{tenant_id}/protocols/oauth2/profiles/smart-v1/token" - ) - + # Cerner-specific safety check: if a token URL contains '/hosts/', + # it is often a malformed sandbox URL that will return 401. try: - scopes = (self._secret_provider.get_secret("cerner_scopes") or "").strip() + token_url = self.secret_provider.get_secret("cerner_token_url") except Exception: - scopes = "" - - if not scopes: - scopes = "system/Patient.read system/Encounter.read system/DocumentReference.read system/DocumentReference.write" - - logger.debug("Cerner token request | token_url=%s | scopes=%r | client_id=%s", token_url, scopes, client_id) - - # Decode escaped newlines in PEM key if stored as a single-line env var (e.g. "\\n" -> "\n"). - # Avoid codecs.unicode_escape which can corrupt non-ASCII bytes in some PEM keys. - if "\\n" in private_key_str: - private_key_pem = private_key_str.replace("\\n", "\n") - else: - private_key_pem = private_key_str - - now = int(datetime.now(tz=timezone.utc).timestamp()) - jwt_token = jwt.encode( - { - "iss": client_id, - "sub": client_id, - "aud": token_url, - "jti": str(uuid.uuid4()), - "iat": now, - "exp": now + 300, - "scope": scopes, - }, - private_key_pem, - algorithm="RS384", - headers={"alg": "RS384", "typ": "JWT", "kid": kid}, - ) + token_url = None - post_data = { - "grant_type": "client_credentials", - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": jwt_token, - "scope": scopes, - } - - async with httpx.AsyncClient() as client: - token_response = await client.post( - token_url, - data=post_data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, + if token_url and "/hosts/" in token_url: + raise ValueError( + "Cerner token_url must not contain '/hosts/' (found in secret). " + "Ensure you are using the 'smart-v1/token' endpoint, e.g. " + "https://authorization.cerner.com/tenants/{tenant}/protocols/oauth2/profiles/smart-v1/token" ) - if token_response.status_code != 200: - logger.error( - "OAuth token exchange failed | status=%s | body=%s", - token_response.status_code, token_response.text, - ) - token_response.raise_for_status() - token_data = token_response.json() - access_token = token_data.get("access_token") - if not access_token: - raise ValueError("Cerner token response did not contain an access_token") + headers = await self.get_auth_headers() + # Ensure FHIR content types are present if the provider didn't include them (e.g. StaticTokenAuthProvider). + if "Content-Type" not in headers: + headers["Content-Type"] = "application/fhir+json" + if "Accept" not in headers: + headers["Accept"] = "application/fhir+json" - headers["Authorization"] = f"Bearer {access_token}" return headers # ------------------------------------------------------------------ @@ -251,15 +170,7 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields. - - Priority: given_name/family_name > name > (nothing). - The ``extra`` dict (raw search_params) is merged at lowest priority. - - .. note:: - Cerner's sandbox name matching is case-sensitive — supply names - with the same capitalisation as stored in the system. - """ + """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) if given_name and given_name.strip(): @@ -305,18 +216,30 @@ async def _read_patient( if params.resource_id: url = f"{base_url}/Patient/{params.resource_id}" query_params: Optional[Dict[str, str]] = None - logger.info("FHIR Patient read by ID", extra={"trace_id": trace_id, "resource_id": params.resource_id}) + logger.info( + "FHIR Patient read by ID", + extra={"trace_id": trace_id, "resource_id": params.resource_id}, + ) elif params.given_name or params.family_name or params.name: url = f"{base_url}/Patient" query_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, + ) + logger.info( + "FHIR Patient read by name fields", + extra={"trace_id": trace_id, "query_params": query_params}, ) - logger.info("FHIR Patient read by name fields", extra={"trace_id": trace_id, "query_params": query_params}) elif params.search_params: url = f"{base_url}/Patient" query_params = params.search_params - logger.info("FHIR Patient read by search", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR Patient read by search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) else: raise ValueError( "Provide resource_id, or name fields (given_name/family_name/name), " @@ -324,11 +247,23 @@ async def _read_patient( ) try: - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=auth_header, params=query_params, timeout=30.0) + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + response = await client.get( + url, + headers=auth_header, + params=query_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) response.raise_for_status() except Exception as exc: - logger.error("FHIR Patient read failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR Patient read failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -340,7 +275,10 @@ async def _read_patient( else: resource = data - logger.info("FHIR Patient read completed", extra={"trace_id": trace_id, "status_code": response.status_code}) + logger.info( + "FHIR Patient read completed", + extra={"trace_id": trace_id, "status_code": response.status_code}, + ) return FhirCernerPatientReadOutput(resource=resource) # ------------------------------------------------------------------ @@ -368,18 +306,21 @@ async def _search_patients( async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: """Return (rid, resource_or_None, error_or_None).""" try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: resp = await client.get( f"{base_url}/Patient/{rid}", headers=auth_header, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) resp.raise_for_status() return rid, resp.json(), None except Exception as exc: logger.warning( "FHIR Cerner Patient fetch failed | resource_id=%s | error=%s", - rid, str(exc), + rid, + str(exc), extra={"trace_id": trace_id}, ) return rid, None, str(exc) @@ -396,15 +337,21 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ logger.info( "FHIR Cerner Patient multi-ID lookup completed | found=%s | errors=%s", - len(resources), len(errors), + len(resources), + len(errors), extra={"trace_id": trace_id}, ) - return FhirCernerPatientSearchOutput(resources=resources, total=len(resources), errors=errors) + return FhirCernerPatientSearchOutput( + resources=resources, total=len(resources), errors=errors + ) # ---- Mode 2: Name-based search (returns Bundle) ---- name_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, ) if not name_params: raise ValueError( @@ -419,41 +366,46 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: response = await client.get( f"{base_url}/Patient", headers=auth_header, params=name_params, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: logger.error( "FHIR Cerner Patient name search failed | status=%s | body=%s", - exc.response.status_code, exc.response.text, + exc.response.status_code, + exc.response.text, extra={"trace_id": trace_id}, ) raise except Exception as exc: logger.error( "FHIR Cerner Patient name search failed | error=%s: %s", - type(exc).__name__, str(exc), + type(exc).__name__, + str(exc), extra={"trace_id": trace_id}, ) raise data = response.json() - resources = [] + bundle_resources: List[Dict[str, Any]] = [] total = data.get("total") if data.get("resourceType") == "Bundle" and data.get("entry"): - resources = [e["resource"] for e in data["entry"] if "resource" in e] + bundle_resources = [e["resource"] for e in data["entry"] if "resource" in e] logger.info( "FHIR Cerner Patient name search completed | found=%s | total=%s", - len(resources), total, + len(bundle_resources), + total, extra={"trace_id": trace_id}, ) - return FhirCernerPatientSearchOutput(resources=resources, total=total) + return FhirCernerPatientSearchOutput(resources=bundle_resources, total=total) # ------------------------------------------------------------------ # Action: search_encounter @@ -463,30 +415,54 @@ async def _search_encounter( self, params: FhirCernerEncounterSearchInput, *, trace_id: str ) -> FhirCernerEncounterSearchOutput: base_url = self._get_base_url() - auth_header = await self._get_auth_header() if params.patient_id or params.status or params.date: query_params = self._build_encounter_search_params( params.patient_id, params.status, params.date, params.search_params ) - logger.info("FHIR Encounter search by explicit fields", extra={"trace_id": trace_id, "query_params": query_params}) + logger.info( + "FHIR Encounter search by explicit fields", + extra={"trace_id": trace_id, "query_params": query_params}, + ) elif params.search_params: query_params = params.search_params - logger.info("FHIR Encounter search by raw params", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR Encounter search by raw params", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) else: raise ValueError("Provide at least patient_id, status, date OR search_params") + assert_encounter_query_has_patient(query_params) + + auth_header = await self._get_auth_header() + try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: response = await client.get( - f"{base_url}/Encounter", headers=auth_header, params=query_params, timeout=30.0, + f"{base_url}/Encounter", + headers=auth_header, + params=query_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: - logger.error("FHIR Encounter search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) + logger.error( + "FHIR Encounter search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) raise except Exception as exc: - logger.error("FHIR Encounter search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR Encounter search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -495,7 +471,11 @@ async def _search_encounter( if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] - logger.info("FHIR Encounter search completed | found=%s", len(resources), extra={"trace_id": trace_id}) + logger.info( + "FHIR Encounter search completed | found=%s", + len(resources), + extra={"trace_id": trace_id}, + ) return FhirCernerEncounterSearchOutput(resources=resources, total=total) # ------------------------------------------------------------------ @@ -508,19 +488,13 @@ async def _create_document_reference( base_url = self._get_base_url() auth_header = await self._get_auth_header() - # Validate context early so callers get the most actionable error. - if params.context: - ctx = dict(params.context) - if ctx.get("encounter") and not ctx.get("period"): - raise ValueError("Cerner requires 'context.period' when 'context.encounter' is provided.") - # Cerner sandbox strictly requires a charset (lowercase, no space) for text types. # Failing to provide it results in: "a character set must be specified" (422). content_type = (params.content_type or "text/plain").strip().lower() if content_type.startswith("text/"): + content_type = content_type.replace(" ", "") if "charset=" not in content_type: - # Match the formatting expected by tests and common HTTP conventions. - content_type = f"{content_type}; charset=UTF-8" + content_type = f"{content_type};charset=utf-8" attachment: Dict[str, Any] = {"contentType": content_type} if params.data: @@ -530,9 +504,13 @@ async def _create_document_reference( else: raise ValueError("Either 'text' or 'data' must be provided") - # Some Cerner tenants require title/creation; default safely when omitted. - attachment["title"] = params.attachment_title or "Document" - attachment["creation"] = params.attachment_creation or datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") + # Cerner requires title and creation on the attachment + if not params.attachment_title: + raise ValueError("Cerner requires 'attachment_title' on DocumentReference create.") + attachment["title"] = params.attachment_title + attachment["creation"] = params.attachment_creation or datetime.now( + tz=timezone.utc + ).strftime("%Y-%m-%dT%H:%M:%S.000Z") doc_ref: Dict[str, Any] = { "resourceType": "DocumentReference", @@ -590,28 +568,58 @@ async def _create_document_reference( if params.custodian: doc_ref["custodian"] = params.custodian - # Note: 'description' is intentionally omitted by default + # Note: 'description' is intentionally omitted by default # as Cerner can reject it depending on tenant configuration. if params.context: - doc_ref["context"] = dict(params.context) + context = dict(params.context) + # Cerner REQUIRES context.period whenever context.encounter is set. + # Auto-inject a period using the document date if the caller didn't supply one. + if context.get("encounter") and not context.get("period"): + # Force .000Z precision and provide a 1-hour clinical window + start_dt = datetime.now(tz=timezone.utc) + end_dt = start_dt + timedelta(hours=1) + context["period"] = { + "start": start_dt.strftime("%Y-%m-%dT%H:%M:%S.000Z"), + "end": end_dt.strftime("%Y-%m-%dT%H:%M:%S.000Z"), + } + logger.debug( + "Auto-injected context.period (required by Cerner when encounter is set)", + extra={"trace_id": trace_id}, + ) + doc_ref["context"] = context if params.additional_fields: doc_ref.update(params.additional_fields) # Ensure no connector-specific fields leaked into the root of the FHIR resource. # Cerner will reject the payload with a 422 if it sees unknown root fields. - for field in ["text", "data", "content_type", "attachment_title", "attachment_creation", "doc_status"]: + for field in [ + "text", + "data", + "content_type", + "attachment_title", + "attachment_creation", + "doc_status", + ]: doc_ref.pop(field, None) - # Note: Some Cerner tenants require author/authenticator. The connector does not - # enforce those fields universally; tenants that require them will return 4xx - # with OperationOutcome diagnostics. + # Cerner requires at least one author for clinical note document types. + if not params.author: + raise ValueError( + "Cerner requires 'author' for clinical note document types. " + "Provide at least one author reference, e.g. [{'reference': 'Practitioner/{id}'}]" + ) logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: response = await client.post( - f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, timeout=30.0, + f"{base_url}/DocumentReference", + json=doc_ref, + headers=auth_header, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -628,18 +636,30 @@ async def _create_document_reference( parts = [p for p in [severity, code, diag or detail_text] if p] if parts: diagnostics.append(" ".join(parts)) - error_detail = " | ".join(diagnostics) if diagnostics else raw_body + error_detail = ( + " | ".join(diagnostics) + if diagnostics + else f"HTTP {exc.response.status_code} from Cerner FHIR endpoint" + ) except Exception: - error_detail = raw_body + error_detail = f"HTTP {exc.response.status_code} from Cerner FHIR endpoint" logger.error( - "FHIR DocumentReference create failed | status=%s | cerner_error=%s | raw_body=%s | sent_payload=%s", - exc.response.status_code, error_detail, raw_body, json.dumps(doc_ref), + "FHIR DocumentReference create failed | status=%s | cerner_error=%s | body_length=%s | payload_summary=%s", + exc.response.status_code, + error_detail, + len(raw_body), + json.dumps(_safe_doc_ref_log_summary(doc_ref)), extra={"trace_id": trace_id}, ) raise ValueError(f"Cerner Error: {error_detail}") from exc except Exception as exc: - logger.error("FHIR DocumentReference create failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference create failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise resource_id: Optional[str] = None @@ -648,7 +668,11 @@ async def _create_document_reference( location = response.headers.get("Location", "") if location: history_marker = location.find("/_history/") - resource_id = location[:history_marker].split("/")[-1] if history_marker != -1 else location.split("/")[-1] + resource_id = ( + location[:history_marker].split("/")[-1] + if history_marker != -1 + else location.split("/")[-1] + ) if not resource_id: content_length = response.headers.get("content-length", "0") @@ -665,8 +689,14 @@ async def _create_document_reference( f"Status: {response.status_code}, Location: {location!r}, Body: {response.text[:200]!r}" ) - logger.info("FHIR DocumentReference create completed | resource_id=%s", resource_id, extra={"trace_id": trace_id}) - return FhirCernerDocumentReferenceCreateOutput(resource_id=resource_id, resource=body if body else None) + logger.info( + "FHIR DocumentReference create completed | resource_id=%s", + resource_id, + extra={"trace_id": trace_id}, + ) + return FhirCernerDocumentReferenceCreateOutput( + resource_id=resource_id, resource=body if body else None + ) # ------------------------------------------------------------------ # Action: search_document_reference @@ -678,19 +708,37 @@ async def _search_document_reference( base_url = self._get_base_url() auth_header = await self._get_auth_header() - logger.info("FHIR DocumentReference search", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR DocumentReference search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: response = await client.get( - f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, timeout=30.0, + f"{base_url}/DocumentReference", + headers=auth_header, + params=params.search_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) response.raise_for_status() except httpx.HTTPStatusError as exc: - logger.error("FHIR DocumentReference search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) raise except Exception as exc: - logger.error("FHIR DocumentReference search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -704,4 +752,4 @@ async def _search_document_reference( len(resources), extra={"trace_id": trace_id}, ) - return FhirCernerDocumentReferenceSearchOutput(resources=resources, total=total) \ No newline at end of file + return FhirCernerDocumentReferenceSearchOutput(resources=resources, total=total) diff --git a/src/connectors/fhir_cerner/registration.py b/src/node_wire_fhir_cerner/registration.py similarity index 88% rename from src/connectors/fhir_cerner/registration.py rename to src/node_wire_fhir_cerner/registration.py index 384bfad..6536b7c 100644 --- a/src/connectors/fhir_cerner/registration.py +++ b/src/node_wire_fhir_cerner/registration.py @@ -1,8 +1,12 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import httpx -from runtime import ErrorCategory, ErrorMapper +from node_wire_runtime import ErrorCategory, ErrorMapper # FHIR/Cerner error mappings for network and HTTP failures. diff --git a/src/connectors/fhir_cerner/schema.py b/src/node_wire_fhir_cerner/schema.py similarity index 79% rename from src/connectors/fhir_cerner/schema.py rename to src/node_wire_fhir_cerner/schema.py index eba29c1..95b26a5 100644 --- a/src/connectors/fhir_cerner/schema.py +++ b/src/node_wire_fhir_cerner/schema.py @@ -1,16 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations -from typing import Any, Dict, List, Optional +import base64 +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator # --------------------------------------------------------------------------- -# Patient – Read (single patient by ID or name search) +# Patient – Read # --------------------------------------------------------------------------- + class FhirCernerPatientReadInput(BaseModel): - """Input for reading a single FHIR Patient resource from Cerner.""" + """Input for reading a FHIR Patient resource from Cerner.""" + + action: Literal["read_patient"] = "read_patient" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_id: Optional[str] = None """Direct Patient ID lookup (e.g. '12345678').""" @@ -23,21 +32,13 @@ class FhirCernerPatientReadInput(BaseModel): """Patient family / last name (used in name-based search).""" name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter. - - Use this when you only have a single combined name string. When both - ``name`` and ``given_name``/``family_name`` are set, the explicit given/ - family fields take precedence. - """ + """Full or partial name string — mapped to FHIR 'name' search parameter.""" birthdate: Optional[str] = None """Date of birth in YYYY-MM-DD format — used alongside name search.""" search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters (e.g. {\"family\": \"Smith\", \"given\": \"John\"}). - - Lowest priority — used only when no ID or explicit name fields are set. - """ + """Raw FHIR search parameters (e.g. {"family": "Smith", "given": "John"}).""" class FhirCernerPatientReadOutput(BaseModel): @@ -51,75 +52,42 @@ class FhirCernerPatientReadOutput(BaseModel): # Patient – Search (multi-ID fan-out OR name search returning multiple results) # --------------------------------------------------------------------------- -class FhirCernerPatientSearchInput(BaseModel): - """Input for searching / fetching multiple FHIR Patient resources from Cerner. - - Two modes are supported: - - 1. **Multi-ID lookup** — pass ``resource_ids`` (list of Patient IDs). - Each ID is fetched concurrently; partial failures are captured in - ``FhirCernerPatientSearchOutput.errors`` rather than raising globally. - - 2. **Name-based search** — pass ``given_name``, ``family_name``, ``name``, - and/or ``birthdate``. A single FHIR search request is issued and all - matching Bundle entries are returned. - Only one mode should be used per request. If ``resource_ids`` is set it - takes priority over the name/search fields. +class FhirCernerPatientSearchInput(BaseModel): + """Input for searching / fetching multiple FHIR Patient resources from Cerner.""" - .. note:: - Cerner's sandbox name search is case-sensitive. Use the exact - capitalisation stored in the system (e.g. ``family_name="Smith"`` not - ``"smith"``). The ``name`` parameter maps to the standard FHIR - ``name`` token which Cerner supports as a partial-match. - """ + action: Literal["search_patients"] = "search_patients" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_ids: Optional[List[str]] = None - """List of Cerner Patient IDs to fetch concurrently (e.g. ['12345678', '87654321']).""" + """List of Cerner Patient IDs to fetch concurrently.""" given_name: Optional[str] = None - """Patient given / first name.""" - family_name: Optional[str] = None - """Patient family / last name.""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter.""" - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format.""" - search_params: Optional[Dict[str, str]] = None - """Additional raw FHIR search parameters merged with the name fields.""" class FhirCernerPatientSearchOutput(BaseModel): """Output for searching multiple FHIR Patient resources from Cerner.""" resources: List[Dict[str, Any]] - """List of successfully retrieved FHIR Patient JSON objects.""" - total: Optional[int] = None - """Total number of matches reported by the server Bundle (name-search mode).""" - - errors: List[Dict[str, Any]] = [] - """Per-ID errors encountered during multi-ID fan-out. - - Each entry has the shape:: - - {"resource_id": "", "error": ""} - - An empty list means all lookups succeeded. - """ + errors: List[Dict[str, Any]] = Field(default_factory=list) # --------------------------------------------------------------------------- # Encounter – Search # --------------------------------------------------------------------------- + class FhirCernerEncounterSearchInput(BaseModel): """Input for searching FHIR Encounter resources in Cerner.""" + action: Literal["search_encounter"] = "search_encounter" + """Action discriminator (one endpoint, multiple actions pattern).""" + patient_id: Optional[str] = None """Cerner Patient ID to find encounters for (maps to 'patient' FHIR param).""" @@ -147,9 +115,13 @@ class FhirCernerEncounterSearchOutput(BaseModel): # DocumentReference – Create # --------------------------------------------------------------------------- + class FhirCernerDocumentReferenceCreateInput(BaseModel): """Input for creating a FHIR DocumentReference resource in Cerner.""" + action: Literal["create_document_reference"] = "create_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + identifier: Optional[list[Dict[str, Any]]] = None """Document identifier. @@ -216,8 +188,20 @@ class FhirCernerDocumentReferenceCreateInput(BaseModel): All provided dates must include a time component. """ - data: Optional[str] = None + @field_validator("data", mode="after") + @classmethod + def validate_base64_data(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + try: + base64.b64decode(v, validate=True) + except Exception: + raise ValueError("data must be a valid base64-encoded string") + return v + + data: Optional[str] = Field(None, max_length=10 * 1024 * 1024) """Base64-encoded document content. Required for both binary files (PDFs) and plain text. + Max size 10MB. Note: If you provide raw text in the ``text`` field, the connector will automatically encode it to base64 for you. @@ -225,7 +209,7 @@ class FhirCernerDocumentReferenceCreateInput(BaseModel): text: Optional[str] = None """Raw string content for the document attachment. - + The connector will automatically base64-encode this string and send it via ``attachment.data``, as the Cerner sandbox does not support ``attachment.text``. """ @@ -252,7 +236,7 @@ class FhirCernerDocumentReferenceCreateInput(BaseModel): custodian: Optional[Dict[str, Any]] = None """Custodian of the document (e.g. Organization reference). - + Example: {"reference": "Organization/{id}"} """ @@ -292,11 +276,15 @@ class FhirCernerDocumentReferenceCreateOutput(BaseModel): # DocumentReference – Search # --------------------------------------------------------------------------- + class FhirCernerDocumentReferenceSearchInput(BaseModel): """Input for searching FHIR DocumentReference resources in Cerner.""" + action: Literal["search_document_reference"] = "search_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + search_params: Dict[str, str] - """Search parameters (e.g. {\"patient\": \"12345678\"}).""" + """Search parameters (e.g. {"patient": "12345678"}).""" class FhirCernerDocumentReferenceSearchOutput(BaseModel): @@ -307,3 +295,17 @@ class FhirCernerDocumentReferenceSearchOutput(BaseModel): total: Optional[int] = None """Total number of results reported by the Bundle.""" + + +class FhirCernerOperationOutput(BaseModel): + """ + Unified output for all Cerner FHIR actions (BaseConnector single output_model). + + Fields are populated depending on the action; unused fields are None. + """ + + resource: Optional[Dict[str, Any]] = None + resources: Optional[list[Dict[str, Any]]] = None + total: Optional[int] = None + resource_id: Optional[str] = None + errors: Optional[list[Dict[str, Any]]] = None diff --git a/src/connectors/fhir_epic/README.md b/src/node_wire_fhir_epic/README.md similarity index 96% rename from src/connectors/fhir_epic/README.md rename to src/node_wire_fhir_epic/README.md index d9b4aaf..f96e65e 100644 --- a/src/connectors/fhir_epic/README.md +++ b/src/node_wire_fhir_epic/README.md @@ -1,10 +1,16 @@ + + # FHIR Epic Connector — Technical Documentation > **Platform:** Node Wire > **Standard:** FHIR R4 > **Auth Method:** SMART Backend Services — RS384 JWT / OAuth2 > **Actions:** `read_patient` · `search_encounter` · `create_document_reference` · `search_document_reference` -> **Source:** `src/connectors/fhir_epic/` +> **Source:** `src/node_wire_fhir_epic/` > **Test Collection:** `postman_fhir_epic_collection.json` --- @@ -123,8 +129,8 @@ A Postman collection is provided at the root: `postman_fhir_epic_collection.json | File / Path | Purpose | |---|---| -| `src/connectors/fhir_epic/logic.py` | Core logic and action dispatch | -| `src/connectors/fhir_epic/schema.py` | Pydantic input/output models | +| `src/node_wire_fhir_epic/logic.py` | Core logic and action dispatch | +| `src/node_wire_fhir_epic/schema.py` | Pydantic input/output models | | `src/bindings/factory.py` | Connector instantiation logic | | `src/bindings/rest_api/app.py` | REST API routing | | `tests/test_fhir_epic.py` | Comprehensive test suite | diff --git a/src/node_wire_fhir_epic/__init__.py b/src/node_wire_fhir_epic/__init__.py new file mode 100644 index 0000000..f01af1d --- /dev/null +++ b/src/node_wire_fhir_epic/__init__.py @@ -0,0 +1,5 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""FHIR Epic connector package.""" diff --git a/src/node_wire_fhir_epic/logic.py b/src/node_wire_fhir_epic/logic.py new file mode 100644 index 0000000..ebf21a4 --- /dev/null +++ b/src/node_wire_fhir_epic/logic.py @@ -0,0 +1,603 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import asyncio +import logging +import os +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import httpx +import json + +from node_wire_runtime import BaseConnector, nw_action, sdk_action +from node_wire_runtime.fhir_encounter import assert_encounter_query_has_patient +from node_wire_runtime.mcp_normalizers import ( + normalize_fhir_read_patient, + normalize_fhir_search_encounter, + normalize_fhir_search_patients, +) + +from .schema import ( + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceCreateOutput, + FhirDocumentReferenceSearchInput, + FhirDocumentReferenceSearchOutput, + FhirEncounterSearchInput, + FhirEncounterSearchOutput, + FhirEpicOperationOutput, + FhirPatientReadInput, + FhirPatientReadOutput, + FhirPatientSearchInput, + FhirPatientSearchOutput, +) + +logger = logging.getLogger("connectors.fhir_epic") + + +def _safe_doc_ref_log_summary(doc_ref: Dict[str, Any]) -> Dict[str, Any]: + attachment: Dict[str, Any] = {} + content_items = doc_ref.get("content") + if isinstance(content_items, list) and content_items: + first = content_items[0] + if isinstance(first, dict): + attachment = ( + first.get("attachment", {}) if isinstance(first.get("attachment"), dict) else {} + ) + data_value = attachment.get("data") + data_len = len(data_value) if isinstance(data_value, str) else 0 + return { + "keys": sorted(doc_ref.keys()), + "content_items": len(content_items) if isinstance(content_items, list) else 0, + "attachment_content_type": attachment.get("contentType"), + "attachment_data_length": data_len, + } + + +class FhirEpicConnector(BaseConnector): + """FHIR/Epic connector: one @nw_action per operation.""" + + connector_id = "fhir_epic" + action = "execute" + output_model = FhirEpicOperationOutput + + @sdk_action( + "read_patient", + alias_tolerant=True, + mcp_normalize=normalize_fhir_read_patient, + ) + async def read_patient( + self, params: FhirPatientReadInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._read_patient(params, trace_id=trace_id) + return FhirEpicOperationOutput(resource=out.resource) + + @sdk_action( + "search_patients", + alias_tolerant=True, + mcp_normalize=normalize_fhir_search_patients, + ) + async def search_patients( + self, params: FhirPatientSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_patients(params, trace_id=trace_id) + return FhirEpicOperationOutput( + resources=out.resources, + total=out.total, + errors=out.errors, + ) + + @sdk_action( + "search_encounter", + alias_tolerant=True, + mcp_normalize=normalize_fhir_search_encounter, + ) + async def search_encounter( + self, params: FhirEncounterSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_encounter(params, trace_id=trace_id) + return FhirEpicOperationOutput(resources=out.resources, total=out.total) + + @nw_action("create_document_reference") + async def create_document_reference( + self, params: FhirDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._create_document_reference(params, trace_id=trace_id) + return FhirEpicOperationOutput(resource_id=out.resource_id, resource=out.resource) + + @nw_action("search_document_reference") + async def search_document_reference( + self, params: FhirDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_document_reference(params, trace_id=trace_id) + return FhirEpicOperationOutput(resources=out.resources, total=out.total) + + # ------------------------------------------------------------------ + # Shared helpers — base URL + auth headers via AuthProvider + # ------------------------------------------------------------------ + + def _get_base_url(self) -> str: + return self.secret_provider.get_secret("epic_fhir_base_url").rstrip("/") + + async def _get_auth_header(self) -> Dict[str, str]: + """Delegate to the runtime AuthProvider injected by the factory. + + Returns ready-to-use FHIR request headers including the Bearer token. + Token acquisition, JWT construction, scope resolution and caching are + all handled by the provider. + """ + headers = await self.get_auth_headers() + # Ensure FHIR content types are present if the provider didn't include them (e.g. StaticTokenAuthProvider). + if "Content-Type" not in headers: + headers["Content-Type"] = "application/fhir+json" + if "Accept" not in headers: + headers["Accept"] = "application/fhir+json" + + return headers + + @staticmethod + def _build_name_search_params( + given_name: Optional[str], + family_name: Optional[str], + name: Optional[str], + birthdate: Optional[str], + extra: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + params: Dict[str, str] = dict(extra or {}) + + if given_name and given_name.strip(): + params["given"] = given_name.strip() + if family_name and family_name.strip(): + params["family"] = family_name.strip() + if name and name.strip() and "given" not in params and "family" not in params: + params["name"] = name.strip() + if birthdate and birthdate.strip(): + params["birthdate"] = birthdate.strip() + + return params + + @staticmethod + def _build_encounter_search_params( + patient_id: Optional[str], + status: Optional[str], + date: Optional[str], + extra: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + params: Dict[str, str] = dict(extra or {}) + + if patient_id and patient_id.strip(): + params["patient"] = patient_id.strip() + if status and status.strip(): + params["status"] = status.strip() + if date and date.strip(): + params["date"] = date.strip() + + return params + + async def _read_patient( + self, params: FhirPatientReadInput, *, trace_id: str + ) -> FhirPatientReadOutput: + base_url = self._get_base_url() + auth_header = await self._get_auth_header() + + if params.resource_id: + url = f"{base_url}/Patient/{params.resource_id}" + query_params: Optional[Dict[str, str]] = None + logger.info( + "FHIR Patient read by ID", + extra={"trace_id": trace_id, "resource_id": params.resource_id}, + ) + elif params.given_name or params.family_name or params.name: + url = f"{base_url}/Patient" + query_params = self._build_name_search_params( + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, + ) + logger.info( + "FHIR Patient read by name fields", + extra={"trace_id": trace_id, "query_params": query_params}, + ) + elif params.search_params: + url = f"{base_url}/Patient" + query_params = params.search_params + logger.info( + "FHIR Patient read by search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) + else: + raise ValueError( + "Provide resource_id, or name fields (given_name/family_name/name), " + "or search_params" + ) + + try: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + response = await client.get( + url, + headers=auth_header, + params=query_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) + response.raise_for_status() + except Exception as exc: + logger.error( + "FHIR Patient read failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) + raise + + data = response.json() + if data.get("resourceType") == "Bundle": + if data.get("entry"): + resource = data["entry"][0].get("resource", {}) + else: + raise ValueError("No patients found in search results") + else: + resource = data + + logger.info( + "FHIR Patient read completed", + extra={"trace_id": trace_id, "status_code": response.status_code}, + ) + return FhirPatientReadOutput(resource=resource) + + async def _search_patients( + self, params: FhirPatientSearchInput, *, trace_id: str + ) -> FhirPatientSearchOutput: + base_url = self._get_base_url() + auth_header = await self._get_auth_header() + + if params.resource_ids: + ids = [rid.strip() for rid in params.resource_ids if rid.strip()] + if not ids: + raise ValueError("resource_ids list is empty") + + logger.info( + "FHIR Patient multi-ID lookup | count=%s", + len(ids), + extra={"trace_id": trace_id, "resource_ids": ids}, + ) + + async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: + try: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + resp = await client.get( + f"{base_url}/Patient/{rid}", + headers=auth_header, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) + resp.raise_for_status() + return rid, resp.json(), None + except Exception as exc: + logger.warning( + "FHIR Patient fetch failed | resource_id=%s | error=%s", + rid, + str(exc), + extra={"trace_id": trace_id}, + ) + return rid, None, str(exc) + + results = await asyncio.gather(*[_fetch_one(rid) for rid in ids]) + + resources: List[Dict[str, Any]] = [] + errors: List[Dict[str, Any]] = [] + for rid, resource, error in results: + if resource is not None: + resources.append(resource) + else: + errors.append({"resource_id": rid, "error": error or "Unknown error"}) + + logger.info( + "FHIR Patient multi-ID lookup completed | found=%s | errors=%s", + len(resources), + len(errors), + extra={"trace_id": trace_id}, + ) + return FhirPatientSearchOutput(resources=resources, total=len(resources), errors=errors) + + name_params = self._build_name_search_params( + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, + ) + if not name_params: + raise ValueError( + "Provide resource_ids for multi-ID lookup, or at least one of " + "given_name / family_name / name / birthdate / search_params for name-based search" + ) + + logger.info( + "FHIR Patient name search | params=%s", + name_params, + extra={"trace_id": trace_id}, + ) + + try: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + response = await client.get( + f"{base_url}/Patient", + headers=auth_header, + params=name_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + logger.error( + "FHIR Patient name search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) + raise + except Exception as exc: + logger.error( + "FHIR Patient name search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) + raise + + data = response.json() + bundle_resources: List[Dict[str, Any]] = [] + total = data.get("total") + if data.get("resourceType") == "Bundle" and data.get("entry"): + bundle_resources = [e["resource"] for e in data["entry"] if "resource" in e] + + logger.info( + "FHIR Patient name search completed | found=%s | total=%s", + len(bundle_resources), + total, + extra={"trace_id": trace_id}, + ) + return FhirPatientSearchOutput(resources=bundle_resources, total=total) + + async def _search_encounter( + self, params: FhirEncounterSearchInput, *, trace_id: str + ) -> FhirEncounterSearchOutput: + base_url = self._get_base_url() + + if params.patient_id or params.status or params.date: + query_params = self._build_encounter_search_params( + params.patient_id, params.status, params.date, params.search_params + ) + logger.info( + "FHIR Encounter search by explicit fields", + extra={"trace_id": trace_id, "query_params": query_params}, + ) + elif params.search_params: + query_params = params.search_params + logger.info( + "FHIR Encounter search by raw params", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) + else: + raise ValueError("Provide at least patient_id, status, date OR search_params") + + assert_encounter_query_has_patient(query_params) + + auth_header = await self._get_auth_header() + + try: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + response = await client.get( + f"{base_url}/Encounter", + headers=auth_header, + params=query_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + logger.error( + "FHIR Encounter search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) + raise + except Exception as exc: + logger.error( + "FHIR Encounter search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) + raise + + data = response.json() + resources: list[Dict[str, Any]] = [] + total = data.get("total") + if data.get("resourceType") == "Bundle" and data.get("entry"): + resources = [e["resource"] for e in data["entry"] if "resource" in e] + + logger.info( + "FHIR Encounter search completed | found=%s", + len(resources), + extra={"trace_id": trace_id}, + ) + return FhirEncounterSearchOutput(resources=resources, total=total) + + async def _create_document_reference( + self, params: FhirDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirDocumentReferenceCreateOutput: + base_url = self._get_base_url() + auth_header = await self._get_auth_header() + + doc_ref: Dict[str, Any] = { + "resourceType": "DocumentReference", + "identifier": params.identifier, + "status": params.status, + "type": params.type, + "subject": {"reference": params.subject}, + "date": datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "content": [ + { + "attachment": { + "contentType": params.content_type or "text/plain", + "data": params.data, + } + } + ], + } + if params.category: + doc_ref["category"] = params.category + if params.author: + doc_ref["author"] = params.author + if params.description: + doc_ref["description"] = params.description + if params.context: + doc_ref["context"] = params.context + if params.additional_fields: + doc_ref.update(params.additional_fields) + + logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) + + try: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + response = await client.post( + f"{base_url}/DocumentReference", + json=doc_ref, + headers=auth_header, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + try: + resp_json = exc.response.json() + diagnostics = [] + if resp_json.get("resourceType") == "OperationOutcome": + for issue in resp_json.get("issue", []): + if "diagnostics" in issue: + diagnostics.append(issue["diagnostics"]) + error_detail = ( + " | ".join(diagnostics) + if diagnostics + else f"HTTP {exc.response.status_code} from Epic FHIR endpoint" + ) + except Exception: + error_detail = f"HTTP {exc.response.status_code} from Epic FHIR endpoint" + + logger.error( + "FHIR DocumentReference create failed | status=%s | epic_error=%s | payload_summary=%s", + exc.response.status_code, + error_detail, + json.dumps(_safe_doc_ref_log_summary(doc_ref)), + extra={"trace_id": trace_id}, + ) + raise ValueError(f"Epic Error: {error_detail}") from exc + except Exception as exc: + logger.error( + "FHIR DocumentReference create failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) + raise + + resource_id: Optional[str] = None + body: Dict[str, Any] = {} + + location = response.headers.get("Location", "") + if location: + history_marker = location.find("/_history/") + resource_id = ( + location[:history_marker].split("/")[-1] + if history_marker != -1 + else location.split("/")[-1] + ) + + if not resource_id: + content_length = response.headers.get("content-length", "0") + if content_length != "0" and response.content: + try: + body = response.json() + resource_id = body.get("id") + except Exception: + pass + + if not resource_id: + raise ValueError( + f"Could not extract resource ID from DocumentReference create response. " + f"Status: {response.status_code}, Location: {location!r}, Body: {response.text[:200]!r}" + ) + + logger.info( + "FHIR DocumentReference create completed | resource_id=%s", + resource_id, + extra={"trace_id": trace_id}, + ) + return FhirDocumentReferenceCreateOutput( + resource_id=resource_id, resource=body if body else None + ) + + async def _search_document_reference( + self, params: FhirDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirDocumentReferenceSearchOutput: + base_url = self._get_base_url() + auth_header = await self._get_auth_header() + + logger.info( + "FHIR DocumentReference search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) + + try: + async with httpx.AsyncClient( + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + ) as client: + response = await client.get( + f"{base_url}/DocumentReference", + headers=auth_header, + params=params.search_params, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + logger.error( + "FHIR DocumentReference search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) + raise + except Exception as exc: + logger.error( + "FHIR DocumentReference search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) + raise + + data = response.json() + resources: list[Dict[str, Any]] = [] + total = data.get("total") + if data.get("resourceType") == "Bundle" and data.get("entry"): + resources = [e["resource"] for e in data["entry"] if "resource" in e] + + logger.info( + "FHIR DocumentReference search completed | found=%s", + len(resources), + extra={"trace_id": trace_id}, + ) + return FhirDocumentReferenceSearchOutput(resources=resources, total=total) diff --git a/src/connectors/fhir_epic/registration.py b/src/node_wire_fhir_epic/registration.py similarity index 86% rename from src/connectors/fhir_epic/registration.py rename to src/node_wire_fhir_epic/registration.py index 4307f2a..f590ebb 100644 --- a/src/connectors/fhir_epic/registration.py +++ b/src/node_wire_fhir_epic/registration.py @@ -1,8 +1,12 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import httpx -from runtime import ErrorCategory, ErrorMapper +from node_wire_runtime import ErrorCategory, ErrorMapper # FHIR/Epic error mappings for network and HTTP failures. @@ -19,4 +23,3 @@ # Request errors (DNS issues, invalid URLs, etc.) are generally fatal. ErrorMapper.register(httpx.RequestError, ErrorCategory.FATAL, code="FHIR_REQUEST_ERROR") - diff --git a/src/connectors/fhir_epic/schema.py b/src/node_wire_fhir_epic/schema.py similarity index 64% rename from src/connectors/fhir_epic/schema.py rename to src/node_wire_fhir_epic/schema.py index 99aa9b5..4eb7b3f 100644 --- a/src/connectors/fhir_epic/schema.py +++ b/src/node_wire_fhir_epic/schema.py @@ -1,43 +1,36 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations -from typing import Any, Dict, List, Optional +import base64 +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator # --------------------------------------------------------------------------- -# Patient – Read (single patient by ID or name search) +# Patient – Read # --------------------------------------------------------------------------- + class FhirPatientReadInput(BaseModel): - """Input for reading a single FHIR Patient resource from Epic.""" + """Input for reading a FHIR Patient resource.""" + + action: Literal["read_patient"] = "read_patient" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_id: Optional[str] = None """Direct Patient ID lookup (e.g. 'eXYZ123').""" - # Convenience name fields — take priority over raw search_params when set. given_name: Optional[str] = None - """Patient given / first name (used in name-based search).""" - family_name: Optional[str] = None - """Patient family / last name (used in name-based search).""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter. - - Use this when you only have a single combined name string. When both - ``name`` and ``given_name``/``family_name`` are set, the explicit given/ - family fields take precedence. - """ - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format — used alongside name search.""" search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters (e.g. {\"family\": \"Smith\", \"given\": \"John\"}). - - Lowest priority — used only when no ID or explicit name fields are set. - """ + """Search parameters (e.g. {"family": "Smith", "given": "John"}).""" class FhirPatientReadOutput(BaseModel): @@ -51,80 +44,44 @@ class FhirPatientReadOutput(BaseModel): # Patient – Search (multi-ID fan-out OR name search returning multiple results) # --------------------------------------------------------------------------- -class FhirPatientSearchInput(BaseModel): - """Input for searching / fetching multiple FHIR Patient resources from Epic. - - Two modes are supported: - - 1. **Multi-ID lookup** — pass ``resource_ids`` (list of Patient IDs). - Each ID is fetched concurrently; partial failures are captured in - ``FhirPatientSearchOutput.errors`` rather than raising globally. - 2. **Name-based search** — pass ``given_name``, ``family_name``, ``name``, - and/or ``birthdate``. A single FHIR search request is issued and all - matching Bundle entries are returned. +class FhirPatientSearchInput(BaseModel): + """Input for searching / fetching multiple FHIR Patient resources from Epic.""" - Only one mode should be used per request. If ``resource_ids`` is set it - takes priority over the name/search fields. - """ + action: Literal["search_patients"] = "search_patients" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_ids: Optional[List[str]] = None - """List of Epic Patient IDs to fetch concurrently (e.g. ['eABC', 'eDEF']).""" - given_name: Optional[str] = None - """Patient given / first name.""" - family_name: Optional[str] = None - """Patient family / last name.""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter.""" - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format.""" - search_params: Optional[Dict[str, str]] = None - """Additional raw FHIR search parameters merged with the name fields.""" class FhirPatientSearchOutput(BaseModel): """Output for searching multiple FHIR Patient resources.""" resources: List[Dict[str, Any]] - """List of successfully retrieved FHIR Patient JSON objects.""" - total: Optional[int] = None - """Total number of matches reported by the server Bundle (name-search mode).""" - - errors: List[Dict[str, Any]] = [] - """Per-ID errors encountered during multi-ID fan-out. - - Each entry has the shape:: - - {"resource_id": "", "error": ""} - - An empty list means all lookups succeeded. - """ + errors: List[Dict[str, Any]] = Field(default_factory=list) # --------------------------------------------------------------------------- # Encounter – Search # --------------------------------------------------------------------------- + class FhirEncounterSearchInput(BaseModel): """Input for searching FHIR Encounter resources.""" - patient_id: Optional[str] = None - """FHIR Patient ID to find encounters for (maps to 'patient' FHIR param).""" + action: Literal["search_encounter"] = "search_encounter" + """Action discriminator (one endpoint, multiple actions pattern).""" + patient_id: Optional[str] = None status: Optional[str] = None - """Status of the encounters to find (e.g. 'finished', 'arrived').""" - date: Optional[str] = None - """Date or date range for the encounters (e.g. '2024', 'gt2023-01-01').""" - search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters. Used if explicit fields above are not provided.""" class FhirEncounterSearchOutput(BaseModel): @@ -141,9 +98,13 @@ class FhirEncounterSearchOutput(BaseModel): # DocumentReference – Create # --------------------------------------------------------------------------- + class FhirDocumentReferenceCreateInput(BaseModel): """Input for creating a FHIR DocumentReference resource.""" + action: Literal["create_document_reference"] = "create_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + identifier: list[Dict[str, Any]] """Document identifier.""" @@ -159,8 +120,17 @@ class FhirDocumentReferenceCreateInput(BaseModel): subject: str """Patient reference string (e.g. 'Patient/{id}'). Required by Epic.""" - data: str - """Base64-encoded document content. Required by Epic.""" + @field_validator("data", mode="after") + @classmethod + def validate_base64_data(cls, v: str) -> str: + try: + base64.b64decode(v, validate=True) + except Exception: + raise ValueError("data must be a valid base64-encoded string") + return v + + data: str = Field(..., max_length=10 * 1024 * 1024) + """Base64-encoded document content. Required by Epic. Max size 10MB.""" content_type: Optional[str] = None """MIME type of the document content (e.g. 'text/plain', 'application/pdf'). Defaults to 'text/plain'.""" @@ -210,11 +180,15 @@ class FhirDocumentReferenceCreateOutput(BaseModel): # DocumentReference – Search # --------------------------------------------------------------------------- + class FhirDocumentReferenceSearchInput(BaseModel): """Input for searching FHIR DocumentReference resources.""" + action: Literal["search_document_reference"] = "search_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + search_params: Dict[str, str] - """Search parameters (e.g. {\"patient\": \"eXYZ123\"}).""" + """Search parameters (e.g. {"patient": "eXYZ123"}).""" class FhirDocumentReferenceSearchOutput(BaseModel): @@ -224,4 +198,18 @@ class FhirDocumentReferenceSearchOutput(BaseModel): """The list of raw FHIR DocumentReference JSON objects found.""" total: Optional[int] = None - """Total number of results reported by the Bundle.""" \ No newline at end of file + """Total number of results reported by the Bundle.""" + + +class FhirEpicOperationOutput(BaseModel): + """ + Unified output for all Epic FHIR actions (BaseConnector single output_model). + + Fields are populated depending on the action; unused fields are None. + """ + + resource: Optional[Dict[str, Any]] = None + resources: Optional[list[Dict[str, Any]]] = None + total: Optional[int] = None + resource_id: Optional[str] = None + errors: Optional[list[Dict[str, Any]]] = None diff --git a/src/connectors/google_drive/README.md b/src/node_wire_google_drive/README.md similarity index 76% rename from src/connectors/google_drive/README.md rename to src/node_wire_google_drive/README.md index 3b44409..2a285a4 100644 --- a/src/connectors/google_drive/README.md +++ b/src/node_wire_google_drive/README.md @@ -1,8 +1,14 @@ + + # Google Drive Connector — Technical Documentation > **Platform:** Node Wire > **Connector ID:** `google_drive` -> **Endpoint:** `POST /connectors/google_drive/execute` +> **REST:** One route per operation, e.g. `POST /connectors/google_drive/files.list` (the `action` field is still set on the body for `BaseConnector` dispatch). > **Discriminator:** `action` field (discriminated-union payload) > **Source:** `connectors/google_drive/` @@ -10,7 +16,21 @@ ## 1. Operations Overview -All requests go through a single `execute` endpoint. The `action` field determines which Google Drive operation runs. All responses share a common output shape and error taxonomy enforced by the runtime. +The runtime validates requests against the discriminated union in `schema.py`, then dispatches to `@nw_action` handlers on `GoogleDriveConnector`. Each handler delegates to an **action spec** in `action_spec.py` that maps the validated model to the Google Drive API v3 client (`googleapiclient`). Shared concerns (thread offload, `HttpError` translation, logging) stay in `logic.py`. All responses share a common output shape and error taxonomy enforced by the runtime. + +### Action-spec layout + +| Piece | Role | +|-------|------| +| [`action_spec.py`](action_spec.py) | `GOOGLE_DRIVE_ACTION_SPECS`: per-action `SdkActionSpec` (resource path, method, field/body mapping, constants, optional `build_kwargs` / `post_process`). | +| [`logic.py`](logic.py) | Client build, `_translate_and_raise_http_error`, `_execute_action_spec`, thin `@nw_action` methods. | +| [`runtime/sdk_action_spec.py`](../../runtime/sdk_action_spec.py) | Reusable primitives: `SdkActionSpec`, `default_build_kwargs`, `execute_spec_in_thread`. | + +**Adding a new operation:** Add a Pydantic variant in `schema.py` (with an `action` discriminator literal), extend the `GoogleDriveOperationInput` union, and add an entry to `GOOGLE_DRIVE_ACTION_SPECS` in `action_spec.py` (or a `build_kwargs` hook for non-generic cases such as multipart upload). `BaseConnector.__init_subclass__` auto-generates the handler — do **not** also add an `@nw_action` method for the same action name, as that will raise a `TypeError` at class-definition time. + +### Migrating other connectors + +Use the same pattern: put declarative mapping in a connector-local `*_action_spec` module; `BaseConnector.__init_subclass__` auto-generates `@nw_action`-equivalent handlers from `action_specs`, so no manual `@nw_action` decorators are needed for spec-driven actions. Use `SdkActionSpec.build_kwargs` when the vendor API needs custom assembly (uploads, explicit `None` args, etc.). ### Available Operations diff --git a/src/node_wire_google_drive/__init__.py b/src/node_wire_google_drive/__init__.py new file mode 100644 index 0000000..9348710 --- /dev/null +++ b/src/node_wire_google_drive/__init__.py @@ -0,0 +1,6 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# + +# Connector subpackage: google_drive diff --git a/src/node_wire_google_drive/action_spec.py b/src/node_wire_google_drive/action_spec.py new file mode 100644 index 0000000..0dc4114 --- /dev/null +++ b/src/node_wire_google_drive/action_spec.py @@ -0,0 +1,214 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +Google Drive action specs: mapping from validated Pydantic inputs to Drive API v3 calls. + +Used by GoogleDriveConnector to reduce per-action boilerplate while preserving +behavior (defaults, field masks, shared drives flags). +""" + +from __future__ import annotations + +import base64 +from typing import Any, Callable, Dict, cast + +from googleapiclient.http import MediaInMemoryUpload +from pydantic import BaseModel + +from node_wire_runtime.mcp_normalizers import normalize_google_drive_files_upload +from node_wire_runtime.sdk_action_spec import SdkActionSpec + +from .schema import ( + FilesCreateOperation, + FilesDeleteOperation, + FilesGetOperation, + FilesListOperation, + FilesUpdateOperation, + FilesUploadOperation, + PermissionsCreateOperation, +) + +DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" + + +def _files_get_fields_kwarg(p: FilesGetOperation) -> str: + return p.fields or "id,name,mimeType,parents" + + +def _files_update_add_parents(p: FilesUpdateOperation) -> str | None: + return ",".join(p.add_parents) if p.add_parents else None + + +def _files_update_remove_parents(p: FilesUpdateOperation) -> str | None: + return ",".join(p.remove_parents) if p.remove_parents else None + + +# Action name -> SdkActionSpec (matches @nw_action("...") strings) +GOOGLE_DRIVE_ACTION_SPECS: Dict[str, SdkActionSpec] = {} + + +def _register_files_create() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.create"] = SdkActionSpec( + resource_segments=("files",), + method_name="create", + body_from_model={ + "name": "name", + "mime_type": "mimeType", + "parents": "parents", + }, + constant_kwargs={ + "fields": "id, name, webViewLink", + "supportsAllDrives": True, + }, + input_model=FilesCreateOperation, + ) + + +def _build_files_list_kwargs(_drive: Any, model: BaseModel) -> Dict[str, Any]: + """Match legacy behavior: pass q/pageToken explicitly even when None.""" + p = model if isinstance(model, FilesListOperation) else FilesListOperation.model_validate(model) + return { + "pageSize": p.page_size, + "q": p.query, + "fields": p.fields or DEFAULT_LIST_FIELDS, + "pageToken": p.page_token, + "supportsAllDrives": True, + "includeItemsFromAllDrives": True, + } + + +def _register_files_list() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.list"] = SdkActionSpec( + resource_segments=("files",), + method_name="list", + build_kwargs=_build_files_list_kwargs, + input_model=FilesListOperation, + ) + + +def _register_files_get() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.get"] = SdkActionSpec( + resource_segments=("files",), + method_name="get", + kwargs_from_model={"file_id": "fileId"}, + computed_kwargs=cast( + Dict[str, Callable[[BaseModel], Any]], + {"fields": _files_get_fields_kwarg}, + ), + constant_kwargs={"supportsAllDrives": True}, + input_model=FilesGetOperation, + ) + + +def _register_files_update() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.update"] = SdkActionSpec( + resource_segments=("files",), + method_name="update", + kwargs_from_model={"file_id": "fileId"}, + body_from_model={ + "name": "name", + "mime_type": "mimeType", + }, + computed_kwargs=cast( + Dict[str, Callable[[BaseModel], Any]], + { + "addParents": _files_update_add_parents, + "removeParents": _files_update_remove_parents, + }, + ), + constant_kwargs={"supportsAllDrives": True}, + include_empty_body=True, + input_model=FilesUpdateOperation, + ) + + +def _build_upload_kwargs(drive: Any, model: BaseModel) -> Dict[str, Any]: + params = ( + model + if isinstance(model, FilesUploadOperation) + else FilesUploadOperation.model_validate(model) + ) + body = { + k: v + for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() + if v is not None + } + if params.content_base64 is not None: + media_bytes = base64.b64decode(params.content_base64) + elif params.content is not None: + media_bytes = params.content.encode("utf-8") + else: + raise ValueError("Either content or content_base64 must be provided for files.upload") + media = MediaInMemoryUpload( + media_bytes, + mimetype=params.mime_type, + resumable=False, + ) + return { + "body": body, + "media_body": media, + "fields": "id, name, webViewLink", + "supportsAllDrives": True, + } + + +def _register_files_upload() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.upload"] = SdkActionSpec( + resource_segments=("files",), + method_name="create", + build_kwargs=_build_upload_kwargs, + input_model=FilesUploadOperation, + alias_tolerant=True, + mcp_normalize=normalize_google_drive_files_upload, + ) + + +def _register_files_delete() -> None: + def _post_delete(_result: Any, model: BaseModel) -> Dict[str, Any]: + file_id = getattr(model, "file_id", None) + return {"file_id": file_id, "status": "deleted"} + + GOOGLE_DRIVE_ACTION_SPECS["files.delete"] = SdkActionSpec( + resource_segments=("files",), + method_name="update", + kwargs_from_model={"file_id": "fileId"}, + body_constant={"trashed": True}, + constant_kwargs={"supportsAllDrives": True}, + post_process=_post_delete, + input_model=FilesDeleteOperation, + ) + + +def _register_permissions_create() -> None: + GOOGLE_DRIVE_ACTION_SPECS["permissions.create"] = SdkActionSpec( + resource_segments=("permissions",), + method_name="create", + kwargs_from_model={"file_id": "fileId"}, + body_from_model={ + "role": "role", + "type": "type", + "email_address": "emailAddress", + "domain": "domain", + }, + constant_kwargs={"supportsAllDrives": True}, + input_model=PermissionsCreateOperation, + ) + + +def _init_specs() -> None: + _register_files_create() + _register_files_list() + _register_files_get() + _register_files_update() + _register_files_upload() + _register_files_delete() + _register_permissions_create() + + +_init_specs() diff --git a/src/connectors/google_drive/exceptions.py b/src/node_wire_google_drive/exceptions.py similarity index 82% rename from src/connectors/google_drive/exceptions.py rename to src/node_wire_google_drive/exceptions.py index fd3a14b..d56de6d 100644 --- a/src/connectors/google_drive/exceptions.py +++ b/src/node_wire_google_drive/exceptions.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations diff --git a/src/node_wire_google_drive/logic.py b/src/node_wire_google_drive/logic.py new file mode 100644 index 0000000..7aa8a88 --- /dev/null +++ b/src/node_wire_google_drive/logic.py @@ -0,0 +1,134 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import logging +from typing import Any + +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError + +from node_wire_runtime import BaseConnector +from node_wire_runtime.models import ErrorCategory +from node_wire_runtime.sdk_action_spec import execute_spec_in_thread + +from .action_spec import DEFAULT_LIST_FIELDS, GOOGLE_DRIVE_ACTION_SPECS +from .exceptions import ( + GoogleDriveAuthError, + GoogleDriveBusinessError, + GoogleDriveFatalError, + GoogleDriveRateLimitError, +) +from .schema import GoogleDriveOperationOutput + +logger = logging.getLogger("connectors.google_drive") + +# Re-export for tests and callers that imported from logic. +__all__ = ["DEFAULT_LIST_FIELDS", "GoogleDriveConnector"] + + +class GoogleDriveConnector(BaseConnector): + """ + Google Drive connector: Drive API v3 operations are driven by action specs + (see action_spec.py) and thin @nw_action handlers for logging and dispatch. + """ + + connector_id = "google_drive" + action = "execute" + output_model = GoogleDriveOperationOutput + action_specs = GOOGLE_DRIVE_ACTION_SPECS + + error_map = { + GoogleDriveAuthError: (ErrorCategory.AUTH, "GDRIVE_AUTH"), + GoogleDriveRateLimitError: (ErrorCategory.RETRYABLE, "GDRIVE_RATE_LIMIT"), + GoogleDriveBusinessError: (ErrorCategory.BUSINESS, "GDRIVE_BUSINESS_RULE"), + GoogleDriveFatalError: (ErrorCategory.FATAL, "GDRIVE_FATAL"), + } + + def build_client(self) -> Any: + import asyncio + + # get_client_credentials() is async; run it synchronously here since + # build_client() is called from the synchronous get_client() accessor. + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # In an async context, we can't use run_until_complete. + # Instead, fetch credentials synchronously via the underlying + # ServiceAccountAuthProvider._build_credentials() pattern. + # This code path is reached during connector initialisation + # inside an async frame (e.g. in tests with pytest-asyncio). + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + creds = pool.submit( + lambda: asyncio.run(self._auth_provider.get_client_credentials()) + ).result() + else: + creds = loop.run_until_complete(self._auth_provider.get_client_credentials()) + except RuntimeError: + creds = asyncio.run(self._auth_provider.get_client_credentials()) + + if creds is None: + # Fallback for NoAuthProvider or unconfigured provider — + # attempt direct secret resolution for backward compatibility. + raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") + try: + from google.oauth2 import service_account # type: ignore[import] + import json as _json + + info = _json.loads(raw_sa) + creds = service_account.Credentials.from_service_account_info( + info, + scopes=["https://www.googleapis.com/auth/drive"], + ) + except Exception: + creds = service_account.Credentials.from_service_account_file( + raw_sa.strip(), + scopes=["https://www.googleapis.com/auth/drive"], + ) + + return build("drive", "v3", credentials=creds) + + def _translate_and_raise_http_error(self, exc: HttpError) -> None: + status = exc.resp.status + content_str = str(getattr(exc, "content", "") or "") + + if status in (401, 403): + if "quotaExceeded" in content_str or "rateLimitExceeded" in content_str: + raise GoogleDriveRateLimitError("Google Drive quota/rate limit exceeded") from exc + raise GoogleDriveAuthError("Authentication or permissions failure") from exc + + if status == 429 or status >= 500: + raise GoogleDriveRateLimitError("Upstream service unavailable or rate limited") from exc + + if status in (400, 404, 409): + reason = getattr(exc, "reason", str(exc)) + raise GoogleDriveBusinessError(f"Business logic failure: {reason}") from exc + + raise GoogleDriveFatalError(f"Unhandled HttpError status {status}") from exc + + async def _execute_action_spec( + self, + action_name: str, + params: Any, + *, + trace_id: str, + log_extra: dict[str, Any] | None = None, + ) -> GoogleDriveOperationOutput: + spec = GOOGLE_DRIVE_ACTION_SPECS.get(action_name) + if spec is None: + raise ValueError(f"No action spec registered for {action_name!r}") + drive = self.get_client() + extra = {"trace_id": trace_id, **(log_extra or {})} + logger.info("Google Drive %s", action_name, extra=extra) + try: + raw = await execute_spec_in_thread(drive, spec, params) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=raw, + description=f"Successfully executed {action_name}", + ) diff --git a/src/connectors/google_drive/registration.py b/src/node_wire_google_drive/registration.py similarity index 78% rename from src/connectors/google_drive/registration.py rename to src/node_wire_google_drive/registration.py index b0566e8..a93eb2c 100644 --- a/src/connectors/google_drive/registration.py +++ b/src/node_wire_google_drive/registration.py @@ -1,6 +1,10 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations -from runtime import ErrorCategory, ErrorMapper +from node_wire_runtime import ErrorCategory, ErrorMapper from .exceptions import ( GoogleDriveAuthError, diff --git a/src/connectors/google_drive/schema.py b/src/node_wire_google_drive/schema.py similarity index 62% rename from src/connectors/google_drive/schema.py rename to src/node_wire_google_drive/schema.py index a2f22e8..8f3f5d2 100644 --- a/src/connectors/google_drive/schema.py +++ b/src/node_wire_google_drive/schema.py @@ -1,8 +1,12 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations from typing import Annotated, Any, Dict, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator +from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator class BaseDriveOperation(BaseModel): @@ -11,9 +15,6 @@ class BaseDriveOperation(BaseModel): model_config = ConfigDict(extra="forbid") -# --- Specific Operation Schemas --- - - class FilesCreateOperation(BaseDriveOperation): action: Literal["files.create"] name: str = Field(..., description="The name of the file.") @@ -23,7 +24,15 @@ class FilesCreateOperation(BaseDriveOperation): class FilesListOperation(BaseDriveOperation): action: Literal["files.list"] - page_size: int = Field(10, ge=1, le=100) + page_size: Optional[int] = Field( + 10, ge=1, le=100, description="Do not send null; omit if unsure." + ) + + @field_validator("page_size", mode="before") + @classmethod + def _default_page_size(cls, v: Any) -> int: + return 10 if v is None else int(v) + query: Optional[str] = Field(None, description="Search query string.") fields: Optional[str] = Field( None, @@ -32,6 +41,10 @@ class FilesListOperation(BaseDriveOperation): "uses a performant default: nextPageToken, files(id, name, mimeType, webViewLink)." ), ) + page_token: Optional[str] = Field( + None, + description="Token for the next page of results from a previous files.list response.", + ) class PermissionsCreateOperation(BaseDriveOperation): @@ -42,6 +55,13 @@ class PermissionsCreateOperation(BaseDriveOperation): type: Literal["user", "group", "domain", "anyone"] domain: Optional[str] = Field(None, description="G Suite domain when type is domain.") + @field_validator("email_address", "domain", mode="before") + @classmethod + def _empty_str_to_none(cls, v: Any) -> Any: + if isinstance(v, str) and not v.strip(): + return None + return v + @model_validator(mode="after") def require_fields_for_perm_type(self) -> "PermissionsCreateOperation": if self.type in ("user", "group"): @@ -58,9 +78,7 @@ class FilesGetOperation(BaseDriveOperation): file_id: str fields: Optional[str] = Field( None, - description=( - "Optional fields mask; if omitted, a safe default is used by the connector." - ), + description=("Optional fields mask; if omitted, a safe default is used by the connector."), ) @@ -83,7 +101,20 @@ class FilesUploadOperation(BaseDriveOperation): mime_type: str = Field(..., description="The MIME type of the file content.") parents: Optional[list[str]] = Field(None, description="List of parent folder IDs.") content: Optional[str] = Field(None, description="UTF-8 text content to upload.") - content_base64: Optional[str] = Field(None, description="Base64 encoded binary content to upload.") + content_base64: Optional[str] = Field( + None, description="Base64 encoded binary content to upload." + ) + + @model_validator(mode="after") + def exactly_one_of_content_or_base64(self) -> "FilesUploadOperation": + """Match Drive upload semantics: exactly one body source (aligned with action_spec).""" + has_text = self.content is not None + has_b64 = self.content_base64 is not None + if not has_text and not has_b64: + raise ValueError("Provide exactly one of 'content' or 'content_base64'.") + if has_text and has_b64: + raise ValueError("Provide exactly one of 'content' or 'content_base64', not both.") + return self class FilesDeleteOperation(BaseDriveOperation): @@ -91,11 +122,7 @@ class FilesDeleteOperation(BaseDriveOperation): file_id: str -# --- The Envelope --- -# The runtime validates against this single type. Pydantic automatically -# routes the validation to the correct sub-model based on the "action" field. -# RootModel accepts **raw_input in __init__ so BaseConnector's _input_model_cls(**raw_input) works. -_OperationUnion = Annotated[ +_GoogleDriveOperationUnion = Annotated[ Union[ FilesCreateOperation, FilesListOperation, @@ -108,9 +135,10 @@ class FilesDeleteOperation(BaseDriveOperation): Field(discriminator="action"), ] -GoogleDriveOperationInput = RootModel[_OperationUnion] +# Discriminated union for tests/agents; must stay aligned with GoogleDriveConnector @nw_action set. +GoogleDriveOperationInput = RootModel[_GoogleDriveOperationUnion] class GoogleDriveOperationOutput(BaseModel): raw: Dict[str, Any] - description: str \ No newline at end of file + description: str diff --git a/src/node_wire_http_generic/__init__.py b/src/node_wire_http_generic/__init__.py new file mode 100644 index 0000000..a76ac90 --- /dev/null +++ b/src/node_wire_http_generic/__init__.py @@ -0,0 +1,6 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# + +# Connector subpackage: http_generic diff --git a/src/connectors/http_generic/logic.py b/src/node_wire_http_generic/logic.py similarity index 60% rename from src/connectors/http_generic/logic.py rename to src/node_wire_http_generic/logic.py index 88afc67..f2400b8 100644 --- a/src/connectors/http_generic/logic.py +++ b/src/node_wire_http_generic/logic.py @@ -1,47 +1,71 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import logging +import os from typing import Any +from urllib.parse import urlsplit, urlunsplit import httpx -from runtime import BaseConnector +from node_wire_runtime import BaseConnector, nw_action from .schema import HttpRequestInput, HttpResponseOutput logger = logging.getLogger("connectors.http_generic") -class HttpGenericConnector(BaseConnector[HttpRequestInput, HttpResponseOutput]): +def _sanitize_url_for_log(raw_url: str) -> str: + """ + Remove query and fragment from URLs before logging to avoid leaking tokens/PII. + """ + try: + parsed = urlsplit(raw_url) + host = parsed.hostname or "" + if ":" in host and not host.startswith("["): + host = f"[{host}]" + netloc = host + if parsed.port is not None: + netloc = f"{netloc}:{parsed.port}" + return urlunsplit((parsed.scheme, netloc, parsed.path, "", "")) + except Exception: # noqa: BLE001 + return "" + + +class HttpGenericConnector(BaseConnector): """ Lightweight HTTP connector for generic REST integrations. """ connector_id = "http_generic" - action = "request" + output_model = HttpResponseOutput - async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> HttpResponseOutput: + @nw_action("request") + async def request(self, params: HttpRequestInput, *, trace_id: str) -> HttpResponseOutput: """ Perform an HTTP request using httpx. All potential network errors are raised and mapped by the runtime's ErrorMapper, with detailed, human-readable logs at the connector level. """ + safe_url = _sanitize_url_for_log(str(params.url)) logger.info( "Preparing HTTP request", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "request", "method": params.method, - "url": str(params.url), + "url": safe_url, }, ) - print(f"trace_id: {trace_id} from node-wire-connector") - try: - async with httpx.AsyncClient() as client: + timeout = float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")) + async with httpx.AsyncClient(timeout=timeout, trust_env=False) as client: response = await client.request( method=params.method, url=str(params.url), @@ -49,7 +73,7 @@ async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> params=params.params, json=params.body if isinstance(params.body, (dict, list)) else None, content=None if isinstance(params.body, (dict, list)) else params.body, - timeout=30.0, + timeout=timeout, ) except Exception as exc: # noqa: BLE001 # Let ErrorMapper classify the exception, but log clear context here. @@ -58,11 +82,11 @@ async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "request", "method": params.method, - "url": str(params.url), + "url": safe_url, "error_type": type(exc).__name__, - "message": str(exc), + "error_message": str(exc), }, ) raise @@ -72,9 +96,9 @@ async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "request", "method": params.method, - "url": str(params.url), + "url": safe_url, "status_code": response.status_code, }, ) @@ -87,4 +111,3 @@ async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> headers=headers, body=response.text, ) - diff --git a/src/connectors/http_generic/registration.py b/src/node_wire_http_generic/registration.py similarity index 84% rename from src/connectors/http_generic/registration.py rename to src/node_wire_http_generic/registration.py index 7444317..361c199 100644 --- a/src/connectors/http_generic/registration.py +++ b/src/node_wire_http_generic/registration.py @@ -1,8 +1,12 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import httpx -from runtime import ErrorCategory, ErrorMapper +from node_wire_runtime import ErrorCategory, ErrorMapper # Typical HTTP/network error mappings for the generic HTTP connector. @@ -15,4 +19,3 @@ # HTTP status errors are treated as BUSINESS by default; bindings may translate status_code further. ErrorMapper.register(httpx.HTTPStatusError, ErrorCategory.BUSINESS, code="HTTP_STATUS_ERROR") - diff --git a/src/node_wire_http_generic/schema.py b/src/node_wire_http_generic/schema.py new file mode 100644 index 0000000..a365cac --- /dev/null +++ b/src/node_wire_http_generic/schema.py @@ -0,0 +1,69 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import ipaddress +from typing import Any, Dict, Literal, Optional +from urllib.parse import urlsplit + +from pydantic import BaseModel, HttpUrl, field_validator + +_ALLOWED_METHODS = {"GET", "POST", "PUT", "PATCH", "DELETE"} +_BLOCKED_HOSTNAMES = { + "localhost", + "metadata.google.internal", + "metadata", +} + + +class HttpRequestInput(BaseModel): + action: Literal["request"] = "request" + url: HttpUrl + method: str + headers: Optional[Dict[str, str]] = None + params: Optional[Dict[str, str]] = None + body: Optional[Any] = None + + @field_validator("method", mode="before") + @classmethod + def normalize_and_validate_method(cls, value: Any) -> Any: + if not isinstance(value, str): + raise ValueError("method must be a string") + normalized = value.strip().upper() + if normalized not in _ALLOWED_METHODS: + raise ValueError(f"method must be one of: {', '.join(sorted(_ALLOWED_METHODS))}") + return normalized + + @field_validator("url") + @classmethod + def block_internal_targets(cls, value: HttpUrl) -> HttpUrl: + parts = urlsplit(str(value)) + host = (parts.hostname or "").strip().lower().rstrip(".") + if host in _BLOCKED_HOSTNAMES: + raise ValueError("url host is blocked by outbound security policy") + if _is_blocked_ip_literal(host): + raise ValueError("url host resolves to a blocked network target") + return value + + +def _is_blocked_ip_literal(host: str) -> bool: + try: + ip_obj = ipaddress.ip_address(host) + except ValueError: + return False + if ip_obj.is_loopback or ip_obj.is_private or ip_obj.is_link_local: + return True + if ip_obj.is_multicast or ip_obj.is_reserved or ip_obj.is_unspecified: + return True + # Explicit cloud metadata target. + if str(ip_obj) == "169.254.169.254": + return True + return False + + +class HttpResponseOutput(BaseModel): + status_code: int + headers: Dict[str, str] + body: Any diff --git a/src/node_wire_runtime/__init__.py b/src/node_wire_runtime/__init__.py new file mode 100644 index 0000000..1beaa6e --- /dev/null +++ b/src/node_wire_runtime/__init__.py @@ -0,0 +1,67 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from .models import ConnectorResponse, ErrorCategory +from .errors import ErrorMapper +from .secrets import SecretProvider, EnvSecretProvider, SecretNotFoundError, SecretProviderError +from .policy import PolicyHook, PolicyDenied +from .caller_identity import CallerIdentity, build_caller_identity +from .auth import ( + AuthProvider, + NoAuthProvider, + StaticTokenAuthProvider, + OAuth2AuthProvider, + ServiceAccountAuthProvider, +) +from .base_connector import ( + BaseConnector, + NestedConnectorActionError, + nw_action, + sdk_action, + _CONNECTOR_REGISTRY, +) +from .sdk_action_spec import ( + SdkActionSpec, + default_build_kwargs, + execute_spec_in_thread, + navigate_resource, +) +from .streaming import ( + StreamSignal, + stream_completion_log, + resolve_stream_buffer_ms, + BufferedStreamIterator, +) + +__all__ = [ + "ConnectorResponse", + "ErrorCategory", + "ErrorMapper", + "SecretProvider", + "EnvSecretProvider", + "SecretNotFoundError", + "SecretProviderError", + "PolicyHook", + "PolicyDenied", + "CallerIdentity", + "build_caller_identity", + "AuthProvider", + "NoAuthProvider", + "StaticTokenAuthProvider", + "OAuth2AuthProvider", + "ServiceAccountAuthProvider", + "BaseConnector", + "NestedConnectorActionError", + "sdk_action", + "nw_action", + "_CONNECTOR_REGISTRY", + "SdkActionSpec", + "default_build_kwargs", + "execute_spec_in_thread", + "navigate_resource", + "StreamSignal", + "stream_completion_log", + "resolve_stream_buffer_ms", + "BufferedStreamIterator", +] diff --git a/src/node_wire_runtime/auth/__init__.py b/src/node_wire_runtime/auth/__init__.py new file mode 100644 index 0000000..71ce353 --- /dev/null +++ b/src/node_wire_runtime/auth/__init__.py @@ -0,0 +1,47 @@ +""" +node_wire_runtime.auth +======================= + +Pluggable authentication layer for Node Wire connectors. + +All providers implement :class:`AuthProvider` and are safe to inject into any +:class:`~node_wire_runtime.base_connector.BaseConnector` subclass via the +``auth_provider=`` constructor argument. + +Available providers +------------------- +NoAuthProvider + Null-object — returns empty headers. Default when no ``auth:`` block is + present in ``connectors.yaml``. + +StaticTokenAuthProvider + Reads a single secret via :class:`~node_wire_runtime.secrets.SecretProvider` + and injects it as ``Authorization: Bearer `` (or a custom header). + Optionally base64-encodes the value for HTTP Basic auth. + +OAuth2AuthProvider + Fetches and caches OAuth 2.0 access tokens (Client Credentials grant). + Supports ``private_key_jwt`` (SMART Backend Services / Epic / Cerner) and + ``client_secret_post``. Uses ``asyncio.Lock`` to prevent concurrent + token-refresh storms. + +ServiceAccountAuthProvider + Resolves a Google service-account JSON secret and returns + ``google.oauth2.service_account.Credentials`` via + :meth:`~AuthProvider.get_client_credentials`. Used by the Google Drive + connector; returns empty HTTP headers. +""" + +from .base import AuthProvider +from .no_auth import NoAuthProvider +from .oauth2 import OAuth2AuthProvider +from .service_account import ServiceAccountAuthProvider +from .static_token import StaticTokenAuthProvider + +__all__ = [ + "AuthProvider", + "NoAuthProvider", + "StaticTokenAuthProvider", + "OAuth2AuthProvider", + "ServiceAccountAuthProvider", +] diff --git a/src/node_wire_runtime/auth/base.py b/src/node_wire_runtime/auth/base.py new file mode 100644 index 0000000..0792341 --- /dev/null +++ b/src/node_wire_runtime/auth/base.py @@ -0,0 +1,64 @@ +""" +node_wire_runtime.auth.base +============================ + +Abstract base class for all authentication providers. + +All authentication falls into two categories: + + Static Credentials — fixed secrets (API keys, service-account JSON, SMTP passwords). + Dynamic Tokens — short-lived tokens that must be fetched and cached. + +Both are unified behind this interface so connectors stay credential-agnostic. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class AuthProvider(ABC): + """ + Abstract port for connector authentication. + + Connectors call :meth:`get_headers` to obtain ready-to-use HTTP request + headers and :meth:`get_client_credentials` when they need SDK-level + objects (e.g. ``google.oauth2.credentials.Credentials``). + + Override :meth:`refresh` to force credential renewal after a 401. + """ + + @abstractmethod + async def get_headers(self) -> Dict[str, str]: + """ + Return a dict of HTTP headers required to authenticate the request. + + For bearer-token flows this is ``{"Authorization": "Bearer "}``. + For connectors that authenticate at the SDK level (e.g. Google Drive), + this may return an empty dict. + + Implementations are responsible for caching and refreshing tokens + transparently; callers must not cache the result themselves. + """ + raise NotImplementedError + + async def get_client_credentials(self) -> Any: + """ + Return SDK-level credentials (e.g. ``google.oauth2.Credentials``). + + The default implementation returns ``None``; override in providers + that need to supply credentials to vendor SDKs rather than HTTP headers. + """ + return None + + async def refresh(self) -> None: + """ + Force a refresh of any cached credentials on the next call. + + The default is a no-op; override in providers that maintain a cache + (e.g. :class:`~node_wire_runtime.auth.oauth2.OAuth2AuthProvider`). + + Call this after receiving a 401/403 to ensure the next request uses + freshly-issued credentials. + """ diff --git a/src/node_wire_runtime/auth/no_auth.py b/src/node_wire_runtime/auth/no_auth.py new file mode 100644 index 0000000..c0bd35a --- /dev/null +++ b/src/node_wire_runtime/auth/no_auth.py @@ -0,0 +1,34 @@ +""" +node_wire_runtime.auth.no_auth +================================ + +Null-object implementation of :class:`AuthProvider`. + +Returns empty headers and ``None`` credentials. Acts as the safe default +when no ``auth:`` block is present in ``connectors.yaml``, ensuring connectors +never receive ``None`` and never need to guard against an unconfigured provider. +""" + +from __future__ import annotations + +from typing import Any, Dict + +from .base import AuthProvider + + +class NoAuthProvider(AuthProvider): + """ + No-op authentication provider. + + Suitable for connectors that handle auth at the SDK level in a custom + ``build_client()`` override, or for internal/localhost endpoints that + require no credentials. + """ + + async def get_headers(self) -> Dict[str, str]: + """Return an empty dict — no auth headers injected.""" + return {} + + async def get_client_credentials(self) -> Any: + """Return ``None`` — no SDK credentials required.""" + return None diff --git a/src/node_wire_runtime/auth/oauth2.py b/src/node_wire_runtime/auth/oauth2.py new file mode 100644 index 0000000..8f0842f --- /dev/null +++ b/src/node_wire_runtime/auth/oauth2.py @@ -0,0 +1,368 @@ +""" +node_wire_runtime.auth.oauth2 +================================ + +:class:`OAuth2AuthProvider` — implements the OAuth 2.0 Client Credentials +grant with two assertion methods: + + ``private_key_jwt`` — SMART Backend Services / RS384 JWT-bearer assertion. + Used by Epic and Cerner FHIR endpoints. + ``client_secret_post`` — standard ``client_id`` + ``client_secret`` POST body. + +Security design: + - Tokens are cached in memory using the ``expires_in`` value from the + token response, minus a configurable buffer (default 60 s). + - An ``asyncio.Lock`` serialises concurrent refresh calls to prevent + the thundering-herd problem under high concurrency. + - ``refresh()`` clears the cache so callers can force re-issue after a 401. + - Private keys, client IDs, and token URLs are resolved at call-time via + :class:`~node_wire_runtime.secrets.SecretProvider` so they are never held + in plain text in config files. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from typing import Any, Dict, List, Optional + +import httpx +import jwt + +from node_wire_runtime.secrets import SecretProvider + +from .base import AuthProvider + +logger = logging.getLogger("runtime.auth.oauth2") + +_DEFAULT_BUFFER_SECS = 60 +_DEFAULT_TOKEN_TTL_SECS = 3600 # fallback when expires_in absent from response + + +class OAuth2AuthProvider(AuthProvider): + """ + OAuth 2.0 Client Credentials provider with token caching. + + Parameters + ---------- + secret_provider: + Runtime :class:`SecretProvider` used to resolve all secret references. + grant_method: + ``"private_key_jwt"`` (default) or ``"client_secret_post"``. + token_url_secret: + Secret key whose value is the token endpoint URL. + client_id_secret: + Secret key whose value is ``client_id``. + algorithm: + JWT signing algorithm. Default: ``"RS384"`` (required by SMART). + private_key_secret: + *(private_key_jwt only)* Secret key for the PEM private key. + kid_secret: + *(private_key_jwt only)* Secret key for the JWT ``kid`` header. + client_secret_secret: + *(client_secret_post only)* Secret key for ``client_secret``. + scopes: + List of OAuth scopes. If ``None``, no ``scope`` param is sent. + scopes_secret: + Alternative: secret key whose value is a space-separated scope string. + Overrides ``scopes`` if set. + extra_content_type_headers: + Additional fixed headers merged into the response (e.g. FHIR content-type). + Default: ``{"Content-Type": "application/fhir+json", "Accept": "application/fhir+json"}``. + buffer_secs: + Seconds before ``expires_in`` to treat the token as expired. + Default: 60. + jwt_ttl_secs: + Lifetime of the JWT assertion in seconds. Default: 300. + """ + + def __init__( + self, + *, + secret_provider: SecretProvider, + grant_method: str = "private_key_jwt", + token_url_secret: str, + client_id_secret: str, + algorithm: str = "RS384", + private_key_secret: Optional[str] = None, + kid_secret: Optional[str] = None, + client_secret_secret: Optional[str] = None, + refresh_token_secret: Optional[str] = None, + scopes: Optional[List[str]] = None, + scopes_secret: Optional[str] = None, + extra_content_type_headers: Optional[Dict[str, str]] = None, + buffer_secs: int = _DEFAULT_BUFFER_SECS, + jwt_ttl_secs: int = 300, + ) -> None: + if grant_method not in ("private_key_jwt", "client_secret_post", "refresh_token"): + raise ValueError( + f"Unsupported grant_method {grant_method!r}. " + "Use 'private_key_jwt', 'client_secret_post', or 'refresh_token'." + ) + self._sp = secret_provider + self._grant_method = grant_method + self._token_url_secret = token_url_secret + self._client_id_secret = client_id_secret + self._algorithm = algorithm + self._private_key_secret = private_key_secret + self._kid_secret = kid_secret + self._client_secret_secret = client_secret_secret + self._refresh_token_secret = refresh_token_secret + + self._static_scopes = scopes + self._scopes_secret = scopes_secret + self._extra_headers: Dict[str, str] = ( + extra_content_type_headers + if extra_content_type_headers is not None + else { + "Content-Type": "application/fhir+json", + "Accept": "application/fhir+json", + } + ) + self._buffer_secs = buffer_secs + self._jwt_ttl_secs = jwt_ttl_secs + + # Cache state — protected by _lock. + self._access_token: Optional[str] = None + self._expires_at: float = 0.0 + self._lock: asyncio.Lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + async def get_headers(self) -> Dict[str, str]: + """ + Return ``Authorization: Bearer `` plus any extra fixed headers. + + The token is fetched from the upstream IdP only when necessary + (first call, expiry, or after an explicit :meth:`refresh` call). + Concurrent callers block on an ``asyncio.Lock`` so only one HTTP + request is issued per refresh cycle. + """ + token = await self._get_or_refresh_token() + headers: Dict[str, str] = {"Authorization": f"Bearer {token}"} + headers.update(self._extra_headers) + return headers + + async def refresh(self) -> None: + """ + Invalidate the cached token. + + Call this after receiving a 401/403 so the next :meth:`get_headers` + call fetches a fresh token instead of reusing the (now-rejected) one. + """ + async with self._lock: + logger.debug("OAuth2AuthProvider: cache invalidated by refresh()") + self._access_token = None + self._expires_at = 0.0 + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _is_valid(self) -> bool: + """Return True if the cached token is still within its TTL window.""" + return ( + self._access_token is not None + and time.monotonic() < self._expires_at - self._buffer_secs + ) + + async def _get_or_refresh_token(self) -> str: + """ + Return a valid access token, fetching one if necessary. + + Uses a double-checked locking pattern: + 1. Fast-path: check validity outside the lock (no contention). + 2. Acquire lock and re-check (another coroutine may have refreshed). + 3. Fetch if still invalid. + """ + # Fast path — no lock, no contention. + if self._is_valid(): + return self._access_token # type: ignore[return-value] + + async with self._lock: + # Re-check after acquiring the lock. + if self._is_valid(): + return self._access_token # type: ignore[return-value] + + logger.debug( + "OAuth2AuthProvider: fetching new access token", + extra={"grant_method": self._grant_method}, + ) + token_data = await self._fetch_token() + + access_token = token_data.get("access_token") + if not access_token: + raise ValueError( + "OAuth2 token response did not contain an 'access_token'. " + f"Response keys: {list(token_data.keys())}" + ) + + expires_in = int(token_data.get("expires_in", _DEFAULT_TOKEN_TTL_SECS)) + self._access_token = access_token + self._expires_at = time.monotonic() + expires_in + + logger.debug( + "OAuth2AuthProvider: token cached", + extra={"expires_in": expires_in, "buffer_secs": self._buffer_secs}, + ) + return self._access_token + + def _resolve_scopes(self) -> Optional[str]: + """Resolve the scope string from secret or static list. Returns None if absent.""" + if self._scopes_secret: + try: + val = self._sp.get_secret(self._scopes_secret) + if val and val.strip(): + return val.strip() + except Exception: + pass + if self._static_scopes: + return " ".join(self._static_scopes) + return None + + async def _fetch_token(self) -> Dict[str, Any]: + """Dispatch to the appropriate grant method implementation.""" + if self._grant_method == "private_key_jwt": + return await self._fetch_private_key_jwt() + if self._grant_method == "refresh_token": + return await self._fetch_refresh_token() + return await self._fetch_client_secret_post() + + async def _fetch_refresh_token(self) -> Dict[str, Any]: + """Exchange refresh_token for a new access token.""" + if not self._refresh_token_secret: + raise ValueError( + "OAuth2AuthProvider (refresh_token): 'refresh_token_secret' must be configured." + ) + + client_id = self._sp.get_secret(self._client_id_secret) + client_secret = ( + self._sp.get_secret(self._client_secret_secret) if self._client_secret_secret else None + ) + refresh_token = self._sp.get_secret(self._refresh_token_secret) + token_url = self._sp.get_secret(self._token_url_secret) + + post_data: Dict[str, str] = { + "grant_type": "refresh_token", + "client_id": client_id, + "refresh_token": refresh_token, + } + if client_secret: + post_data["client_secret"] = client_secret + + scope = self._resolve_scopes() + if scope: + post_data["scope"] = scope + + logger.debug( + "OAuth2AuthProvider: refresh_token token request", + extra={"token_url": token_url}, + ) + return await self._post_token(token_url, post_data) + + async def _fetch_private_key_jwt(self) -> Dict[str, Any]: + """ + Exchange a signed JWT assertion for an access token. + + Follows RFC 7523 / SMART Backend Services specification. + """ + if not self._private_key_secret or not self._kid_secret: + raise ValueError( + "OAuth2AuthProvider (private_key_jwt): " + "both 'private_key_secret' and 'kid_secret' must be configured." + ) + + private_key_raw = self._sp.get_secret(self._private_key_secret) + kid = self._sp.get_secret(self._kid_secret) + client_id = self._sp.get_secret(self._client_id_secret) + token_url = self._sp.get_secret(self._token_url_secret) + + # Normalise PEM keys stored as single-line env vars with escaped newlines. + private_key_pem = ( + private_key_raw.replace("\\n", "\n") if "\\n" in private_key_raw else private_key_raw + ) + + now = int(time.time()) + claims: Dict[str, Any] = { + "iss": client_id, + "sub": client_id, + "aud": token_url, + "jti": str(uuid.uuid4()), + "iat": now, + "nbf": now, + "exp": now + self._jwt_ttl_secs, + } + + scope = self._resolve_scopes() + if scope: + claims["scope"] = scope + + jwt_token = jwt.encode( + claims, + private_key_pem, + algorithm=self._algorithm, + headers={"alg": self._algorithm, "typ": "JWT", "kid": kid}, + ) + + post_data: Dict[str, str] = { + "grant_type": "client_credentials", + "client_assertion_type": ("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), + "client_assertion": jwt_token, + } + if scope: + post_data["scope"] = scope + + logger.debug( + "OAuth2AuthProvider: private_key_jwt token request", + extra={"token_url": token_url, "client_id": client_id}, + ) + return await self._post_token(token_url, post_data) + + async def _fetch_client_secret_post(self) -> Dict[str, Any]: + """Exchange client_id + client_secret for an access token.""" + if not self._client_secret_secret: + raise ValueError( + "OAuth2AuthProvider (client_secret_post): " + "'client_secret_secret' must be configured." + ) + + client_id = self._sp.get_secret(self._client_id_secret) + client_secret = self._sp.get_secret(self._client_secret_secret) + token_url = self._sp.get_secret(self._token_url_secret) + + post_data: Dict[str, str] = { + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + } + scope = self._resolve_scopes() + if scope: + post_data["scope"] = scope + + logger.debug( + "OAuth2AuthProvider: client_secret_post token request", + extra={"token_url": token_url}, + ) + return await self._post_token(token_url, post_data) + + @staticmethod + async def _post_token(token_url: str, data: Dict[str, str]) -> Dict[str, Any]: + """POST to the token endpoint and return the parsed JSON body.""" + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + logger.error( + "OAuth2 token request failed | status=%s | body=%s", + response.status_code, + response.text, + ) + response.raise_for_status() + return response.json() # type: ignore[no-any-return] diff --git a/src/node_wire_runtime/auth/service_account.py b/src/node_wire_runtime/auth/service_account.py new file mode 100644 index 0000000..fde0bd2 --- /dev/null +++ b/src/node_wire_runtime/auth/service_account.py @@ -0,0 +1,130 @@ +""" +node_wire_runtime.auth.service_account +========================================= + +:class:`ServiceAccountAuthProvider` — Google service-account authentication. + +Instead of injecting HTTP headers, this provider supplies a +``google.oauth2.credentials.Credentials`` object via +:meth:`get_client_credentials`, which vendor SDKs (e.g. ``google-api-python-client``) +consume directly. + +The service-account JSON is stored as a secret and resolved at first use. +Credentials are refreshed automatically by the Google auth library when they +expire; this provider does not implement its own TTL. + +This provider intentionally returns an empty dict from :meth:`get_headers` because +Google Drive authentication is handled at the SDK level, not via HTTP headers set +by the connector. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Optional + +from node_wire_runtime.secrets import SecretProvider + +from .base import AuthProvider + +logger = logging.getLogger("runtime.auth.service_account") + + +class ServiceAccountAuthProvider(AuthProvider): + """ + Google service-account credentials provider. + + Parameters + ---------- + secret_provider: + Runtime :class:`SecretProvider` used to resolve the service-account JSON. + sa_json_secret: + Secret key whose value is either: + - A JSON string containing the full service-account key, **or** + - A filesystem path to the service-account JSON file. + scopes: + List of OAuth2 scopes to request. Default: + ``["https://www.googleapis.com/auth/drive"]``. + """ + + def __init__( + self, + *, + secret_provider: SecretProvider, + sa_json_secret: str, + scopes: Optional[List[str]] = None, + ) -> None: + self._sp = secret_provider + self._sa_json_secret = sa_json_secret + self._scopes = scopes or ["https://www.googleapis.com/auth/drive"] + self._credentials: Any = None + + def _build_credentials(self) -> Any: + """ + Resolve the service-account secret and return a + ``google.oauth2.service_account.Credentials`` object. + + Supports both inline JSON and a file-path fallback for local development. + The import is deferred so that packages without the Google libraries + installed do not fail at import time. + """ + try: + from google.oauth2 import service_account # type: ignore[import] + except ImportError as exc: + raise ImportError( + "ServiceAccountAuthProvider requires 'google-auth'. " + "Install it with: pip install google-auth" + ) from exc + + raw = self._sp.get_secret(self._sa_json_secret) + try: + info = json.loads(raw) + creds = service_account.Credentials.from_service_account_info(info, scopes=self._scopes) + except (json.JSONDecodeError, ValueError): + # Fallback: treat the secret value as a file path. + creds = service_account.Credentials.from_service_account_file( + raw.strip(), scopes=self._scopes + ) + + logger.debug( + "ServiceAccountAuthProvider: credentials built", + extra={ + "sa_json_secret": self._sa_json_secret, + "scopes": self._scopes, + }, + ) + return creds + + async def get_headers(self) -> Dict[str, str]: + """ + Return an empty dict. + + Google Drive auth is handled at the SDK level via + :meth:`get_client_credentials`; no HTTP headers are injected. + """ + return {} + + async def get_client_credentials(self) -> Any: + """ + Return a ``google.oauth2.service_account.Credentials`` instance. + + The credentials object is built once and cached for the lifetime of + this provider. The Google auth library handles token refresh internally. + """ + if self._credentials is None: + self._credentials = self._build_credentials() + return self._credentials + + async def refresh(self) -> None: + """ + Invalidate the cached credentials object. + + Forces :meth:`get_client_credentials` to rebuild from the secret on the + next call, picking up any rotated service-account JSON. + """ + logger.debug( + "ServiceAccountAuthProvider: credentials cache invalidated", + extra={"sa_json_secret": self._sa_json_secret}, + ) + self._credentials = None diff --git a/src/node_wire_runtime/auth/static_token.py b/src/node_wire_runtime/auth/static_token.py new file mode 100644 index 0000000..1b30fd1 --- /dev/null +++ b/src/node_wire_runtime/auth/static_token.py @@ -0,0 +1,96 @@ +""" +node_wire_runtime.auth.static_token +====================================== + +:class:`StaticTokenAuthProvider` — reads a single secret via +:class:`~node_wire_runtime.secrets.SecretProvider` and injects it as an HTTP +request header. + +Suitable for: + - API-key authentication (e.g. Stripe, generic HTTP connectors) + - Pre-issued bearer tokens that do not expire + - HTTP Basic authentication (set ``encoding="base64"``) + +The secret is fetched **once** and held in memory for the lifetime of the +provider instance. Because these secrets are long-lived and do not expire, no +TTL or refresh mechanism is implemented — tear down and recreate the provider +if the secret is rotated. +""" + +from __future__ import annotations + +import base64 +import logging +from typing import Dict, Optional + +from node_wire_runtime.secrets import SecretProvider + +from .base import AuthProvider + +logger = logging.getLogger("runtime.auth.static_token") + + +class StaticTokenAuthProvider(AuthProvider): + """ + Injects a static secret as an HTTP ``Authorization`` (or custom) header. + + Parameters + ---------- + secret_provider: + The runtime :class:`SecretProvider` used to resolve secrets. + secret_key: + The secret key passed to ``secret_provider.get_secret()``. + header_name: + The HTTP header to set. Default: ``"Authorization"``. + prefix: + String prepended to the secret value (with a space separator). + Pass ``""`` for raw injection (e.g. some proprietary API-key headers). + Default: ``"Bearer"``. + encoding: + Optional encoding applied to the raw secret before injection. + Currently supports ``"base64"`` (for HTTP Basic auth pairs that are + already formatted as ``user:password``). Default: ``None``. + """ + + def __init__( + self, + *, + secret_provider: SecretProvider, + secret_key: str, + header_name: str = "Authorization", + prefix: str = "Bearer", + encoding: Optional[str] = None, + ) -> None: + self._secret_provider = secret_provider + self._secret_key = secret_key + self._header_name = header_name + self._prefix = prefix + self._encoding = encoding + self._cached_header: Optional[Dict[str, str]] = None + + def _build_header(self) -> Dict[str, str]: + raw = self._secret_provider.get_secret(self._secret_key) + + if self._encoding == "base64": + raw = base64.b64encode(raw.encode()).decode() + + value = f"{self._prefix} {raw}".strip() if self._prefix else raw + return {self._header_name: value} + + async def get_headers(self) -> Dict[str, str]: + """Return the header dict, computing it once and caching thereafter.""" + if self._cached_header is None: + logger.debug( + "StaticTokenAuthProvider: resolving secret", + extra={"secret_key": self._secret_key, "header": self._header_name}, + ) + self._cached_header = self._build_header() + return dict(self._cached_header) + + async def refresh(self) -> None: + """Invalidate the cached header so the secret is re-read on the next call.""" + logger.debug( + "StaticTokenAuthProvider: cache invalidated", + extra={"secret_key": self._secret_key}, + ) + self._cached_header = None diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py new file mode 100644 index 0000000..d57c064 --- /dev/null +++ b/src/node_wire_runtime/base_connector.py @@ -0,0 +1,672 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import contextvars +import inspect +import logging +import uuid +from abc import ABC +from collections import defaultdict +from dataclasses import dataclass +from typing import ( + Annotated, + Any, + Callable, + ClassVar, + Dict, + Optional, + Tuple, + Type, + Union, + cast, + get_type_hints, + List, +) + +from opentelemetry import trace +from opentelemetry.trace import Tracer +from pybreaker import CircuitBreaker +from pydantic import BaseModel, Field, RootModel, ValidationError + +from .auth import AuthProvider, NoAuthProvider +from .errors import ErrorMapper +from .models import ConnectorResponse, ErrorCategory +from .policy import PolicyContext, PolicyHook, PolicyDenied +from .resilience import with_resilience +from .secrets import SecretProvider +from .sdk_action_spec import SdkActionSpec + +logger = logging.getLogger("runtime.base_connector") +tracer: Tracer = trace.get_tracer("runtime") +ErrorMapper.register(PolicyDenied, ErrorCategory.AUTH, code="POLICY_DENIED") + + +class NestedConnectorActionError(Exception): + """Nested action invoked via :meth:`call_action` returned ``ConnectorResponse.success=False``.""" + + def __init__(self, response: ConnectorResponse) -> None: + self.response = response + msg = response.message or response.error_code or "Nested action failed" + super().__init__(msg) + + +def _merge_nested_failure_details(nested: ConnectorResponse) -> Any: + """Attach nested trace id for debugging without dropping existing ``details``.""" + tid = nested.trace_id + d = nested.details + if tid is None or tid == "": + return d + if d is None: + return {"nested_trace_id": tid} + if isinstance(d, dict): + merged = dict(d) + merged.setdefault("nested_trace_id", tid) + return merged + return {"nested_trace_id": tid, "nested_details": d} + + +# principal, tenant_id, scopes — set during :meth:`run` for nested :meth:`call_action`. +_caller_execution_ctx: contextvars.ContextVar[ + tuple[Optional[str], Optional[str], Optional[tuple[str, ...]]] | None +] = contextvars.ContextVar("nw_connector_caller_execution", default=None) + +# Populated by BaseConnector.__init_subclass__ +_CONNECTOR_REGISTRY: Dict[str, Type["BaseConnector"]] = {} + + +def _make_spec_handler( + action_name: str, + input_model: Any, + output_model: Any, + cls_qualname: str, + cls_module: str, + alias_tolerant: bool = False, + mcp_normalize: Optional[Callable[[Dict[str, Any]], None]] = None, + requires_auth: bool = True, + scopes: Optional[List[str]] = None, + rate_limit: Optional[Dict[str, Any]] = None, + deprecated: bool = False, +) -> Any: + """ + Build a single async handler function for one action_specs entry. + Using a factory function (rather than a loop + default-arg trick) ensures + action_name is captured by value in the closure and does not appear in the + method signature seen by inspect.signature / get_type_hints. + """ + fn_name = action_name.replace(".", "_").replace("-", "_") + + async def _handler(self, params, *, trace_id: str): + return await self._execute_action_spec(action_name, params, trace_id=trace_id) + + handler_fn: Any = _handler + handler_fn.__name__ = fn_name + handler_fn.__qualname__ = f"{cls_qualname}.{fn_name}" + handler_fn.__module__ = cls_module + # Set actual type objects (not strings) so get_type_hints() resolves correctly + # even when `from __future__ import annotations` is active in the connector module. + handler_fn.__annotations__ = {"params": input_model, "return": output_model} + handler_fn._sdk_action_name = action_name + handler_fn._alias_tolerant = alias_tolerant + handler_fn._mcp_normalize = mcp_normalize + handler_fn._requires_auth = requires_auth + handler_fn._scopes = scopes + handler_fn._rate_limit = rate_limit + handler_fn._deprecated = deprecated + # Backward-compatible alias for legacy callers/tests. + handler_fn._nw_action_name = action_name + return handler_fn + + +def _generate_methods_from_action_specs(cls: Any) -> None: + """ + For each entry in cls.action_specs, generate an async @nw_action method and + attach it to cls. Called at the top of BaseConnector.__init_subclass__ so the + existing discovery loop picks up the generated methods. + + Opt-in: only triggers when the class defines action_specs in its own __dict__. + """ + specs = cls.__dict__.get("action_specs") + if specs is None: + return + + fallback_output = getattr(cls, "output_model", None) + + for action_name, spec in specs.items(): + if not isinstance(spec, SdkActionSpec): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] must be a SdkActionSpec instance" + ) + input_model = spec.input_model + if not (isinstance(input_model, type) and issubclass(input_model, BaseModel)): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] requires " + "input_model=" + ) + + output_model = spec.output_model if spec.output_model is not None else fallback_output + if not (isinstance(output_model, type) and issubclass(output_model, BaseModel)): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] has no resolvable " + "output_model — set it on the SdkActionSpec or define cls.output_model" + ) + + fn_name = action_name.replace(".", "_").replace("-", "_") + if fn_name in cls.__dict__: + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] conflicts with " + f"existing method {fn_name!r}" + ) + + handler = _make_spec_handler( + action_name, + input_model, + output_model, + cls.__qualname__, + cls.__module__, + alias_tolerant=spec.alias_tolerant, + mcp_normalize=spec.mcp_normalize, + requires_auth=spec.requires_auth, + scopes=spec.scopes, + rate_limit=spec.rate_limit, + deprecated=spec.deprecated, + ) + setattr(cls, fn_name, handler) + + +def sdk_action( + name: str, + *, + alias_tolerant: bool = False, + mcp_normalize: Optional[Callable[[Dict[str, Any]], None]] = None, + requires_auth: bool = True, + scopes: Optional[List[str]] = None, + rate_limit: Optional[Dict[str, Any]] = None, + deprecated: bool = False, +): + """ + Mark a connector method as a named, auto-discoverable action. + + The decorated method must be async and have full type annotations for its + params (first arg after self) and return type. + + Set alias_tolerant=True for actions whose MCP input schema should accept + extra/alias fields (e.g. LLM-generated aliases) before normalization runs. + + Optional mcp_normalize mutates tool argument dicts in place before connector.run. + """ + + def decorator(fn: Any) -> Any: + fn._sdk_action_name = name + fn._alias_tolerant = alias_tolerant + fn._mcp_normalize = mcp_normalize + fn._requires_auth = requires_auth + fn._scopes = scopes + fn._rate_limit = rate_limit + fn._deprecated = deprecated + # Backward-compatible alias for legacy callers/tests. + fn._nw_action_name = name + return fn + + return decorator + + +def nw_action(name: str): + """Backward-compatible decorator alias for sdk_action().""" + return sdk_action(name) + + +@dataclass +class NwActionMeta: + """Metadata for one @nw_action method.""" + + name: str + fn_name: str + input_model: Type[BaseModel] + output_model: Type[BaseModel] + alias_tolerant: bool = False + mcp_normalize: Optional[Callable[[Dict[str, Any]], None]] = None + requires_auth: bool = True + scopes: Optional[List[str]] = None + rate_limit: Optional[Dict[str, Any]] = None + deprecated: bool = False + + +class BaseConnector(ABC): + """ + Base class for all connectors. + + Subclasses define: + - connector_id: str + - output_model: Type[BaseModel] (common output envelope for all actions) + - error_map: optional mapping of exception -> (ErrorCategory, code) + - build_client() / get_client() for vendor SDK lifecycle + + Actions are declared with @nw_action("resource.operation") on async methods. + """ + + connector_id: str + action: str = "execute" + + error_map: ClassVar[Dict[Type[BaseException], Tuple[ErrorCategory, str]]] = {} + output_model: ClassVar[Type[BaseModel]] + + _action_registry: ClassVar[Dict[str, NwActionMeta]] + _union_input_model: ClassVar[Type[RootModel[Any]]] + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + # Phase 0: auto-generate @nw_action methods from action_specs (opt-in). + # Must run before the dir(cls) discovery loop below. + _generate_methods_from_action_specs(cls) + + registry: Dict[str, NwActionMeta] = {} + for attr_name in dir(cls): + method = getattr(cls, attr_name, None) + if not callable(method): + continue + action_name = getattr(method, "_sdk_action_name", None) or getattr( + method, "_nw_action_name", None + ) + if not action_name: + continue + + try: + hints = get_type_hints(method) + except Exception: + hints = {} + + try: + sig_params = [ + p + for p in inspect.signature(method).parameters.values() + if p.name not in ("self", "trace_id") + ] + input_param_name = sig_params[0].name if sig_params else None + except (ValueError, TypeError): + input_param_name = None + + if not input_param_name: + raise TypeError( + f"{cls.__name__}.{attr_name}: @nw_action method must have a params argument " + "after self" + ) + + input_model = hints.get(input_param_name) + output_model = hints.get("return") + if ( + input_model is None + or not isinstance(input_model, type) + or not issubclass(input_model, BaseModel) + ): + raise TypeError( + f"{cls.__name__}.{attr_name}: missing or invalid type hint for " + f"parameter {input_param_name!r}" + ) + if ( + output_model is None + or not isinstance(output_model, type) + or not issubclass(output_model, BaseModel) + ): + raise TypeError(f"{cls.__name__}.{attr_name}: missing or invalid return type hint") + + registry[action_name] = NwActionMeta( + name=action_name, + fn_name=attr_name, + input_model=input_model, + output_model=output_model, + alias_tolerant=getattr(method, "_alias_tolerant", False), + mcp_normalize=getattr(method, "_mcp_normalize", None), + requires_auth=getattr(method, "_requires_auth", True), + scopes=getattr(method, "_scopes", None), + rate_limit=getattr(method, "_rate_limit", None), + deprecated=getattr(method, "_deprecated", False), + ) + + cls._action_registry = registry + + valid_models = [m.input_model for m in registry.values()] + if not valid_models: + raise TypeError(f"{cls.__name__}: BaseConnector must define at least one @nw_action") + + if len(valid_models) == 1: + root_for_rm: Any = valid_models[0] + else: + root_for_rm = Annotated[ + Union[tuple(valid_models)], # type: ignore[arg-type] + Field(discriminator="action"), + ] + + cls._union_input_model = cast(Type[RootModel[Any]], RootModel[root_for_rm]) + cls._union_input_model.model_rebuild() + + own_error_map = cls.__dict__.get("error_map", {}) + for exc_type, (category, code) in own_error_map.items(): + ErrorMapper.register(exc_type, category, code=code) + + if "connector_id" in cls.__dict__: + _CONNECTOR_REGISTRY[cls.connector_id] = cls + logger.debug( + "Registered BaseConnector subclass", + extra={"connector_id": cls.connector_id}, + ) + + def __init__( + self, + *, + secret_provider: Optional[SecretProvider] = None, + policy_hook: Optional[PolicyHook] = None, + auth_provider: Optional[AuthProvider] = None, + ) -> None: + cls = type(self) + self._input_model_cls = cls._union_input_model + self._output_model_cls = cls.output_model + self._secret_provider = secret_provider + self._policy_hook = policy_hook + # Default to NoAuthProvider (null-object) so connectors never receive None. + self._auth_provider: AuthProvider = ( + auth_provider if auth_provider is not None else NoAuthProvider() + ) + self._breakers: dict[str, CircuitBreaker] = defaultdict(self._create_breaker) + self._client: Any = None + + def _create_breaker(self) -> CircuitBreaker: + cls = type(self) + return CircuitBreaker( + fail_max=5, + reset_timeout=30, + name=f"{cls.__name__}_breaker", + ) + + def _breaker_key(self, tenant_id: Optional[str]) -> str: + return tenant_id or "__default__" + + def _breaker_for_tenant(self, tenant_id: Optional[str]) -> CircuitBreaker: + # Tests may delete `_breakers` to simulate cache loss; rebuild lazily. + if not hasattr(self, "_breakers"): + self._breakers = defaultdict(self._create_breaker) + return self._breakers[self._breaker_key(tenant_id)] + + @property + def secret_provider(self) -> SecretProvider: + if self._secret_provider is None: + raise RuntimeError("SecretProvider has not been configured for this connector.") + return self._secret_provider + + @property + def auth_provider(self) -> AuthProvider: + """The :class:`AuthProvider` configured for this connector. + + Always returns a valid provider — defaults to :class:`NoAuthProvider` + when none was injected, so callers never need a ``None`` guard. + """ + return self._auth_provider + + async def get_auth_headers(self) -> Dict[str, str]: + """Return authentication headers from the configured :class:`AuthProvider`. + + Connectors should call this instead of reading secrets directly:: + + headers = await self.get_auth_headers() + # merge with any connector-specific headers + headers.update({"Content-Type": "application/json"}) + + Returns an empty dict when the provider is :class:`NoAuthProvider`. + """ + return await self._auth_provider.get_headers() + + async def run( + self, + raw_input: Dict[str, Any], + principal: Optional[str] = None, + tenant_id: Optional[str] = None, + scopes: Optional[tuple[str, ...]] = None, + ) -> ConnectorResponse: + """ + Public execution entrypoint. + + - Generates a trace ID + - Starts an OpenTelemetry span + - Validates input + - Executes policy hook + - Wraps internal execution with retries and circuit breaking + - Maps exceptions into the standard error taxonomy + """ + trace_id = str(uuid.uuid4()) + + with tracer.start_as_current_span( + "connector.run", + attributes={ + "connector.id": self.connector_id, + "connector.action": self.action, + "tenant.id": tenant_id or "", + "principal.id": principal or "", + "trace.id": trace_id, + }, + ): + logger.info( + "Starting connector execution", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": self.action, + }, + ) + + token = _caller_execution_ctx.set((principal, tenant_id, scopes)) + try: + try: + input_model = self._input_model_cls.model_validate(raw_input) + except ValidationError as exc: + logger.error( + "Input validation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": self.action, + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + details = [ + {"loc": e["loc"], "msg": e["msg"], "type": e["type"]} for e in exc.errors() + ] + return ConnectorResponse( + success=False, + error_code="VALIDATION_ERROR", + error_category=ErrorCategory.BUSINESS, + message="Input validation failed; please check the request payload.", + trace_id=trace_id, + details=details, + ) + + # Policy hook + if self._policy_hook is not None: + input_payload = input_model.model_dump() + policy_action = str(input_payload.get("action", self.action)) + context = PolicyContext( + connector_id=self.connector_id, + action=policy_action, + input_payload=input_payload, + principal=principal, + tenant_id=tenant_id, + scopes=scopes, + ) + try: + self._policy_hook.check(context) + except PolicyDenied as exc: + logger.warning( + "AUDIT: Execution blocked by policy hook", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": self.action, + "error_type": type(exc).__name__, + "error_message": str(exc), + "audit": True, + "audit_event": "policy_denial", + "tenant_id": tenant_id, + "principal": principal, + }, + ) + mapped = ErrorMapper.resolve(exc) + return ConnectorResponse( + success=False, + error_code=mapped.code, + error_category=mapped.category, + message=str(exc), + trace_id=trace_id, + ) + + execute_with_resilience = with_resilience(self._breaker_for_tenant(tenant_id)) + + @execute_with_resilience + async def _do_execute(*, trace_id: str) -> Any: + return await self.internal_execute(input_model, trace_id=trace_id) + + output_model = await _do_execute(trace_id=trace_id) + + logger.info( + "Connector execution completed successfully", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": self.action, + }, + ) + + return ConnectorResponse( + success=True, + data=output_model.model_dump(), + trace_id=trace_id, + ) + except NestedConnectorActionError as exc: + nested = exc.response + logger.warning( + "Nested connector action failed via call_action", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "nested_error_code": nested.error_code or "", + "nested_trace_id": nested.trace_id, + }, + ) + return ConnectorResponse( + success=False, + error_code=nested.error_code, + error_category=nested.error_category, + message=nested.message, + trace_id=trace_id, + details=_merge_nested_failure_details(nested), + ) + except Exception as exc: # noqa: BLE001 + mapped = ErrorMapper.resolve(exc) + logger.error( + "Connector execution failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": self.action, + "error_code": mapped.code, + "error_category": mapped.category.value, + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + return ConnectorResponse( + success=False, + error_code=mapped.code, + error_category=mapped.category, + message=str(exc), + trace_id=trace_id, + ) + finally: + _caller_execution_ctx.reset(token) + + @classmethod + def get_registry(cls) -> Dict[str, Type[BaseConnector]]: + """Public access to the global connector registry.""" + return dict(_CONNECTOR_REGISTRY) + + @classmethod + def sdk_action_metas(cls) -> Dict[str, NwActionMeta]: + """Registry of action name -> metadata (for manifest/ingress).""" + return dict(cls._action_registry) + + @classmethod + def nw_action_metas(cls) -> Dict[str, NwActionMeta]: + """Backward-compatible alias for sdk_action_metas().""" + return dict(cls._action_registry) + + def build_client(self) -> Any: + """Override in subclasses to build the vendor SDK client.""" + return None + + def get_client(self) -> Any: + if self._client is None: + self._client = self.build_client() + return self._client + + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + """Dispatch to the @nw_action method matching the validated input.""" + root = params.root if hasattr(params, "root") else params + action_key = getattr(root, "action", None) + if action_key is None: + raise ValueError(f"Input model missing action discriminator: {type(root).__name__}") + + meta = self._action_registry.get(str(action_key)) + if meta is None: + raise ValueError( + f"Connector {self.connector_id!r} has no registered action {action_key!r}. " + f"Available: {list(self._action_registry)}" + ) + fn = getattr(self, meta.fn_name) + logger.debug( + "Dispatching action", + extra={ + "connector_id": self.connector_id, + "action": action_key, + "trace_id": trace_id, + }, + ) + return await fn(root, trace_id=trace_id) + + async def call_action( + self, + name: str, + params_dict: Dict[str, Any], + *, + principal: Optional[str] = None, + tenant_id: Optional[str] = None, + scopes: Optional[tuple[str, ...]] = None, + ) -> Any: + """Invoke another action via :meth:`run` so policy hooks and resilience apply. + + When called from within an action that was entered through :meth:`run` + (e.g. MCP/REST with identity), caller ``principal`` / ``tenant_id`` / + ``scopes`` are inherited from that outer run unless overridden here. + """ + meta = self._action_registry.get(name) + if meta is None: + raise ValueError( + f"call_action: unknown action {name!r} on connector {self.connector_id!r}" + ) + p, t, s = principal, tenant_id, scopes + if p is None and t is None and s is None: + inherited = _caller_execution_ctx.get() + if inherited is not None: + p, t, s = inherited + + payload = dict(params_dict) + payload["action"] = name + resp = await self.run(payload, principal=p, tenant_id=t, scopes=s) + if not resp.success: + if resp.error_code == "POLICY_DENIED": + raise PolicyDenied(resp.message or "Policy denied") + raise NestedConnectorActionError(resp) + if resp.data is None: + raise RuntimeError("call_action: connector returned no data") + return meta.output_model.model_validate(resp.data) diff --git a/src/node_wire_runtime/caller_identity.py b/src/node_wire_runtime/caller_identity.py new file mode 100644 index 0000000..9b0949f --- /dev/null +++ b/src/node_wire_runtime/caller_identity.py @@ -0,0 +1,66 @@ +"""Transport-neutral caller identity for connector execution and policy hooks.""" + +from __future__ import annotations + +import json +import os +import re +from dataclasses import dataclass +from typing import Any, Mapping + + +@dataclass(frozen=True) +class CallerIdentity: + """Who is calling ``connector.run`` (REST, MCP, or other bindings).""" + + principal: str + tenant_id: str | None + scopes: tuple[str, ...] + claims: Mapping[str, Any] + auth_type: str + + +def build_caller_identity(claims: Mapping[str, Any], auth_type: str) -> CallerIdentity: + """Build identity from JWT-style claims (``sub``, ``tenant_id``, ``scopes`` / ``scope``).""" + principal = str(claims.get("sub") or claims.get("client_id") or "unknown") + tenant_val = claims.get("tenant_id") + tenant_id = str(tenant_val) if tenant_val is not None else None + raw_scopes = claims.get("scopes") + if raw_scopes is None: + raw_scopes = claims.get("scope") + if isinstance(raw_scopes, str): + scopes = tuple(s for s in raw_scopes.split(" ") if s) + elif isinstance(raw_scopes, (list, tuple, set)): + scopes = tuple(str(s) for s in raw_scopes if str(s).strip()) + else: + scopes = tuple() + return CallerIdentity( + principal=principal, + tenant_id=tenant_id, + scopes=scopes, + claims=dict(claims), + auth_type=auth_type, + ) + + +def parse_api_key_scopes_from_env(env_var: str) -> tuple[str, ...]: + """ + Parse scopes for shared API keys (MCP / REST), e.g. ``NW_MCP_API_KEY_SCOPES``. + + Accepts: + + - JSON array: ``["mcp:smtp.send_email","mcp:other"]`` + - Whitespace or comma separated tokens: ``mcp:a mcp:b`` or ``mcp:a,mcp:b`` + + Empty / unset means **no** scopes (not wildcard). + """ + raw = os.environ.get(env_var) + if raw is None or not str(raw).strip(): + return tuple() + raw = str(raw).strip() + if raw.startswith("["): + parsed = json.loads(raw) + if not isinstance(parsed, list): + raise ValueError(f"{env_var} JSON must be an array of strings") + return tuple(str(s).strip() for s in parsed if str(s).strip()) + return tuple(p for p in re.split(r"[\s,]+", raw) if p) diff --git a/src/node_wire_runtime/connector_registry.py b/src/node_wire_runtime/connector_registry.py new file mode 100644 index 0000000..7012f56 --- /dev/null +++ b/src/node_wire_runtime/connector_registry.py @@ -0,0 +1,190 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +node_wire_runtime.connector_registry +===================================== + +Entry-point–based connector auto-registration. + +Installed connector packages declare themselves via the ``node_wire.connectors`` +entry point group in their ``pyproject.toml``: + + [project.entry-points."node_wire.connectors"] + fhir_epic = "node_wire_fhir_epic.logic" + +Calling :func:`auto_register` loads each allowed connector package's ``logic`` module +(triggering ``BaseConnector.__init_subclass__`` registration) and its optional +``registration`` module (for ``ErrorMapper`` side effects). + +Allowlist (recommended for production): + +* ``NW_ALLOWED_CONNECTORS`` — comma-separated entry point **names** (e.g. ``fhir_epic,http_generic``). + If unset or empty, no entry points are loaded (secure default). + +* ``NW_CONNECTOR_MODULE_PREFIX`` — if set (default ``node_wire_``), entry points whose + target module does not start with this prefix are skipped with a warning. + +""" + +from __future__ import annotations + +import importlib +import logging +import os +from importlib.metadata import EntryPoint, entry_points +from typing import List + +logger = logging.getLogger("node_wire_runtime.connector_registry") + + +def _parse_allowed_names() -> set[str]: + """Return allowed entry point names. Defaults to empty set (nothing allowed).""" + raw = os.environ.get("NW_ALLOWED_CONNECTORS") + if raw is None or not str(raw).strip(): + return set() + return {x.strip() for x in str(raw).split(",") if x.strip()} + + +def _module_prefix() -> str | None: + """Prefix that logic module names must start with; None disables the check.""" + raw = os.environ.get("NW_CONNECTOR_MODULE_PREFIX") + if raw is None: + return "node_wire_" + s = str(raw).strip() + return s if s else None + + +def _logic_module_dotted_path(ep: EntryPoint) -> str: + """Dotted import path for the entry point target (e.g. ``node_wire_fhir_epic.logic``).""" + val = ep.value.strip() + if ":" in val: + return val.split(":", 1)[0].strip() + return val + + +def _parent_package_for_logic_module(logic_module: str) -> str: + """``node_wire_fhir_epic.logic`` -> ``node_wire_fhir_epic``.""" + return logic_module.rsplit(".", 1)[0] + + +def _should_skip_ep(ep: EntryPoint, allowed: set[str], prefix: str | None) -> bool: + if ep.name not in allowed: + logger.warning( + "Skipping connector entry point %r (not in NW_ALLOWED_CONNECTORS)", + ep.name, + ) + return True + logic_mod = _logic_module_dotted_path(ep) + if prefix and not logic_mod.startswith(prefix): + logger.warning( + "Skipping connector entry point %r: module %r does not start with NW_CONNECTOR_MODULE_PREFIX %r", + ep.name, + logic_mod, + prefix, + ) + return True + return False + + +def auto_register() -> List[str]: + """Load connector packages declared under ``node_wire.connectors``. + + For each entry point: + 1. Load the ``logic`` module — triggers ``BaseConnector.__init_subclass__``, + which populates ``_CONNECTOR_REGISTRY``. + 2. Attempt to load a sibling ``registration`` module (optional) for + ``ErrorMapper`` registrations and other import-time side effects. + + If an allowed connector is not discovered via entry points, attempts to fallback + to importing the logic module directly. + + Returns the list of loaded module name strings (useful for testing / logging). + """ + loaded: List[str] = [] + allowed = _parse_allowed_names() + prefix = _module_prefix() + + discovered_names = set() + + for ep in entry_points(group="node_wire.connectors"): + if _should_skip_ep(ep, allowed, prefix): + continue + + logic_mod = _logic_module_dotted_path(ep) + importlib.import_module(logic_mod) + loaded.append(logic_mod) + logger.debug("Registered connector: %s (%s)", ep.name, ep.value) + + pkg = _parent_package_for_logic_module(logic_mod) + reg_name = f"{pkg}.registration" + try: + importlib.import_module(reg_name) + loaded.append(reg_name) + logger.debug("Loaded registration module: %s", reg_name) + except ModuleNotFoundError as exc: + if exc.name == reg_name: + pass + else: + logger.error( + "Import error inside %s (missing dep: %s): %s", + reg_name, + exc.name, + exc, + ) + raise + except Exception as exc: + logger.error("Unexpected error loading %s: %s", reg_name, exc) + raise + + discovered_names.add(ep.name) + + # Fallback for allowlisted names not discovered via entry points + for name in allowed: + if name not in discovered_names: + pkg_prefix = prefix if prefix is not None else "node_wire_" + pkg = f"{pkg_prefix}{name}" + logic_mod = f"{pkg}.logic" + reg_name = f"{pkg}.registration" + + try: + importlib.import_module(logic_mod) + loaded.append(logic_mod) + logger.debug("Registered connector via fallback: %s (%s)", name, logic_mod) + except ModuleNotFoundError as exc: + if exc.name == logic_mod or exc.name == pkg: + logger.debug("Fallback connector module not found: %s", logic_mod) + continue + else: + logger.error( + "Import error inside fallback %s (missing dep: %s): %s", + logic_mod, + exc.name, + exc, + ) + raise + except Exception as exc: + logger.error("Unexpected error loading fallback %s: %s", logic_mod, exc) + raise + + try: + importlib.import_module(reg_name) + loaded.append(reg_name) + logger.debug("Loaded registration module via fallback: %s", reg_name) + except ModuleNotFoundError as exc: + if exc.name == reg_name: + pass + else: + logger.error( + "Import error inside fallback registration %s (missing dep: %s): %s", + reg_name, + exc.name, + exc, + ) + raise + except Exception as exc: + logger.error("Unexpected error loading fallback registration %s: %s", reg_name, exc) + raise + + return loaded diff --git a/src/node_wire_runtime/connectors.yaml.sample b/src/node_wire_runtime/connectors.yaml.sample new file mode 100644 index 0000000..7544538 --- /dev/null +++ b/src/node_wire_runtime/connectors.yaml.sample @@ -0,0 +1,126 @@ +## +## SPDX-FileCopyrightText: 2026 AOT Technologies +## SPDX-License-Identifier: Apache-2.0 +## + +# connectors.yaml — Node Wire connector configuration +# +# Copy this file to your deployment root and set: +# export NW_CONFIG_PATH=/path/to/connectors.yaml +# +# SECURITY RULE: This file must never contain secrets. +# - Non-sensitive config (base_url, host, port, from_email) → safe to version-control here +# - Secrets (client_id, private_key, api_key, password) → environment variables or a +# cloud secrets backend (AWS Secrets Manager, HashiCorp Vault, Azure Key Vault, GCP SM) +# +# ---------------------------------------------------------------------------- +# auth: block reference +# ---------------------------------------------------------------------------- +# Each connector optionally declares an auth: block that controls which +# AuthProvider the runtime injects at startup. Supported provider types: +# +# provider: none → NoAuthProvider (default — no auth headers) +# provider: static_token → StaticTokenAuthProvider +# secret_key: +# header_name: Authorization # optional, default: Authorization +# prefix: Bearer # optional, default: Bearer; use "" for raw +# encoding: base64 # optional, for HTTP Basic auth +# +# provider: oauth2 → OAuth2AuthProvider (token cached, Lock-protected) +# grant_method: private_key_jwt # or client_secret_post +# token_url_secret: # token endpoint URL +# client_id_secret: # OAuth2 client_id +# algorithm: RS384 # optional (private_key_jwt only) +# private_key_secret: # PEM private key (private_key_jwt only) +# kid_secret: # JWT kid header (private_key_jwt only) +# client_secret_secret: # client_secret (client_secret_post only) +# scopes_secret: # space-separated scope string (optional) +# scopes: [...] # static scope list (optional) +# buffer_secs: 60 # optional, expire buffer before expires_in +# jwt_ttl_secs: 300 # optional, JWT assertion lifetime +# +# provider: service_account → ServiceAccountAuthProvider (Google APIs) +# sa_json_secret: +# scopes: [...] # optional, defaults to drive scope +# +# provider: static_credentials → SMTP-style username+password pair +# username_secret: SMTP_USERNAME # optional, defaults to SMTP_USERNAME +# password_secret: SMTP_PASSWORD # optional, defaults to SMTP_PASSWORD + +connectors: + fhir_epic: + enabled: false + exposed_via: [rest, grpc, mcp] + # Non-sensitive endpoint config — safe in YAML + base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: epic_token_url + client_id_secret: epic_client_id + private_key_secret: epic_private_key + kid_secret: epic_kid + algorithm: RS384 + buffer_secs: 60 + jwt_ttl_secs: 300 + + fhir_cerner: + enabled: false + exposed_via: [rest, grpc, mcp] + base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: cerner_token_url + client_id_secret: cerner_client_id + private_key_secret: cerner_private_key + kid_secret: cerner_kid + algorithm: RS384 + # Optional: override default scopes from a secret (space-separated) + scopes_secret: cerner_scopes + # Default scopes used when cerner_scopes secret is absent or empty: + scopes: + - system/Patient.read + - system/Encounter.read + - system/DocumentReference.read + - system/DocumentReference.write + buffer_secs: 60 + jwt_ttl_secs: 300 + + google_drive: + enabled: false + exposed_via: [rest, grpc, mcp] + # Optional: restrict uploads to a specific folder (leave empty for root) + folder_id: "${GDRIVE_FOLDER_ID:}" + auth: + provider: service_account + sa_json_secret: GOOGLE_DRIVE_SA_JSON + scopes: + - https://www.googleapis.com/auth/drive + + smtp: + enabled: false + exposed_via: [rest, grpc, mcp] + # Non-sensitive SMTP server config — safe in YAML + host: "smtp.example.com" + port: 587 + from_email: "noreply@example.com" + auth: + provider: static_credentials + username_secret: SMTP_USERNAME + password_secret: SMTP_PASSWORD + + stripe: + enabled: false + exposed_via: [grpc, mcp] + auth: + provider: static_token + secret_key: stripe_api_key + header_name: Authorization + prefix: "" # Stripe expects the raw key with no "Bearer" prefix + + http_generic: + enabled: false + exposed_via: [rest, grpc, mcp] + # auth: not set — defaults to NoAuthProvider + diff --git a/src/runtime/errors.py b/src/node_wire_runtime/errors.py similarity index 78% rename from src/runtime/errors.py rename to src/node_wire_runtime/errors.py index 76acb87..56b01f7 100644 --- a/src/runtime/errors.py +++ b/src/node_wire_runtime/errors.py @@ -1,7 +1,11 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Type +from typing import Dict, Optional, Type from .models import ErrorCategory @@ -22,7 +26,9 @@ class ErrorMapper: _registry: Dict[Type[BaseException], MappedError] = {} @classmethod - def register(cls, exc_type: Type[BaseException], category: ErrorCategory, code: Optional[str] = None) -> None: + def register( + cls, exc_type: Type[BaseException], category: ErrorCategory, code: Optional[str] = None + ) -> None: """ Register an exception type with a category and optional stable error code. """ diff --git a/src/node_wire_runtime/fhir_encounter.py b/src/node_wire_runtime/fhir_encounter.py new file mode 100644 index 0000000..1d8ea12 --- /dev/null +++ b/src/node_wire_runtime/fhir_encounter.py @@ -0,0 +1,24 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""FHIR Encounter search helpers shared by Epic/Cerner connectors.""" + +from __future__ import annotations + +from typing import Dict + + +def assert_encounter_query_has_patient(query_params: Dict[str, str]) -> None: + """ + Require a patient filter on Encounter search (enterprise default). + + Prevents broad or accidental unscoped queries that return 400 from the vendor + or leak unrelated encounters. + """ + p = query_params.get("patient") + if not p or not str(p).strip(): + raise ValueError( + "Encounter search requires a patient-scoped filter: set patient_id, " + "or include patient in search_params." + ) diff --git a/src/node_wire_runtime/ingress.py b/src/node_wire_runtime/ingress.py new file mode 100644 index 0000000..60b9e44 --- /dev/null +++ b/src/node_wire_runtime/ingress.py @@ -0,0 +1,62 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +Shared ingress helpers for connector bindings (MCP, REST, gRPC). + +Tool/route action is authoritative for MCP and REST; normalizers map LLM aliases +to canonical Pydantic fields before validation. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict + +from node_wire_runtime import BaseConnector + +logger = logging.getLogger("node_wire_runtime.ingress") + + +def normalize_mcp_tool_arguments( + connector: BaseConnector, action: str, arguments: Dict[str, Any] +) -> Dict[str, Any]: + """ + Apply action-registered argument normalizers (see SdkActionMeta.mcp_normalize). + + Used for MCP, REST, and gRPC so the same alias mapping applies across bindings. + Mutates a copy of ``arguments`` and returns it. + """ + args = dict(arguments) + if not isinstance(connector, BaseConnector): + return args + meta = type(connector).sdk_action_metas().get(action) + if meta is not None and meta.mcp_normalize is not None: + meta.mcp_normalize(args) + return args + + +def enforce_authoritative_action(payload: Dict[str, Any], authoritative_action: str) -> None: + """ + Ensure the payload does not contradict the invoked tool or REST route. + + After :func:`normalize_mcp_tool_arguments`, the payload may contain a temporary + ``action`` alias (e.g. ``upload``) that normalizers rewrite to the canonical + action; those must match ``authoritative_action`` before the final assignment. + + :raises ValueError: if ``action`` is present and differs from ``authoritative_action``. + """ + if "action" not in payload: + return + raw = payload.get("action") + if raw is None: + return + if isinstance(raw, str) and not raw.strip(): + return + current = str(raw).strip() if isinstance(raw, str) else str(raw) + if current != authoritative_action: + raise ValueError( + f"Payload 'action' {raw!r} does not match the invoked route or tool " + f"action {authoritative_action!r}." + ) diff --git a/src/node_wire_runtime/manifest.py b/src/node_wire_runtime/manifest.py new file mode 100644 index 0000000..6e5c268 --- /dev/null +++ b/src/node_wire_runtime/manifest.py @@ -0,0 +1,132 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import copy +from typing import Any, Dict, List, Type + +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector +from node_wire_runtime.models import ErrorCategory + +# Bump when published input/output schema shape policy changes (MCP clients cache tools/list). +MCP_MANIFEST_CONTRACT_VERSION = "3" + + +def _schema_for(model: Type[BaseModel], *, strict: bool = True) -> Dict[str, Any]: + schema = copy.deepcopy(model.model_json_schema(by_alias=False)) + # Remove `action` from `required`: it is always auto-injected from the tool + # name by invoke_tool (run_args.setdefault("action", action)), so LLMs must + # not be required to pass it. Keeping it as an optional property is fine. + if "required" in schema: + schema["required"] = [f for f in schema["required"] if f != "action"] + if not schema["required"]: + del schema["required"] + # Only remove additionalProperties:false for alias-tolerant actions so that + # common LLM aliases (e.g. mimeType → mime_type) are not rejected by the + # MCP SDK's JSON-Schema validation layer before our normalization runs. + # Strict actions retain additionalProperties:false for proper contract enforcement. + if not strict: + schema.pop("additionalProperties", None) + return schema + + +def _strip_action_field_from_json_schema(schema: Dict[str, Any]) -> None: + """ + Remove ``action`` from published input schemas for MCP/REST tool contracts. + + The binding injects ``action`` from the tool name or URL path; exposing it in + ``inputSchema`` invites redundant or legacy values (e.g. ``upload``). + Mutates ``schema`` in place (recurses into ``$defs``). + """ + props = schema.get("properties") + if isinstance(props, dict) and "action" in props: + del props["action"] + defs = schema.get("$defs") + if isinstance(defs, dict): + for sub in defs.values(): + if isinstance(sub, dict): + _strip_action_field_from_json_schema(sub) + # oneOf / anyOf branches + for key in ("oneOf", "anyOf", "allOf"): + branch = schema.get(key) + if isinstance(branch, list): + for item in branch: + if isinstance(item, dict): + _strip_action_field_from_json_schema(item) + + +def _error_category_json_schema() -> Dict[str, Any]: + """Inline enum from runtime ErrorCategory (single source of truth, no drift).""" + return { + "type": "string", + "enum": [e.value for e in ErrorCategory], + } + + +def _connector_response_schema(output_model: Type[BaseModel]) -> Dict[str, Any]: + """ + Build the ConnectorResponse envelope schema with `data` typed to the + specific output_model for this action. + + Built by hand (not from ConnectorResponse.model_json_schema()) to avoid + $defs/$ref pollution from ErrorCategory and to keep the schema self-contained. + """ + output_schema = _schema_for(output_model) + return { + "type": "object", + "title": "ConnectorResponse", + "properties": { + "success": {"type": "boolean"}, + "data": {"anyOf": [output_schema, {"type": "null"}]}, + "error_code": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "error_category": { + "anyOf": [ + _error_category_json_schema(), + {"type": "null"}, + ] + }, + "message": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "trace_id": {"type": "string"}, + "details": {}, + }, + "required": ["success", "trace_id"], + } + + +def build_manifest( + connectors: List[BaseConnector], + *, + strip_input_action: bool = True, +) -> List[Dict[str, Any]]: + """ + One manifest entry per SDK @sdk_action (specific input/output schemas). + + :param strip_input_action: When True (default), omit ``action`` from the + published ``input_schema`` properties. Bindings inject ``action`` from + the MCP tool name or REST path; keeping it out of ``inputSchema`` avoids + redundant/legacy client payloads. + """ + manifest: List[Dict[str, Any]] = [] + for connector in connectors: + cid = connector.connector_id + for action_name, meta in type(connector).sdk_action_metas().items(): + input_schema = _schema_for(meta.input_model, strict=not meta.alias_tolerant) + if strip_input_action: + _strip_action_field_from_json_schema(input_schema) + manifest.append( + { + "connector_id": cid, + "action": action_name, + "input_schema": input_schema, + "output_schema": _connector_response_schema(meta.output_model), + "requires_auth": meta.requires_auth, + "scopes": meta.scopes if meta.scopes is not None else [], + "rate_limit": meta.rate_limit if meta.rate_limit is not None else {}, + "deprecated": meta.deprecated, + } + ) + return manifest diff --git a/src/node_wire_runtime/mcp_contract.py b/src/node_wire_runtime/mcp_contract.py new file mode 100644 index 0000000..f164504 --- /dev/null +++ b/src/node_wire_runtime/mcp_contract.py @@ -0,0 +1,48 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +MCP contract flags: phased deprecation of legacy tool arguments. + +Environment variables (enterprise rollout): + +- ``NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD``: ``allow`` | ``warn`` | ``reject`` + - Legacy: ``action: "upload"`` in the tool payload for ``google_drive.files.upload``. + - Default: ``warn`` (rewrite to canonical + log once per process is not required; use WARNING). + - ``reject``: do not rewrite; authoritative tool name + ``enforce_authoritative_action`` fails. +""" + +from __future__ import annotations + +import logging +import os +from typing import Literal + +logger = logging.getLogger("runtime.mcp_contract") + +ENV_LEGACY_GDRIVE_ACTION_UPLOAD = "NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD" + + +def legacy_gdrive_action_upload_mode() -> Literal["allow", "warn", "reject"]: + raw = (os.environ.get(ENV_LEGACY_GDRIVE_ACTION_UPLOAD) or "warn").strip().lower() + if raw in ("allow", "warn", "reject"): + return raw # type: ignore[return-value] + logger.warning( + "Invalid %s=%r; using 'warn'", + ENV_LEGACY_GDRIVE_ACTION_UPLOAD, + raw, + ) + return "warn" + + +def log_legacy_gdrive_action_upload_usage() -> None: + """Structured log line for metrics/aggregation (no PII).""" + logger.info( + "mcp.legacy.alias | alias=action_upload | tool=google_drive.files.upload", + extra={ + "event": "mcp.legacy.alias", + "alias": "action_upload", + "tool": "google_drive.files.upload", + }, + ) diff --git a/src/node_wire_runtime/mcp_normalizers.py b/src/node_wire_runtime/mcp_normalizers.py new file mode 100644 index 0000000..8b5962c --- /dev/null +++ b/src/node_wire_runtime/mcp_normalizers.py @@ -0,0 +1,216 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +Per-action MCP tool argument normalizers. + +Each function mutates the arguments dict in place (same contract as before refactor). +Registered on actions via @sdk_action(..., mcp_normalize=...) or SdkActionSpec(..., mcp_normalize=...). +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List + +from node_wire_runtime.mcp_contract import ( + legacy_gdrive_action_upload_mode, + log_legacy_gdrive_action_upload_usage, +) + +logger = logging.getLogger("runtime.mcp_normalizers") + + +def _split_ids(value: Any) -> List[str]: + """Turn comma-separated string or list into a list of non-empty IDs.""" + if value is None: + return [] + if isinstance(value, list): + return [str(x).strip() for x in value if str(x).strip()] + s = str(value).strip() + if not s: + return [] + return [p.strip() for p in s.split(",") if p.strip()] + + +def _normalize_search_params_keys(sp: Dict[str, Any]) -> Dict[str, Any]: + """Map legacy/LLM keys inside search_params to FHIR-friendly names.""" + if not sp: + return {} + out = dict(sp) + if "patientId" in out and "identifier" not in out: + out["identifier"] = out.pop("patientId") + if "givenName" in out and "given" not in out: + out["given"] = out.pop("givenName") + if "familyName" in out and "family" not in out: + out["family"] = out.pop("familyName") + return out + + +def _is_missing_or_blank(value: Any) -> bool: + if value is None: + return True + if isinstance(value, str) and not value.strip(): + return True + return False + + +def normalize_fhir_read_patient(args: Dict[str, Any]) -> None: + """Map legacy LLM keys for FHIR read_patient (Epic/Cerner).""" + if not (args.get("resource_id") or "").strip(): + pid = args.get("patient_id") or args.get("patientId") + if pid is not None and str(pid).strip(): + args["resource_id"] = str(pid).strip() + args.pop("patient_id", None) + args.pop("patientId", None) + if not args.get("family_name") and args.get("familyName"): + args["family_name"] = args.pop("familyName") + if not args.get("given_name") and args.get("givenName"): + args["given_name"] = args.pop("givenName") + if args.get("search_params") and isinstance(args["search_params"], dict): + args["search_params"] = _normalize_search_params_keys(args["search_params"]) + + +def normalize_fhir_search_encounter(args: Dict[str, Any]) -> None: + """ + Map common LLM/FHIR mistakes for search_encounter (Epic/Cerner). + + - Root ``patient`` / ``patientId`` -> ``patient_id`` (strip ``Patient/`` prefix). + - Root ``sort`` -> FHIR ``_sort`` (merged into ``search_params``). + - ``sort`` inside ``search_params`` -> ``_sort``. + """ + if not (args.get("patient_id") or "").strip(): + p = args.get("patient") or args.get("patientId") + if p is not None and str(p).strip(): + p_str = str(p).strip() + if p_str.startswith("Patient/"): + p_str = p_str[len("Patient/") :] + args["patient_id"] = p_str + args.pop("patient", None) + args.pop("patientId", None) + + sp: Dict[str, Any] = { + **(dict(args["search_params"]) if isinstance(args.get("search_params"), dict) else {}) + } + root_sort = args.pop("sort", None) + root_usort = args.pop("_sort", None) + if root_sort is not None and str(root_sort).strip() and "_sort" not in sp: + sp["_sort"] = str(root_sort).strip() + elif root_usort is not None and str(root_usort).strip() and "_sort" not in sp: + sp["_sort"] = str(root_usort).strip() + if "sort" in sp and "_sort" not in sp: + sp["_sort"] = str(sp.pop("sort")).strip() + if sp: + args["search_params"] = sp + + +def normalize_fhir_search_patients(args: Dict[str, Any]) -> None: + """Map legacy LLM keys for FHIR search_patients (Epic/Cerner).""" + if not args.get("resource_ids"): + raw = args.get("patient_ids") or args.get("patientIds") + ids = _split_ids(raw) + if ids: + args["resource_ids"] = ids + args.pop("patient_ids", None) + args.pop("patientIds", None) + if not args.get("family_name") and args.get("familyName"): + args["family_name"] = args.pop("familyName") + if not args.get("given_name") and args.get("givenName"): + args["given_name"] = args.pop("givenName") + if args.get("search_params") and isinstance(args["search_params"], dict): + args["search_params"] = _normalize_search_params_keys(args["search_params"]) + + +def normalize_google_drive_files_upload(args: Dict[str, Any]) -> None: + """ + Map common LLM mistakes for files.upload to FilesUploadOperation fields. + Mutates args in place. Canonical keys already set on the root win over aliases/nesting. + """ + media = args.get("media") + if media is not None: + if isinstance(media, dict): + if _is_missing_or_blank(args.get("name")) and not _is_missing_or_blank( + media.get("name") + ): + args["name"] = media.get("name") + + if _is_missing_or_blank(args.get("mime_type")): + mt = media.get("mime_type") or media.get("mimeType") + if not _is_missing_or_blank(mt): + args["mime_type"] = mt + + if _is_missing_or_blank(args.get("parents")): + parents = media.get("parents") + if isinstance(parents, list) and parents: + args["parents"] = parents + elif isinstance(parents, str) and parents.strip(): + args["parents"] = _split_ids(parents) + + if _is_missing_or_blank(args.get("content_base64")) and _is_missing_or_blank( + args.get("content") + ): + b64 = media.get("content_base64") or media.get("base64") or media.get("data") + if not _is_missing_or_blank(b64): + args["content_base64"] = b64 + else: + text = media.get("content") or media.get("text") or media.get("body") + if not _is_missing_or_blank(text): + args["content"] = text + elif isinstance(media, str): + if _is_missing_or_blank(args.get("content_base64")) and _is_missing_or_blank( + args.get("content") + ): + if media.strip(): + args["content"] = media + + args.pop("media", None) + + args.pop("media_body", None) + + nested = args.get("file") + if isinstance(nested, dict): + for key in ("name", "mime_type", "parents", "content", "content_base64"): + if key in nested and _is_missing_or_blank(args.get(key)): + args[key] = nested[key] + if _is_missing_or_blank(args.get("mime_type")) and nested.get("mimeType"): + args["mime_type"] = nested["mimeType"] + args.pop("file", None) + + if not _is_missing_or_blank(args.get("mimeType")) and _is_missing_or_blank( + args.get("mime_type") + ): + args["mime_type"] = args["mimeType"] + args.pop("mimeType", None) + + if args.get("action") == "upload": + mode = legacy_gdrive_action_upload_mode() + if mode == "reject": + logger.warning( + "Rejected legacy action value 'upload' for google_drive.files.upload " + "(set %s=allow or omit action; tool name is authoritative).", + "NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD", + ) + else: + if mode == "warn": + logger.warning( + "Deprecated: action 'upload' in google_drive.files.upload payload; " + "omit 'action' or use 'files.upload'. " + "Set NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD=reject to hard-fail." + ) + log_legacy_gdrive_action_upload_usage() + args["action"] = "files.upload" + + +def normalize_smtp_send_email(args: Dict[str, Any]) -> None: + """Map common LLM aliases for smtp.send_email to SmtpSendInput fields.""" + if _is_missing_or_blank(args.get("from_email")): + for alias in ("from", "sender", "from_addr"): + if not _is_missing_or_blank(args.get(alias)): + args["from_email"] = args[alias] + break + for alias in ("from", "sender", "from_addr"): + args.pop(alias, None) + + if isinstance(args.get("to"), str): + args["to"] = [args["to"]] diff --git a/src/runtime/models.py b/src/node_wire_runtime/models.py similarity index 68% rename from src/runtime/models.py rename to src/node_wire_runtime/models.py index 2fb135c..9daf43c 100644 --- a/src/runtime/models.py +++ b/src/node_wire_runtime/models.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations from typing import Any, Optional @@ -22,4 +26,6 @@ class ConnectorResponse(BaseModel): error_category: Optional[ErrorCategory] = None message: Optional[str] = None trace_id: str - details: Optional[Any] = None # e.g. validation errors: [{"loc": ["url"], "msg": "...", "type": "..."}] + details: Optional[Any] = ( + None # e.g. validation errors: [{"loc": ["url"], "msg": "...", "type": "..."}] + ) diff --git a/src/runtime/observability.py b/src/node_wire_runtime/observability.py similarity index 57% rename from src/runtime/observability.py rename to src/node_wire_runtime/observability.py index e5f6eff..a85293f 100644 --- a/src/runtime/observability.py +++ b/src/node_wire_runtime/observability.py @@ -1,8 +1,12 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import logging import os -from typing import Optional +from typing import Optional, cast from opentelemetry._logs import set_logger_provider from opentelemetry import trace @@ -10,9 +14,9 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler -from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogExporter from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased logger = logging.getLogger("runtime.observability") @@ -33,6 +37,69 @@ def filter(self, record: logging.LogRecord) -> bool: # noqa: A003 return True +_SENSITIVE_KEYS = { + "patient", + "ssn", + "secret", + "password", + "email", + "phone", + "dob", + "encounter", + "resourceid", +} + + +def _is_sensitive(key: str) -> bool: + k = key.lower().replace("_", "").replace("-", "").replace(" ", "") + for s in _SENSITIVE_KEYS: + if s in k: + return True + return False + + +class SanitizingSpanExporter(SpanExporter): + def __init__(self, delegate: SpanExporter): + self._delegate = delegate + + def export(self, spans): + for span in spans: + if hasattr(span, "_attributes") and span._attributes: + for k in list(span._attributes.keys()): + if _is_sensitive(k): + span._attributes[k] = "***REDACTED***" + return self._delegate.export(spans) + + def shutdown(self): + return self._delegate.shutdown() + + def force_flush(self, timeout_millis: int = 30000): + if hasattr(self._delegate, "force_flush"): + return self._delegate.force_flush(timeout_millis) + return True + + +class SanitizingLogExporter(LogExporter): + def __init__(self, delegate: LogExporter): + self._delegate = delegate + + def export(self, batch): + for record in batch: + if hasattr(record, "attributes") and record.attributes: + for k in list(record.attributes.keys()): + if _is_sensitive(k): + record.attributes[k] = "***REDACTED***" + return self._delegate.export(batch) + + def shutdown(self): + return self._delegate.shutdown() + + def force_flush(self, timeout_millis: int = 30000): + if hasattr(self._delegate, "force_flush"): + return self._delegate.force_flush(timeout_millis) + return True + + def init_observability(app_name: str = "node_wire") -> None: """ Initialize OpenTelemetry + OpenLLMetry/Traceloop for the process. @@ -67,10 +134,12 @@ def init_observability(app_name: str = "node_wire") -> None: otlp_headers: Optional[str] = os.getenv("OTEL_EXPORTER_OTLP_HEADERS") - span_exporter = OTLPSpanExporter( - headers=dict( - header.split("=", 1) for header in otlp_headers.split(",") - ) if otlp_headers else None, + span_exporter = SanitizingSpanExporter( + OTLPSpanExporter( + headers=dict(header.split("=", 1) for header in otlp_headers.split(",")) + if otlp_headers + else None, + ) ) span_processor = BatchSpanProcessor(span_exporter) @@ -79,10 +148,15 @@ def init_observability(app_name: str = "node_wire") -> None: # Logs: export Python logging records via OTLP/HTTP to the local collector. # This enables Loki ingestion when using grafana/otel-lgtm. - log_exporter = OTLPLogExporter( - headers=dict( - header.split("=", 1) for header in otlp_headers.split(",") - ) if otlp_headers else None, + log_exporter = SanitizingLogExporter( + cast( + LogExporter, + OTLPLogExporter( + headers=dict(header.split("=", 1) for header in otlp_headers.split(",")) + if otlp_headers + else None, + ), + ) ) logger_provider = LoggerProvider(resource=resource) logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) @@ -98,11 +172,10 @@ def init_observability(app_name: str = "node_wire") -> None: from traceloop.sdk import Traceloop Traceloop.init( - app_name=app_name, + app_name=app_name, ) except Exception as exc: # pragma: no cover - defensive; should not fail app startup logger.warning("Failed to initialize Traceloop/OpenLLMetry: %s", exc) _INITIALIZED = True logger.info("Observability initialized for app %s", app_name) - diff --git a/src/node_wire_runtime/policies/mcp_scope_policy.py b/src/node_wire_runtime/policies/mcp_scope_policy.py new file mode 100644 index 0000000..dc281e2 --- /dev/null +++ b/src/node_wire_runtime/policies/mcp_scope_policy.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path +from typing import Mapping, Optional + +from dotenv import load_dotenv + +from node_wire_runtime.policy import PolicyContext, PolicyDenied, PolicyHook + +logger = logging.getLogger("runtime.policy.scope") + +# Public for tests and MCP tool listing (must match hook behavior). +DEFAULT_SCOPE_MODE_ALLOW = "allow" +DEFAULT_SCOPE_MODE_DENY = "deny" + + +def _truthy_default_mode(val: str | None) -> str: + if val is None: + return DEFAULT_SCOPE_MODE_ALLOW + v = val.strip().lower() + if v in ("deny", "default-deny", "closed"): + return DEFAULT_SCOPE_MODE_DENY + return DEFAULT_SCOPE_MODE_ALLOW + + +def load_scope_policy_default_from_env() -> str: + """Return ``allow`` or ``deny`` from ``NW_MCP_SCOPE_POLICY_DEFAULT``.""" + raw = os.environ.get("NW_MCP_SCOPE_POLICY_DEFAULT") + if not raw or not str(raw).strip(): + return DEFAULT_SCOPE_MODE_ALLOW + return _truthy_default_mode(str(raw)) + + +def resolve_required_scope_for_action( + *, + connector_id: str, + action: str, + action_scope_map: Mapping[str, str], + default_mode: str, +) -> Optional[str]: + """ + Determine the scope string required for this action. + + - **allow** (default): only enforce when ``NW_MCP_ACTION_SCOPE_MAP_JSON`` has + an entry for ``connector_id.action``. + - **deny**: require either that explicit map entry or the conventional + fallback ``mcp:.``. + """ + action_key = f"{connector_id}.{action}" + explicit = action_scope_map.get(action_key) + if explicit: + return explicit + if default_mode == DEFAULT_SCOPE_MODE_DENY: + return f"mcp:{connector_id}.{action}" + return None + + +def action_allowed_for_identity_scopes( + *, + connector_id: str, + action: str, + principal: Optional[str], + tenant_id: Optional[str], + scopes: Optional[tuple[str, ...]], + action_scope_map: Mapping[str, str], + default_mode: str, +) -> bool: + """ + Same authorization decision as :class:`ScopePolicyHook` / ``tools/list`` filtering. + + Returns True if the action should be visible or executable for this caller. + """ + required = resolve_required_scope_for_action( + connector_id=connector_id, + action=action, + action_scope_map=action_scope_map, + default_mode=default_mode, + ) + scope_tuple = tuple(scopes or ()) + # Defer transport-specific authz until caller identity is propagated. + if required and not principal and not scope_tuple: + logger.info( + "Scope policy bypassed due to missing caller identity", + extra={ + "action_key": f"{connector_id}.{action}", + "required_scope": required, + }, + ) + return True + if not required: + return True + scope_set = set(scope_tuple) + return required in scope_set or "*" in scope_set + + +class ScopePolicyHook(PolicyHook): + def __init__( + self, + action_scope_map: Mapping[str, str], + *, + default_mode: str = DEFAULT_SCOPE_MODE_ALLOW, + ) -> None: + self._map = dict(action_scope_map) + self._default_mode = ( + default_mode + if default_mode in (DEFAULT_SCOPE_MODE_ALLOW, DEFAULT_SCOPE_MODE_DENY) + else DEFAULT_SCOPE_MODE_ALLOW + ) + + def check(self, context: PolicyContext) -> None: + action_key = f"{context.connector_id}.{context.action}" + required = resolve_required_scope_for_action( + connector_id=context.connector_id, + action=context.action, + action_scope_map=self._map, + default_mode=self._default_mode, + ) + scopes = tuple(context.scopes or ()) + if required and not context.principal and not scopes: + logger.info( + "Scope policy bypassed due to missing caller identity", + extra={ + "action_key": action_key, + "required_scope": required, + }, + ) + return + logger.info( + "Scope policy evaluating action", + extra={ + "action_key": action_key, + "required_scope": required or "", + "principal": context.principal or "", + "tenant_id": context.tenant_id or "", + "scopes": list(scopes), + }, + ) + if not required: + return + scope_set = set(scopes) + if required in scope_set or "*" in scope_set: + return + raise PolicyDenied(f"Missing required scope: {required}") + + +def load_scope_map_from_env() -> dict[str, str]: + raw = os.environ.get("NW_MCP_ACTION_SCOPE_MAP_JSON") + if not raw: + # Mirror MCP auth bootstrap behavior: recover config from project .env + # when launch paths inherit incomplete shell env. Use override=False so + # explicitly set variables (e.g. pytest conftest, production injection) are not + # stomped by repo .env — same as playground/scenarios load_dotenv(). + if os.environ.get("NW_REST_LOAD_DOTENV", "true").lower() not in ("0", "false", "no"): + repo_root_env = Path(__file__).resolve().parents[3] / ".env" + load_dotenv(override=False) + load_dotenv(repo_root_env, override=False) + raw = os.environ.get("NW_MCP_ACTION_SCOPE_MAP_JSON") + if not raw: + logger.info("Scope policy map not configured (env empty)") + return {} + parsed = json.loads(raw) + if not isinstance(parsed, dict): + raise ValueError("NW_MCP_ACTION_SCOPE_MAP_JSON must be a JSON object.") + out: dict[str, str] = {} + for key, value in parsed.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise ValueError( + "NW_MCP_ACTION_SCOPE_MAP_JSON must map string action keys to string scopes." + ) + out[key] = value + logger.info( + "Scope policy map loaded", + extra={"entries": len(out), "action_keys": sorted(out.keys())}, + ) + return out diff --git a/src/runtime/policy.py b/src/node_wire_runtime/policy.py similarity index 83% rename from src/runtime/policy.py rename to src/node_wire_runtime/policy.py index 14e7034..3904ce7 100644 --- a/src/runtime/policy.py +++ b/src/node_wire_runtime/policy.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations from abc import ABC, abstractmethod @@ -12,6 +16,7 @@ class PolicyContext: input_payload: Mapping[str, Any] principal: Optional[str] = None tenant_id: Optional[str] = None + scopes: Optional[tuple[str, ...]] = None class PolicyDenied(Exception): diff --git a/src/node_wire_runtime/rate_limit.py b/src/node_wire_runtime/rate_limit.py new file mode 100644 index 0000000..4caf487 --- /dev/null +++ b/src/node_wire_runtime/rate_limit.py @@ -0,0 +1,60 @@ +""" +In-memory Token Bucket rate limiter to prevent DoS across bindings. +Configuration via environment variables: + - NW_RATE_LIMIT_BURST: maximum number of tokens (default: 50) + - NW_RATE_LIMIT_REFILL_RATE: tokens added per second (default: 10.0) +""" + +from __future__ import annotations + +import asyncio +import os +import time + + +class RateLimitExceeded(Exception): + """Raised when the rate limit has been exceeded.""" + + pass + + +class TokenBucket: + def __init__(self, capacity: float, refill_rate: float) -> None: + """ + :param capacity: Maximum number of tokens the bucket can hold. + :param refill_rate: Number of tokens added to the bucket per second. + """ + self.capacity = float(capacity) + self.refill_rate = float(refill_rate) + self.tokens = self.capacity + self.last_refill = time.monotonic() + self._lock = asyncio.Lock() + + async def acquire(self, amount: int = 1) -> None: + """ + Attempt to acquire `amount` tokens from the bucket. + :raises RateLimitExceeded: if there are not enough tokens available. + """ + async with self._lock: + now = time.monotonic() + elapsed = now - self.last_refill + + # Refill the bucket based on elapsed time + self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate) + self.last_refill = now + + if self.tokens >= amount: + self.tokens -= amount + else: + raise RateLimitExceeded("Global rate limit exceeded. Please try again later.") + + +# Global default instance configured via environment variables +burst = float(os.environ.get("NW_RATE_LIMIT_BURST", "50")) +rate = float(os.environ.get("NW_RATE_LIMIT_REFILL_RATE", "10.0")) + +# Check if rate limiting is disabled for tests +if os.environ.get("NW_RATE_LIMIT_DISABLED", "false").lower() in ("0", "false", "no"): + global_rate_limiter = TokenBucket(capacity=burst, refill_rate=rate) +else: + global_rate_limiter = TokenBucket(capacity=float("inf"), refill_rate=float("inf")) diff --git a/src/node_wire_runtime/resilience.py b/src/node_wire_runtime/resilience.py new file mode 100644 index 0000000..53f753c --- /dev/null +++ b/src/node_wire_runtime/resilience.py @@ -0,0 +1,131 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import logging +from functools import wraps +from typing import Any, Awaitable, Callable, Coroutine, TypeVar + +from pybreaker import CircuitBreaker, CircuitBreakerError +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from .errors import ErrorMapper +from .models import ErrorCategory + +logger = logging.getLogger("runtime.resilience") + +T = TypeVar("T") + + +class _AbortRetry(BaseException): + """Wraps a non-retryable exception to escape tenacity's retry loop.""" + + def __init__(self, cause: Exception) -> None: + self.cause = cause + super().__init__(str(cause)) + + +def _resolve_breaker( + breaker: CircuitBreaker | Callable[[], CircuitBreaker], +) -> CircuitBreaker: + # CircuitBreaker instances are callable (__call__); treat concrete instances + # before callable check so we don't invoke them like factory functions. + if isinstance(breaker, CircuitBreaker): + return breaker + return breaker() if callable(breaker) else breaker + + +def with_resilience( + breaker: CircuitBreaker | Callable[[], CircuitBreaker], + max_attempts: int = 3, + base_wait: float = 0.5, + max_wait: float = 5.0, +) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Coroutine[Any, Any, T]]]: + """ + Decorator that applies retry (Tenacity) and circuit breaking (PyBreaker) + around an async function that may raise exceptions. + """ + + def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Coroutine[Any, Any, T]]: + @wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> T: + trace_id: str = kwargs.get("trace_id", "unknown-trace") + + async def _call() -> T: + current_breaker = _resolve_breaker(breaker) + if current_breaker.state.name == "open": + logger.error( + "Circuit breaker is OPEN; rejecting call", + extra={ + "trace_id": trace_id, + "component": "resilience", + "error": "circuit open", + }, + ) + raise CircuitBreakerError("Circuit breaker is open") + try: + result = await fn(*args, **kwargs) + current_breaker._state.on_success() # noqa: SLF001 + return result + except Exception as exc: + current_breaker._state.on_failure(exc) # noqa: SLF001 + raise + except NameError: + # pybreaker < 1.0 requires Tornado's `gen` in call_async. + # Fall back to a direct call until pybreaker is upgraded to >= 1.0. + return await fn(*args, **kwargs) + + try: + async for attempt in AsyncRetrying( + retry=retry_if_exception_type(Exception), + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=base_wait, max=max_wait), + reraise=True, + ): + with attempt: + try: + return await _call() + except Exception as exc: # noqa: BLE001 + mapped = ErrorMapper.resolve(exc) + if mapped.category is not ErrorCategory.RETRYABLE: + # Non-retryable: log, then escape the retry loop entirely. + logger.error( + "Non-retryable error during execution", + extra={ + "trace_id": trace_id, + "error_code": mapped.code, + "error_category": mapped.category.value, + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise _AbortRetry(exc) + + logger.warning( + "Retryable error during execution; will retry", + extra={ + "trace_id": trace_id, + "error_code": mapped.code, + "error_category": mapped.category.value, + "attempt_number": attempt.retry_state.attempt_number, + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + except _AbortRetry as abort: + raise abort.cause + + # Should not be reached because reraise=True ensures RetryError is propagated. + raise RuntimeError("Exhausted retries without success") + + return wrapper + + return decorator diff --git a/src/node_wire_runtime/sdk_action_spec.py b/src/node_wire_runtime/sdk_action_spec.py new file mode 100644 index 0000000..1a84704 --- /dev/null +++ b/src/node_wire_runtime/sdk_action_spec.py @@ -0,0 +1,135 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +Generic action-spec primitives for SDK-backed connectors (e.g. googleapiclient). + +Subclasses describe how validated Pydantic models map to vendor SDK calls: +resource navigation, method name, keyword/body mapping, constants, and optional +custom builders or post-processors. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple + +from pydantic import BaseModel + + +def navigate_resource(client: Any, segments: Tuple[str, ...]) -> Any: + """Traverse discovery-style APIs: client.files().permissions()...""" + api = client + for seg in segments: + api = getattr(api, seg)() + return api + + +def default_build_kwargs( + *, + kwargs_from_model: Dict[str, str], + body_from_model: Optional[Dict[str, str]], + body_constant: Optional[Dict[str, Any]], + constant_kwargs: Dict[str, Any], + computed_kwargs: Dict[str, Callable[[BaseModel], Any]], + include_empty_body: bool, + model: BaseModel, +) -> Dict[str, Any]: + """Build SDK method kwargs from a validated input model.""" + kw: Dict[str, Any] = dict(constant_kwargs) + + for attr, sdk_name in kwargs_from_model.items(): + val = getattr(model, attr, None) + if val is not None: + kw[sdk_name] = val + + for sdk_name, fn in computed_kwargs.items(): + val = fn(model) + if val is not None: + kw[sdk_name] = val + + body: Dict[str, Any] = {} + if body_constant: + body.update(body_constant) + if body_from_model: + for attr, bkey in body_from_model.items(): + val = getattr(model, attr, None) + if val is not None: + body[bkey] = val + + if body_from_model is not None or body_constant is not None: + if body or include_empty_body: + kw["body"] = body + + return kw + + +@dataclass(frozen=True) +class SdkActionSpec: + """ + Describes one vendor SDK call: resource().method(**kwargs).execute() + + When ``build_kwargs`` is None, kwargs are built from the mapping fields. + When ``build_kwargs`` is set, it receives (client, model) and must return + the full kwargs dict for the SDK method. + """ + + resource_segments: Tuple[str, ...] + method_name: str + kwargs_from_model: Dict[str, str] = field(default_factory=dict) + body_from_model: Optional[Dict[str, str]] = None + body_constant: Optional[Dict[str, Any]] = None + constant_kwargs: Dict[str, Any] = field(default_factory=dict) + computed_kwargs: Dict[str, Callable[[BaseModel], Any]] = field(default_factory=dict) + # Pass body={} when the API requires a body key even if empty (e.g. files.update). + include_empty_body: bool = False + build_kwargs: Optional[Callable[[Any, BaseModel], Dict[str, Any]]] = None + post_process: Optional[Callable[[Any, BaseModel], Any]] = None + # Set these when the spec is declared in a connector's action_specs class var. + # input_model is required; output_model falls back to cls.output_model if None. + input_model: Optional[Any] = None + output_model: Optional[Any] = None + alias_tolerant: bool = False + # Optional: mutates MCP tool args dict in place before connector.run (see mcp_normalizers). + mcp_normalize: Optional[Callable[[Dict[str, Any]], None]] = None + # Security metadata + requires_auth: bool = True + scopes: Optional[List[str]] = None + rate_limit: Optional[Dict[str, Any]] = None + deprecated: bool = False + + +def build_method_kwargs(spec: SdkActionSpec, client: Any, model: BaseModel) -> Dict[str, Any]: + if spec.build_kwargs is not None: + return spec.build_kwargs(client, model) + return default_build_kwargs( + kwargs_from_model=spec.kwargs_from_model, + body_from_model=spec.body_from_model, + body_constant=spec.body_constant, + constant_kwargs=spec.constant_kwargs, + computed_kwargs=spec.computed_kwargs, + include_empty_body=spec.include_empty_body, + model=model, + ) + + +def execute_spec_sync(client: Any, spec: SdkActionSpec, model: BaseModel) -> Any: + """Run spec.method_name on navigated resource; return execute() result (sync).""" + kwargs = build_method_kwargs(spec, client, model) + resource_api = navigate_resource(client, spec.resource_segments) + method = getattr(resource_api, spec.method_name) + result = method(**kwargs).execute() + if spec.post_process is not None: + return spec.post_process(result, model) + return result + + +async def execute_spec_in_thread( + client: Any, + spec: SdkActionSpec, + model: BaseModel, +) -> Any: + """Run execute_spec_sync in a worker thread (for sync googleapiclient).""" + return await asyncio.to_thread(execute_spec_sync, client, spec, model) diff --git a/src/node_wire_runtime/secrets/__init__.py b/src/node_wire_runtime/secrets/__init__.py new file mode 100644 index 0000000..beb097c --- /dev/null +++ b/src/node_wire_runtime/secrets/__init__.py @@ -0,0 +1,39 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +node_wire_runtime.secrets +========================= + +Pluggable secret resolution for Node Wire connectors. + +Baseline (always available): + SecretProvider — abstract base class + EnvSecretProvider — reads from os.environ (default for all deployments) + SecretNotFoundError — key absent in this provider + SecretProviderError — provider itself is broken (auth / network / config) + ChainedSecretProvider — tries providers in order; only falls through on NotFound + +Cloud backends (installed as extras): + AwsSecretsManagerProvider pip install node-wire-runtime[aws] + HashiCorpVaultProvider pip install node-wire-runtime[vault] + AzureKeyVaultProvider pip install node-wire-runtime[azure] + GcpSecretManagerProvider pip install node-wire-runtime[gcp] +""" + +from node_wire_runtime.secrets.base import ( + EnvSecretProvider, + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) +from node_wire_runtime.secrets.chained import ChainedSecretProvider + +__all__ = [ + "SecretProvider", + "EnvSecretProvider", + "SecretNotFoundError", + "SecretProviderError", + "ChainedSecretProvider", +] diff --git a/src/node_wire_runtime/secrets/aws.py b/src/node_wire_runtime/secrets/aws.py new file mode 100644 index 0000000..2e9c52a --- /dev/null +++ b/src/node_wire_runtime/secrets/aws.py @@ -0,0 +1,50 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import json + +from node_wire_runtime.secrets.base import ( + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) + +try: + import boto3 + from botocore.exceptions import BotoCoreError, ClientError +except ImportError as _e: + raise ImportError( + "boto3 is required for AwsSecretsManagerProvider. " + "Install it with: pip install 'node-wire-runtime[aws]'" + ) from _e + + +class AwsSecretsManagerProvider(SecretProvider): + """Fetches a JSON secret bundle from AWS Secrets Manager at init time. + + Keys in the JSON map to connector secret names (e.g. ``epic_client_id``). + Raise SecretProviderError on auth / network / config failures so the caller + knows the provider itself is broken rather than a single key being absent. + """ + + def __init__(self, secret_name: str, region: str = "us-east-1") -> None: + try: + client = boto3.client("secretsmanager", region_name=region) + raw = client.get_secret_value(SecretId=secret_name)["SecretString"] + self._data: dict = json.loads(raw) + except ClientError as exc: + code = exc.response["Error"]["Code"] + if code == "ResourceNotFoundException": + raise SecretNotFoundError(secret_name) from exc + raise SecretProviderError(f"AWS Secrets Manager error ({code}): {secret_name}") from exc + except BotoCoreError as exc: + raise SecretProviderError(f"AWS connection error for secret {secret_name!r}") from exc + + def get_secret(self, key: str) -> str: + try: + return self._data[key] + except KeyError: + raise SecretNotFoundError(key) diff --git a/src/node_wire_runtime/secrets/azure.py b/src/node_wire_runtime/secrets/azure.py new file mode 100644 index 0000000..6bcbf91 --- /dev/null +++ b/src/node_wire_runtime/secrets/azure.py @@ -0,0 +1,53 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from node_wire_runtime.secrets.base import ( + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) + +try: + from azure.identity import DefaultAzureCredential + from azure.keyvault.secrets import SecretClient + from azure.core.exceptions import ResourceNotFoundError, HttpResponseError +except ImportError as _e: + raise ImportError( + "azure-keyvault-secrets and azure-identity are required for AzureKeyVaultProvider. " + "Install with: pip install 'node-wire-runtime[azure]'" + ) from _e + + +class AzureKeyVaultProvider(SecretProvider): + """Reads individual secrets from Azure Key Vault on demand. + + Uses DefaultAzureCredential — works with managed identities, environment + credentials, and interactive logins without changing application code. + """ + + def __init__(self, vault_url: str) -> None: + try: + credential = DefaultAzureCredential() + self._client = SecretClient(vault_url=vault_url, credential=credential) + except Exception as exc: + raise SecretProviderError( + f"Failed to initialise Azure Key Vault client for {vault_url!r}: {exc}" + ) from exc + + def get_secret(self, key: str) -> str: + # Azure KV names use hyphens; map underscores for convention compatibility. + azure_name = key.replace("_", "-") + try: + secret = self._client.get_secret(azure_name) + if secret.value is None: + raise SecretNotFoundError(key) + return secret.value + except ResourceNotFoundError: + raise SecretNotFoundError(key) + except HttpResponseError as exc: + raise SecretProviderError( + f"Azure Key Vault HTTP error for secret {key!r}: {exc}" + ) from exc diff --git a/src/node_wire_runtime/secrets/base.py b/src/node_wire_runtime/secrets/base.py new file mode 100644 index 0000000..337dbd0 --- /dev/null +++ b/src/node_wire_runtime/secrets/base.py @@ -0,0 +1,63 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +from abc import ABC, abstractmethod + + +class SecretNotFoundError(KeyError): + """The requested key does not exist in this provider.""" + + +class SecretProviderError(RuntimeError): + """The provider itself failed (auth, network, config). Do not swallow.""" + + +class SecretProvider(ABC): + """ + Abstract port for secret resolution. + + Implementations may use environment variables, a cloud secrets manager, + or any other secure storage backend. + """ + + @abstractmethod + def get_secret(self, key: str) -> str: + """Return the secret value for the given key, or raise SecretNotFoundError.""" + raise NotImplementedError + + +class EnvSecretProvider(SecretProvider): + """SecretProvider backed by environment variables. + + Strips surrounding whitespace and quotes from values. + Tries the key as-is, then uppercased. + Raises :class:`SecretNotFoundError` if the key is absent (fail-closed). + + Set ``NW_ENV_SECRET_LEGACY_EMPTY=true`` to restore legacy behaviour of returning + ``""`` when a variable is missing (not recommended for production). + """ + + def __init__(self, *, legacy_empty_on_missing: bool | None = None) -> None: + self._env = os.environ + if legacy_empty_on_missing is None: + legacy_empty_on_missing = os.environ.get("NW_ENV_SECRET_LEGACY_EMPTY", "").lower() in ( + "1", + "true", + "yes", + ) + self._legacy_empty_on_missing = legacy_empty_on_missing + + def get_secret(self, key: str) -> str: + val = self._env.get(key) + if val is not None: + return val.strip(" '\"") + val = self._env.get(key.upper()) + if val is not None: + return val.strip(" '\"") + if self._legacy_empty_on_missing: + return "" + raise SecretNotFoundError(key) diff --git a/src/node_wire_runtime/secrets/chained.py b/src/node_wire_runtime/secrets/chained.py new file mode 100644 index 0000000..aae67e3 --- /dev/null +++ b/src/node_wire_runtime/secrets/chained.py @@ -0,0 +1,45 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import logging +from typing import Sequence + +from node_wire_runtime.secrets.base import ( + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) + +logger = logging.getLogger("node_wire_runtime.secrets.chained") + + +class ChainedSecretProvider(SecretProvider): + """Try providers in order. + + Falls through ONLY on SecretNotFoundError / KeyError (key absent). + Propagates SecretProviderError immediately — never mask a broken provider. + """ + + def __init__(self, *providers: SecretProvider) -> None: + if not providers: + raise ValueError("ChainedSecretProvider requires at least one provider") + self._providers: Sequence[SecretProvider] = providers + + def get_secret(self, key: str) -> str: + last_not_found: Exception | None = None + for provider in self._providers: + try: + return provider.get_secret(key) + except SecretProviderError: + # Provider is broken (IAM, network, config). Fail hard. + raise + except (SecretNotFoundError, KeyError) as exc: + last_not_found = exc + continue # Try next provider + + raise SecretNotFoundError( + f"Secret '{key}' not found in any of {len(self._providers)} provider(s)" + ) from last_not_found diff --git a/src/node_wire_runtime/secrets/gcp.py b/src/node_wire_runtime/secrets/gcp.py new file mode 100644 index 0000000..bda68ac --- /dev/null +++ b/src/node_wire_runtime/secrets/gcp.py @@ -0,0 +1,51 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from node_wire_runtime.secrets.base import ( + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) + +try: + from google.cloud import secretmanager + from google.api_core.exceptions import NotFound, GoogleAPICallError +except ImportError as _e: + raise ImportError( + "google-cloud-secret-manager is required for GcpSecretManagerProvider. " + "Install with: pip install 'node-wire-runtime[gcp]'" + ) from _e + + +class GcpSecretManagerProvider(SecretProvider): + """Reads the latest version of a GCP Secret Manager secret at init time. + + The secret should be a JSON object whose keys map to connector secret names. + Uses Application Default Credentials (ADC) — works with workload identity, + service account key files, and gcloud auth. + """ + + def __init__(self, project_id: str, secret_id: str, version: str = "latest") -> None: + import json + + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version}" + try: + response = client.access_secret_version(request={"name": name}) + payload = response.payload.data.decode("utf-8") + self._data: dict = json.loads(payload) + except NotFound: + raise SecretNotFoundError(f"{project_id}/{secret_id}") + except GoogleAPICallError as exc: + raise SecretProviderError( + f"GCP Secret Manager error for {project_id}/{secret_id}: {exc}" + ) from exc + + def get_secret(self, key: str) -> str: + try: + return self._data[key] + except KeyError: + raise SecretNotFoundError(key) diff --git a/src/node_wire_runtime/secrets/vault.py b/src/node_wire_runtime/secrets/vault.py new file mode 100644 index 0000000..fb4374a --- /dev/null +++ b/src/node_wire_runtime/secrets/vault.py @@ -0,0 +1,56 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from node_wire_runtime.secrets.base import ( + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) + +try: + import hvac +except ImportError as _e: + raise ImportError( + "hvac is required for HashiCorpVaultProvider. " + "Install it with: pip install 'node-wire-runtime[vault]'" + ) from _e + + +class HashiCorpVaultProvider(SecretProvider): + """Reads a KV-v2 secret from HashiCorp Vault. + + Fetches the secret at init time and caches it in memory. + Raises SecretProviderError on Vault connectivity / auth failures. + """ + + def __init__( + self, + secret_path: str, + *, + url: str = "http://127.0.0.1:8200", + token: str | None = None, + mount_point: str = "secret", + ) -> None: + try: + client = hvac.Client(url=url, token=token) + if not client.is_authenticated(): + raise SecretProviderError("Vault client is not authenticated") + response = client.secrets.kv.v2.read_secret_version( + path=secret_path, mount_point=mount_point + ) + self._data: dict = response["data"]["data"] + except SecretProviderError: + raise + except hvac.exceptions.InvalidPath: + raise SecretNotFoundError(secret_path) + except hvac.exceptions.VaultError as exc: + raise SecretProviderError(f"Vault error for path {secret_path!r}: {exc}") from exc + + def get_secret(self, key: str) -> str: + try: + return self._data[key] + except KeyError: + raise SecretNotFoundError(key) diff --git a/src/node_wire_runtime/streaming.py b/src/node_wire_runtime/streaming.py new file mode 100644 index 0000000..e66079d --- /dev/null +++ b/src/node_wire_runtime/streaming.py @@ -0,0 +1,93 @@ +import os +import time +import logging +from enum import Enum +from typing import AsyncIterator, Dict, Any, Optional + +logger = logging.getLogger("runtime.streaming") + + +class StreamSignal(str, Enum): + STARTED = "started" + CHUNK = "chunk" + COMPLETED = "completed" + FAILED = "failed" + + +def stream_completion_log(trace_id: str, success: bool, *, connector_id: str, action: str) -> None: + status = StreamSignal.COMPLETED.value if success else StreamSignal.FAILED.value + msg = "Stream completed" if success else "Stream failed" + extra = { + "trace_id": trace_id, + "connector_id": connector_id, + "action": action, + "stream_status": status, + } + if success: + logger.info( + "%s | trace_id=%s | connector_id=%s | action=%s | status=%s", + msg, + trace_id, + connector_id, + action, + status, + extra=extra, + ) + else: + logger.warning( + "%s | trace_id=%s | connector_id=%s | action=%s | status=%s", + msg, + trace_id, + connector_id, + action, + status, + extra=extra, + ) + + +def resolve_stream_buffer_ms(override: Optional[int] = None) -> int: + if override is not None: + return max(0, min(int(override), 30000)) + val = os.environ.get("NW_STREAM_BUFFER_MS", "0").strip() + try: + n = int(val) + except ValueError: + n = 0 + return max(0, min(n, 30000)) + + +async def BufferedStreamIterator( + iterator: AsyncIterator[Dict[str, Any]], + buffer_ms: int, + trace_id: str, + connector_id: str = "agent", + action: str = "stream", +) -> AsyncIterator[Dict[str, Any]]: + success = True + try: + if buffer_ms <= 0: + async for item in iterator: + yield item + return + + buffer_sec = buffer_ms / 1000.0 + buffer = [] + last_flush = time.monotonic() + + async for item in iterator: + buffer.append(item) + now = time.monotonic() + if now - last_flush >= buffer_sec: + for b_item in buffer: + yield b_item + buffer.clear() + last_flush = now + + for b_item in buffer: + yield b_item + except Exception: + success = False + raise + finally: + # Automatically emit completion log when stream ends + stream_completion_log(trace_id, success, connector_id=connector_id, action=action) diff --git a/src/node_wire_salesforce/__init__.py b/src/node_wire_salesforce/__init__.py new file mode 100644 index 0000000..b2c3109 --- /dev/null +++ b/src/node_wire_salesforce/__init__.py @@ -0,0 +1 @@ +# Connector subpackage: salesforce diff --git a/src/node_wire_salesforce/logic.py b/src/node_wire_salesforce/logic.py new file mode 100644 index 0000000..a60ac08 --- /dev/null +++ b/src/node_wire_salesforce/logic.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional, Tuple, Type, ClassVar +import httpx + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.models import ErrorCategory +from .schema import ( + CreateLeadInput, + CreateContactInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, + SalesforceOperationOutput, +) + +logger = logging.getLogger("connectors.salesforce") + + +class SalesforceTransientError(httpx.HTTPStatusError): + """Exception for transient Salesforce errors that should be retried.""" + + pass + + +class SalesforceConnector(BaseConnector): + """Salesforce connector for managing Leads and Contacts.""" + + connector_id = "salesforce" + action = "execute" # Multi-action dispatcher + output_model = SalesforceOperationOutput + + error_map: ClassVar[Dict[Type[BaseException], Tuple[ErrorCategory, str]]] = { + httpx.ConnectError: (ErrorCategory.RETRYABLE, "SALESFORCE_CONNECT_ERROR"), + httpx.TimeoutException: (ErrorCategory.RETRYABLE, "SALESFORCE_TIMEOUT"), + SalesforceTransientError: (ErrorCategory.RETRYABLE, "SALESFORCE_TRANSIENT_ERROR"), + httpx.HTTPStatusError: (ErrorCategory.BUSINESS, "SALESFORCE_API_ERROR"), + } + + def _get_base_url(self) -> str: + return self.secret_provider.get_secret("salesforce_instance_url").rstrip("/") + + def _get_api_version(self) -> str: + return "v58.0" + + async def _get_auth_headers(self) -> Dict[str, str]: + return await self.get_auth_headers() + + @nw_action("create_lead") + async def create_lead( + self, params: CreateLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest( + "POST", "Lead", params.model_dump(by_alias=True, exclude={"action"}), trace_id + ) + + @nw_action("read_lead") + async def read_lead(self, params: ReadLeadInput, *, trace_id: str) -> SalesforceOperationOutput: + return await self._execute_rest("GET", f"Lead/{params.record_id}", None, trace_id) + + @nw_action("update_lead") + async def update_lead( + self, params: UpdateLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest( + "PATCH", f"Lead/{params.record_id}", params.fields, trace_id + ) + + @nw_action("delete_lead") + async def delete_lead( + self, params: DeleteLeadInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("DELETE", f"Lead/{params.record_id}", None, trace_id) + + @nw_action("create_contact") + async def create_contact( + self, params: CreateContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest( + "POST", "Contact", params.model_dump(by_alias=True, exclude={"action"}), trace_id + ) + + @nw_action("read_contact") + async def read_contact( + self, params: ReadContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("GET", f"Contact/{params.record_id}", None, trace_id) + + @nw_action("update_contact") + async def update_contact( + self, params: UpdateContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest( + "PATCH", f"Contact/{params.record_id}", params.fields, trace_id + ) + + @nw_action("delete_contact") + async def delete_contact( + self, params: DeleteContactInput, *, trace_id: str + ) -> SalesforceOperationOutput: + return await self._execute_rest("DELETE", f"Contact/{params.record_id}", None, trace_id) + + async def _execute_rest( + self, method: str, path: str, payload: Optional[Dict[str, Any]], trace_id: str + ) -> SalesforceOperationOutput: + base_url = self._get_base_url() + api_version = self._get_api_version() + url = f"{base_url}/services/data/{api_version}/sobjects/{path}" + + headers = await self._get_auth_headers() + if payload: + headers["Content-Type"] = "application/json" + if isinstance(payload, dict): + payload = {k: v for k, v in payload.items() if v is not None} + + logger.info( + "Executing Salesforce REST call", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "method": method, + "path": path, + }, + ) + + async with httpx.AsyncClient() as client: + try: + response = await client.request( + method, url, headers=headers, json=payload, timeout=30.0 + ) + + # Handle transient errors (5xx) by raising a retryable exception + if response.status_code >= 500: + raise SalesforceTransientError( + message=f"Salesforce server error: {response.status_code}", + request=response.request, + response=response, + ) + + response.raise_for_status() + + data = {} + if response.content: + try: + data = response.json() + except Exception: + data = {"text": response.text} + + obj_type = path.split("/")[0] + res_id = data.get("id") or data.get("Id") if isinstance(data, dict) else None + + if not res_id and "/" in path: + res_id = path.split("/")[1] + + return SalesforceOperationOutput( + success=True, resource_id=res_id, resource_type=obj_type, data=data + ) + except Exception as exc: + # We log and re-raise to let the platform (ErrorMapper + Resilience) handle it + logger.error( + "Salesforce REST call failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "method": method, + "path": path, + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise diff --git a/src/node_wire_salesforce/registration.py b/src/node_wire_salesforce/registration.py new file mode 100644 index 0000000..86de1be --- /dev/null +++ b/src/node_wire_salesforce/registration.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +# Salesforce registration module. +# Mappings for httpx errors are now handled directly in SalesforceConnector.error_map +# in logic.py, which BaseConnector registers automatically. +# This module remains for package-level registration/discovery side effects. diff --git a/src/node_wire_salesforce/schema.py b/src/node_wire_salesforce/schema.py new file mode 100644 index 0000000..c51d0a9 --- /dev/null +++ b/src/node_wire_salesforce/schema.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Literal, Optional, Union +from pydantic import BaseModel, Field, field_validator, ConfigDict +import re + +SALESFORCE_ID_REGEX = re.compile(r"^[a-zA-Z0-9]{15,18}$") + + +class SalesforceError(BaseModel): + message: str + code: Optional[str] = None + fields: Optional[List[str]] = None + + +class SalesforceOperationOutput(BaseModel): + success: bool = True + resource_id: Optional[str] = None + resource_type: Optional[str] = None + data: Optional[Dict[str, Any]] = None + errors: Optional[List[SalesforceError]] = None + + +# Creation Models +class CreateLeadInput(BaseModel): + action: Literal["create_lead"] = "create_lead" + last_name: str = Field(..., alias="LastName") + company: str = Field(..., alias="Company") + first_name: Optional[str] = Field(None, alias="FirstName") + title: Optional[str] = Field(None, alias="Title") + email: Optional[str] = Field(None, alias="Email") + phone: Optional[str] = Field(None, alias="Phone") + mobile_phone: Optional[str] = Field(None, alias="MobilePhone") + street: Optional[str] = Field(None, alias="Street") + city: Optional[str] = Field(None, alias="City") + state: Optional[str] = Field(None, alias="State") + postal_code: Optional[str] = Field(None, alias="PostalCode") + country: Optional[str] = Field(None, alias="Country") + description: Optional[str] = Field(None, alias="Description") + lead_source: Optional[str] = Field(None, alias="LeadSource") + status: Optional[str] = Field(None, alias="Status") + rating: Optional[str] = Field(None, alias="Rating") + website: Optional[str] = Field(None, alias="Website") + number_of_employees: Optional[int] = Field(None, alias="NumberOfEmployees") + industry: Optional[str] = Field(None, alias="Industry") + annual_revenue: Optional[float] = Field(None, alias="AnnualRevenue") + + model_config = ConfigDict(populate_by_name=True) + + +class CreateContactInput(BaseModel): + action: Literal["create_contact"] = "create_contact" + last_name: str = Field(..., alias="LastName") + first_name: Optional[str] = Field(None, alias="FirstName") + account_id: Optional[str] = Field(None, alias="AccountId") + title: Optional[str] = Field(None, alias="Title") + + @field_validator("account_id") + @classmethod + def validate_account_id(cls, v: Optional[str]) -> Optional[str]: + if v and not SALESFORCE_ID_REGEX.match(v): + raise ValueError( + "Invalid Salesforce AccountId format (must be 15 or 18 alphanumeric characters)" + ) + return v + + email: Optional[str] = Field(None, alias="Email") + phone: Optional[str] = Field(None, alias="Phone") + mobile_phone: Optional[str] = Field(None, alias="MobilePhone") + mailing_street: Optional[str] = Field(None, alias="MailingStreet") + mailing_city: Optional[str] = Field(None, alias="MailingCity") + mailing_state: Optional[str] = Field(None, alias="MailingState") + mailing_postal_code: Optional[str] = Field(None, alias="MailingPostalCode") + mailing_country: Optional[str] = Field(None, alias="MailingCountry") + description: Optional[str] = Field(None, alias="Description") + lead_source: Optional[str] = Field(None, alias="LeadSource") + department: Optional[str] = Field(None, alias="Department") + + model_config = ConfigDict(populate_by_name=True) + + +# Read/Delete Models +class SalesforceResourceInput(BaseModel): + action: Literal["read_lead", "delete_lead", "read_contact", "delete_contact"] + record_id: str + + @field_validator("record_id") + @classmethod + def validate_id(cls, v: str) -> str: + if not SALESFORCE_ID_REGEX.match(v): + raise ValueError( + "Invalid Salesforce record_id format (must be 15 or 18 alphanumeric characters)" + ) + return v + + +class ReadLeadInput(SalesforceResourceInput): + action: Literal["read_lead"] = "read_lead" + + +class DeleteLeadInput(SalesforceResourceInput): + action: Literal["delete_lead"] = "delete_lead" + + +class ReadContactInput(SalesforceResourceInput): + action: Literal["read_contact"] = "read_contact" + + +class DeleteContactInput(SalesforceResourceInput): + action: Literal["delete_contact"] = "delete_contact" + + +# Update Models +class UpdateLeadInput(BaseModel): + action: Literal["update_lead"] = "update_lead" + record_id: str + fields: Dict[str, Any] + + @field_validator("record_id") + @classmethod + def validate_id(cls, v: str) -> str: + if not SALESFORCE_ID_REGEX.match(v): + raise ValueError( + "Invalid Salesforce record_id format (must be 15 or 18 alphanumeric characters)" + ) + return v + + +class UpdateContactInput(BaseModel): + action: Literal["update_contact"] = "update_contact" + record_id: str + fields: Dict[str, Any] + + @field_validator("record_id") + @classmethod + def validate_id(cls, v: str) -> str: + if not SALESFORCE_ID_REGEX.match(v): + raise ValueError( + "Invalid Salesforce record_id format (must be 15 or 18 alphanumeric characters)" + ) + return v + + +SalesforceInput = Union[ + CreateLeadInput, + CreateContactInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, +] diff --git a/src/node_wire_slack/README.md b/src/node_wire_slack/README.md new file mode 100644 index 0000000..4f1fb8e --- /dev/null +++ b/src/node_wire_slack/README.md @@ -0,0 +1,126 @@ +# Slack Connector — Technical Documentation + +> **Platform:** Node Wire +> **Connector ID:** `slack` +> **REST:** One route per operation, e.g. `POST /connectors/slack/post_message`. +> **Discriminator:** `action` field (discriminated-union payload) +> **Source:** `src/node_wire_slack/` + +--- + +## 1. Operations Overview + +The Slack connector provides high-level actions for messaging and file management. It follows the standard Node-Wire 4-file structure, ensuring consistent authentication, error handling, and schema validation. + +### Architecture +- **Schema Validation**: All inputs are validated using Pydantic models in `schema.py`. The `action` field acts as a discriminator to route payloads to the correct handler. +- **Channel Resolution**: A specialized internal helper, `_resolve_channel_id`, automatically maps flexible target identifiers for messaging operations (Channel Names like `#general`, Channel IDs, or User IDs like `U...`) to the correct Slack identifiers. For `upload_file`, Slack's external upload API requires a real conversation ID, so unresolved channel names are rejected before upload begins. +- **File Uploads**: Implements Slack's recommended 3-step external upload flow (`getUploadURLExternal`), supporting both absolute filesystem paths and base64-encoded content. +- **Authentication**: Bot tokens (`xoxb-...`) are resolved at call-time via the `SecretProvider`, ensuring no credentials are ever logged or stored on the instance. + +### Available Operations + +| Action | Description | +|---|---| +| `post_message` | Send a message to a channel, group, or user | +| `send_direct_message` | Send a private message to a specific user (resolves User ID to DM) | +| `upload_file` | Upload and share a file to a channel or direct message | + +--- + +## 2. Operation Reference + +### `post_message` + +Sends a message to a Slack conversation. Supports plain text and rich Block Kit layouts. + +| Field | Type | Required | Description | +|---|---|---|---| +| `action` | string | ✅ | Must be `"post_message"` | +| `channel` | string | ✅ | Target Channel ID (`C...`), Name (`#general`), or User ID (`U...`) | +| `message` | string | ✅ | Plain-text fallback message (markdown supported) | +| `blocks` | array / string | No | Block Kit payload as JSON string or pre-parsed list | +| `token_secret_key` | string | No | SecretProvider key (Default: `SLACK_BOT_TOKEN`) | + +**Request body — with blocks:** + +```json +{ + "action": "post_message", + "channel": "#general", + "message": "Hello from Node-Wire!", + "blocks": [ + { + "type": "section", + "text": { "type": "mrkdwn", "text": "Hello *Node-Wire*!" } + } + ] +} +``` + +--- + +### `send_direct_message` + +A specialized action for private communication. If a User ID is provided, the connector automatically opens/resolves the DM channel before posting. + +| Field | Type | Required | Description | +|---|---|---|---| +| `action` | string | ✅ | Must be `"send_direct_message"` | +| `channel` | string | ✅ | Target User ID (`U...`), Channel ID (`D...`), or Name | +| `message` | string | ✅ | Plain-text fallback message | +| `blocks` | array / string | No | Optional Block Kit payload | + +**Request body:** + +```json +{ + "action": "send_direct_message", + "channel": "U12345678", + "message": "Private clinical notification." +} +``` + +--- + +### `upload_file` + +Uploads a file to Slack using the external-upload flow. Supports local file paths (sandboxed) or raw base64 data. +Unlike `chat.postMessage`, this Slack API requires a real conversation ID when sharing the uploaded file, so unresolved names like `#general` are not accepted. + +| Field | Type | Required | Description | +|---|---|---|---| +| `action` | string | ✅ | Must be `"upload_file"` | +| `channel` | string | ✅ | Target Channel, Name, or User ID to share the file with | +| `filename` | string | No | Display name for the uploaded file | +| `initial_comment` | string | No | Message posted alongside the file | +| `filepath` | string | No* | Absolute path to local file (sandboxed). *Required if content_base64 is empty* | +| `content_base64` | string | No* | Base64-encoded content. *Required if filepath is empty* | + +**Request body — via Filename:** + +```json +{ + "action": "upload_file", + "channel": "C12345678", + "filename": "patient_summary.pdf", + "filepath": "/slack_attachments/summary_123.pdf" +} +``` + +--- + +## 3. Error Taxonomy + +All Slack API errors are mapped into the Node-Wire platform taxonomy. The connector translates `ok: false` responses into typed exceptions for consistent retry and troubleshooting. + +| Error Code | Category | Trigger Conditions | +|---|---|---| +| `SLACK_AUTH_ERROR` | `AUTH` | Token invalid, revoked, or account inactive | +| `SLACK_PERMISSION_ERROR` | `AUTH` | Missing required OAuth scopes (e.g., `chat:write`) | +| `SLACK_RATE_LIMIT` | `RETRYABLE` | HTTP 429 or `ratelimited` error | +| `SLACK_UPLOAD_ERROR` | `BUSINESS` | Bad content, file too large, or upload step failure | +| `SLACK_MESSAGE_ERROR` | `BUSINESS` | Channel not found, invalid blocks, or other message errors | + +### Implementation Note +The connector enforces a default upload limit (configurable via `NW_SLACK_UPLOAD_LIMIT_MB`) to prevent memory exhaustion during base64 decoding. diff --git a/src/node_wire_slack/__init__.py b/src/node_wire_slack/__init__.py new file mode 100644 index 0000000..74ae7ef --- /dev/null +++ b/src/node_wire_slack/__init__.py @@ -0,0 +1 @@ +# Connector subpackage: slack diff --git a/src/node_wire_slack/exceptions.py b/src/node_wire_slack/exceptions.py new file mode 100644 index 0000000..fa379fc --- /dev/null +++ b/src/node_wire_slack/exceptions.py @@ -0,0 +1,28 @@ +""" +Domain exception hierarchy for the Slack connector. + +These exceptions are raised by logic.py and mapped to ErrorCategory codes +by registration.py via ErrorMapper. +""" + +from __future__ import annotations + + +class SlackAuthError(Exception): + """Raised when the bot token is invalid, revoked, or the account is inactive.""" + + +class SlackPermissionError(Exception): + """Raised when the token lacks the required OAuth scope for the operation.""" + + +class SlackRateLimitError(Exception): + """Raised on HTTP 429 or Slack's `ratelimited` error — eligible for retry.""" + + +class SlackUploadError(Exception): + """Raised when a file upload step fails (bad content, missing fields, Slack error).""" + + +class SlackMessageError(Exception): + """Raised when a chat.postMessage call fails for a business-logic reason.""" diff --git a/src/node_wire_slack/logic.py b/src/node_wire_slack/logic.py new file mode 100644 index 0000000..3218a83 --- /dev/null +++ b/src/node_wire_slack/logic.py @@ -0,0 +1,488 @@ +""" +Slack connector for Node-Wire. + +Structure mirrors node_wire_smtp/logic.py: + - Private async HTTP helpers at module level (no separate helper file). + - SlackConnector(BaseConnector) with one @sdk_action per operation. + +The Slack Bot Token is NEVER logged or included in exceptions. +All credentials are resolved at call-time via SecretProvider. +""" + +from __future__ import annotations + +import base64 +import binascii +import json +import logging +import os +import re +from typing import Any + +import httpx + +from node_wire_runtime import BaseConnector, sdk_action + +from .exceptions import ( + SlackAuthError, + SlackMessageError, + SlackPermissionError, + SlackRateLimitError, + SlackUploadError, +) +from .schema import ( + SlackOutput, + SlackPostMessageInput, + SlackSendDirectMessageInput, + SlackUploadFileInput, +) + +logger = logging.getLogger("connectors.slack") + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_CHAT_POST_URL = "https://slack.com/api/chat.postMessage" +_GET_UPLOAD_URL = "https://slack.com/api/files.getUploadURLExternal" +_COMPLETE_UPLOAD_URL = "https://slack.com/api/files.completeUploadExternal" + +_DEFAULT_TIMEOUT = 30.0 +_HARD_UPLOAD_LIMIT_MB = 100 +_DEFAULT_UPLOAD_LIMIT_MB = 50 +_CHANNEL_ID_RE = re.compile(r"^[CGDZ][A-Z0-9]{8,}$") + + +def _get_api_url(path: str) -> str: + """Helper to construct Slack API URLs, allowing base URL override via NW_SLACK_API_BASE_URL.""" + base = os.environ.get("NW_SLACK_API_BASE_URL", "https://slack.com/api").rstrip("/") + return f"{base}/{path.lstrip('/')}" + + +_CHAT_POST_URL = _get_api_url("chat.postMessage") +_GET_UPLOAD_URL = _get_api_url("files.getUploadURLExternal") +_COMPLETE_UPLOAD_URL = _get_api_url("files.completeUploadExternal") + +# Sandboxed directory for filesystem-based uploads. +_ATTACHMENTS_DIR = os.environ.get("NW_SLACK_ATTACHMENTS_DIR", "/slack_attachments") + +# Slack error strings that map to specific domain exceptions. +_AUTH_ERRORS = frozenset({"invalid_auth", "token_revoked", "account_inactive", "not_authed"}) +_SCOPE_ERRORS = frozenset({"missing_scope", "invalid_scopes"}) +_RATE_ERRORS = frozenset({"ratelimited"}) + + +# --------------------------------------------------------------------------- +# Private HTTP helpers (module-level, not a separate file) +# --------------------------------------------------------------------------- + + +def _auth_headers(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +def _is_valid_channel_id(value: str) -> bool: + """Return True when *value* matches Slack's channel_id format.""" + return bool(_CHANNEL_ID_RE.fullmatch(value.strip())) + + +def _raise_for_slack_error(response_json: dict[str, Any], http_status: int) -> None: + """Translate a Slack `ok: false` payload into a typed domain exception.""" + slack_error = response_json.get("error", "unknown") + messages = response_json.get("response_metadata", {}).get("messages", []) + detail = ". ".join(messages) if messages else str(slack_error) + + if slack_error in _AUTH_ERRORS: + raise SlackAuthError("Slack authentication failed or token was revoked.") + if slack_error in _SCOPE_ERRORS: + raise SlackPermissionError(f"Slack permission error: {detail}") + if slack_error in _RATE_ERRORS or http_status == 429: + raise SlackRateLimitError(f"Slack rate limit: {detail}") + raise SlackMessageError(f"Slack API error '{slack_error}': {detail}") + + +async def _post_json( + url: str, + token: str, + body: dict[str, Any], + timeout: float = _DEFAULT_TIMEOUT, +) -> dict[str, Any]: + """POST JSON to the Slack API. Raises a typed exception on failure.""" + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + url, + headers={**_auth_headers(token), "Content-Type": "application/json"}, + json=body, + ) + data = response.json() + if not data.get("ok"): + _raise_for_slack_error(data, response.status_code) + return data + + +async def _get_upload_url( + token: str, + filename: str, + length: int, + timeout: float = _DEFAULT_TIMEOUT, +) -> tuple[str, str]: + """Step 1 of the external upload flow. Returns (upload_url, file_id).""" + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + _GET_UPLOAD_URL, + headers=_auth_headers(token), + data={"filename": filename, "length": str(length)}, + ) + data = response.json() + if not data.get("ok"): + _raise_for_slack_error(data, response.status_code) + upload_url = data.get("upload_url", "") + file_id = data.get("file_id", "") + if not upload_url or not file_id: + raise SlackUploadError("Slack did not return upload_url or file_id.") + return upload_url, file_id + + +async def _upload_bytes( + upload_url: str, + content: bytes, + timeout: float = _DEFAULT_TIMEOUT, +) -> None: + """Step 2: PUT raw bytes to the pre-signed URL Slack returned.""" + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + upload_url, + content=content, + headers={"Content-Type": "application/octet-stream"}, + ) + if response.status_code != 200: + raise SlackUploadError(f"Upload to pre-signed URL failed with HTTP {response.status_code}.") + + +async def _complete_upload( + token: str, + file_id: str, + title: str, + channel_id: str = "", + initial_comment: str = "", + timeout: float = _DEFAULT_TIMEOUT, +) -> dict[str, Any]: + """Step 3: Finalise the upload and optionally share to a channel.""" + data: dict[str, Any] = { + "files": json.dumps([{"id": file_id, "title": title}]), + } + if _is_valid_channel_id(channel_id): + data["channel_id"] = channel_id + if initial_comment: + data["initial_comment"] = initial_comment + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + _COMPLETE_UPLOAD_URL, + headers=_auth_headers(token), + data=data, + ) + resp_data = response.json() + if not resp_data.get("ok"): + _raise_for_slack_error(resp_data, response.status_code) + return resp_data + + +def _resolve_blocks(blocks: Any) -> list[Any] | None: + """Parse Block Kit payload from a JSON string or pass through a list. + Raises SlackMessageError on invalid JSON.""" + if blocks is None: + return None + if isinstance(blocks, list): + return blocks + if isinstance(blocks, str) and blocks.strip(): + try: + parsed = json.loads(blocks) + except (json.JSONDecodeError, TypeError) as exc: + raise SlackMessageError(f"Invalid blocks JSON: {exc}") from exc + if not isinstance(parsed, list): + raise SlackMessageError("blocks must be a JSON array.") + return parsed + return None + + +def _get_upload_limit_bytes() -> int: + raw = os.environ.get("NW_SLACK_UPLOAD_LIMIT_MB", "") + try: + mb = int(raw.strip()) if raw.strip() else _DEFAULT_UPLOAD_LIMIT_MB + except ValueError: + mb = _DEFAULT_UPLOAD_LIMIT_MB + mb = max(1, min(mb, _HARD_UPLOAD_LIMIT_MB)) + return mb * 1024 * 1024 + + +def _resolve_upload_path(filepath: str) -> str: + """Validate that *filepath* is under the sandboxed attachments directory.""" + allowed = os.path.realpath(_ATTACHMENTS_DIR) + if not os.path.isabs(filepath): + raise SlackUploadError( + f"filepath must be an absolute path under '{allowed}'. Got: {filepath!r}" + ) + candidate = os.path.realpath(filepath) + if candidate != allowed and not candidate.startswith(allowed + os.sep): + raise SlackUploadError(f"filepath must be under '{allowed}'. Got: {filepath!r}") + return candidate + + +async def _resolve_channel_id(token: str, target: str) -> str: + """ + Resolve a target (Channel name, Channel ID, or User ID) to a Slack Channel ID. + - If NW_SLACK_SKIP_RESOLVE=true, returns target as-is (useful for mocks/restricted envs). + - If it already looks like a Channel ID (C, G, D), return it. + - If it starts with U or W, it's a User ID; call conversations.open to get the DM channel. + - Otherwise, return as-is (names like #general are handled natively by chat.postMessage). + """ + target = target.strip() + if not target: + return target + + if os.environ.get("NW_SLACK_SKIP_RESOLVE", "").lower() == "true": + logger.debug(f"Skipping channel resolution for {target} (NW_SLACK_SKIP_RESOLVE=true)") + return target + + prefix = target[0].upper() + + # Already a channel/group/dm ID + if prefix in ("C", "G", "D", "Z"): + return target + + # User ID -> Resolve to DM channel + if prefix in ("U", "W"): + try: + async with httpx.AsyncClient(timeout=_DEFAULT_TIMEOUT) as client: + response = await client.post( + _get_api_url("conversations.open"), + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"users": target}, + ) + data = response.json() + if data.get("ok"): + resolved_id = data["channel"]["id"] + logger.debug(f"Resolved User ID {target} to DM channel {resolved_id}") + return resolved_id + + # If Slack returns ok: false, fallback to the original ID + logger.warning( + f"Failed to resolve User ID {target} to DM channel: {data.get('error')}" + ) + return target + except Exception as exc: + # Catch network errors (ConnectError, etc.) and fallback to original ID + logger.warning(f"Network error resolving User ID {target} to DM channel: {exc}") + return target + + return target + + +# --------------------------------------------------------------------------- +# Connector +# --------------------------------------------------------------------------- + + +class SlackConnector(BaseConnector): + """ + Slack connector: post messages, send DMs, and upload files to Slack channels. + + Authentication uses a Slack Bot Token (xoxb-…) fetched at call-time via + SecretProvider. The token is never stored on the instance or emitted in logs. + + Actions + ------- + post_message — chat.postMessage to a channel + send_direct_message — chat.postMessage to a user DM (same API, by user ID) + upload_file — 3-step external upload (getUploadURLExternal flow) + """ + + connector_id = "slack" + output_model = SlackOutput + + # ------------------------------------------------------------------ + # post_message + # ------------------------------------------------------------------ + + @sdk_action("post_message") + async def post_message(self, params: SlackPostMessageInput, *, trace_id: str) -> SlackOutput: + logger.info( + "Sending Slack channel message", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "post_message", + "channel": params.channel, + }, + ) + token = self.secret_provider.get_secret(params.token_secret_key) + channel_id = await _resolve_channel_id(token, params.channel) + + body: dict[str, Any] = {"channel": channel_id, "text": params.message} + parsed_blocks = _resolve_blocks(params.blocks) + if parsed_blocks is not None: + body["blocks"] = parsed_blocks + + data = await _post_json(_CHAT_POST_URL, token, body) + + logger.info( + "Slack channel message sent", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "post_message", + "channel": channel_id, + "ts": data.get("ts"), + }, + ) + return SlackOutput( + ok=True, + ts=data.get("ts"), + channel=data.get("channel") or channel_id, + description=f"Message sent to {params.channel}.", + raw=data, + ) + + # ------------------------------------------------------------------ + # send_direct_message + # ------------------------------------------------------------------ + + @sdk_action("send_direct_message") + async def send_direct_message( + self, params: SlackSendDirectMessageInput, *, trace_id: str + ) -> SlackOutput: + logger.info( + "Sending Slack direct message", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "send_direct_message", + "channel": params.channel, + }, + ) + token = self.secret_provider.get_secret(params.token_secret_key) + channel_id = await _resolve_channel_id(token, params.channel) + + body: dict[str, Any] = {"channel": channel_id, "text": params.message} + parsed_blocks = _resolve_blocks(params.blocks) + if parsed_blocks is not None: + body["blocks"] = parsed_blocks + + data = await _post_json(_CHAT_POST_URL, token, body) + + logger.info( + "Slack direct message sent", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "send_direct_message", + "channel": channel_id, + "ts": data.get("ts"), + }, + ) + return SlackOutput( + ok=True, + ts=data.get("ts"), + channel=data.get("channel") or channel_id, + description=f"Direct message sent to user {params.channel}.", + raw=data, + ) + + # ------------------------------------------------------------------ + # upload_file + # ------------------------------------------------------------------ + + @sdk_action("upload_file") + async def upload_file(self, params: SlackUploadFileInput, *, trace_id: str) -> SlackOutput: + logger.info( + "Starting Slack file upload", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "upload_file", + "channel": params.channel, + }, + ) + token = self.secret_provider.get_secret(params.token_secret_key) + channel_id = await _resolve_channel_id(token, params.channel) + if params.channel and not _is_valid_channel_id(channel_id): + raise SlackUploadError( + f"Could not resolve {params.channel!r} to a valid Slack channel ID. " + "Provide a channel ID (for example C01AB2CD3EF) instead of a channel name." + ) + limit_bytes = _get_upload_limit_bytes() + + # --- Resolve content bytes --- + if params.filepath: + safe_path = _resolve_upload_path(params.filepath) + if not os.path.isfile(safe_path): + raise SlackUploadError(f"No such file in upload directory: {params.filepath!r}") + size = os.path.getsize(safe_path) + effective_filename = params.filename or os.path.basename(safe_path) + if size > limit_bytes: + raise SlackUploadError( + f"File '{effective_filename}' is {size / 1024 / 1024:.2f} MB, " + f"exceeds limit of {limit_bytes / 1024 / 1024:.0f} MB." + ) + with open(safe_path, "rb") as fh: + content_bytes = fh.read() + + elif params.content_base64: + effective_filename = params.filename or "upload.bin" + try: + content_bytes = base64.b64decode(params.content_base64, validate=True) + except binascii.Error as exc: + raise SlackUploadError(f"Invalid base64 content: {exc}") from exc + if len(content_bytes) > limit_bytes: + raise SlackUploadError( + f"Decoded content is {len(content_bytes) / 1024 / 1024:.2f} MB, " + f"exceeds limit of {limit_bytes / 1024 / 1024:.0f} MB." + ) + + else: + raise SlackUploadError("Either 'filepath' or 'content_base64' must be provided.") + + # --- 3-step external upload --- + logger.info( + "Requesting upload URL from Slack", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "upload_file", + "nw_filename": effective_filename, + "size_bytes": len(content_bytes), + }, + ) + upload_url, file_id = await _get_upload_url(token, effective_filename, len(content_bytes)) + + await _upload_bytes(upload_url, content_bytes) + + data = await _complete_upload( + token, + file_id, + title=effective_filename, + channel_id=channel_id, + initial_comment=params.initial_comment, + ) + + logger.info( + "Slack file upload completed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "upload_file", + "channel": channel_id, + "file_id": file_id, + }, + ) + return SlackOutput( + ok=True, + file_id=file_id, + channel=params.channel, + description=f"File '{effective_filename}' uploaded to {params.channel}.", + raw=data, + ) diff --git a/src/node_wire_slack/registration.py b/src/node_wire_slack/registration.py new file mode 100644 index 0000000..31d53fd --- /dev/null +++ b/src/node_wire_slack/registration.py @@ -0,0 +1,34 @@ +""" +ErrorMapper registrations for the Slack connector. + +Mirrors node_wire_google_drive/registration.py — registers domain exceptions +from exceptions.py so the runtime can translate them into the standard +ConnectorResponse error taxonomy. +""" + +from __future__ import annotations + +from node_wire_runtime import ErrorCategory, ErrorMapper + +from .exceptions import ( + SlackAuthError, + SlackMessageError, + SlackPermissionError, + SlackRateLimitError, + SlackUploadError, +) + +# Auth failures — token is invalid or revoked. +ErrorMapper.register(SlackAuthError, ErrorCategory.AUTH, code="SLACK_AUTH_ERROR") + +# Permission failures — token lacks the required OAuth scope. +ErrorMapper.register(SlackPermissionError, ErrorCategory.AUTH, code="SLACK_PERMISSION_ERROR") + +# Rate-limit — eligible for automatic retry by the runtime. +ErrorMapper.register(SlackRateLimitError, ErrorCategory.RETRYABLE, code="SLACK_RATE_LIMIT") + +# Upload failures — bad content, missing fields, or Slack API error during upload. +ErrorMapper.register(SlackUploadError, ErrorCategory.BUSINESS, code="SLACK_UPLOAD_ERROR") + +# Message failures — channel not found, payload rejected, etc. +ErrorMapper.register(SlackMessageError, ErrorCategory.BUSINESS, code="SLACK_MESSAGE_ERROR") diff --git a/src/node_wire_slack/schema.py b/src/node_wire_slack/schema.py new file mode 100644 index 0000000..54079d5 --- /dev/null +++ b/src/node_wire_slack/schema.py @@ -0,0 +1,134 @@ +""" +Pydantic v2 input/output models for the Slack connector. + +All input models include an `action` discriminator field so that +BaseConnector can build a discriminated union and route to the correct +@sdk_action method automatically — the same pattern used by Google Drive. + +Only `channel` and `message` are required for messaging actions. +Authentication is handled via SecretProvider (never hard-coded here). +""" + +from __future__ import annotations + +from typing import Annotated, Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + + +# --------------------------------------------------------------------------- +# Shared input base +# --------------------------------------------------------------------------- + + +class _BaseSlackInput(BaseModel): + """Fields shared by every Slack input model.""" + + model_config = ConfigDict(extra="forbid") + + token_secret_key: str = Field( + default="SLACK_BOT_TOKEN", + description=( + "SecretProvider key that holds the Slack Bot Token (xoxb-…). " + "Override only when running multiple bots." + ), + ) + + +# --------------------------------------------------------------------------- +# Action: post_message +# --------------------------------------------------------------------------- + + +class SlackPostMessageInput(_BaseSlackInput): + """Send a message to a Slack channel.""" + + action: Literal["post_message"] = "post_message" + channel: str = Field( + ..., description="Target Channel ID (C…), Name (#general), or User ID (U…)." + ) + message: str = Field(..., description="Plain-text fallback message (markdown supported).") + blocks: Optional[Union[str, List[Any]]] = Field( + default=None, + description="Block Kit payload as a JSON string or a pre-parsed list.", + ) + + +# --------------------------------------------------------------------------- +# Action: send_direct_message +# --------------------------------------------------------------------------- + + +class SlackSendDirectMessageInput(_BaseSlackInput): + """Send a direct message to a Slack user.""" + + action: Literal["send_direct_message"] = "send_direct_message" + channel: str = Field( + ..., description="Target User ID (U…), Channel ID (C…), or Name (#general)." + ) + message: str = Field(..., description="Plain-text fallback message (markdown supported).") + blocks: Optional[Union[str, List[Any]]] = Field( + default=None, + description="Block Kit payload as a JSON string or a pre-parsed list.", + ) + + +# --------------------------------------------------------------------------- +# Action: upload_file +# --------------------------------------------------------------------------- + + +class SlackUploadFileInput(_BaseSlackInput): + """Upload a file to a Slack channel or DM via the external-upload API.""" + + action: Literal["upload_file"] = "upload_file" + channel: str = Field( + ..., + description=( + "Target Channel ID (C/G/D/Z...) or User ID (U/W...) to share the file with. " + "Channel names like #general are not accepted by Slack's external upload API." + ), + ) + filename: str = Field(default="", description="Display name for the uploaded file.") + initial_comment: str = Field(default="", description="Message posted alongside the file.") + filepath: str = Field( + default="", + description=( + "Absolute path to a file under the sandboxed attachments directory " + "(NW_SLACK_ATTACHMENTS_DIR). Mutually exclusive with content_base64." + ), + ) + content_base64: str = Field( + default="", + description="Base64-encoded file content. Mutually exclusive with filepath.", + ) + + +# --------------------------------------------------------------------------- +# Discriminated union — used by BaseConnector internally +# --------------------------------------------------------------------------- + +_SlackOperationUnion = Annotated[ + Union[ + SlackPostMessageInput, + SlackSendDirectMessageInput, + SlackUploadFileInput, + ], + Field(discriminator="action"), +] + + +# --------------------------------------------------------------------------- +# Output +# --------------------------------------------------------------------------- + + +class SlackOutput(BaseModel): + """Unified output envelope for all Slack actions.""" + + ok: bool = Field(..., description="True when Slack acknowledged the request.") + ts: Optional[str] = Field(default=None, description="Message timestamp (chat actions).") + file_id: Optional[str] = Field(default=None, description="File ID (upload action).") + channel: Optional[str] = Field(default=None, description="Channel the message was sent to.") + description: str = Field(default="", description="Human-readable summary of the outcome.") + raw: Dict[str, Any] = Field(default_factory=dict, description="Full Slack API response.") diff --git a/src/node_wire_smtp/__init__.py b/src/node_wire_smtp/__init__.py new file mode 100644 index 0000000..68d3669 --- /dev/null +++ b/src/node_wire_smtp/__init__.py @@ -0,0 +1,6 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# + +# Connector subpackage: smtp diff --git a/src/connectors/smtp/logic.py b/src/node_wire_smtp/logic.py similarity index 52% rename from src/connectors/smtp/logic.py rename to src/node_wire_smtp/logic.py index c5809ba..9847b60 100644 --- a/src/connectors/smtp/logic.py +++ b/src/node_wire_smtp/logic.py @@ -1,41 +1,70 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import logging +import os from email.message import EmailMessage import aiosmtplib -from runtime import BaseConnector +from node_wire_runtime import BaseConnector, sdk_action +from node_wire_runtime.mcp_normalizers import normalize_smtp_send_email from .schema import SmtpSendInput, SmtpSendOutput logger = logging.getLogger("connectors.smtp") -class SmtpConnector(BaseConnector[SmtpSendInput, SmtpSendOutput]): +class SmtpConnector(BaseConnector): """ SMTP connector for sending emails via aiosmtplib. """ connector_id = "smtp" - action = "send_email" + output_model = SmtpSendOutput - async def internal_execute(self, params: SmtpSendInput, *, trace_id: str) -> SmtpSendOutput: + @sdk_action( + "send_email", + alias_tolerant=True, + mcp_normalize=normalize_smtp_send_email, + ) + async def send_email(self, params: SmtpSendInput, *, trace_id: str) -> SmtpSendOutput: + # Derive a domain-only hint so the sender identity (PII) is never written to logs. + _sender_domain = ( + str(params.from_email).split("@")[-1] if "@" in str(params.from_email) else "unknown" + ) logger.info( "Preparing SMTP message", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "send_email", "host": params.host, "port": params.port, - "from_email": str(params.from_email), + "sender_domain": _sender_domain, "recipient_count": len(params.to), }, ) - username = self.secret_provider.get_secret(params.username_secret_key) - password = self.secret_provider.get_secret(params.password_secret_key) + # Resolve credentials from AuthProvider (injected by factory). + # Falls back to environment variables for backward compatibility when + # the connector is instantiated without an explicit auth_provider. + creds = await self._auth_provider.get_client_credentials() + if creds is not None and isinstance(creds, (list, tuple)) and len(creds) == 2: + username, password = str(creds[0]), str(creds[1]) + else: + # Fallback: resolve from environment / secret_provider directly. + try: + username = self.secret_provider.get_secret("SMTP_USERNAME") + password = self.secret_provider.get_secret("SMTP_PASSWORD") + except Exception: + import os as _os + + username = _os.environ.get("SMTP_USERNAME", "") + password = _os.environ.get("SMTP_PASSWORD", "") message = EmailMessage() message["From"] = str(params.from_email) @@ -53,7 +82,7 @@ async def internal_execute(self, params: SmtpSendInput, *, trace_id: str) -> Smt password=password, use_tls=use_implicit, start_tls=params.use_tls and not use_implicit, - timeout=30.0, + timeout=float(os.getenv("AOT_CONNECTOR_TIMEOUT", "30.0")), ) except Exception as exc: # noqa: BLE001 logger.error( @@ -61,7 +90,7 @@ async def internal_execute(self, params: SmtpSendInput, *, trace_id: str) -> Smt extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "send_email", "host": params.host, "port": params.port, "error_type": type(exc).__name__, @@ -75,13 +104,13 @@ async def internal_execute(self, params: SmtpSendInput, *, trace_id: str) -> Smt extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "send_email", "host": params.host, "port": params.port, + "sender_domain": _sender_domain, "response": str(response), }, ) # aiosmtplib returns (code, message) tuple; message-id is not guaranteed, keep output simple. return SmtpSendOutput(sent=True) - diff --git a/src/node_wire_smtp/registration.py b/src/node_wire_smtp/registration.py new file mode 100644 index 0000000..8483583 --- /dev/null +++ b/src/node_wire_smtp/registration.py @@ -0,0 +1,26 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import aiosmtplib + +from node_wire_runtime import ErrorCategory, ErrorMapper + + +# Connection / timeout issues are retryable. +ErrorMapper.register( + aiosmtplib.errors.SMTPConnectError, ErrorCategory.RETRYABLE, code="SMTP_CONNECT_ERROR" +) +ErrorMapper.register( + aiosmtplib.errors.SMTPTimeoutError, ErrorCategory.RETRYABLE, code="SMTP_TIMEOUT" +) + +# Authentication failures map to AUTH. +ErrorMapper.register( + aiosmtplib.errors.SMTPAuthenticationError, ErrorCategory.AUTH, code="SMTP_AUTH_ERROR" +) + +# Generic SMTP protocol problems are treated as BUSINESS by default. +ErrorMapper.register(aiosmtplib.errors.SMTPException, ErrorCategory.BUSINESS, code="SMTP_ERROR") diff --git a/src/node_wire_smtp/schema.py b/src/node_wire_smtp/schema.py new file mode 100644 index 0000000..dddf70a --- /dev/null +++ b/src/node_wire_smtp/schema.py @@ -0,0 +1,93 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +import re +from typing import Any, List, Literal, Optional, Union + +from pydantic import BaseModel, EmailStr, model_validator + + +def _strip_env(s: str) -> str: + return s.strip(" '\"") + + +def _extract_email(value: str) -> str: + """Pydantic EmailStr does not accept 'Name '.""" + match = re.search(r"<(.+?)>", value) + return (match.group(1) if match else value).strip() + + +class SmtpSendInput(BaseModel): + """ + Send an email via SMTP. + + Only ``to``, ``subject``, and ``body`` are required — connection settings + (``host``, ``port``, ``use_tls``) fall back to server-side environment + variables when not supplied. + + Credentials (username and password) are **not** part of this schema. + They are managed entirely by the :class:`AuthProvider` injected into the + connector by the factory, keeping secrets out of the request payload. + """ + + action: Literal["send_email"] = "send_email" + host: str = "" + port: int = 0 + use_tls: bool = True + from_email: Optional[EmailStr] = None + to: Union[str, List[EmailStr]] + subject: str + body: str + + @model_validator(mode="before") + @classmethod + def _fill_env_and_normalize(cls, values: Any) -> Any: + if not isinstance(values, dict): + return values + + if not (values.get("host") or "").strip(): + values["host"] = _strip_env(os.environ.get("SMTP_HOST", "smtp.gmail.com")) + port_raw = values.get("port") + if port_raw in (None, "", 0): + values["port"] = int(_strip_env(os.environ.get("SMTP_PORT", "587"))) + if "use_tls" not in values: + values["use_tls"] = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" + + if "from" in values and not values.get("from_email"): + values["from_email"] = values.pop("from") + + fe = values.get("from_email") + if fe is None or not str(fe).strip(): + values["from_email"] = _strip_env( + os.environ.get("FROM_EMAIL") + or os.environ.get("SMTP_USERNAME") + or "noreply@node-wire.local" + ) + else: + values["from_email"] = _extract_email(_strip_env(str(fe))) + + # Guardrail: reject placeholder / invalid sender hints from callers + sender = str(values["from_email"]) + if not sender or "@" not in sender or "system_default" in sender: + values["from_email"] = _strip_env( + os.environ.get("FROM_EMAIL") + or os.environ.get("SMTP_USERNAME") + or "noreply@node-wire.local" + ) + + raw_to = values.get("to") + if isinstance(raw_to, str): + values["to"] = [_extract_email(raw_to)] + elif isinstance(raw_to, list): + values["to"] = [_extract_email(str(x)) for x in raw_to] + + return values + + +class SmtpSendOutput(BaseModel): + sent: bool + message_id: Optional[str] = None diff --git a/src/node_wire_stripe/README.md b/src/node_wire_stripe/README.md new file mode 100644 index 0000000..45a2da4 --- /dev/null +++ b/src/node_wire_stripe/README.md @@ -0,0 +1,68 @@ +# Node Wire Connector — Stripe + +The Stripe connector provides a reliable, async adapter for processing payments and managing subscriptions using the Stripe Python SDK. It follows the Node Wire platform contract: consistent error handling, resilience (retries/circuit breaking), and standardized telemetry. + +## Supported Actions + +The connector exposes several actions through the `@nw_action` decorator. Each action is available via REST, gRPC, and MCP. + +| Action | Description | Key Parameters | +| :--- | :--- | :--- | +| `charge` | Legacy charge creation. | `amount`, `currency`, `source` | +| `create_payment_intent` | Modern payment flow for one-time payments. | `amount`, `currency`, `customer_id`, `confirm` | +| `create_subscription` | Create a recurring subscription. | `customer_id`, `price_id`, `card_token` | +| `cancel_subscription` | Terminate or schedule the end of a subscription. | `subscription_id`, `cancel_at_period_end` | +| `issue_refund` | Full or partial refund for a charge or payment intent. | `charge_id` or `payment_intent_id`, `amount` | + +## Setup & Configuration + +### Environment Variables + +The connector requires a Stripe secret API key. By default, the `EnvSecretProvider` looks for: + +- `STRIPE_API_KEY`: Your Stripe secret key (e.g., `sk_test_...` or `sk_live_...`). + +Add this to your `.env` or system environment: + +```bash +STRIPE_API_KEY=sk_test_your_secret_key +``` + +### Enabling the Connector + +In `config/connectors.yaml`, ensuring the connector is enabled and exposed: + +```yaml +connectors: + stripe: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] +``` + +## Detailed Action Reference + +### `create_subscription` + +This action supports multiple payment integration flows: + +1. **Saved Payment Method**: Pass `default_payment_method` with an existing `pm_xxx` ID. +2. **Card Token**: Pass `card_token` (e.g., `tok_visa`). The connector will automatically create a PaymentMethod and attach it to the customer before creating the subscription. +3. **SCA / Action Required**: If the subscription requires further action (like 3D Secure), the connector returns the `client_secret` from the associated Setup Intent or Payment Intent. + +### `cancel_subscription` + +- Set `cancel_at_period_end: true` to let the subscription finish its current cycle. +- Set `cancel_at_period_end: false` (default) to terminate the subscription immediately. + +## Error Handling + +Mapped Stripe exceptions to Node Wire error categories: + +- `RateLimitError` -> `RETRYABLE` (`STRIPE_RATE_LIMIT`) +- `CardError` -> `BUSINESS` (`STRIPE_CARD_ERROR`) +- `AuthenticationError` -> `AUTH` (`STRIPE_AUTH_ERROR`) +- `APIConnectionError` -> `RETRYABLE` (`STRIPE_API_CONNECTION`) +- `InvalidRequestError` -> `BUSINESS` (`STRIPE_INVALID_REQUEST`) +- `StripeError` -> `FATAL` (`STRIPE_ERROR`) + +Trace IDs are included in all error responses for easier troubleshooting in the Stripe Dashboard. diff --git a/src/node_wire_stripe/__init__.py b/src/node_wire_stripe/__init__.py new file mode 100644 index 0000000..baa5905 --- /dev/null +++ b/src/node_wire_stripe/__init__.py @@ -0,0 +1,6 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# + +# Connector subpackage: stripe diff --git a/src/node_wire_stripe/logic.py b/src/node_wire_stripe/logic.py new file mode 100644 index 0000000..6cd8750 --- /dev/null +++ b/src/node_wire_stripe/logic.py @@ -0,0 +1,343 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import asyncio +import logging +from typing import Any, ClassVar, cast + +import stripe + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.models import ErrorCategory + +from .schema import ( + CancelSubscriptionInput, + ChargeInput, + CreatePaymentIntentInput, + CreateSubscriptionInput, + IssueRefundInput, + StripeOperationOutput, +) + +logger = logging.getLogger("connectors.stripe") + + +class StripeConnector(BaseConnector): + """Stripe connector: payments and subscriptions as @nw_action methods.""" + + connector_id = "stripe" + output_model = StripeOperationOutput + + error_map: ClassVar[dict[type[BaseException], tuple[ErrorCategory, str]]] = { + stripe.error.RateLimitError: (ErrorCategory.RETRYABLE, "STRIPE_RATE_LIMIT"), + stripe.error.APIConnectionError: (ErrorCategory.RETRYABLE, "STRIPE_API_CONNECTION"), + stripe.error.CardError: (ErrorCategory.BUSINESS, "STRIPE_CARD_ERROR"), + stripe.error.InvalidRequestError: (ErrorCategory.BUSINESS, "STRIPE_INVALID_REQUEST"), + stripe.error.AuthenticationError: (ErrorCategory.AUTH, "STRIPE_AUTH_ERROR"), + stripe.error.StripeError: (ErrorCategory.FATAL, "STRIPE_ERROR"), + } + + def _get_api_key(self) -> str: + return self.secret_provider.get_secret("stripe_api_key") + + @nw_action("charge") + async def charge(self, params: ChargeInput, *, trace_id: str) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Creating Stripe charge", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "charge", + "amount": params.amount, + "currency": params.currency, + }, + ) + + def _create() -> stripe.Charge: + create = cast(Any, stripe.Charge.create) + return create( + api_key=api_key, + amount=params.amount, + currency=params.currency, + source=params.source, + customer=params.customer_id, + description=params.description, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + charge = await asyncio.to_thread(_create) + except Exception as exc: + logger.error( + "Stripe charge creation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "charge", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + charge_id=getattr(charge, "id", None), + receipt_url=getattr(charge, "receipt_url", None), + status="succeeded" if getattr(charge, "paid", False) else "failed", + ) + + @nw_action("create_payment_intent") + async def create_payment_intent( + self, params: CreatePaymentIntentInput, *, trace_id: str + ) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Creating Stripe Payment Intent", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_payment_intent", + "amount": params.amount, + "currency": params.currency, + }, + ) + + def _create() -> stripe.PaymentIntent: + create = cast(Any, stripe.PaymentIntent.create) + return create( + api_key=api_key, + amount=params.amount, + currency=params.currency, + customer=params.customer_id, + payment_method=params.payment_method, + confirm=params.confirm, + description=params.description, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + pi = await asyncio.to_thread(_create) + except Exception as exc: + logger.error( + "Stripe Payment Intent creation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_payment_intent", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + payment_intent_id=getattr(pi, "id", None), + client_secret=getattr(pi, "client_secret", None), + status=getattr(pi, "status", None), + ) + + @nw_action("create_subscription") + async def create_subscription( + self, params: CreateSubscriptionInput, *, trace_id: str + ) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Creating Stripe Subscription", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_subscription", + "customer_id": params.customer_id, + "price_id": params.price_id, + }, + ) + + def _create() -> stripe.Subscription: + payment_method_id = params.default_payment_method + + # If card_token is provided, create and attach PaymentMethod + if params.card_token: + pm = stripe.PaymentMethod.create( + api_key=api_key, + type="card", + card={"token": params.card_token}, + ) + stripe.PaymentMethod.attach( + pm.id, + api_key=api_key, + customer=params.customer_id, + ) + payment_method_id = pm.id + + return cast(Any, stripe.Subscription.create)( + api_key=api_key, + customer=params.customer_id, + items=[{"price": params.price_id}], + payment_behavior=params.payment_behavior, + default_payment_method=payment_method_id, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + sub = await asyncio.to_thread(_create) + except Exception as exc: + logger.error( + "Stripe Subscription creation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "create_subscription", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + # Subscription might have a setup_intent or latest_invoice.payment_intent + client_secret = None + pending_setup_intent = getattr(sub, "pending_setup_intent", None) + latest_invoice_id = getattr(sub, "latest_invoice", None) + + def _stripe_obj_id(obj: Any) -> str: + if isinstance(obj, str): + return obj + oid = getattr(obj, "id", None) + return str(oid) if oid is not None else str(obj) + + if pending_setup_intent: + si = await asyncio.to_thread( + stripe.SetupIntent.retrieve, + _stripe_obj_id(pending_setup_intent), + api_key=api_key, + ) + client_secret = getattr(si, "client_secret", None) + elif latest_invoice_id: + inv = await asyncio.to_thread( + stripe.Invoice.retrieve, + _stripe_obj_id(latest_invoice_id), + api_key=api_key, + ) + pi_id = getattr(inv, "payment_intent", None) + if pi_id: + pi = await asyncio.to_thread( + stripe.PaymentIntent.retrieve, + _stripe_obj_id(pi_id), + api_key=api_key, + ) + client_secret = getattr(pi, "client_secret", None) + + return StripeOperationOutput( + subscription_id=getattr(sub, "id", None), + status=getattr(sub, "status", None), + client_secret=client_secret, + ) + + @nw_action("cancel_subscription") + async def cancel_subscription( + self, params: CancelSubscriptionInput, *, trace_id: str + ) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Cancelling Stripe Subscription", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "cancel_subscription", + "subscription_id": params.subscription_id, + }, + ) + + def _cancel() -> stripe.Subscription: + if params.cancel_at_period_end: + return stripe.Subscription.modify( + params.subscription_id, + api_key=api_key, + cancel_at_period_end=True, + idempotency_key=params.idempotency_key or trace_id, + ) + else: + return stripe.Subscription.cancel( + params.subscription_id, + api_key=api_key, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + sub = await asyncio.to_thread(_cancel) + except Exception as exc: + logger.error( + "Stripe Subscription cancellation failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "cancel_subscription", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + subscription_id=getattr(sub, "id", None), + status=getattr(sub, "status", None), + ) + + @nw_action("issue_refund") + async def issue_refund( + self, params: IssueRefundInput, *, trace_id: str + ) -> StripeOperationOutput: + api_key = self._get_api_key() + + logger.info( + "Issuing Stripe Refund", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "issue_refund", + "charge_id": params.charge_id, + "payment_intent_id": params.payment_intent_id, + }, + ) + + def _refund() -> stripe.Refund: + create = cast(Any, stripe.Refund.create) + return create( + api_key=api_key, + charge=params.charge_id, + payment_intent=params.payment_intent_id, + amount=params.amount, + reason=params.reason, + metadata=params.metadata, + idempotency_key=params.idempotency_key or trace_id, + ) + + try: + refund = await asyncio.to_thread(_refund) + except Exception as exc: + logger.error( + "Stripe Refund issuance failed", + extra={ + "trace_id": trace_id, + "connector_id": self.connector_id, + "action": "issue_refund", + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) + raise + + return StripeOperationOutput( + refund_id=getattr(refund, "id", None), + status=getattr(refund, "status", None), + ) diff --git a/src/connectors/stripe/registration.py b/src/node_wire_stripe/registration.py similarity index 65% rename from src/connectors/stripe/registration.py rename to src/node_wire_stripe/registration.py index de6df01..7ea2f53 100644 --- a/src/connectors/stripe/registration.py +++ b/src/node_wire_stripe/registration.py @@ -1,14 +1,19 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import stripe -from runtime import ErrorCategory, ErrorMapper +from node_wire_runtime import ErrorCategory, ErrorMapper # Stripe SDK error mappings following the spec examples. ErrorMapper.register(stripe.error.RateLimitError, ErrorCategory.RETRYABLE, code="STRIPE_RATE_LIMIT") ErrorMapper.register(stripe.error.CardError, ErrorCategory.BUSINESS, code="STRIPE_CARD_ERROR") ErrorMapper.register(stripe.error.AuthenticationError, ErrorCategory.AUTH, code="STRIPE_AUTH_ERROR") -ErrorMapper.register(stripe.error.APIConnectionError, ErrorCategory.RETRYABLE, code="STRIPE_API_CONNECTION") +ErrorMapper.register( + stripe.error.APIConnectionError, ErrorCategory.RETRYABLE, code="STRIPE_API_CONNECTION" +) ErrorMapper.register(stripe.error.StripeError, ErrorCategory.FATAL, code="STRIPE_ERROR") - diff --git a/src/node_wire_stripe/schema.py b/src/node_wire_stripe/schema.py new file mode 100644 index 0000000..b64b491 --- /dev/null +++ b/src/node_wire_stripe/schema.py @@ -0,0 +1,127 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from typing import Any, Annotated, Literal + +from pydantic import BaseModel, Field, field_validator + + +class ChargeInput(BaseModel): + action: Literal["charge"] = "charge" + amount: Annotated[int, Field(ge=1, le=99_999_999)] + currency: Annotated[str, Field(pattern=r"^[a-z]{3}$")] + source: str + customer_id: str | None = None + description: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field( + None, description="Optional unique key to prevent duplicate operations." + ) + + @field_validator("currency", mode="before") + @classmethod + def normalize_currency(cls, value: object) -> object: + if isinstance(value, str): + return value.strip().lower() + return value + + +class ChargeOutput(BaseModel): + charge_id: str + receipt_url: str | None = None + + +class CancelSubscriptionInput(BaseModel): + action: Literal["cancel_subscription"] = "cancel_subscription" + subscription_id: str + cancel_at_period_end: bool = False + idempotency_key: str | None = Field( + None, description="Optional unique key to prevent duplicate operations." + ) + + +class CancelSubscriptionOutput(BaseModel): + subscription_id: str + status: str + + +class CreatePaymentIntentInput(BaseModel): + action: Literal["create_payment_intent"] = "create_payment_intent" + amount: Annotated[int, Field(ge=1, le=99_999_999)] + currency: Annotated[str, Field(pattern=r"^[a-z]{3}$")] + customer_id: str | None = None + payment_method: str | None = None + confirm: bool = False + description: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field( + None, description="Optional unique key to prevent duplicate operations." + ) + + @field_validator("currency", mode="before") + @classmethod + def normalize_currency(cls, value: object) -> object: + if isinstance(value, str): + return value.strip().lower() + return value + + +class CreatePaymentIntentOutput(BaseModel): + payment_intent_id: str + client_secret: str | None = None + status: str + + +class CreateSubscriptionInput(BaseModel): + action: Literal["create_subscription"] = "create_subscription" + customer_id: str + price_id: str + payment_behavior: str = "default_incomplete" + default_payment_method: str | None = None + card_token: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field( + None, description="Optional unique key to prevent duplicate operations." + ) + + +class CreateSubscriptionOutput(BaseModel): + subscription_id: str + client_secret: str | None = None + status: str + + +class IssueRefundInput(BaseModel): + action: Literal["issue_refund"] = "issue_refund" + charge_id: str | None = None + payment_intent_id: str | None = None + amount: int | None = Field(None, ge=1, le=99999999) + reason: str | None = None + metadata: dict | None = None + idempotency_key: str | None = Field( + None, description="Optional unique key to prevent duplicate operations." + ) + + +class IssueRefundOutput(BaseModel): + refund_id: str + status: str + + +class StripeOperationOutput(BaseModel): + """ + Unified output model for all Stripe actions. + The actual result will be contained in one or more of these fields. + """ + + charge_id: str | None = None + receipt_url: str | None = None + subscription_id: str | None = None + status: str | None = None + payment_intent_id: str | None = None + client_secret: str | None = None + refund_id: str | None = None + raw: dict[str, Any] | None = None diff --git a/src/runtime/__init__.py b/src/runtime/__init__.py deleted file mode 100644 index 76d63e9..0000000 --- a/src/runtime/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .models import ConnectorResponse, ErrorCategory -from .errors import ErrorMapper -from .base import BaseConnector -from .secrets import SecretProvider -from .policy import PolicyHook, PolicyDenied - -__all__ = [ - "ConnectorResponse", - "ErrorCategory", - "ErrorMapper", - "BaseConnector", - "SecretProvider", - "PolicyHook", - "PolicyDenied", -] diff --git a/src/runtime/base.py b/src/runtime/base.py deleted file mode 100644 index 25596d9..0000000 --- a/src/runtime/base.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import annotations - -import logging -import uuid -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TypeVar - -from opentelemetry import trace -from opentelemetry.trace import Tracer -from pybreaker import CircuitBreaker -from pydantic import BaseModel, ValidationError - -from .errors import ErrorMapper -from .models import ConnectorResponse, ErrorCategory -from .policy import PolicyContext, PolicyHook, PolicyDenied -from .resilience import with_resilience -from .secrets import SecretProvider - -logger = logging.getLogger("runtime.base") -tracer: Tracer = trace.get_tracer("runtime") - -InputModelT = TypeVar("InputModelT", bound=BaseModel) -OutputModelT = TypeVar("OutputModelT", bound=BaseModel) - - -class BaseConnector(ABC, Generic[InputModelT, OutputModelT]): - """ - Base class for all connectors. - - This is the single execution entrypoint used by all bindings. - """ - - connector_id: str - action: str - - def __init__( - self, - input_model: Type[InputModelT], - output_model: Type[OutputModelT], - secret_provider: Optional[SecretProvider] = None, - policy_hook: Optional[PolicyHook] = None, - breaker: Optional[CircuitBreaker] = None, - ) -> None: - self._input_model_cls = input_model - self._output_model_cls = output_model - self._secret_provider = secret_provider - self._policy_hook = policy_hook - self._breaker = breaker or CircuitBreaker( - fail_max=5, - reset_timeout=30, - name=f"{self.__class__.__name__}_breaker", - ) - - @property - def secret_provider(self) -> SecretProvider: - if self._secret_provider is None: - raise RuntimeError("SecretProvider has not been configured for this connector.") - return self._secret_provider - - async def run( - self, - raw_input: Dict[str, Any], - principal: Optional[str] = None, - tenant_id: Optional[str] = None, - ) -> ConnectorResponse: - """ - Public execution entrypoint. - - - Generates a trace ID - - Starts an OpenTelemetry span - - Validates input - - Executes policy hook - - Wraps internal execution with retries and circuit breaking - - Maps exceptions into the standard error taxonomy - """ - trace_id = str(uuid.uuid4()) - print(f"trace_id: {trace_id} from runtime.base") - - with tracer.start_as_current_span( - "connector.run", - attributes={ - "connector.id": self.connector_id, - "connector.action": self.action, - "tenant.id": tenant_id or "", - "principal.id": principal or "", - "trace.id": trace_id, - }, - ): - logger.info( - "Starting connector execution", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - }, - ) - - try: - try: - input_model = self._input_model_cls(**raw_input) - except ValidationError as exc: - logger.error( - "Input validation failed", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - # Expose validation details so clients know which fields failed. - details = [ - {"loc": e["loc"], "msg": e["msg"], "type": e["type"]} - for e in exc.errors() - ] - return ConnectorResponse( - success=False, - error_code="VALIDATION_ERROR", - error_category=ErrorCategory.BUSINESS, - message="Input validation failed; please check the request payload.", - trace_id=trace_id, - details=details, - ) - - # Policy hook - if self._policy_hook is not None: - context = PolicyContext( - connector_id=self.connector_id, - action=self.action, - input_payload=input_model.model_dump(), - principal=principal, - tenant_id=tenant_id, - ) - try: - self._policy_hook.check(context) - except PolicyDenied as exc: - logger.warning( - "Execution blocked by policy hook", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - mapped = ErrorMapper.resolve(exc) - return ConnectorResponse( - success=False, - error_code=mapped.code, - error_category=mapped.category, - message=str(exc), - trace_id=trace_id, - ) - - execute_with_resilience = with_resilience(self._breaker) - - @execute_with_resilience - async def _do_execute(*, trace_id: str) -> OutputModelT: - return await self.internal_execute(input_model, trace_id=trace_id) - - output_model = await _do_execute(trace_id=trace_id) - - logger.info( - "Connector execution completed successfully - runtime.base", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - }, - ) - - return ConnectorResponse( - success=True, - data=output_model.model_dump(), - trace_id=trace_id, - ) - except Exception as exc: # noqa: BLE001 - mapped = ErrorMapper.resolve(exc) - logger.error( - "Connector execution failed", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "error_code": mapped.code, - "error_category": mapped.category.value, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - return ConnectorResponse( - success=False, - error_code=mapped.code, - error_category=mapped.category, - message=str(exc), - trace_id=trace_id, - ) - - @abstractmethod - async def internal_execute(self, params: InputModelT, *, trace_id: str) -> OutputModelT: - """ - Implement connector-specific logic here. - - All external calls must be wrapped in try/except blocks with clear, - human-readable logging messages. Any raised exceptions will be - standardized by the ErrorMapper. - """ - raise NotImplementedError diff --git a/src/runtime/resilience.py b/src/runtime/resilience.py deleted file mode 100644 index fc49845..0000000 --- a/src/runtime/resilience.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -import logging -from functools import wraps -from typing import Any, Awaitable, Callable, Coroutine, TypeVar - -from pybreaker import CircuitBreaker, CircuitBreakerError -from tenacity import AsyncRetrying, RetryError, retry_if_exception_type, stop_after_attempt, wait_exponential - -from .errors import ErrorMapper -from .models import ErrorCategory - -logger = logging.getLogger("runtime.resilience") - -T = TypeVar("T") - - -def with_resilience( - breaker: CircuitBreaker, - max_attempts: int = 3, - base_wait: float = 0.5, - max_wait: float = 5.0, -) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Coroutine[Any, Any, T]]]: - """ - Decorator that applies retry (Tenacity) and circuit breaking (PyBreaker) - around an async function that may raise exceptions. - """ - - def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Coroutine[Any, Any, T]]: - @wraps(fn) - async def wrapper(*args: Any, **kwargs: Any) -> T: - trace_id: str = kwargs.get("trace_id", "unknown-trace") - - async def _call() -> T: - if breaker.state.name == "open": - logger.error( - "Circuit breaker is OPEN; rejecting call", - extra={"trace_id": trace_id, "component": "resilience", "error": "circuit open"}, - ) - raise CircuitBreakerError("Circuit breaker is open") - try: - result = await fn(*args, **kwargs) - breaker._state.on_success() # noqa: SLF001 - return result - except Exception as exc: - breaker._state.on_failure(exc) # noqa: SLF001 - raise - except NameError: - # pybreaker < 1.0 requires Tornado's `gen` in call_async. - # Fall back to a direct call until pybreaker is upgraded to >= 1.0. - return await fn(*args, **kwargs) - - async for attempt in AsyncRetrying( - retry=retry_if_exception_type(Exception), - stop=stop_after_attempt(max_attempts), - wait=wait_exponential(multiplier=base_wait, max=max_wait), - reraise=True, - ): - with attempt: - try: - return await _call() - except Exception as exc: # noqa: BLE001 - mapped = ErrorMapper.resolve(exc) - if mapped.category is not ErrorCategory.RETRYABLE: - # Non-retryable: log and re-raise without further retries. - logger.error( - "Non-retryable error during execution", - extra={ - "trace_id": trace_id, - "error_code": mapped.code, - "error_category": mapped.category.value, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - raise - - logger.warning( - "Retryable error during execution; will retry", - extra={ - "trace_id": trace_id, - "error_code": mapped.code, - "error_category": mapped.category.value, - "attempt_number": attempt.retry_state.attempt_number, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - raise - - # Should not be reached because reraise=True ensures RetryError is propagated. - raise RetryError("Exhausted retries without success") - - return wrapper - - return decorator diff --git a/src/runtime/secrets.py b/src/runtime/secrets.py deleted file mode 100644 index da864a8..0000000 --- a/src/runtime/secrets.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod - - -class SecretProvider(ABC): - """ - Abstract port for secret resolution. - - Implementations live in Layer C and may use environment variables, - a secrets manager, or any other secure storage. - """ - - @abstractmethod - def get_secret(self, key: str) -> str: - """Return the secret value for the given key or raise an exception.""" - raise NotImplementedError diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e61b03d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,69 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Shared pytest configuration. + +REST API tests default to ``NW_REST_AUTH_DISABLED=true`` so existing tests do not need +headers. MCP tests default to ``NW_MCP_AUTH_DISABLED=true`` for the same reason. +Tests that assert authentication behavior override these env vars. +""" + +from __future__ import annotations + +import os +import importlib +import warnings +from pathlib import Path + +import pytest + +_TESTS_ROOT = Path(__file__).resolve().parent + +# Ensure tests can import app.py which builds dynamic routes via factory (needs allowed connectors to not crash M3 fail-fast) +os.environ["NW_ALLOWED_CONNECTORS"] = "http_generic,smtp,stripe,google_drive,fhir_epic,fhir_cerner" +# Skip REST bind dotenv so repo `.env` cannot override the allowlist above during collection/import. +os.environ["NW_REST_LOAD_DOTENV"] = "false" +# Use a connector config where optional connectors (e.g. slack, salesforce) are disabled so CI and +# devs without those packages still match the narrow allowlist (see tests/fixtures/connectors_for_tests.yaml). +os.environ["NW_CONFIG_PATH"] = str(_TESTS_ROOT / "fixtures" / "connectors_for_tests.yaml") + + +def _preload_connector_logic_modules() -> None: + """Register connectors without relying on ``importlib.metadata`` entry points. + + Ensures :func:`bindings.rest_api.app._build_dynamic_routes` sees connectors when + tests run with ``PYTHONPATH=src`` but without an editable install. + """ + for mod in ( + "node_wire_http_generic.logic", + "node_wire_smtp.logic", + "node_wire_stripe.logic", + "node_wire_google_drive.logic", + "node_wire_fhir_epic.logic", + "node_wire_fhir_cerner.logic", + ): + try: + importlib.import_module(mod) + except ImportError as exc: + warnings.warn( + f"tests: could not import {mod!r} (connectors may be missing in this env): {exc}", + UserWarning, + stacklevel=2, + ) + except Exception as exc: + raise RuntimeError( + f"tests: unexpected error importing connector module {mod!r}" + ) from exc + + +_preload_connector_logic_modules() + + +@pytest.fixture(autouse=True) +def _rest_auth_disabled_for_tests(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_REST_AUTH_DISABLED", "true") + monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "true") + monkeypatch.setenv("NW_RATE_LIMIT_BURST", "1000") # Increase for tests + monkeypatch.setenv("NW_RATE_LIMIT_REFILL_RATE", "100.0") # Increase for tests + monkeypatch.setenv("NW_RATE_LIMIT_DISABLED", "true") # Disable rate limiting for tests diff --git a/tests/fixtures/bandit_minimal_report.json b/tests/fixtures/bandit_minimal_report.json new file mode 100644 index 0000000..a29319b --- /dev/null +++ b/tests/fixtures/bandit_minimal_report.json @@ -0,0 +1,28 @@ +{ + "errors": [], + "generated_at": "2026-01-01T00:00:00Z", + "metrics": { + "_totals": { + "CONFIDENCE.HIGH": 0, + "CONFIDENCE.LOW": 0, + "CONFIDENCE.MEDIUM": 0, + "CONFIDENCE.UNDEFINED": 0, + "SEVERITY.HIGH": 0, + "SEVERITY.LOW": 1, + "SEVERITY.MEDIUM": 0, + "SEVERITY.UNDEFINED": 0, + "loc": 100, + "nosec": 0, + "skipped_tests": 0 + } + }, + "results": [ + { + "filename": "src/example.py", + "line_number": 1, + "issue_severity": "LOW", + "test_id": "B999", + "issue_text": "Example finding for summary script tests." + } + ] +} diff --git a/tests/fixtures/connectors_for_tests.yaml b/tests/fixtures/connectors_for_tests.yaml new file mode 100644 index 0000000..8768e17 --- /dev/null +++ b/tests/fixtures/connectors_for_tests.yaml @@ -0,0 +1,90 @@ +# Test fixture: mirrors ../config/connectors.yaml but optional connectors not on +# the pytest allowlist are disabled (slack, salesforce). +# Enabling them here would fail ConnectorFactory.load() when not registered +# (NW_ALLOWED_CONNECTORS / missing package). +# +# SECURITY RULE: This file must never contain secrets. + +connectors: + http_generic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + + smtp: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + host: "smtp.example.com" + port: 587 + from_email: "noreply@example.com" + auth: + provider: static_credentials + username_secret: SMTP_USERNAME + password_secret: SMTP_PASSWORD + + stripe: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: static_token + secret_key: stripe_api_key + header_name: Authorization + prefix: "" + + google_drive: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: service_account + sa_json_secret: GOOGLE_DRIVE_SA_JSON + scopes: + - https://www.googleapis.com/auth/drive + + fhir_epic: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${EPIC_FHIR_BASE_URL:https://fhir.epic.sandbox/api/FHIR/R4}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: EPIC_TOKEN_URL + client_id_secret: EPIC_CLIENT_ID + private_key_secret: EPIC_PRIVATE_KEY + kid_secret: EPIC_KID + algorithm: RS384 + + fhir_cerner: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] + base_url: "${CERNER_FHIR_BASE_URL:https://fhir-ehr-code.cerner.com/r4/your-tenant-id}" + auth: + provider: oauth2 + grant_method: private_key_jwt + token_url_secret: CERNER_TOKEN_URL + client_id_secret: CERNER_CLIENT_ID + private_key_secret: CERNER_PRIVATE_KEY + kid_secret: CERNER_KID + algorithm: RS384 + scopes_secret: CERNER_SCOPES + scopes: + - system/Patient.read + - system/Encounter.read + - system/DocumentReference.read + - system/DocumentReference.write + + salesforce: + enabled: false + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: oauth2 + grant_method: refresh_token + token_url_secret: SALESFORCE_TOKEN_URL + client_id_secret: SALESFORCE_CLIENT_ID + client_secret_secret: SALESFORCE_CLIENT_SECRET + refresh_token_secret: SALESFORCE_REFRESH_TOKEN + + slack: + enabled: false + exposed_via: ["rest", "grpc", "mcp"] + auth: + provider: static_token + secret_key: SLACK_BOT_TOKEN diff --git a/tests/playground/cerner/README.md b/tests/playground/cerner/README.md new file mode 100644 index 0000000..80c72b7 --- /dev/null +++ b/tests/playground/cerner/README.md @@ -0,0 +1,84 @@ + + +# Cerner Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Cerner connector panel, click the run button with the +pre-filled defaults, and assert on the rendered pipeline state. No mocking +— every test hits the real Cerner FHIR R4 Sandbox API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_cerner_post_consultation_default` | Post-consultation sync — pre-filled patient Nancy Smart, all 4 steps must succeed | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's +`fetch("/scenarios/cerner-post-consultation")` call routes to the real backend, +which authenticates via private-key JWT and calls the real Cerner FHIR R4 +Sandbox. No `page.route()` interception. + +The form is pre-filled in the HTML with a sandbox patient (`12724066`) and +encounter (`97957281`) — no field changes or dropdown selections are needed +before clicking run. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Cerner tests +uv run pytest tests/playground/cerner/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/cerner/ --no-cov -v -s +``` + +> **Note:** Cerner tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `CERNER_CLIENT_ID` | Cerner backend application client ID | +| `CERNER_PRIVATE_KEY` | RSA private key (PEM) used for private-key JWT auth | +| `CERNER_TOKEN_URL` | Cerner token endpoint URL | +| `CERNER_KID` | Key ID (`kid`) that matches the public key registered in Cerner | +| `CERNER_FHIR_BASE_URL` | Base FHIR R4 URL including tenant ID (defaults to the Cerner code sandbox if unset) | +| `CERNER_SCOPES` | Space-separated OAuth2 scopes (optional; defaults defined in `connectors.yaml`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +Cerner tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `CERNER_CLIENT_ID` | `CERNER_CLIENT_ID` | +| `CERNER_PRIVATE_KEY` | `CERNER_PRIVATE_KEY` | +| `CERNER_TOKEN_URL` | `CERNER_TOKEN_URL` | +| `CERNER_KID` | `CERNER_KID` | +| `CERNER_FHIR_BASE_URL` | `CERNER_FHIR_BASE_URL` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The test uses the Cerner open Sandbox patient and encounter IDs pre-filled in the +Playground HTML. These are read-only sandbox resources — no records are created +or modified in a real Cerner environment. No cleanup is required after the session. diff --git a/tests/playground/cerner/cerner_page.py b/tests/playground/cerner/cerner_page.py new file mode 100644 index 0000000..76ca363 --- /dev/null +++ b/tests/playground/cerner/cerner_page.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class CernerPage: + """Page Object Model for the Cerner connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Cerner card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='cerner']") + + # Panel root and header + self.panel: Locator = page.locator("#cerner-panel") + self.title: Locator = page.locator("#cerner-panel .card-title h2") + self.run_btn: Locator = page.locator("#cerner-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Output and log elements + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Cerner card in system connectors to open the panel.""" + self.connector_card.click() + + def submit(self) -> None: + """Submit the form to execute the Cerner workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/cerner/conftest.py b/tests/playground/cerner/conftest.py new file mode 100644 index 0000000..3ac2004 --- /dev/null +++ b/tests/playground/cerner/conftest.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def cerner_connector_available(api_server_url: str) -> None: + """Skip the entire Cerner test session if the connector returns HTTP 500. + + This happens when Cerner FHIR credentials are missing or when NW_ALLOWED_CONNECTORS + is set but does not include 'fhir_cerner'. + """ + with httpx.Client(timeout=15) as client: + resp = client.post( + f"{api_server_url}/scenarios/cerner-post-consultation", + json={ + "patient_id": "12724066", + "encounter_id": "97957281", + "patient_given": "Nancy", + "patient_family": "Smart", + "note_text": "health-check", + }, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"Cerner connector not available ({detail}). " + "Ensure Cerner credentials are configured and 'fhir_cerner' is in NW_ALLOWED_CONNECTORS " + "(or leave it unset)." + ) diff --git a/tests/playground/cerner/test_cerner_integration.py b/tests/playground/cerner/test_cerner_integration.py new file mode 100644 index 0000000..8866c3d --- /dev/null +++ b/tests/playground/cerner/test_cerner_integration.py @@ -0,0 +1,47 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Cerner connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Cerner panel, +clicks the run button with pre-filled defaults, and asserts the resulting +pipeline state — no API mocking, real Cerner FHIR Sandbox calls. + +Required env vars (loaded from .env): + Cerner credentials (e.g. CERNER_CLIENT_ID, CERNER_CLIENT_SECRET, CERNER_BASE_URL) +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.cerner.cerner_page import CernerPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 25_000 # ms — 4-step pipeline with async Cerner FHIR API calls + + +def _navigate_to_cerner(page: Page) -> CernerPage: + PlaygroundHomePage(page).click_connectors() + cerner = CernerPage(page) + cerner.navigate_to_panel() + return cerner + + +def test_cerner_post_consultation_default(playground_page: Page) -> None: + """Submit a Cerner consultation with default pre-filled values; all 4 steps must succeed.""" + cerner = _navigate_to_cerner(playground_page) + cerner.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(cerner.final_result).to_be_visible(timeout=_TIMEOUT) + expect(cerner.summary_text).to_contain_text("Cerner EHR") + expect(cerner.result_tag).to_be_visible() + expect(playground_page.locator("#cerner-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(cerner.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() diff --git a/tests/playground/conftest.py b/tests/playground/conftest.py new file mode 100644 index 0000000..641d81f --- /dev/null +++ b/tests/playground/conftest.py @@ -0,0 +1,93 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +import socket +import threading +import time +from pathlib import Path + +from dotenv import load_dotenv +import httpx +import pytest + +_REPO_ROOT = Path(__file__).resolve().parent.parent.parent + +# tests/conftest.py restricts NW_ALLOWED_CONNECTORS to the narrow CI-safe set and +# points NW_CONFIG_PATH at a fixture yaml that disables salesforce/slack. +# Playground integration tests hit real external services and need the full allowlist +# and the real config/connectors.yaml, so override those values here before any app +# import occurs. +os.environ["NW_ALLOWED_CONNECTORS"] = ( + "http_generic,smtp,stripe,google_drive,fhir_epic,fhir_cerner,salesforce,slack" +) +os.environ["NW_CONFIG_PATH"] = str(_REPO_ROOT / "config" / "connectors.yaml") +os.environ["NW_REST_LOAD_DOTENV"] = "true" + +# Load .env before any app imports so connectors initialise with real credentials. +load_dotenv(override=False) + + +@pytest.fixture(scope="session") +def browser_type_launch_args(browser_type_launch_args): + """Override Playwright launch arguments dynamically via environment variables.""" + env_val = ( + os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") or os.getenv("PLAYWRIGHT_HEADLESS") + ) + is_headed = False + if env_val: + env_val_lower = env_val.lower().strip() + if env_val_lower in ("true", "1", "yes"): + is_headed = True + elif env_val_lower in ("false", "0", "no") and os.getenv("PLAYWRIGHT_HEADLESS"): + is_headed = True + return {**browser_type_launch_args, "headless": not is_headed} + + +@pytest.fixture(scope="session") +def api_server_url(): + """Start the real FastAPI server on a free port and yield its base URL. + + The playground UI is served at /playground/ and the scenarios API at + /scenarios/*, so browser fetch() calls with relative paths resolve + correctly without any Playwright route interception. + """ + import uvicorn # noqa: PLC0415 + from bindings.rest_api.app import app as rest_app # noqa: PLC0415 + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + config = uvicorn.Config(rest_app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base = f"http://127.0.0.1:{port}" + with httpx.Client(timeout=2) as probe: + for _ in range(60): + try: + probe.get(f"{base}/health") + break + except Exception: + time.sleep(0.3) + else: + pytest.fail("FastAPI server did not start within 18 seconds") + + yield base + + server.should_exit = True + thread.join(timeout=5) + + +@pytest.fixture +def playground_page(page, api_server_url: str): + """Navigate to the playground served by the real FastAPI server.""" + page.goto(f"{api_server_url}/playground/") + page.wait_for_load_state("domcontentloaded") + return page diff --git a/tests/playground/epic_fhir/README.md b/tests/playground/epic_fhir/README.md new file mode 100644 index 0000000..c9c78e7 --- /dev/null +++ b/tests/playground/epic_fhir/README.md @@ -0,0 +1,82 @@ + + +# Epic FHIR Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Epic FHIR (EHR) connector panel, click the run button with +the pre-filled defaults, and assert on the rendered pipeline state. No mocking +— every test hits the real Epic FHIR Sandbox API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_epic_fhir_post_consultation_default` | Post-consultation sync — pre-filled patient Jason Smith, all 4 steps must succeed | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/post-consultation")` +call routes to the real backend, which authenticates via private-key JWT and calls +the real Epic FHIR Sandbox. No `page.route()` interception. + +The form is pre-filled in the HTML with a sandbox patient (`e63wRTbPfr1p8UW81d8Seiw3`) +and encounter (`ecgXt3jVqNNpsXnNXZ3KljA3`) — no field changes or dropdown +selections are needed before clicking run. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Epic FHIR tests +uv run pytest tests/playground/epic_fhir/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/epic_fhir/ --no-cov -v -s +``` + +> **Note:** Epic FHIR tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `EPIC_CLIENT_ID` | Epic backend application client ID | +| `EPIC_PRIVATE_KEY` | RSA private key (PEM) used for private-key JWT auth | +| `EPIC_TOKEN_URL` | Epic token endpoint URL | +| `EPIC_KID` | Key ID (`kid`) that matches the public key registered in Epic | +| `EPIC_FHIR_BASE_URL` | Base FHIR R4 URL (defaults to the Epic Sandbox if unset) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +Epic FHIR tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `EPIC_CLIENT_ID` | `EPIC_CLIENT_ID` | +| `EPIC_PRIVATE_KEY` | `EPIC_PRIVATE_KEY` | +| `EPIC_TOKEN_URL` | `EPIC_TOKEN_URL` | +| `EPIC_KID` | `EPIC_KID` | +| `EPIC_FHIR_BASE_URL` | `EPIC_FHIR_BASE_URL` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The test uses the Epic open Sandbox patient and encounter IDs pre-filled in the +Playground HTML. These are read-only sandbox resources — no records are created +or modified in a real Epic environment. No cleanup is required after the session. diff --git a/tests/playground/epic_fhir/conftest.py b/tests/playground/epic_fhir/conftest.py new file mode 100644 index 0000000..6b6cde5 --- /dev/null +++ b/tests/playground/epic_fhir/conftest.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def epic_fhir_connector_available(api_server_url: str) -> None: + """Skip the entire Epic FHIR test session if the connector returns HTTP 500. + + This happens when Epic FHIR credentials are missing or when NW_ALLOWED_CONNECTORS + is set but does not include 'fhir_epic'. + """ + with httpx.Client(timeout=15) as client: + resp = client.post( + f"{api_server_url}/scenarios/post-consultation", + json={ + "patient_id": "e63wRTbPfr1p8UW81d8Seiw3", + "encounter_id": "ecgXt3jVqNNpsXnNXZ3KljA3", + "patient_given": "Jason", + "patient_family": "Smith", + "note_text": "health-check", + }, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"Epic FHIR connector not available ({detail}). " + "Ensure Epic credentials are configured and 'fhir_epic' is in NW_ALLOWED_CONNECTORS " + "(or leave it unset)." + ) diff --git a/tests/playground/epic_fhir/epic_fhir_page.py b/tests/playground/epic_fhir/epic_fhir_page.py new file mode 100644 index 0000000..0a65c4a --- /dev/null +++ b/tests/playground/epic_fhir/epic_fhir_page.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class EpicFhirPage: + """Page Object Model for the Epic FHIR (EHR) connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Epic FHIR card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='ehr']") + + # Panel root and header + self.panel: Locator = page.locator("#ehr-panel") + self.title: Locator = page.locator("#ehr-panel .card-title h2") + self.run_btn: Locator = page.locator("#run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Output and log elements + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Epic FHIR card in system connectors to open the panel.""" + self.connector_card.click() + + def submit(self) -> None: + """Submit the form to execute the Epic FHIR workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/epic_fhir/test_epic_fhir_integration.py b/tests/playground/epic_fhir/test_epic_fhir_integration.py new file mode 100644 index 0000000..c7b52ca --- /dev/null +++ b/tests/playground/epic_fhir/test_epic_fhir_integration.py @@ -0,0 +1,47 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Epic FHIR connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Epic FHIR panel, +clicks the run button with pre-filled defaults, and asserts the resulting +pipeline state — no API mocking, real Epic FHIR Sandbox calls. + +Required env vars (loaded from .env): + Epic FHIR credentials (e.g. EPIC_CLIENT_ID, EPIC_CLIENT_SECRET, EPIC_BASE_URL) +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.epic_fhir.epic_fhir_page import EpicFhirPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 25_000 # ms — 4-step pipeline with async Epic FHIR API calls + + +def _navigate_to_epic_fhir(page: Page) -> EpicFhirPage: + PlaygroundHomePage(page).click_connectors() + epic = EpicFhirPage(page) + epic.navigate_to_panel() + return epic + + +def test_epic_fhir_post_consultation_default(playground_page: Page) -> None: + """Submit a consultation with default pre-filled values; all 4 steps must succeed.""" + epic = _navigate_to_epic_fhir(playground_page) + epic.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(epic.final_result).to_be_visible(timeout=_TIMEOUT) + expect(epic.summary_text).to_contain_text("Epic") + expect(epic.result_tag).to_be_visible() + expect(playground_page.locator("#run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(epic.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() diff --git a/tests/playground/gdrive/README.md b/tests/playground/gdrive/README.md new file mode 100644 index 0000000..fe68471 --- /dev/null +++ b/tests/playground/gdrive/README.md @@ -0,0 +1,86 @@ + + +# Google Drive Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Google Drive connector panel, and assert on the rendered +pipeline state. No mocking — every test hits the real Google Drive API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_gdrive_list_files_default_page_size` | `files.list` — default page size | +| `test_gdrive_list_files_explicit_page_size` | `files.list` — explicit page_size=5 | +| `test_gdrive_list_files_with_query` | `files.list` — mimeType filter | +| `test_gdrive_get_file` | `files.get` — valid file ID with field mask | +| `test_gdrive_get_file_without_fields` | `files.get` — no fields mask | +| `test_gdrive_get_file_invalid_id` | `files.get` — nonexistent ID, expects error state | +| `test_gdrive_update_file_name` | `files.update` — rename file | +| `test_gdrive_update_file_name_and_mime` | `files.update` — rename + mime_type | +| `test_gdrive_upload_file` | `files.upload` — attach file, fill recipient, assert 4-step pipeline | +| `test_gdrive_upload_remove_and_reattach` | `files.upload` — remove attachment UI, re-attach | +| `test_gdrive_switch_list_then_get` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/gdrive-archival")` +calls route to the real backend, which calls the real Google Drive API. +No `page.route()` interception. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all GDrive tests +uv run pytest tests/playground/gdrive/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/gdrive/ --no-cov -v -s +``` + +> **Note:** GDrive tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `GOOGLE_DRIVE_SA_JSON` | Service-account JSON (path to file or full JSON string inline) | +| `GOOGLE_DRIVE_FOLDER_ID` | Google Drive folder ID where test files are uploaded | +| `GDRIVE_TEST_RECIPIENT_EMAIL` | Sharing recipient email for upload tests (default: `test@mailinator.com`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +GDrive tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `GOOGLE_DRIVE_SA_JSON` | `GOOGLE_DRIVE_SA_JSON` | +| `GOOGLE_DRIVE_FOLDER_ID` | `GOOGLE_DRIVE_FOLDER_ID` | +| `GDRIVE_TEST_RECIPIENT_EMAIL` | `GDRIVE_TEST_RECIPIENT_EMAIL` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The `uploaded_test_file_id` session fixture uploads a small file +(`nw-integration-test.txt`) to Google Drive once per test session. This file +is **not automatically deleted** after the tests finish — clean it up manually +via the Google Drive UI if needed. + +The `files.update` tests rename this file but do not delete it. diff --git a/tests/playground/gdrive/conftest.py b/tests/playground/gdrive/conftest.py new file mode 100644 index 0000000..3c40aec --- /dev/null +++ b/tests/playground/gdrive/conftest.py @@ -0,0 +1,68 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import base64 +import os + +import httpx +import pytest + +_TEST_RECIPIENT_EMAIL = os.environ.get("GDRIVE_TEST_RECIPIENT_EMAIL", "test@mailinator.com") + + +@pytest.fixture(scope="session") +def real_gdrive_file_id(api_server_url: str) -> str: + """Return a real Google Drive file ID by listing the Drive via the API. + + Skips the test if no files exist in the configured Drive folder. + """ + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/gdrive-archival", + json={"action": "files.list", "list_page_size": 5}, + ) + resp.raise_for_status() + data = resp.json() + files = data.get("steps", [{}])[0].get("data", {}).get("raw", {}).get("files", []) + if not files: + pytest.skip("No files found in Google Drive — skipping tests that need a real file ID") + return files[0]["id"] + + +@pytest.fixture(scope="session") +def uploaded_test_file_id(api_server_url: str) -> str: + """Upload a small test file to Google Drive once per session and return its ID. + + Used by files.update tests so they operate on a disposable file. + Note: the file is left in Google Drive after the session (manual cleanup needed). + """ + content = b"node-wire integration test file - safe to delete" + with httpx.Client(timeout=60) as client: + resp = client.post( + f"{api_server_url}/scenarios/gdrive-archival", + json={ + "action": "files.upload", + "document_name": "nw-integration-test.txt", + "recipient_email": _TEST_RECIPIENT_EMAIL, + "file_base64": base64.b64encode(content).decode(), + "file_mime_type": "text/plain", + }, + ) + resp.raise_for_status() + data = resp.json() + file_id = data.get("final_resource_id") + if not file_id: + pytest.skip( + f"Setup upload failed — cannot run update tests. " + f"Error: {data.get('error_message') or 'no file_id returned'}" + ) + return file_id + + +@pytest.fixture(scope="session") +def test_recipient_email() -> str: + """Email address used as the sharing recipient in upload tests.""" + return _TEST_RECIPIENT_EMAIL diff --git a/tests/playground/gdrive/gdrive_page.py b/tests/playground/gdrive/gdrive_page.py new file mode 100644 index 0000000..46ae6de --- /dev/null +++ b/tests/playground/gdrive/gdrive_page.py @@ -0,0 +1,140 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class GoogleDrivePage: + """Page Object Model for the Google Drive connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Google Drive card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='gdrive']") + + # Panel root and main headers + self.panel: Locator = page.locator("#gdrive-panel") + self.title: Locator = page.locator("#gdrive-panel .card-title h2") + self.action_select: Locator = page.locator("#gdrive-action-select") + self.run_btn: Locator = page.locator("#gdrive-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # --- files.upload action elements --- + self.upload_section: Locator = page.locator("#gdrive-upload-only") + self.recipient_email: Locator = page.locator( + "#gdrive-upload-only input[name='recipient_email']" + ) + self.doc_name_group: Locator = page.locator("#gdrive-doc-name-group") + self.document_name: Locator = page.locator( + "#gdrive-doc-name-group input[name='document_name']" + ) + self.file_section: Locator = page.locator("#gdrive-file-section") + self.file_input: Locator = page.locator("#gdrive-file") + self.file_drop_zone: Locator = page.locator("#file-drop-zone") + self.file_chosen_preview: Locator = page.locator("#file-chosen-preview") + self.preview_name: Locator = page.locator("#file-chosen-preview .preview-name") + self.remove_file_btn: Locator = page.locator("#file-chosen-preview .remove-file-btn") + + # --- files.get action elements --- + self.get_section: Locator = page.locator("#gdrive-get-only") + self.get_file_id: Locator = page.locator("#gdrive-get-only input[name='get_file_id']") + self.get_fields: Locator = page.locator("#gdrive-get-only input[name='get_fields']") + + # --- files.update action elements --- + self.update_section: Locator = page.locator("#gdrive-update-only") + self.update_file_id: Locator = page.locator( + "#gdrive-update-only input[name='update_file_id']" + ) + self.update_name: Locator = page.locator("#gdrive-update-only input[name='update_name']") + self.update_mime_type: Locator = page.locator( + "#gdrive-update-only input[name='update_mime_type']" + ) + self.update_add_parents: Locator = page.locator( + "#gdrive-update-only input[name='update_add_parents']" + ) + self.update_remove_parents: Locator = page.locator( + "#gdrive-update-only input[name='update_remove_parents']" + ) + + # --- files.list action elements --- + self.list_section: Locator = page.locator("#gdrive-list-only") + self.list_page_size: Locator = page.locator( + "#gdrive-list-only input[name='list_page_size']" + ) + self.list_query: Locator = page.locator("#gdrive-list-only input[name='list_query']") + self.list_fields: Locator = page.locator("#gdrive-list-only input[name='list_fields']") + + # --- Output and Logs elements --- + self.pipeline_steps: Locator = page.locator(".flow-node") + self.step_nodes: list[Locator] = [page.locator(f"#step-{i}") for i in range(4)] + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Google Drive card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the action via the select element.""" + self.action_select.select_option(action) + + def fill_upload_fields(self, recipient_email: str, doc_name: str | None = None) -> None: + """Fill upload parameters.""" + self.recipient_email.fill(recipient_email) + if doc_name is not None: + # First ensure doc_name field is shown by switching to Write Note/sub-mode if needed, + # or fill directly if exposed. + self.document_name.fill(doc_name) + + def fill_get_fields(self, file_id: str, fields: str | None = None) -> None: + """Fill get parameters.""" + self.get_file_id.fill(file_id) + if fields is not None: + self.get_fields.fill(fields) + + def fill_update_fields( + self, + file_id: str, + new_name: str | None = None, + mime_type: str | None = None, + add_parents: str | None = None, + remove_parents: str | None = None, + ) -> None: + """Fill update parameters.""" + self.update_file_id.fill(file_id) + if new_name is not None: + self.update_name.fill(new_name) + if mime_type is not None: + self.update_mime_type.fill(mime_type) + if add_parents is not None: + self.update_add_parents.fill(add_parents) + if remove_parents is not None: + self.update_remove_parents.fill(remove_parents) + + def fill_list_fields( + self, + page_size: int | None = None, + query: str | None = None, + fields: str | None = None, + ) -> None: + """Fill list parameters.""" + if page_size is not None: + self.list_page_size.fill(str(page_size)) + if query is not None: + self.list_query.fill(query) + if fields is not None: + self.list_fields.fill(fields) + + def submit(self) -> None: + """Submit the form to execute the archival/orchestration workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/gdrive/test_gdrive_integration.py b/tests/playground/gdrive/test_gdrive_integration.py new file mode 100644 index 0000000..4d2e0a9 --- /dev/null +++ b/tests/playground/gdrive/test_gdrive_integration.py @@ -0,0 +1,268 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Google Drive connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Google Drive panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Google Drive calls. + +Required env vars (loaded from .env): + GOOGLE_DRIVE_SA_JSON — service-account JSON (path or inline JSON) + GOOGLE_DRIVE_FOLDER_ID — target folder for uploads + GDRIVE_TEST_RECIPIENT_EMAIL — email used as sharing recipient (default: test@mailinator.com) +""" + +from __future__ import annotations + +import tempfile + +from playwright.sync_api import Page, expect + +from tests.playground.gdrive.gdrive_page import GoogleDrivePage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT_STEP = 20_000 # ms — single-step operations (list, get) +_TIMEOUT_MULTI = 45_000 # ms — multi-step operations (upload, update) + + +def _navigate_to_gdrive(page: Page) -> GoogleDrivePage: + PlaygroundHomePage(page).click_connectors() + gdrive = GoogleDrivePage(page) + gdrive.navigate_to_panel() + return gdrive + + +# ── files.list ──────────────────────────────────────────────────────────────── + + +def test_gdrive_list_files_default_page_size(playground_page: Page) -> None: + """List files with the default page size; assert the pipeline step succeeds.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.list") + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("file(s)") + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_gdrive_list_files_explicit_page_size(playground_page: Page) -> None: + """List files with page_size=5; summary must mention the requested page size.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.list") + gdrive.fill_list_fields(page_size=5) + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("page size 5") + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_gdrive_list_files_with_query(playground_page: Page) -> None: + """List files filtered by mimeType query; step label and success state must appear.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.list") + gdrive.fill_list_fields(page_size=10, query="mimeType='text/plain'") + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── files.get ───────────────────────────────────────────────────────────────── + + +def test_gdrive_get_file(playground_page: Page, real_gdrive_file_id: str) -> None: + """Retrieve metadata for a real file; assert single-step success and result card.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.get") + gdrive.fill_get_fields(real_gdrive_file_id, "id,name,mimeType") + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("Google Drive file metadata") + expect(gdrive.result_tag).to_contain_text(real_gdrive_file_id) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_gdrive_get_file_without_fields(playground_page: Page, real_gdrive_file_id: str) -> None: + """files.get without a fields mask; Drive returns default metadata fields.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.get") + gdrive.fill_get_fields(real_gdrive_file_id) # no fields argument + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +def test_gdrive_get_file_invalid_id(playground_page: Page) -> None: + """files.get with a nonexistent ID; the pipeline step must show the error state.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.get") + gdrive.fill_get_fields("this-id-does-not-exist-9999999999") + gdrive.submit() + + expect(playground_page.locator("#step-0.error")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_hidden() + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(gdrive.log_terminal).to_contain_text("FAILED") + + maybe_sleep() + + +# ── files.update ────────────────────────────────────────────────────────────── + + +def test_gdrive_update_file_name(playground_page: Page, uploaded_test_file_id: str) -> None: + """Rename the integration-test file; assert all 4 update pipeline steps succeed.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.update") + gdrive.fill_update_fields( + file_id=uploaded_test_file_id, + new_name="nw-integration-test-renamed.txt", + ) + gdrive.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT_MULTI) + + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) + expect(gdrive.summary_text).to_contain_text("Updated") + expect(gdrive.result_tag).to_contain_text(uploaded_test_file_id) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_gdrive_update_file_name_and_mime( + playground_page: Page, uploaded_test_file_id: str +) -> None: + """Update both the file name and mime_type; all 4 steps must succeed.""" + gdrive = _navigate_to_gdrive(playground_page) + + gdrive.select_action("files.update") + gdrive.fill_update_fields( + file_id=uploaded_test_file_id, + new_name="nw-integration-test-v2.txt", + mime_type="text/plain", + ) + gdrive.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT_MULTI) + + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── files.upload ────────────────────────────────────────────────────────────── + + +def test_gdrive_upload_file(playground_page: Page, test_recipient_email: str) -> None: + """Attach a temp file, fill recipient email, submit, assert all 4 steps succeed.""" + gdrive = _navigate_to_gdrive(playground_page) + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_ui_test_") as tmp: + tmp.write(b"Integration test document - uploaded via Playwright UI test.") + tmp_path = tmp.name + + gdrive.file_input.set_input_files(tmp_path) + expect(gdrive.file_chosen_preview).to_be_visible(timeout=3_000) + expect(gdrive.file_drop_zone).to_be_hidden() + expect(gdrive.preview_name).to_contain_text("nw_ui_test_") + + gdrive.fill_upload_fields(test_recipient_email) + gdrive.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT_MULTI) + + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_MULTI) + expect(gdrive.summary_text).to_contain_text("archived to Google Drive") + expect(gdrive.summary_text).to_contain_text(test_recipient_email) + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(gdrive.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_gdrive_upload_remove_and_reattach(playground_page: Page) -> None: + """Remove an attached file → drop zone reappears; re-attach → preview is restored.""" + gdrive = _navigate_to_gdrive(playground_page) + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_reattach_") as tmp: + tmp.write(b"Reattach UI test content - safe to delete") + tmp_path = tmp.name + + # Attach + gdrive.file_input.set_input_files(tmp_path) + expect(gdrive.file_chosen_preview).to_be_visible(timeout=3_000) + expect(gdrive.file_drop_zone).to_be_hidden() + + # Remove + gdrive.remove_file_btn.click() + expect(gdrive.file_chosen_preview).to_be_hidden(timeout=3_000) + expect(gdrive.file_drop_zone).to_be_visible() + + # Re-attach + gdrive.file_input.set_input_files(tmp_path) + expect(gdrive.file_chosen_preview).to_be_visible(timeout=3_000) + expect(gdrive.preview_name).to_contain_text("nw_reattach_") + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_gdrive_switch_list_then_get(playground_page: Page, real_gdrive_file_id: str) -> None: + """Run files.list, switch to files.get on the same page — both must complete successfully.""" + gdrive = _navigate_to_gdrive(playground_page) + + # First run: files.list + gdrive.select_action("files.list") + gdrive.submit() + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("file(s)") + + # Switch action and run again + gdrive.select_action("files.get") + gdrive.fill_get_fields(real_gdrive_file_id) + gdrive.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.final_result).to_be_visible(timeout=_TIMEOUT_STEP) + expect(gdrive.summary_text).to_contain_text("Google Drive file metadata") + expect(playground_page.locator("#gdrive-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() diff --git a/tests/playground/home_page.py b/tests/playground/home_page.py new file mode 100644 index 0000000..08910a6 --- /dev/null +++ b/tests/playground/home_page.py @@ -0,0 +1,69 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class PlaygroundHomePage: + """Page Object Model for the node-wire Playground Home (landing/selection) page.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Section views + self.root_selection_view: Locator = page.locator("#root-selection-view") + self.main_layout: Locator = page.locator(".layout-main") + + # Header components + self.brand_header: Locator = page.locator(".dashboard-header") + self.brand_title: Locator = page.locator(".brand-text h1") + self.tagline: Locator = page.locator(".tagline") + self.header_actions: Locator = page.locator("#header-actions") + + # Selection Cards + self.selection_cards: Locator = page.locator(".selection-card") + + # Agentic Workflow Card + self.agentic_card: Locator = page.locator(".selection-card.card-mcp") + self.agentic_card_title: Locator = self.agentic_card.locator("h3") + self.agentic_card_desc: Locator = self.agentic_card.locator("p") + + # Connectors Card + self.connectors_card: Locator = page.locator(".selection-card.card-connectors") + self.connectors_card_title: Locator = self.connectors_card.locator("h3") + self.connectors_card_desc: Locator = self.connectors_card.locator("p") + + # Connector Apps Card + self.connector_apps_card: Locator = page.locator(".selection-card.card-apps-directory") + self.connector_apps_card_title: Locator = self.connector_apps_card.locator("h3") + self.connector_apps_card_desc: Locator = self.connector_apps_card.locator("p") + + # Connector Apps sub-menu view + self.connector_apps_view: Locator = page.locator("#connector-apps-selection-view") + self.apps_back_btn: Locator = page.locator("#apps-back-btn") + + # Navigation + self.back_selection_btn: Locator = page.locator("#back-selection-btn") + + def click_agentic_workflow(self) -> None: + """Click the Agentic Workflow (MCP) selection card to navigate to the agent view.""" + self.agentic_card.click() + + def click_connectors(self) -> None: + """Click the Connectors selection card to navigate to the clinical workflows view.""" + self.connectors_card.click() + + def click_connector_apps(self) -> None: + """Click the Connector Apps selection card to navigate to the apps sub-menu.""" + self.connector_apps_card.click() + + def go_back_from_apps(self) -> None: + """Click the back button inside the Connector Apps sub-menu.""" + self.apps_back_btn.click() + + def go_back_to_selection(self) -> None: + """Click the back button to return to the selection page.""" + self.back_selection_btn.click() diff --git a/tests/playground/http_connector/README.md b/tests/playground/http_connector/README.md new file mode 100644 index 0000000..a1d8912 --- /dev/null +++ b/tests/playground/http_connector/README.md @@ -0,0 +1,68 @@ + + +# HTTP Connector Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the HTTP connector (IT Ops) panel, click the run button with the +pre-filled defaults, and assert on the rendered pipeline state. No mocking — +every test makes a real HTTP POST via the `http_generic` connector. + +## What is tested + +| Test | Action | +|------|--------| +| `test_http_connector_submit_incident_default` | IT incident report — pre-filled High severity Gateway Proxy incident, all 4 steps must succeed | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/report-incident")` +call routes to the real backend, which formats an ITSM payload and dispatches it +via `http_generic` to `https://httpbin.org/post` — a public echo endpoint. No +`page.route()` interception. + +The form is pre-filled in the HTML with a sample incident (title, description, +severity, component, reporter) — no field changes or dropdown selections are +needed before clicking run. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all HTTP connector tests +uv run pytest tests/playground/http_connector/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/http_connector/ --no-cov -v -s +``` + +> **Note:** HTTP connector tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +No connector credentials are needed — `http_generic` dispatches to the public +`httpbin.org` endpoint. + +| Variable | Description | +|----------|-------------| +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +HTTP connector tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +No secrets are required for this connector beyond the standard auth bypass flag. + +## Test data and cleanup + +All requests are sent to `https://httpbin.org/post`, which echoes the payload +and discards it. No records are persisted anywhere. No cleanup is required after +the session. diff --git a/tests/playground/http_connector/conftest.py b/tests/playground/http_connector/conftest.py new file mode 100644 index 0000000..e9ae907 --- /dev/null +++ b/tests/playground/http_connector/conftest.py @@ -0,0 +1,33 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def http_connector_available(api_server_url: str) -> None: + """Skip the entire HTTP connector test session if the connector returns HTTP 500. + + This happens when NW_ALLOWED_CONNECTORS is set but does not include 'http_generic'. + """ + with httpx.Client(timeout=15) as client: + resp = client.post( + f"{api_server_url}/scenarios/report-incident", + json={ + "title": "health-check", + "severity": "HIGH", + "component": "Gateway Proxy", + "description": "health-check", + "reported_by": "DevOps Team Alpha", + }, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"HTTP connector not available ({detail}). " + "Ensure 'http_generic' is in NW_ALLOWED_CONNECTORS (or leave it unset)." + ) diff --git a/tests/playground/http_connector/http_connector_page.py b/tests/playground/http_connector/http_connector_page.py new file mode 100644 index 0000000..1d08758 --- /dev/null +++ b/tests/playground/http_connector/http_connector_page.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class HttpConnectorPage: + """Page Object Model for the HTTP connector (IT Ops) panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the HTTP connector card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='itops']") + + # Panel root and header + self.panel: Locator = page.locator("#itops-panel") + self.title: Locator = page.locator("#itops-panel .card-title h2") + self.run_btn: Locator = page.locator("#itops-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Output and log elements + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the HTTP connector card in system connectors to open the panel.""" + self.connector_card.click() + + def submit(self) -> None: + """Submit the form to execute the HTTP connector workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/http_connector/test_http_connector_integration.py b/tests/playground/http_connector/test_http_connector_integration.py new file mode 100644 index 0000000..12398ea --- /dev/null +++ b/tests/playground/http_connector/test_http_connector_integration.py @@ -0,0 +1,46 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""HTTP connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the HTTP connector (IT Ops) +panel, clicks the run button with pre-filled defaults, and asserts the +resulting pipeline state — no API mocking, real HTTP calls to httpbin.org. + +No credentials required; http_generic uses a public endpoint. +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.http_connector.http_connector_page import HttpConnectorPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 20_000 # ms — 4-step pipeline with httpbin.org calls + + +def _navigate_to_http_connector(page: Page) -> HttpConnectorPage: + PlaygroundHomePage(page).click_connectors() + http = HttpConnectorPage(page) + http.navigate_to_panel() + return http + + +def test_http_connector_submit_incident_default(playground_page: Page) -> None: + """Submit an IT incident with default pre-filled values; all 4 steps must succeed.""" + http = _navigate_to_http_connector(playground_page) + http.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(http.final_result).to_be_visible(timeout=_TIMEOUT) + expect(http.summary_text).to_contain_text("IT Incident") + expect(http.result_tag).to_be_visible() + expect(playground_page.locator("#itops-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(http.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() diff --git a/tests/playground/salesforce/README.md b/tests/playground/salesforce/README.md new file mode 100644 index 0000000..ed67970 --- /dev/null +++ b/tests/playground/salesforce/README.md @@ -0,0 +1,108 @@ + + +# Salesforce CRM Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Salesforce CRM connector panel, and assert on the rendered +pipeline state. No mocking — every test hits the real Salesforce API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_sf_create_lead_minimal` | `create_lead` — required fields only (LastName + Company) | +| `test_sf_create_lead_full` | `create_lead` — with first name and email | +| `test_sf_create_contact_minimal` | `create_contact` — required field only (LastName) | +| `test_sf_create_contact_with_email` | `create_contact` — with first name and email | +| `test_sf_read_lead` | `read_lead` — valid Lead ID, asserts success state | +| `test_sf_read_lead_invalid_id` | `read_lead` — nonexistent ID, expects error state | +| `test_sf_read_contact` | `read_contact` — valid Contact ID, asserts success state | +| `test_sf_read_contact_invalid_id` | `read_contact` — nonexistent ID, expects error state | +| `test_sf_update_lead` | `update_lead` — rename + company change | +| `test_sf_update_lead_email` | `update_lead` — email-only update | +| `test_sf_update_contact` | `update_contact` — name update | +| `test_sf_update_contact_email` | `update_contact` — email-only update | +| `test_sf_delete_lead` | `delete_lead` — delete a freshly created Lead | +| `test_sf_delete_contact` | `delete_contact` — delete a freshly created Contact | +| `test_sf_switch_create_lead_to_read` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/salesforce-*")` +calls route to the real backend, which calls the real Salesforce API via OAuth2 +refresh token. No `page.route()` interception. + +Session fixtures create Lead and Contact records once via the REST API for use +across read and update tests. Delete tests each create their own fresh record. +All generated names and emails use random suffixes (e.g. `Lead839201`, +`test748203@mailinator.com`) so repeated runs never collide on duplicate values. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Salesforce tests +uv run pytest tests/playground/salesforce/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/salesforce/ --no-cov -v -s + +# Run a single test +uv run pytest tests/playground/salesforce/ --no-cov -v -k test_sf_create_lead_minimal +``` + +> **Note:** Salesforce tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `SALESFORCE_INSTANCE_URL` | Your Salesforce org URL, e.g. `https://orgname.my.salesforce.com` | +| `SALESFORCE_TOKEN_URL` | OAuth2 token endpoint, e.g. `https://login.salesforce.com/services/oauth2/token` | +| `SALESFORCE_CLIENT_ID` | Connected App client ID | +| `SALESFORCE_CLIENT_SECRET` | Connected App client secret | +| `SALESFORCE_REFRESH_TOKEN` | Long-lived OAuth2 refresh token | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## CI / GitHub Actions + +Salesforce tests run **only on manual `workflow_dispatch`** trigger +(`Actions → CI – Pytest → Run workflow`). + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `SALESFORCE_INSTANCE_URL` | `SALESFORCE_INSTANCE_URL` | +| `SALESFORCE_TOKEN_URL` | `SALESFORCE_TOKEN_URL` | +| `SALESFORCE_CLIENT_ID` | `SALESFORCE_CLIENT_ID` | +| `SALESFORCE_CLIENT_SECRET` | `SALESFORCE_CLIENT_SECRET` | +| `SALESFORCE_REFRESH_TOKEN` | `SALESFORCE_REFRESH_TOKEN` | + +Set these under **Settings → Secrets and variables → Actions** before triggering +the workflow. + +## Test data and cleanup + +The `real_sf_lead_id` and `real_sf_contact_id` session fixtures create one Lead +and one Contact in Salesforce at the start of the test session. These records are +**not automatically deleted** after the tests finish — clean them up manually via +the Salesforce UI or Developer Console if needed. Look for records with names +matching the pattern `IntegLead` and `IntegContact`. + +The `deletable_lead_id` and `deletable_contact_id` fixtures create a fresh record +per delete test and those records are consumed (deleted) by the test itself. + +Update tests mutate the session-scoped records in place (name, email). Because +Salesforce does not enforce unique constraints on Lead/Contact names, this is safe +to run multiple times without conflicts. diff --git a/tests/playground/salesforce/__init__.py b/tests/playground/salesforce/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/playground/salesforce/conftest.py b/tests/playground/salesforce/conftest.py new file mode 100644 index 0000000..cb7f4a5 --- /dev/null +++ b/tests/playground/salesforce/conftest.py @@ -0,0 +1,80 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import httpx +import pytest + +from tests.playground.salesforce.helpers import rnd as _rnd + + +def _create_lead(api_server_url: str, last_name: str, company: str) -> str: + """Create a Salesforce Lead via the REST API and return its record ID.""" + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/salesforce-create-lead", + json={"last_name": last_name, "company": company}, + ) + resp.raise_for_status() + data = resp.json() + record_id = data.get("final_resource_id") + if not record_id: + pytest.skip( + f"Salesforce Lead creation failed — cannot run dependent tests. " + f"Error: {data.get('error_message') or 'no record ID returned'}" + ) + return record_id + + +def _create_contact(api_server_url: str, last_name: str) -> str: + """Create a Salesforce Contact via the REST API and return its record ID.""" + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/salesforce-create-contact", + json={"last_name": last_name}, + ) + resp.raise_for_status() + data = resp.json() + record_id = data.get("final_resource_id") + if not record_id: + pytest.skip( + f"Salesforce Contact creation failed — cannot run dependent tests. " + f"Error: {data.get('error_message') or 'no record ID returned'}" + ) + return record_id + + +@pytest.fixture(scope="session") +def real_sf_lead_id(api_server_url: str) -> str: + """Create a Salesforce Lead once per session for read and update tests. + + The Lead is left in Salesforce after the session (manual cleanup needed). + """ + return _create_lead(api_server_url, f"IntegLead{_rnd()}", f"Corp{_rnd()}") + + +@pytest.fixture(scope="session") +def real_sf_contact_id(api_server_url: str) -> str: + """Create a Salesforce Contact once per session for read and update tests. + + The Contact is left in Salesforce after the session (manual cleanup needed). + """ + return _create_contact(api_server_url, f"IntegContact{_rnd()}") + + +@pytest.fixture +def deletable_lead_id(api_server_url: str) -> str: + """Create a fresh Salesforce Lead per test for delete tests. + + Each invocation creates a new record so the delete test always operates + on an existing record. + """ + return _create_lead(api_server_url, f"DelLead{_rnd()}", f"Corp{_rnd()}") + + +@pytest.fixture +def deletable_contact_id(api_server_url: str) -> str: + """Create a fresh Salesforce Contact per test for delete tests.""" + return _create_contact(api_server_url, f"DelContact{_rnd()}") diff --git a/tests/playground/salesforce/helpers.py b/tests/playground/salesforce/helpers.py new file mode 100644 index 0000000..ac20088 --- /dev/null +++ b/tests/playground/salesforce/helpers.py @@ -0,0 +1,15 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import random + + +def rnd() -> str: + return str(random.randint(100_000, 999_999)) + + +def random_email() -> str: + return f"test{rnd()}@mailinator.com" diff --git a/tests/playground/salesforce/salesforce_page.py b/tests/playground/salesforce/salesforce_page.py new file mode 100644 index 0000000..344a531 --- /dev/null +++ b/tests/playground/salesforce/salesforce_page.py @@ -0,0 +1,153 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class SalesforcePage: + """Page Object Model for the Salesforce CRM connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Connector card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='salesforce']") + + # Panel root and top-level controls + self.panel: Locator = page.locator("#salesforce-panel") + self.action_select: Locator = page.locator("#salesforce-action-select") + self.run_btn: Locator = page.locator("#salesforce-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # --- Lead section (create_lead / update_lead) --- + self.lead_section: Locator = page.locator("#salesforce-section-lead") + self.lead_id: Locator = page.locator("#salesforce-section-lead input[name='lead_id']") + self.lead_first_name: Locator = page.locator( + "#salesforce-section-lead input[name='lead_first_name']" + ) + self.lead_last_name: Locator = page.locator( + "#salesforce-section-lead input[name='lead_last_name']" + ) + self.lead_company: Locator = page.locator( + "#salesforce-section-lead input[name='lead_company']" + ) + self.lead_email: Locator = page.locator("#salesforce-section-lead input[name='lead_email']") + + # --- Contact section (create_contact / update_contact) --- + self.contact_section: Locator = page.locator("#salesforce-section-contact") + self.contact_id: Locator = page.locator( + "#salesforce-section-contact input[name='contact_id']" + ) + self.contact_first_name: Locator = page.locator( + "#salesforce-section-contact input[name='contact_first_name']" + ) + self.contact_last_name: Locator = page.locator( + "#salesforce-section-contact input[name='contact_last_name']" + ) + self.contact_email: Locator = page.locator( + "#salesforce-section-contact input[name='contact_email']" + ) + self.contact_account_id: Locator = page.locator( + "#salesforce-section-contact input[name='contact_account_id']" + ) + + # --- Generic ID section (read_lead / read_contact / delete_lead / delete_contact) --- + self.id_only_section: Locator = page.locator("#salesforce-section-id-only") + self.generic_record_id: Locator = page.locator( + "#salesforce-section-id-only input[name='generic_record_id']" + ) + + # --- Output / log elements (shared across connectors) --- + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Salesforce card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the CRM action via the select element.""" + self.action_select.select_option(action) + + def fill_lead_fields( + self, + last_name: str, + company: str, + first_name: str | None = None, + email: str | None = None, + ) -> None: + """Fill Lead create form fields.""" + self.lead_last_name.fill(last_name) + self.lead_company.fill(company) + if first_name is not None: + self.lead_first_name.fill(first_name) + if email is not None: + self.lead_email.fill(email) + + def fill_lead_update_fields( + self, + record_id: str, + last_name: str | None = None, + company: str | None = None, + first_name: str | None = None, + email: str | None = None, + ) -> None: + """Fill Lead update form fields (record ID + any changed fields).""" + self.lead_id.fill(record_id) + if last_name is not None: + self.lead_last_name.fill(last_name) + if company is not None: + self.lead_company.fill(company) + if first_name is not None: + self.lead_first_name.fill(first_name) + if email is not None: + self.lead_email.fill(email) + + def fill_contact_fields( + self, + last_name: str, + first_name: str | None = None, + email: str | None = None, + account_id: str | None = None, + ) -> None: + """Fill Contact create form fields.""" + self.contact_last_name.fill(last_name) + if first_name is not None: + self.contact_first_name.fill(first_name) + if email is not None: + self.contact_email.fill(email) + if account_id is not None: + self.contact_account_id.fill(account_id) + + def fill_contact_update_fields( + self, + record_id: str, + last_name: str | None = None, + first_name: str | None = None, + email: str | None = None, + ) -> None: + """Fill Contact update form fields (record ID + any changed fields).""" + self.contact_id.fill(record_id) + if last_name is not None: + self.contact_last_name.fill(last_name) + if first_name is not None: + self.contact_first_name.fill(first_name) + if email is not None: + self.contact_email.fill(email) + + def fill_id_only(self, record_id: str) -> None: + """Fill the generic record ID field used by read/delete actions.""" + self.generic_record_id.fill(record_id) + + def submit(self) -> None: + """Click the run button to execute the selected CRM action.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to the connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/salesforce/test_salesforce_integration.py b/tests/playground/salesforce/test_salesforce_integration.py new file mode 100644 index 0000000..4caad9e --- /dev/null +++ b/tests/playground/salesforce/test_salesforce_integration.py @@ -0,0 +1,347 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Salesforce CRM connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Salesforce panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Salesforce calls. + +Required env vars (loaded from .env): + SALESFORCE_INSTANCE_URL — https://.my.salesforce.com + SALESFORCE_TOKEN_URL — OAuth2 token endpoint + SALESFORCE_CLIENT_ID — Connected App client ID + SALESFORCE_CLIENT_SECRET — Connected App client secret + SALESFORCE_REFRESH_TOKEN — Long-lived refresh token +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.salesforce.helpers import rnd as _rnd, random_email as _email +from tests.playground.salesforce.salesforce_page import SalesforcePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 20_000 # ms — all Salesforce operations are single-step + + +def _navigate_to_salesforce(page: Page) -> SalesforcePage: + PlaygroundHomePage(page).click_connectors() + sf = SalesforcePage(page) + sf.navigate_to_panel() + return sf + + +# ── create_lead ─────────────────────────────────────────────────────────────── + + +def test_sf_create_lead_minimal(playground_page: Page) -> None: + """Create a Lead with only the required fields (LastName + Company).""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_lead") + sf.fill_lead_fields(last_name=f"Lead{_rnd()}", company=f"Corp{_rnd()}") + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Lead created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_sf_create_lead_full(playground_page: Page) -> None: + """Create a Lead with first name and email in addition to required fields.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_lead") + sf.fill_lead_fields( + last_name=f"Lead{_rnd()}", + company=f"Corp{_rnd()}", + first_name="John", + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Lead created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── create_contact ──────────────────────────────────────────────────────────── + + +def test_sf_create_contact_minimal(playground_page: Page) -> None: + """Create a Contact with only the required LastName field.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_contact") + sf.fill_contact_fields(last_name=f"Contact{_rnd()}") + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Contact created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_sf_create_contact_with_email(playground_page: Page) -> None: + """Create a Contact with first name and email.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("create_contact") + sf.fill_contact_fields( + last_name=f"Contact{_rnd()}", + first_name="Jane", + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Contact created successfully") + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── read_lead ───────────────────────────────────────────────────────────────── + + +def test_sf_read_lead(playground_page: Page, real_sf_lead_id: str) -> None: + """Retrieve metadata for a real Lead; assert single-step success and result card.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_lead") + sf.fill_id_only(real_sf_lead_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(real_sf_lead_id) + expect(sf.result_tag).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_sf_read_lead_invalid_id(playground_page: Page) -> None: + """read_lead with a nonexistent ID; pipeline step must show the error state.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_lead") + sf.fill_id_only("00Q000000000001AAA") + sf.submit() + + expect(playground_page.locator("#step-0.error")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_hidden() + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(sf.log_terminal).to_contain_text("FAILED") + + maybe_sleep() + + +# ── read_contact ────────────────────────────────────────────────────────────── + + +def test_sf_read_contact(playground_page: Page, real_sf_contact_id: str) -> None: + """Retrieve metadata for a real Contact; assert single-step success and result card.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_contact") + sf.fill_id_only(real_sf_contact_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(real_sf_contact_id) + expect(sf.result_tag).to_contain_text(real_sf_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_sf_read_contact_invalid_id(playground_page: Page) -> None: + """read_contact with a nonexistent ID; pipeline step must show the error state.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("read_contact") + sf.fill_id_only("003000000000001AAA") + sf.submit() + + expect(playground_page.locator("#step-0.error")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_hidden() + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(sf.log_terminal).to_contain_text("FAILED") + + maybe_sleep() + + +# ── update_lead ─────────────────────────────────────────────────────────────── + + +def test_sf_update_lead(playground_page: Page, real_sf_lead_id: str) -> None: + """Update a Lead's last name; assert single-step success and summary contains the record ID.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_lead") + sf.fill_lead_update_fields( + record_id=real_sf_lead_id, + last_name=f"Lead{_rnd()}", + company=f"Corp{_rnd()}", + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("updated successfully") + expect(sf.result_tag).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_sf_update_lead_email(playground_page: Page, real_sf_lead_id: str) -> None: + """Update only a Lead's email; assert success with result ID.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_lead") + sf.fill_lead_update_fields( + record_id=real_sf_lead_id, + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.result_tag).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── update_contact ──────────────────────────────────────────────────────────── + + +def test_sf_update_contact(playground_page: Page, real_sf_contact_id: str) -> None: + """Update a Contact's name; assert single-step success and summary contains the record ID.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_contact") + sf.fill_contact_update_fields( + record_id=real_sf_contact_id, + last_name=f"Contact{_rnd()}", + first_name="Updated", + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("updated successfully") + expect(sf.result_tag).to_contain_text(real_sf_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_sf_update_contact_email(playground_page: Page, real_sf_contact_id: str) -> None: + """Update only a Contact's email; assert success.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("update_contact") + sf.fill_contact_update_fields( + record_id=real_sf_contact_id, + email=_email(), + ) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.result_tag).to_contain_text(real_sf_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── delete_lead ─────────────────────────────────────────────────────────────── + + +def test_sf_delete_lead(playground_page: Page, deletable_lead_id: str) -> None: + """Delete a Lead; assert single-step success and the record ID appears in the result.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("delete_lead") + sf.fill_id_only(deletable_lead_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(deletable_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +# ── delete_contact ──────────────────────────────────────────────────────────── + + +def test_sf_delete_contact(playground_page: Page, deletable_contact_id: str) -> None: + """Delete a Contact; assert single-step success and the record ID appears in the result.""" + sf = _navigate_to_salesforce(playground_page) + + sf.select_action("delete_contact") + sf.fill_id_only(deletable_contact_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(deletable_contact_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(sf.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_sf_switch_create_lead_to_read(playground_page: Page, real_sf_lead_id: str) -> None: + """Create a Lead, then switch to read_lead on the same page — both must succeed.""" + sf = _navigate_to_salesforce(playground_page) + + # First run: create_lead + sf.select_action("create_lead") + sf.fill_lead_fields(last_name=f"Lead{_rnd()}", company=f"Corp{_rnd()}") + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text("Lead created successfully") + + # Switch action and run read_lead + sf.select_action("read_lead") + sf.fill_id_only(real_sf_lead_id) + sf.submit() + + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + expect(sf.final_result).to_be_visible(timeout=_TIMEOUT) + expect(sf.summary_text).to_contain_text(real_sf_lead_id) + expect(playground_page.locator("#salesforce-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() diff --git a/tests/playground/slack/README.md b/tests/playground/slack/README.md new file mode 100644 index 0000000..d1e0b6b --- /dev/null +++ b/tests/playground/slack/README.md @@ -0,0 +1,81 @@ + + +# Slack Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Slack connector panel, and assert on the rendered pipeline +state. No mocking — every test hits the real Slack API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_slack_post_message_default` | `post_message` — default message to test channel | +| `test_slack_post_message_custom_message` | `post_message` — custom message content | +| `test_slack_post_message_invalid_channel` | `post_message` — nonexistent channel, expects error at step-1 | +| `test_slack_send_direct_message` | `send_direct_message` — DM to real user (requires `SLACK_TEST_USER_ID`) | +| `test_slack_upload_file` | `upload_file` — attach and upload a temp file | +| `test_slack_upload_remove_and_reattach` | `upload_file` — remove attachment UI, re-attach | +| `test_slack_switch_post_message_then_upload` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/slack-messaging")` +calls route to the real backend, which calls the real Slack API. +No `page.route()` interception. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Slack tests +uv run pytest tests/playground/slack/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/slack/ --no-cov -v -s +``` + +> **Note:** Slack tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `SLACK_BOT_TOKEN` | Slack bot token (`xoxb-...`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## Optional environment variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `SLACK_TEST_CHANNEL` | Target channel for post_message and send_direct_message tests | `#general` | +| `SLACK_TEST_CHANNEL_ID` | Channel **ID** (`C...`) for upload_file tests — required because the Slack external-upload API does not accept channel names | *(skipped if absent)* | +| `SLACK_TEST_USER_ID` | Slack user ID (`U...`) for DM tests | *(skipped if absent)* | + +The bot must be a member of `SLACK_TEST_CHANNEL` and `SLACK_TEST_CHANNEL_ID`. +`test_slack_send_direct_message` is automatically skipped when `SLACK_TEST_USER_ID` is absent. +`test_slack_upload_file` and `test_slack_switch_post_message_then_upload` are automatically skipped when `SLACK_TEST_CHANNEL_ID` is absent (and `SLACK_TEST_CHANNEL` is not already a bare ID). + +## CI / GitHub Actions + +Slack tests run **only on manual `workflow_dispatch`** trigger alongside the other +playground integration tests. + +Credentials are read from repository secrets: + +| Secret | Maps to env var | +|--------|----------------| +| `SLACK_BOT_TOKEN` | `SLACK_BOT_TOKEN` | +| `SLACK_TEST_CHANNEL` | `SLACK_TEST_CHANNEL` | +| `SLACK_TEST_CHANNEL_ID` | `SLACK_TEST_CHANNEL_ID` | +| `SLACK_TEST_USER_ID` | `SLACK_TEST_USER_ID` | diff --git a/tests/playground/slack/conftest.py b/tests/playground/slack/conftest.py new file mode 100644 index 0000000..9c3a180 --- /dev/null +++ b/tests/playground/slack/conftest.py @@ -0,0 +1,75 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os + +import httpx +import pytest + +_DEFAULT_CHANNEL = os.environ.get("SLACK_TEST_CHANNEL", "#general") +_DEFAULT_CHANNEL_ID = os.environ.get("SLACK_TEST_CHANNEL_ID", "") + + +@pytest.fixture(scope="session", autouse=True) +def slack_connector_available(api_server_url: str) -> None: + """Skip the entire Slack test session if the connector returns HTTP 500. + + This happens when SLACK_BOT_TOKEN is missing or when NW_ALLOWED_CONNECTORS + is set but does not include 'slack'. Converts a 25-second timeout per test + into a single fast skip with a clear reason. + """ + with httpx.Client(timeout=10) as client: + resp = client.post( + f"{api_server_url}/scenarios/slack-messaging", + json={"action": "post_message", "channel": "#general", "message": "health-check"}, + ) + if resp.status_code == 500: + detail = resp.json().get("detail", "unknown") + pytest.skip( + f"Slack connector not available ({detail}). " + "Ensure SLACK_BOT_TOKEN is set and 'slack' is in NW_ALLOWED_CONNECTORS (or leave it unset)." + ) + + +@pytest.fixture(scope="session") +def slack_test_channel() -> str: + """Slack channel used as the target for post_message and upload_file tests. + + Defaults to #general. Override via SLACK_TEST_CHANNEL env var. + The bot must be a member of this channel. + """ + return _DEFAULT_CHANNEL + + +@pytest.fixture(scope="session") +def slack_upload_channel() -> str: + """Channel ID used for upload_file tests. + + Prefers SLACK_TEST_CHANNEL_ID (must be a bare channel ID like C0ANP6RADHU). + Falls back to SLACK_TEST_CHANNEL, but skips if that is still a name — the + Slack external-upload API requires an ID, not a name. + """ + if _DEFAULT_CHANNEL_ID: + return _DEFAULT_CHANNEL_ID + if _DEFAULT_CHANNEL and _DEFAULT_CHANNEL[0].upper() in ("C", "G", "D"): + return _DEFAULT_CHANNEL + pytest.skip( + "upload_file tests require a channel ID. " + "Set SLACK_TEST_CHANNEL_ID (e.g. C0ANP6RADHU) in .env." + ) + + +@pytest.fixture(scope="session") +def slack_test_user_id() -> str: + """Slack user ID (U...) used as the target for send_direct_message tests. + + Requires SLACK_TEST_USER_ID env var. Tests that depend on this fixture + are skipped when the var is absent. + """ + user_id = os.environ.get("SLACK_TEST_USER_ID") + if not user_id: + pytest.skip("SLACK_TEST_USER_ID is required for direct message tests") + return user_id diff --git a/tests/playground/slack/slack_page.py b/tests/playground/slack/slack_page.py new file mode 100644 index 0000000..43a54c5 --- /dev/null +++ b/tests/playground/slack/slack_page.py @@ -0,0 +1,88 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class SlackPage: + """Page Object Model for the Slack connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Slack card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='slack']") + + # Panel root and header + self.panel: Locator = page.locator("#slack-panel") + self.title: Locator = page.locator("#slack-panel .card-title h2") + self.action_select: Locator = page.locator("#slack-action-select") + self.run_btn: Locator = page.locator("#slack-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # Shared channel input (always visible) + self.channel: Locator = page.locator("#slack-panel input[name='channel']") + + # --- post_message / send_direct_message section --- + self.message_section: Locator = page.locator("#slack-message-section") + self.message: Locator = page.locator("#slack-message-section textarea[name='message']") + + # --- upload_file section --- + self.file_section: Locator = page.locator("#slack-file-section") + self.filename: Locator = page.locator("#slack-file-section input[name='filename']") + self.initial_comment: Locator = page.locator( + "#slack-file-section input[name='initial_comment']" + ) + self.file_input: Locator = page.locator("#slack-file") + self.file_drop_zone: Locator = page.locator("#slack-file-drop-zone") + self.file_chosen_preview: Locator = page.locator("#slack-file-chosen-preview") + self.preview_name: Locator = page.locator("#slack-file-chosen-preview .preview-name") + self.remove_file_btn: Locator = page.locator("#slack-file-chosen-preview .remove-file-btn") + + # --- Output and Logs elements --- + self.pipeline_steps: Locator = page.locator(".flow-node") + self.step_nodes: list[Locator] = [page.locator(f"#step-{i}") for i in range(4)] + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Slack card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the action via the select element.""" + self.action_select.select_option(action) + + def fill_message_fields(self, channel: str | None = None, message: str | None = None) -> None: + """Fill post_message / send_direct_message parameters.""" + if channel is not None: + self.channel.fill(channel) + if message is not None: + self.message.fill(message) + + def fill_upload_fields( + self, + channel: str | None = None, + filename: str | None = None, + initial_comment: str | None = None, + ) -> None: + """Fill upload_file parameters (excluding the file attachment itself).""" + if channel is not None: + self.channel.fill(channel) + if filename is not None: + self.filename.fill(filename) + if initial_comment is not None: + self.initial_comment.fill(initial_comment) + + def submit(self) -> None: + """Submit the form to execute the Slack workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/slack/test_slack_integration.py b/tests/playground/slack/test_slack_integration.py new file mode 100644 index 0000000..21de6c1 --- /dev/null +++ b/tests/playground/slack/test_slack_integration.py @@ -0,0 +1,230 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Slack connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Slack panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Slack API calls. + +Required env vars (loaded from .env): + SLACK_BOT_TOKEN — Slack bot token (xoxb-...) + +Optional env vars: + SLACK_TEST_CHANNEL — target channel (default: #general; bot must be a member) + SLACK_TEST_USER_ID — Slack user ID for DM tests (U...); skipped when absent +""" + +from __future__ import annotations + +import tempfile + +from playwright.sync_api import Page, expect + +from tests.playground.slack.slack_page import SlackPage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 25_000 # ms — 4-step pipeline with async Slack API calls + + +def _navigate_to_slack(page: Page) -> SlackPage: + PlaygroundHomePage(page).click_connectors() + slack = SlackPage(page) + slack.navigate_to_panel() + return slack + + +# ── post_message ────────────────────────────────────────────────────────────── + + +def test_slack_post_message_default(playground_page: Page, slack_test_channel: str) -> None: + """Post a message with default values; all 4 steps must succeed.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("post_message") + slack.fill_message_fields(channel=slack_test_channel) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("post_message") + expect(slack.summary_text).to_contain_text(slack_test_channel) + expect(slack.result_tag).to_be_visible() + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(slack.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_slack_post_message_custom_message(playground_page: Page, slack_test_channel: str) -> None: + """Post a message with custom content; summary must reflect the channel.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("post_message") + slack.fill_message_fields( + channel=slack_test_channel, + message="node-wire integration test — safe to ignore.", + ) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text(slack_test_channel) + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +def test_slack_post_message_invalid_channel(playground_page: Page) -> None: + """Post to a nonexistent channel; step-1 (Dispatch) must show error state.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("post_message") + slack.fill_message_fields(channel="this-channel-does-not-exist-99999") + slack.submit() + + # step-0 (Format Slack Payload) is local — always succeeds + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + # step-1 (Dispatch to Slack API) must fail for an invalid channel + expect(playground_page.locator("#step-1.error")).to_be_visible(timeout=_TIMEOUT) + expect(slack.final_result).to_be_hidden() + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(slack.log_terminal).to_contain_text("FAILED") + + maybe_sleep() + + +# ── send_direct_message ─────────────────────────────────────────────────────── + + +def test_slack_send_direct_message(playground_page: Page, slack_test_user_id: str) -> None: + """Send a DM to a real user; all 4 steps must succeed.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("send_direct_message") + slack.fill_message_fields( + channel=slack_test_user_id, + message="node-wire DM integration test — safe to ignore.", + ) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("send_direct_message") + expect(slack.summary_text).to_contain_text(slack_test_user_id) + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(slack.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +# ── upload_file ─────────────────────────────────────────────────────────────── + + +def test_slack_upload_file(playground_page: Page, slack_upload_channel: str) -> None: + """Attach a temp file and upload it; all 4 steps must succeed.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("upload_file") + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_slack_test_") as tmp: + tmp.write(b"node-wire Slack upload integration test - safe to delete.") + tmp_path = tmp.name + + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + expect(slack.file_drop_zone).to_be_hidden() + expect(slack.preview_name).to_contain_text("nw_slack_test_") + + slack.fill_upload_fields( + channel=slack_upload_channel, + initial_comment="node-wire integration test upload — safe to delete.", + ) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("upload_file") + expect(slack.summary_text).to_contain_text(slack_upload_channel) + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(slack.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_slack_upload_remove_and_reattach(playground_page: Page) -> None: + """Remove an attached file → drop zone reappears; re-attach → preview is restored.""" + slack = _navigate_to_slack(playground_page) + + slack.select_action("upload_file") + + with tempfile.NamedTemporaryFile( + suffix=".txt", delete=False, prefix="nw_slack_reattach_" + ) as tmp: + tmp.write(b"Reattach UI test content - safe to delete.") + tmp_path = tmp.name + + # Attach + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + expect(slack.file_drop_zone).to_be_hidden() + + # Remove + slack.remove_file_btn.click() + expect(slack.file_chosen_preview).to_be_hidden(timeout=3_000) + expect(slack.file_drop_zone).to_be_visible() + + # Re-attach + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + expect(slack.preview_name).to_contain_text("nw_slack_reattach_") + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_slack_switch_post_message_then_upload( + playground_page: Page, slack_test_channel: str, slack_upload_channel: str +) -> None: + """Run post_message then switch to upload_file on the same page — both must succeed.""" + slack = _navigate_to_slack(playground_page) + + # First run: post_message + slack.select_action("post_message") + slack.fill_message_fields(channel=slack_test_channel) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + + # Switch to upload_file and run + slack.select_action("upload_file") + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False, prefix="nw_slack_switch_") as tmp: + tmp.write(b"Cross-action switch test - safe to delete.") + tmp_path = tmp.name + + slack.file_input.set_input_files(tmp_path) + expect(slack.file_chosen_preview).to_be_visible(timeout=3_000) + + slack.fill_upload_fields(channel=slack_upload_channel) + slack.submit() + + for i in range(4): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(slack.final_result).to_be_visible(timeout=_TIMEOUT) + expect(slack.summary_text).to_contain_text("upload_file") + expect(playground_page.locator("#slack-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() diff --git a/tests/playground/stripe/README.md b/tests/playground/stripe/README.md new file mode 100644 index 0000000..9292c8e --- /dev/null +++ b/tests/playground/stripe/README.md @@ -0,0 +1,78 @@ + + +# Stripe Playground Integration Tests + +End-to-end Playwright tests that open the Playground UI in a real browser, +navigate to the Stripe connector panel, and assert on the rendered pipeline +state. No mocking — every test hits the real Stripe test-mode API. + +## What is tested + +| Test | Action | +|------|--------| +| `test_stripe_charge_default` | `charge` — default values (2000 usd) | +| `test_stripe_charge_custom_amount` | `charge` — custom amount + description | +| `test_stripe_charge_no_description` | `charge` — empty description | +| `test_stripe_payment_intent_default` | `payment_intent` — defaults (5000 usd, pm_card_visa) | +| `test_stripe_payment_intent_custom_amount` | `payment_intent` — custom amount, result tag contains pi_ | +| `test_stripe_payment_intent_no_payment_method` | `payment_intent` — no payment method | +| `test_stripe_cancel_subscription_invalid_id` | `cancel_subscription` — nonexistent ID, expects error state | +| `test_stripe_cancel_subscription` | `cancel_subscription` — real subscription ID (requires env vars) | +| `test_stripe_refund_by_charge_id` | `refund` — full refund against a real charge | +| `test_stripe_refund_invalid_id` | `refund` — nonexistent ID, expects error state | +| `test_stripe_switch_charge_then_payment_intent` | Cross-action switch on same page | + +## How it works + +The test session starts a real FastAPI server on a random local port. Playwright +navigates to `/playground/`. The browser's `fetch("/scenarios/stripe-*")` calls +route to the real backend, which calls the real Stripe test-mode API. +No `page.route()` interception. + +## Running locally + +```bash +# Install Playwright browsers (once) +uv run python -m playwright install chromium + +# Run all Stripe tests +uv run pytest tests/playground/stripe/ --no-cov -v + +# Run headed (watch the browser) +PLAYGROUND_HEADED=true uv run pytest tests/playground/stripe/ --no-cov -v -s +``` + +> **Note:** Stripe tests are excluded from the default `uv run pytest` run and +> from regular CI (push/PR). They must be triggered explicitly. + +## Required environment variables + +Set these before running (`.env` is loaded automatically if present): + +| Variable | Description | +|----------|-------------| +| `STRIPE_API_KEY` | Stripe secret key (`sk_test_...`) | +| `NW_REST_AUTH_DISABLED` | Set to `true` to skip REST auth middleware | + +## Optional environment variables (for subscription tests) + +| Variable | Description | +|----------|-------------| +| `STRIPE_TEST_CUSTOMER_ID` | Pre-existing Stripe test customer (`cus_...`) | +| `STRIPE_TEST_PRICE_ID` | Pre-existing Stripe test price (`price_...`) | + +`test_stripe_cancel_subscription` is automatically skipped when these are absent. + +## Test data and cleanup + +The `real_stripe_charge_id` session fixture creates a small charge (`$5.00 usd`) +against the `tok_visa` test token once per session. The `test_stripe_refund_by_charge_id` +test immediately refunds this charge in full, so no balance is left outstanding. + +The optional `real_stripe_subscription_id` fixture creates a subscription that +is cancelled by `test_stripe_cancel_subscription` — leaving no active subscription +after the session. diff --git a/tests/playground/stripe/conftest.py b/tests/playground/stripe/conftest.py new file mode 100644 index 0000000..9ad006a --- /dev/null +++ b/tests/playground/stripe/conftest.py @@ -0,0 +1,63 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os + +import httpx +import pytest + + +@pytest.fixture(scope="session") +def real_stripe_charge_id(api_server_url: str) -> str: + """Create a real Stripe test charge via the API and return its charge ID. + + Uses the default tok_visa source hardcoded in StripeChargeInput so no extra + env vars are needed beyond STRIPE_API_KEY. + """ + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/stripe-charge", + json={"amount": 500, "currency": "usd", "description": "nw-integration-test charge"}, + ) + resp.raise_for_status() + data = resp.json() + charge_id = data.get("final_resource_id") + if not charge_id: + pytest.skip( + f"Stripe charge setup failed — cannot run refund tests. " + f"Error: {data.get('error_message') or 'no charge_id returned'}" + ) + return charge_id + + +@pytest.fixture(scope="session") +def real_stripe_subscription_id(api_server_url: str) -> str: + """Create a real Stripe subscription and return its subscription ID. + + Requires STRIPE_TEST_CUSTOMER_ID and STRIPE_TEST_PRICE_ID env vars. + Tests that use this fixture are skipped when the vars are absent. + """ + customer_id = os.environ.get("STRIPE_TEST_CUSTOMER_ID") + price_id = os.environ.get("STRIPE_TEST_PRICE_ID") + if not customer_id or not price_id: + pytest.skip( + "STRIPE_TEST_CUSTOMER_ID and STRIPE_TEST_PRICE_ID are required for subscription tests" + ) + + with httpx.Client(timeout=30) as client: + resp = client.post( + f"{api_server_url}/scenarios/stripe-subscription", + json={"customer_id": customer_id, "price_id": price_id}, + ) + resp.raise_for_status() + data = resp.json() + sub_id = data.get("final_resource_id") + if not sub_id: + pytest.skip( + f"Stripe subscription setup failed — cannot run cancel tests. " + f"Error: {data.get('error_message') or 'no subscription_id returned'}" + ) + return sub_id diff --git a/tests/playground/stripe/stripe_page.py b/tests/playground/stripe/stripe_page.py new file mode 100644 index 0000000..cc8e883 --- /dev/null +++ b/tests/playground/stripe/stripe_page.py @@ -0,0 +1,135 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +from playwright.sync_api import Page, Locator + + +class StripePage: + """Page Object Model for the Stripe connector panel in the Playground.""" + + def __init__(self, page: Page) -> None: + self.page = page + + # Selector for the Stripe card inside system connectors view + self.connector_card: Locator = page.locator(".connector-card[data-mode='stripe']") + + # Panel root and main headers + self.panel: Locator = page.locator("#stripe-panel") + self.title: Locator = page.locator("#stripe-panel .card-title h2") + self.action_select: Locator = page.locator("#stripe-action-select") + self.run_btn: Locator = page.locator("#stripe-run-btn") + self.back_to_connectors: Locator = page.locator("#back-to-connectors") + + # --- charge action elements --- + self.charge_section: Locator = page.locator("#stripe-section-charge") + self.charge_amount: Locator = page.locator( + "#stripe-section-charge input[name='charge_amount']" + ) + self.charge_currency: Locator = page.locator( + "#stripe-section-charge input[name='charge_currency']" + ) + self.charge_description: Locator = page.locator( + "#stripe-section-charge input[name='charge_description']" + ) + + # --- payment_intent action elements --- + self.pi_section: Locator = page.locator("#stripe-section-pi") + self.pi_amount: Locator = page.locator("#stripe-section-pi input[name='pi_amount']") + self.pi_currency: Locator = page.locator("#stripe-section-pi input[name='pi_currency']") + self.pi_customer: Locator = page.locator("#stripe-section-pi input[name='pi_customer']") + self.pi_payment_method: Locator = page.locator( + "#stripe-section-pi input[name='pi_payment_method']" + ) + + # --- subscription action elements --- + self.sub_section: Locator = page.locator("#stripe-section-sub") + self.sub_customer: Locator = page.locator("#stripe-section-sub input[name='sub_customer']") + self.sub_price: Locator = page.locator("#stripe-section-sub input[name='sub_price']") + + # --- cancel_subscription action elements --- + self.cancel_section: Locator = page.locator("#stripe-section-cancel") + self.cancel_sub_id: Locator = page.locator( + "#stripe-section-cancel input[name='cancel_sub_id']" + ) + + # --- refund action elements --- + self.refund_section: Locator = page.locator("#stripe-section-refund") + self.refund_target_id: Locator = page.locator( + "#stripe-section-refund input[name='refund_target_id']" + ) + self.refund_amount: Locator = page.locator( + "#stripe-section-refund input[name='refund_amount']" + ) + + # --- Output and Logs elements --- + self.pipeline_steps: Locator = page.locator(".flow-node") + self.step_nodes: list[Locator] = [page.locator(f"#step-{i}") for i in range(3)] + self.final_result: Locator = page.locator("#final-result") + self.summary_text: Locator = page.locator("#human-summary") + self.result_tag: Locator = page.locator("#result-id") + self.log_terminal: Locator = page.locator("#log-terminal") + + def navigate_to_panel(self) -> None: + """Click the Stripe card in system connectors to open the panel.""" + self.connector_card.click() + + def select_action(self, action: str) -> None: + """Change the action via the select element.""" + self.action_select.select_option(action) + + def fill_charge_fields( + self, + amount: int | None = None, + currency: str | None = None, + description: str | None = None, + ) -> None: + """Fill charge parameters (all optional — HTML defaults apply if not provided).""" + if amount is not None: + self.charge_amount.fill(str(amount)) + if currency is not None: + self.charge_currency.fill(currency) + if description is not None: + self.charge_description.fill(description) + + def fill_payment_intent_fields( + self, + amount: int | None = None, + currency: str | None = None, + customer_id: str | None = None, + payment_method: str | None = None, + ) -> None: + """Fill payment intent parameters.""" + if amount is not None: + self.pi_amount.fill(str(amount)) + if currency is not None: + self.pi_currency.fill(currency) + if customer_id is not None: + self.pi_customer.fill(customer_id) + if payment_method is not None: + self.pi_payment_method.fill(payment_method) + + def fill_subscription_fields(self, customer_id: str, price_id: str) -> None: + """Fill subscription parameters.""" + self.sub_customer.fill(customer_id) + self.sub_price.fill(price_id) + + def fill_cancel_fields(self, subscription_id: str) -> None: + """Fill cancel subscription parameter.""" + self.cancel_sub_id.fill(subscription_id) + + def fill_refund_fields(self, target_id: str, amount: int | None = None) -> None: + """Fill refund parameters. target_id may be a ch_... or pi_... ID.""" + self.refund_target_id.fill(target_id) + if amount is not None: + self.refund_amount.fill(str(amount)) + + def submit(self) -> None: + """Submit the form to execute the Stripe workflow.""" + self.run_btn.click() + + def go_back(self) -> None: + """Click 'Back to All Connectors' to return to connectors selection view.""" + self.back_to_connectors.click() diff --git a/tests/playground/stripe/test_stripe_integration.py b/tests/playground/stripe/test_stripe_integration.py new file mode 100644 index 0000000..02524bb --- /dev/null +++ b/tests/playground/stripe/test_stripe_integration.py @@ -0,0 +1,263 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Stripe connector Playground real integration tests. + +Each test opens the Playground UI, navigates to the Stripe panel, +selects an action, fills the form, clicks the run button, and asserts the +resulting pipeline state — no API mocking, real Stripe test-mode calls. + +Required env vars (loaded from .env): + STRIPE_API_KEY — Stripe secret key (sk_test_...) + +Optional env vars (for subscription-related tests): + STRIPE_TEST_CUSTOMER_ID — pre-existing Stripe test customer (cus_...) + STRIPE_TEST_PRICE_ID — pre-existing Stripe test price (price_...) +""" + +from __future__ import annotations + +from playwright.sync_api import Page, expect + +from tests.playground.stripe.stripe_page import StripePage +from tests.playground.home_page import PlaygroundHomePage +from tests.playground.utils import maybe_sleep + +_TIMEOUT = 20_000 # ms — all Stripe scenarios are 3-step + + +def _navigate_to_stripe(page: Page) -> StripePage: + PlaygroundHomePage(page).click_connectors() + stripe = StripePage(page) + stripe.navigate_to_panel() + return stripe + + +# ── charge ──────────────────────────────────────────────────────────────────── + + +def test_stripe_charge_default(playground_page: Page) -> None: + """Process a charge with the HTML default values (2000 usd); all 3 steps must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("charge") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("20.00 USD charge") + expect(stripe.result_tag).to_be_visible() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_stripe_charge_custom_amount(playground_page: Page) -> None: + """Process a charge with a custom amount and description; summary must reflect the amount.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("charge") + stripe.fill_charge_fields(amount=1500, currency="usd", description="nw-test charge") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("15.00 USD charge") + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_stripe_charge_no_description(playground_page: Page) -> None: + """Process a charge with an empty description; pipeline must still complete successfully.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("charge") + stripe.fill_charge_fields(amount=1000, currency="usd", description="") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── payment_intent ──────────────────────────────────────────────────────────── + + +def test_stripe_payment_intent_default(playground_page: Page) -> None: + """Create a payment intent with the HTML defaults (5000 usd, pm_card_visa); 3 steps succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("payment_intent") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("payment intent") + expect(stripe.result_tag).to_be_visible() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_stripe_payment_intent_custom_amount(playground_page: Page) -> None: + """Create a payment intent with a custom amount; result tag must contain a pi_ ID.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("payment_intent") + stripe.fill_payment_intent_fields(amount=3000, currency="usd") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.result_tag).to_contain_text("pi_") + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +def test_stripe_payment_intent_no_payment_method(playground_page: Page) -> None: + """Create a payment intent without a payment method; backend creates a requires_payment_method PI.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("payment_intent") + stripe.fill_payment_intent_fields(amount=2500, currency="usd", payment_method="") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() + + +# ── cancel_subscription ─────────────────────────────────────────────────────── + + +def test_stripe_cancel_subscription_invalid_id(playground_page: Page) -> None: + """Cancel with a nonexistent subscription ID; step-1 must show error state.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("cancel_subscription") + stripe.fill_cancel_fields("sub_this_does_not_exist_9999") + stripe.submit() + + # step-0 (Locate Resource) is a validation step — it always succeeds + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + # step-1 (Cancel Sub) calls the real Stripe API — it must fail + expect(playground_page.locator("#step-1.error")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_hidden() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(stripe.log_terminal).to_contain_text("FAILED") + + maybe_sleep() + + +def test_stripe_cancel_subscription( + playground_page: Page, real_stripe_subscription_id: str +) -> None: + """Cancel a real subscription; all 3 steps must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("cancel_subscription") + stripe.fill_cancel_fields(real_stripe_subscription_id) + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("canceled subscription") + expect(stripe.result_tag).to_contain_text(real_stripe_subscription_id) + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +# ── refund ──────────────────────────────────────────────────────────────────── + + +def test_stripe_refund_by_charge_id(playground_page: Page, real_stripe_charge_id: str) -> None: + """Issue a full refund against a real charge; all 3 steps must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("refund") + stripe.fill_refund_fields(real_stripe_charge_id) + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("issued refund") + expect(stripe.result_tag).to_be_visible() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + expect(stripe.log_terminal).to_contain_text("SUCCESS") + + maybe_sleep() + + +def test_stripe_refund_invalid_id(playground_page: Page) -> None: + """Refund with a nonexistent charge ID; step-1 must show error state.""" + stripe = _navigate_to_stripe(playground_page) + + stripe.select_action("refund") + stripe.fill_refund_fields("ch_this_does_not_exist_9999") + stripe.submit() + + # step-0 (Validate Params) is a local validation step — it always succeeds + expect(playground_page.locator("#step-0.success")).to_be_visible(timeout=_TIMEOUT) + # step-1 (Issue Refund) calls the real Stripe API — it must fail + expect(playground_page.locator("#step-1.error")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_hidden() + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Failed") + expect(stripe.log_terminal).to_contain_text("FAILED") + + maybe_sleep() + + +# ── cross-action switch ─────────────────────────────────────────────────────── + + +def test_stripe_switch_charge_then_payment_intent(playground_page: Page) -> None: + """Run a charge, then switch to payment_intent on the same page — both must succeed.""" + stripe = _navigate_to_stripe(playground_page) + + # First run: charge + stripe.select_action("charge") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + + # Switch action and run again + stripe.select_action("payment_intent") + stripe.submit() + + for i in range(3): + expect(playground_page.locator(f"#step-{i}.success")).to_be_visible(timeout=_TIMEOUT) + expect(stripe.final_result).to_be_visible(timeout=_TIMEOUT) + expect(stripe.summary_text).to_contain_text("payment intent") + expect(playground_page.locator("#stripe-run-btn .btn-lbl")).to_have_text("Workflow Active") + + maybe_sleep() diff --git a/tests/playground/test_playground_integration.py b/tests/playground/test_playground_integration.py new file mode 100644 index 0000000..e04df1e --- /dev/null +++ b/tests/playground/test_playground_integration.py @@ -0,0 +1,87 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Playground Home page integration test. + +This test loads the playground page, asserts all header elements and cards, +and verifies interactive transitions between the home page selection view and +individual dashboard views (Agentic Workflow and Connectors) using the Page Object Model. +""" + +from __future__ import annotations + +import os +import time +from playwright.sync_api import Page, expect + +from tests.playground.home_page import PlaygroundHomePage + + +def test_playground_home_page_flow(playground_page: Page) -> None: + """Verify elements visibility, cards presence, and navigation transitions on the Playground Home page.""" + home = PlaygroundHomePage(playground_page) + + # 1. Assert overall page title + assert playground_page.title() == "node-wire Playground" + + # 2. Verify visibility of root components and headers + expect(home.root_selection_view).to_be_visible() + expect(home.main_layout).to_be_hidden() + expect(home.brand_header).to_be_visible() + expect(home.brand_title).to_contain_text("node-") + expect(home.tagline).to_be_visible() + expect(home.header_actions).to_be_hidden() + + # 3. Assert card counts and detailed card contents + assert home.selection_cards.count() == 3 + + # Agentic Workflow Card + expect(home.agentic_card).to_be_visible() + expect(home.agentic_card_title).to_have_text("Agentic Workflow") + expect(home.agentic_card_desc).to_contain_text("via ToolHive") + + # Connectors Card + expect(home.connectors_card).to_be_visible() + expect(home.connectors_card_title).to_have_text("Connectors") + expect(home.connectors_card_desc).to_contain_text("Pre-built Clinical Workflows") + + # Connector Apps Card + expect(home.connector_apps_card).to_be_visible() + expect(home.connector_apps_card_title).to_have_text("Connector Apps") + expect(home.connector_apps_card_desc).to_contain_text("built on top of connectors") + + # 4. Test Navigation Flow: Root -> Agentic Workflow -> Root + home.click_agentic_workflow() + expect(home.root_selection_view).to_be_hidden() + expect(home.main_layout).to_be_visible() + + # Return back to home + home.go_back_to_selection() + expect(home.root_selection_view).to_be_visible() + expect(home.main_layout).to_be_hidden() + + # 5. Test Navigation Flow: Root -> Connectors -> Root + home.click_connectors() + expect(home.root_selection_view).to_be_hidden() + expect(home.main_layout).to_be_visible() + + # Return back to home + home.go_back_to_selection() + expect(home.root_selection_view).to_be_visible() + expect(home.main_layout).to_be_hidden() + + # 6. Test Navigation Flow: Root -> Connector Apps -> Root + home.click_connector_apps() + expect(home.root_selection_view).to_be_hidden() + expect(home.connector_apps_view).to_be_visible() + + # Return back to home + home.go_back_from_apps() + expect(home.root_selection_view).to_be_visible() + expect(home.connector_apps_view).to_be_hidden() + + # 7. Optional visual delay for headed mode + is_headed = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if is_headed and is_headed.lower().strip() in ("true", "1", "yes"): + time.sleep(5) diff --git a/tests/playground/utils.py b/tests/playground/utils.py new file mode 100644 index 0000000..13345e0 --- /dev/null +++ b/tests/playground/utils.py @@ -0,0 +1,15 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +import time + + +def maybe_sleep() -> None: + """Pause for 3 s when running headed so a developer can observe the result.""" + env = os.getenv("PLAYGROUND_HEADED") or os.getenv("HEADED") + if env and env.lower().strip() in ("true", "1", "yes"): + time.sleep(3) diff --git a/tests/test_aot_runtime_basic.py b/tests/test_aot_runtime_basic.py index abfe1b6..badef98 100644 --- a/tests/test_aot_runtime_basic.py +++ b/tests/test_aot_runtime_basic.py @@ -1,13 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import asyncio +from typing import Literal from pydantic import BaseModel -from runtime import BaseConnector, ConnectorResponse, ErrorCategory, ErrorMapper +from node_wire_runtime import ( + BaseConnector, + nw_action, + ConnectorResponse, + ErrorCategory, + ErrorMapper, +) class InputModel(BaseModel): + action: Literal["double"] = "double" value: int @@ -15,18 +27,19 @@ class OutputModel(BaseModel): doubled: int -class TestConnector(BaseConnector[InputModel, OutputModel]): - connector_id = "test" - action = "double" +class DoubleConnector(BaseConnector): + connector_id = "test_double" + output_model = OutputModel - async def internal_execute(self, params: InputModel, *, trace_id: str) -> OutputModel: + @nw_action("double") + async def double(self, params: InputModel, *, trace_id: str) -> OutputModel: return OutputModel(doubled=params.value * 2) def test_successful_execution(): - connector = TestConnector(InputModel, OutputModel) + connector = DoubleConnector() - response: ConnectorResponse = asyncio.run(connector.run({"value": 2})) + response: ConnectorResponse = asyncio.run(connector.run({"action": "double", "value": 2})) assert response.success is True assert response.data == {"doubled": 4} @@ -35,22 +48,53 @@ def test_successful_execution(): assert isinstance(response.trace_id, str) +def test_successful_execution_uses_tenant_breaker_cache(): + connector = DoubleConnector() + + response: ConnectorResponse = asyncio.run( + connector.run({"action": "double", "value": 3}, tenant_id="tenant-a") + ) + + assert response.success is True + assert response.data == {"doubled": 6} + assert "tenant-a" in connector._breakers + + +def test_successful_execution_rebuilds_missing_breaker_cache(): + connector = DoubleConnector() + del connector._breakers + + response: ConnectorResponse = asyncio.run( + connector.run({"action": "double", "value": 4}, tenant_id="tenant-b") + ) + + assert response.success is True + assert response.data == {"doubled": 8} + assert "tenant-b" in connector._breakers + + class CustomError(Exception): pass -class FailingConnector(BaseConnector[InputModel, OutputModel]): - connector_id = "test" - action = "fail" +class FailInputModel(BaseModel): + action: Literal["fail"] = "fail" + value: int - async def internal_execute(self, params: InputModel, *, trace_id: str) -> OutputModel: + +class FailingConnector(BaseConnector): + connector_id = "test_fail" + output_model = OutputModel + + @nw_action("fail") + async def fail(self, params: FailInputModel, *, trace_id: str) -> OutputModel: raise CustomError("boom") def test_error_mapping_defaults_to_fatal(): - connector = FailingConnector(InputModel, OutputModel) + connector = FailingConnector() - response: ConnectorResponse = asyncio.run(connector.run({"value": 1})) + response: ConnectorResponse = asyncio.run(connector.run({"action": "fail", "value": 1})) assert response.success is False assert response.error_code == "CustomError" @@ -59,11 +103,10 @@ def test_error_mapping_defaults_to_fatal(): def test_error_mapping_custom_category(): ErrorMapper.register(CustomError, ErrorCategory.RETRYABLE, code="CUSTOM_RETRYABLE") - connector = FailingConnector(InputModel, OutputModel) + connector = FailingConnector() - response: ConnectorResponse = asyncio.run(connector.run({"value": 1})) + response: ConnectorResponse = asyncio.run(connector.run({"action": "fail", "value": 1})) assert response.success is False assert response.error_code == "CUSTOM_RETRYABLE" assert response.error_category == ErrorCategory.RETRYABLE - diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py new file mode 100644 index 0000000..b619d5f --- /dev/null +++ b/tests/test_auth_providers.py @@ -0,0 +1,430 @@ +""" +tests/test_auth_providers.py +============================== + +Unit tests for the AuthProvider abstraction layer. + +Covers: + - NoAuthProvider + - StaticTokenAuthProvider (bearer, basic, custom header, refresh) + - OAuth2AuthProvider (cache hit/miss, expiry, concurrent refresh, 401 refresh, + private_key_jwt, client_secret_post, missing access_token) + - ServiceAccountAuthProvider + - BaseConnector.get_auth_headers() delegation + - Factory._build_auth_provider() YAML wiring +""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from node_wire_runtime.auth import ( + NoAuthProvider, + OAuth2AuthProvider, + ServiceAccountAuthProvider, + StaticTokenAuthProvider, +) +from node_wire_runtime.secrets import SecretProvider + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +from pydantic import BaseModel +from node_wire_runtime import BaseConnector, sdk_action + + +class _Input(BaseModel): + action: str = "dummy" + + +class _Output(BaseModel): + ok: bool = True + + +class _DummyConnector(BaseConnector): + connector_id = "test_auth_delegation" + output_model = _Output + + @sdk_action("dummy") + async def dummy(self, params: _Input, *, trace_id: str) -> _Output: + return _Output() + + +class _NoAuthConnector(BaseConnector): + connector_id = "test_no_auth_default" + output_model = _Output + + @sdk_action("x") + async def x(self, params: _Input, *, trace_id: str) -> _Output: + return _Output() + + +class _DictSecretProvider(SecretProvider): + def __init__(self, data: dict) -> None: + self._data = data + + def get_secret(self, key: str) -> str: + if key not in self._data: + from node_wire_runtime.secrets import SecretNotFoundError + + raise SecretNotFoundError(key) + return self._data[key] + + +def _token_response(access_token: str = "tok-abc", expires_in: int = 3600) -> MagicMock: + m = MagicMock() + m.status_code = 200 + m.json.return_value = {"access_token": access_token, "expires_in": expires_in} + return m + + +# --------------------------------------------------------------------------- +# NoAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_auth_returns_empty_headers() -> None: + provider = NoAuthProvider() + assert await provider.get_headers() == {} + + +@pytest.mark.asyncio +async def test_no_auth_returns_none_credentials() -> None: + provider = NoAuthProvider() + assert await provider.get_client_credentials() is None + + +@pytest.mark.asyncio +async def test_no_auth_refresh_is_noop() -> None: + provider = NoAuthProvider() + await provider.refresh() # must not raise + + +# --------------------------------------------------------------------------- +# StaticTokenAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_static_token_bearer_header() -> None: + sp = _DictSecretProvider({"MY_KEY": "secret-value"}) + provider = StaticTokenAuthProvider(secret_provider=sp, secret_key="MY_KEY") + headers = await provider.get_headers() + assert headers == {"Authorization": "Bearer secret-value"} + + +@pytest.mark.asyncio +async def test_static_token_custom_header_and_prefix() -> None: + sp = _DictSecretProvider({"api_key": "abc123"}) + provider = StaticTokenAuthProvider( + secret_provider=sp, + secret_key="api_key", + header_name="X-Api-Key", + prefix="", + ) + headers = await provider.get_headers() + assert headers == {"X-Api-Key": "abc123"} + + +@pytest.mark.asyncio +async def test_static_token_base64_encoding() -> None: + import base64 + + sp = _DictSecretProvider({"creds": "user:pass"}) + provider = StaticTokenAuthProvider( + secret_provider=sp, + secret_key="creds", + prefix="Basic", + encoding="base64", + ) + headers = await provider.get_headers() + expected = base64.b64encode(b"user:pass").decode() + assert headers["Authorization"] == f"Basic {expected}" + + +@pytest.mark.asyncio +async def test_static_token_cached_after_first_call() -> None: + """Secret provider is called only once; result is cached.""" + sp = _DictSecretProvider({"k": "val"}) + provider = StaticTokenAuthProvider(secret_provider=sp, secret_key="k") + h1 = await provider.get_headers() + h2 = await provider.get_headers() + assert h1 == h2 + + +@pytest.mark.asyncio +async def test_static_token_refresh_clears_cache() -> None: + """After refresh(), the next call rebuilds the header.""" + calls = [] + + class _Counting(SecretProvider): + def get_secret(self, key: str) -> str: + calls.append(key) + return "val" + + provider = StaticTokenAuthProvider(secret_provider=_Counting(), secret_key="k") + await provider.get_headers() + await provider.refresh() + await provider.get_headers() + assert len(calls) == 2 # resolved twice — once per cache population + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — token caching +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_oauth2_token_cache_hit() -> None: + """A second call within TTL must NOT issue another HTTP request.""" + sp = _DictSecretProvider( + { + "client_id": "cid", + "token_url": "https://auth.example.com/token", + "private_key": "---fake-key---", + "kid": "kid1", + } + ) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + private_key_secret="private_key", + kid_secret="kid", + ) + + with patch( + "node_wire_runtime.auth.oauth2.OAuth2AuthProvider._fetch_token", new_callable=AsyncMock + ) as mock_fetch: + mock_fetch.return_value = {"access_token": "tok-1", "expires_in": 3600} + h1 = await provider.get_headers() + h2 = await provider.get_headers() + + assert mock_fetch.call_count == 1 + assert h1["Authorization"] == "Bearer tok-1" + assert h2["Authorization"] == "Bearer tok-1" + + +@pytest.mark.asyncio +async def test_oauth2_token_cache_miss_on_expiry() -> None: + """An expired token must trigger a new fetch.""" + sp = _DictSecretProvider({"client_id": "x", "token_url": "http://t"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", # dummy + buffer_secs=0, + ) + + with patch( + "node_wire_runtime.auth.oauth2.OAuth2AuthProvider._fetch_token", new_callable=AsyncMock + ) as mock_fetch: + mock_fetch.return_value = {"access_token": "tok-a", "expires_in": 1} + await provider.get_headers() + # Force expiry + provider._expires_at = time.monotonic() - 1 + + mock_fetch.return_value = {"access_token": "tok-b", "expires_in": 3600} + h2 = await provider.get_headers() + + assert mock_fetch.call_count == 2 + assert h2["Authorization"] == "Bearer tok-b" + + +@pytest.mark.asyncio +async def test_oauth2_refresh_clears_cache() -> None: + """Calling refresh() forces a new fetch on the next get_headers().""" + sp = _DictSecretProvider({"client_id": "x", "token_url": "http://t"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + ) + + with patch( + "node_wire_runtime.auth.oauth2.OAuth2AuthProvider._fetch_token", new_callable=AsyncMock + ) as mock_fetch: + mock_fetch.return_value = {"access_token": "tok-1", "expires_in": 3600} + await provider.get_headers() # populates cache + + await provider.refresh() # invalidates cache + + mock_fetch.return_value = {"access_token": "tok-2", "expires_in": 3600} + h2 = await provider.get_headers() # must re-fetch + + assert mock_fetch.call_count == 2 + assert h2["Authorization"] == "Bearer tok-2" + + +@pytest.mark.asyncio +async def test_oauth2_concurrent_refresh_single_fetch() -> None: + """Concurrent get_headers() calls must result in exactly one HTTP fetch (Lock).""" + sp = _DictSecretProvider({"client_id": "x", "token_url": "http://t"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + ) + fetch_count = 0 + + async def _fake_fetch() -> dict: + nonlocal fetch_count + fetch_count += 1 + await asyncio.sleep(0) # yield to allow other coroutines to race + return {"access_token": "tok-concurrent", "expires_in": 3600} + + with patch.object(provider, "_fetch_token", side_effect=_fake_fetch): + results = await asyncio.gather(*[provider.get_headers() for _ in range(10)]) + + assert fetch_count == 1 # exactly one HTTP call despite 10 concurrent waiters + assert all(r["Authorization"] == "Bearer tok-concurrent" for r in results) + + +@pytest.mark.asyncio +async def test_oauth2_401_retry_via_refresh() -> None: + """Simulates: connector receives 401 → calls refresh() → next request gets fresh token.""" + sp = _DictSecretProvider({"client_id": "x", "token_url": "http://t"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + ) + + with patch( + "node_wire_runtime.auth.oauth2.OAuth2AuthProvider._fetch_token", new_callable=AsyncMock + ) as mock_fetch: + mock_fetch.return_value = {"access_token": "old-token", "expires_in": 3600} + h1 = await provider.get_headers() + assert h1["Authorization"] == "Bearer old-token" + + # Simulate 401 — connector calls refresh() + await provider.refresh() + + mock_fetch.return_value = {"access_token": "new-token", "expires_in": 3600} + h2 = await provider.get_headers() + assert h2["Authorization"] == "Bearer new-token" + assert mock_fetch.call_count == 2 + + +@pytest.mark.asyncio +async def test_oauth2_missing_access_token_raises() -> None: + """A token response without access_token must raise ValueError.""" + sp = _DictSecretProvider({"client_id": "x", "token_url": "http://t"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + ) + with patch( + "node_wire_runtime.auth.oauth2.OAuth2AuthProvider._fetch_token", new_callable=AsyncMock + ) as mock_fetch: + mock_fetch.return_value = {"token_type": "bearer"} # no access_token key + with pytest.raises(ValueError, match="access_token"): + await provider.get_headers() + + +@pytest.mark.asyncio +async def test_base_connector_delegates_to_auth_provider(tmp_path: Any) -> None: + """get_auth_headers() returns the provider's headers dict.""" + sp = _DictSecretProvider({"MY_API_KEY": "secret-123"}) + auth = StaticTokenAuthProvider(secret_provider=sp, secret_key="MY_API_KEY") + connector = _DummyConnector(secret_provider=sp, auth_provider=auth) + headers = await connector.get_auth_headers() + assert headers == {"Authorization": "Bearer secret-123"} + + +@pytest.mark.asyncio +async def test_base_connector_no_provider_defaults_to_no_auth(tmp_path: Any) -> None: + """A connector with no auth_provider returns {} from get_auth_headers().""" + connector = _NoAuthConnector() # no auth_provider kwarg + assert await connector.get_auth_headers() == {} + + +# --------------------------------------------------------------------------- +# Factory._build_auth_provider() +# --------------------------------------------------------------------------- + + +def test_factory_defaults_to_no_auth_when_auth_block_absent() -> None: + from bindings.factory import ConnectorFactory + from node_wire_runtime.auth import NoAuthProvider + + sp = _DictSecretProvider({}) + factory = ConnectorFactory.__new__(ConnectorFactory) + factory._secret_provider = sp + factory._configs = {} + factory._connectors = {} + + provider = factory._build_auth_provider("test_connector", {}) + assert isinstance(provider, NoAuthProvider) + + +def test_factory_builds_static_token_provider() -> None: + from bindings.factory import ConnectorFactory + from node_wire_runtime.auth import StaticTokenAuthProvider + + sp = _DictSecretProvider({"my_api_key": "abc"}) + factory = ConnectorFactory.__new__(ConnectorFactory) + factory._secret_provider = sp + + cfg = {"auth": {"provider": "static_token", "secret_key": "my_api_key", "prefix": ""}} + provider = factory._build_auth_provider("stripe", cfg) + assert isinstance(provider, StaticTokenAuthProvider) + + +def test_factory_builds_oauth2_provider() -> None: + from bindings.factory import ConnectorFactory + from node_wire_runtime.auth import OAuth2AuthProvider + + sp = _DictSecretProvider({}) + factory = ConnectorFactory.__new__(ConnectorFactory) + factory._secret_provider = sp + + cfg = { + "auth": { + "provider": "oauth2", + "grant_method": "private_key_jwt", + "token_url_secret": "epic_token_url", + "client_id_secret": "epic_client_id", + "private_key_secret": "epic_private_key", + "kid_secret": "epic_kid", + } + } + provider = factory._build_auth_provider("fhir_epic", cfg) + assert isinstance(provider, OAuth2AuthProvider) + + +def test_factory_builds_service_account_provider() -> None: + from bindings.factory import ConnectorFactory + + sp = _DictSecretProvider({}) + factory = ConnectorFactory.__new__(ConnectorFactory) + factory._secret_provider = sp + + cfg = { + "auth": { + "provider": "service_account", + "sa_json_secret": "GOOGLE_DRIVE_SA_JSON", + } + } + provider = factory._build_auth_provider("google_drive", cfg) + assert isinstance(provider, ServiceAccountAuthProvider) diff --git a/tests/test_bandit_report_summary.py b/tests/test_bandit_report_summary.py new file mode 100644 index 0000000..6f7fb3c --- /dev/null +++ b/tests/test_bandit_report_summary.py @@ -0,0 +1,60 @@ +"""Regression tests for scripts/bandit_report_summary.py (CI log helper).""" + +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +SCRIPT = REPO_ROOT / "scripts" / "bandit_report_summary.py" +FIXTURE = REPO_ROOT / "tests" / "fixtures" / "bandit_minimal_report.json" + + +def test_bandit_report_summary_runs_on_fixture() -> None: + assert SCRIPT.is_file(), "summary script must exist" + assert FIXTURE.is_file(), "fixture must exist" + + proc = subprocess.run( + [sys.executable, str(SCRIPT), str(FIXTURE)], + cwd=str(REPO_ROOT), + check=True, + capture_output=True, + text=True, + ) + assert proc.returncode == 0 + assert "Bandit report summary" in proc.stdout + assert "src/example.py" in proc.stdout + assert "B999" in proc.stdout + assert "HIGH: 0" in proc.stdout + + +def test_bandit_report_summary_missing_file_exits_nonzero() -> None: + proc = subprocess.run( + [sys.executable, str(SCRIPT), str(REPO_ROOT / "nonexistent_bandit.json")], + cwd=str(REPO_ROOT), + capture_output=True, + text=True, + ) + assert proc.returncode == 2 + assert "not found" in proc.stderr.lower() + + +def test_bandit_report_summary_invalid_json_exits_nonzero(tmp_path: Path) -> None: + bad = tmp_path / "bad.json" + bad.write_text("{not json", encoding="utf-8") + proc = subprocess.run( + [sys.executable, str(SCRIPT), str(bad)], + cwd=str(REPO_ROOT), + capture_output=True, + text=True, + ) + assert proc.returncode == 2 + + +def test_bandit_fixture_is_valid_json() -> None: + data = json.loads(FIXTURE.read_text(encoding="utf-8")) + assert "metrics" in data + assert "_totals" in data["metrics"] + assert isinstance(data.get("results"), list) diff --git a/tests/test_base_connector_manifest.py b/tests/test_base_connector_manifest.py new file mode 100644 index 0000000..863c613 --- /dev/null +++ b/tests/test_base_connector_manifest.py @@ -0,0 +1,834 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +from __future__ import annotations + +import os +from importlib import import_module +from typing import Any, Dict + +import pytest +from pydantic import ValidationError + +from bindings.factory import ConnectorFactory +from node_wire_runtime.connector_registry import auto_register +from node_wire_runtime.manifest import build_manifest +from node_wire_stripe.schema import ChargeInput +from node_wire_runtime import BaseConnector +from node_wire_runtime.base_connector import _CONNECTOR_REGISTRY + + +def _normalize_for_mcp(connector_id: str, action: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Test harness: resolve MCP connector and run metadata-driven normalizers.""" + norm = import_module("bindings.mcp_server.server").normalize_mcp_tool_arguments + auto_register() + factory = ConnectorFactory() + factory.load() + connector = factory.get_for_protocol(connector_id, "mcp") + assert connector is not None + return norm(connector, action, arguments) + + +def test_registry_contains_base_connectors(): + auto_register() + assert "google_drive" in _CONNECTOR_REGISTRY + assert "stripe" in _CONNECTOR_REGISTRY + assert "fhir_epic" in _CONNECTOR_REGISTRY + + +def test_manifest_emits_per_action(): + auto_register() + factory = ConnectorFactory() + factory.load() + rest_manifest = build_manifest(factory.list_for_protocol("rest")) + rest_actions = {(e["connector_id"], e["action"]) for e in rest_manifest} + assert ("google_drive", "files.list") in rest_actions + assert ("fhir_epic", "read_patient") in rest_actions + assert ("stripe", "charge") in rest_actions + + mcp_manifest = build_manifest(factory.list_for_protocol("mcp")) + mcp_actions = {(e["connector_id"], e["action"]) for e in mcp_manifest} + assert ("stripe", "charge") in mcp_actions + # Per-action input schema should expose that action's fields (not only a buried union) + for entry in mcp_manifest: + if entry["connector_id"] == "stripe" and entry["action"] == "charge": + props = entry["input_schema"].get("properties", {}) + assert "amount" in props + + +def test_stripe_connector_accepts_charge_payload(): + auto_register() + factory = ConnectorFactory() + factory.load() + connector = factory.get_for_protocol("stripe", "grpc") + assert connector is not None + assert isinstance(connector, BaseConnector) + validated = ChargeInput.model_validate( + {"action": "charge", "amount": 100, "currency": "usd", "source": "tok_visa"} + ) + assert validated.action == "charge" + + +def test_stripe_connector_normalizes_uppercase_currency(): + validated = ChargeInput.model_validate( + {"action": "charge", "amount": 100, "currency": "USD", "source": "tok_visa"} + ) + assert validated.currency == "usd" + + +@pytest.mark.parametrize( + "payload", + [ + {"action": "charge", "amount": 0, "currency": "usd", "source": "tok_visa"}, + {"action": "charge", "amount": -1, "currency": "usd", "source": "tok_visa"}, + {"action": "charge", "amount": 100_000_000, "currency": "usd", "source": "tok_visa"}, + {"action": "charge", "amount": 100, "currency": "US", "source": "tok_visa"}, + {"action": "charge", "amount": 100, "currency": "USDT", "source": "tok_visa"}, + {"action": "charge", "amount": 100, "currency": "us1", "source": "tok_visa"}, + ], +) +def test_stripe_connector_rejects_invalid_charge_payload(payload): + with pytest.raises(ValidationError): + ChargeInput.model_validate(payload) + + +def test_mcp_tool_invoke_sets_action(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + names = {t["name"] for t in tools} + assert "google_drive.files.list" in names + assert "stripe.charge" in names + + +def test_mcp_server_list_tools_includes_output_schema(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + assert tools + assert all("output_schema" in t for t in tools) + + +def test_mcp_server_connector_ids_filters_list_tools(): + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["fhir_cerner"]) + names = {t["name"] for t in server.list_tools()} + assert names + assert all(n.startswith("fhir_cerner.") for n in names) + assert "fhir_epic.read_patient" not in names + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_rejects_disallowed_connector() -> None: + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["google_drive"]) + with pytest.raises(ValueError, match="not allowed"): + await server.invoke_tool( + "smtp.send_email", + {"to": ["doc@example.com"], "subject": "x", "body": "y"}, + ) + + +def test_mcp_server_run_stdio_smoke(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + assert callable(server.run_stdio) + assert callable(server._run_stdio_async) + + +def test_normalize_mcp_tool_arguments_read_patient_maps_legacy_ids(): + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + from node_wire_fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput + + for cid in ("fhir_cerner", "fhir_epic"): + out = _normalize_for_mcp( + cid, + "read_patient", + {"patientId": "12724066"}, + ) + assert out["resource_id"] == "12724066" + assert "patientId" not in out + model = FhirCernerPatientReadInput if cid == "fhir_cerner" else FhirEpicPatientReadInput + model.model_validate({**out, "action": "read_patient"}) + + # Canonical resource_id wins over alias + out2 = _normalize_for_mcp( + "fhir_cerner", + "read_patient", + {"resource_id": "111", "patient_id": "222"}, + ) + assert out2["resource_id"] == "111" + + out3 = _normalize_for_mcp( + "fhir_cerner", + "read_patient", + {"familyName": "Smith", "givenName": "John"}, + ) + assert out3["family_name"] == "Smith" + assert out3["given_name"] == "John" + + +def test_normalize_mcp_tool_arguments_search_patients_maps_legacy(): + from node_wire_fhir_cerner.schema import FhirCernerPatientSearchInput + + out = _normalize_for_mcp( + "fhir_cerner", + "search_patients", + {"patient_ids": "12724066,12724067"}, + ) + assert out["resource_ids"] == ["12724066", "12724067"] + + out2 = _normalize_for_mcp( + "fhir_cerner", + "search_patients", + {"search_params": {"patientId": "12724066"}}, + ) + assert out2["search_params"]["identifier"] == "12724066" + assert "patientId" not in out2["search_params"] + + FhirCernerPatientSearchInput.model_validate({**out2, "action": "search_patients"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_mime_type_alias(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mimeType": "text/plain", + "parents": ["folder1"], + "content": "hello", + }, + ) + assert out["mime_type"] == "text/plain" + assert "mimeType" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_action_upload(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "action": "upload", + "name": "a.txt", + "mime_type": "text/plain", + "content": "x", + }, + ) + assert out["action"] == "files.upload" + FilesUploadOperation.model_validate(out) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_nested_file(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "content": "body", + "file": { + "mime_type": "text/plain", + "name": "nested.txt", + "parents": ["p1"], + }, + }, + ) + assert out["name"] == "nested.txt" + assert out["mime_type"] == "text/plain" + assert out["parents"] == ["p1"] + assert "file" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_string_maps_to_content(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mime_type": "text/plain", + "media": "hello", + }, + ) + assert out["content"] == "hello" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_object_text_alias_maps_to_content(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mime_type": "text/plain", + "media": {"text": "hello"}, + }, + ) + assert out["content"] == "hello" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_object_base64_maps_to_content_base64(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "name": "a.pdf", + "mime_type": "application/pdf", + "media": {"base64": "Zg=="}, + }, + ) + assert out["content_base64"] == "Zg==" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_metadata_aliases_are_used_when_missing(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "media": { + "name": "nested.txt", + "mimeType": "text/plain", + "parents": "p1,p2", + "content": "hi", + } + }, + ) + assert out["name"] == "nested.txt" + assert out["mime_type"] == "text/plain" + assert out["parents"] == ["p1", "p2"] + assert out["content"] == "hi" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_canonical_content_wins_over_media_alias(): + from node_wire_google_drive.schema import FilesUploadOperation + + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "name": "root.txt", + "mime_type": "text/plain", + "content": "root", + "media": {"content": "ignored"}, + }, + ) + assert out["content"] == "root" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_canonical_mime_type_wins_over_nested(): + out = _normalize_for_mcp( + "google_drive", + "files.upload", + { + "mime_type": "text/plain", + "name": "root.txt", + "content": "c", + "file": {"mime_type": "application/json", "name": "ignored.txt"}, + }, + ) + assert out["mime_type"] == "text/plain" + assert out["name"] == "root.txt" + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_tool_passes_normalized_payload_to_connector_run() -> None: + """invoke_tool should apply normalization before BaseConnector.run.""" + from bindings.mcp_server.server import McpServer + from node_wire_runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["fhir_cerner"]) + cerner = server._factory.get_for_protocol("fhir_cerner", "mcp") + assert cerner is not None + + captured: dict = {} + + async def fake_run(raw_input, **_kwargs): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"resource": {"id": "12724066"}}, trace_id="t") + + orig_run = cerner.run + try: + cerner.run = fake_run + await server.invoke_tool("fhir_cerner.read_patient", {"patientId": "12724066"}) + finally: + cerner.run = orig_run + + assert captured["payload"]["resource_id"] == "12724066" + assert captured["payload"].get("action") == "read_patient" + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_google_drive_files_upload_normalizes_payload() -> None: + """invoke_tool should normalize Drive upload aliases before connector.run.""" + from bindings.mcp_server.server import McpServer + from node_wire_runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["google_drive"]) + gdrive = server._factory.get_for_protocol("google_drive", "mcp") + assert gdrive is not None + + captured: dict = {} + + async def fake_run(raw_input, **_kwargs): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"raw": {}}, trace_id="t") + + orig_run = gdrive.run + try: + # Set NW_RATE_LIMIT_DISABLED env var to disable rate limiting in MCP server + old_rate_limit = os.environ.get("NW_RATE_LIMIT_DISABLED") + os.environ["NW_RATE_LIMIT_DISABLED"] = "true" + + gdrive.run = fake_run + await server.invoke_tool( + "google_drive.files.upload", + { + "mimeType": "text/plain", + "name": "patient_summary.txt", + "parents": ["folder-id"], + "content": "summary", + "media": {"content": "ignored"}, + "action": "upload", + }, + ) + + # Restore original rate limit value + if old_rate_limit is not None: + os.environ["NW_RATE_LIMIT_DISABLED"] = old_rate_limit + finally: + gdrive.run = orig_run + + assert captured["payload"]["mime_type"] == "text/plain" + assert captured["payload"]["action"] == "files.upload" + assert "mimeType" not in captured["payload"] + assert "media" not in captured["payload"] + + +def test_build_manifest_mcp_input_schema_omits_action_property() -> None: + """MCP/REST manifest must not expose `action` in inputSchema (injected by binding).""" + auto_register() + factory = ConnectorFactory() + factory.load() + for entry in build_manifest(factory.list_for_protocol("mcp")): + props = entry["input_schema"].get("properties") or {} + assert "action" not in props, entry + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_rejects_legacy_upload_when_env_reject( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("NODE_WIRE_LEGACY_GDRIVE_ACTION_UPLOAD", "reject") + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["google_drive"]) + with pytest.raises(ValueError, match="does not match"): + await server.invoke_tool( + "google_drive.files.upload", + { + "name": "x.txt", + "mime_type": "text/plain", + "content": "a", + "action": "upload", + }, + ) + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_rejects_conflicting_action() -> None: + """Tool name action must match payload after normalization (no action spoofing).""" + from bindings.mcp_server.server import McpServer + + # Set NW_RATE_LIMIT_DISABLED env var to disable rate limiting in MCP server + old_rate_limit = os.environ.get("NW_RATE_LIMIT_DISABLED") + os.environ["NW_RATE_LIMIT_DISABLED"] = "true" + + try: + server = McpServer(connector_ids=["google_drive"]) + + with pytest.raises(ValueError, match="does not match"): + await server.invoke_tool( + "google_drive.files.upload", + { + "name": "x.txt", + "mime_type": "text/plain", + "content": "a", + "action": "files.list", + }, + ) + finally: + # Restore original rate limit value + if old_rate_limit is not None: + os.environ["NW_RATE_LIMIT_DISABLED"] = old_rate_limit + + +def test_normalize_fhir_search_encounter_maps_llm_aliases(): + out = _normalize_for_mcp( + "fhir_cerner", + "search_encounter", + { + "patient": "12748336", + "sort": "-date", + "status": "finished", + }, + ) + assert out["patient_id"] == "12748336" + assert out["search_params"]["_sort"] == "-date" + assert out.get("patient") is None + + +def test_normalize_mcp_tool_arguments_smtp_send_email_from_alias(): + from node_wire_smtp.schema import SmtpSendInput + + out = _normalize_for_mcp( + "smtp", + "send_email", + { + "from": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "Hi", + "body": "Hello", + }, + ) + assert out["from_email"] == "sender@example.com" + assert "from" not in out + SmtpSendInput.model_validate({**out, "action": "send_email"}) + + +def test_normalize_mcp_tool_arguments_smtp_send_email_sender_alias(): + out = _normalize_for_mcp( + "smtp", + "send_email", + {"sender": "s@example.com", "to": ["r@example.com"], "subject": "x", "body": "y"}, + ) + assert out["from_email"] == "s@example.com" + assert "sender" not in out + + +def test_normalize_mcp_tool_arguments_smtp_send_email_canonical_wins(): + out = _normalize_for_mcp( + "smtp", + "send_email", + { + "from_email": "canonical@example.com", + "from": "alias@example.com", + "to": ["r@example.com"], + "subject": "x", + "body": "y", + }, + ) + assert out["from_email"] == "canonical@example.com" + assert "from" not in out + + +def test_normalize_mcp_tool_arguments_smtp_send_email_to_string_to_list(): + from node_wire_smtp.schema import SmtpSendInput + + out = _normalize_for_mcp( + "smtp", + "send_email", + {"from_email": "s@example.com", "to": "r@example.com", "subject": "x", "body": "y"}, + ) + assert out["to"] == ["r@example.com"] + SmtpSendInput.model_validate({**out, "action": "send_email"}) + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_smtp_send_email_normalizes_payload() -> None: + """invoke_tool should normalize SMTP aliases before connector.run.""" + from bindings.mcp_server.server import McpServer + from node_wire_runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["smtp"]) + smtp = server._factory.get_for_protocol("smtp", "mcp") + assert smtp is not None + + captured: dict = {} + + async def fake_run(raw_input, **_kwargs): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"sent": True}, trace_id="t") + + orig_run = smtp.run + try: + smtp.run = fake_run + await server.invoke_tool( + "smtp.send_email", + { + "from": "sender@example.com", + "to": "recipient@example.com", + "subject": "Test", + "body": "Body", + }, + ) + finally: + smtp.run = orig_run + + assert captured["payload"]["from_email"] == "sender@example.com" + assert captured["payload"]["to"] == ["recipient@example.com"] + assert "from" not in captured["payload"] + assert captured["payload"].get("action") == "send_email" + + +def test_mcp_server_invoke_tool_malformed_name() -> None: + import asyncio + + from bindings.mcp_server.server import McpServer + + async def _run() -> None: + server = McpServer() + with pytest.raises(ValueError, match="Tool name must be in the form"): + await server.invoke_tool("no_dot_separator", {}) + + asyncio.run(_run()) + + +def test_mcp_server_invoke_tool_connector_not_in_filter() -> None: + import asyncio + + from bindings.mcp_server.server import McpServer + + async def _run() -> None: + server = McpServer(connector_ids=["fhir_cerner"]) + with pytest.raises(ValueError, match="not allowed on this MCP server"): + await server.invoke_tool("fhir_epic.read_patient", {"resource_id": "x"}) + + asyncio.run(_run()) + + +def test_mcp_server_invoke_tool_unknown_connector_id() -> None: + import asyncio + + from bindings.mcp_server.server import McpServer + + async def _run() -> None: + server = McpServer() + with pytest.raises(ValueError, match="not available via MCP"): + await server.invoke_tool("unknown_connector_xyz.read_patient", {}) + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Enterprise-quality schema contract tests +# --------------------------------------------------------------------------- + + +def test_mcp_server_list_tools_output_schema_is_connector_response_envelope(): + """output_schema must be the ConnectorResponse envelope with correct structure.""" + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + assert tools + for t in tools: + assert "output_schema" in t + schema = t["output_schema"] + assert schema.get("title") == "ConnectorResponse" + assert schema.get("type") == "object" + props = schema.get("properties", {}) + assert "success" in props + assert "data" in props + assert "trace_id" in props + assert "error_code" in props + assert "error_category" in props + assert set(schema.get("required", [])) == {"success", "trace_id"} + + +def test_connector_response_schema_embeds_output_model_in_data(): + """_connector_response_schema must inline the output model schema as data, no $ref/$defs.""" + from node_wire_runtime.manifest import _connector_response_schema + from node_wire_smtp.schema import SmtpSendOutput + + schema = _connector_response_schema(SmtpSendOutput) + assert schema["title"] == "ConnectorResponse" + assert schema["type"] == "object" + props = schema["properties"] + assert props["success"] == {"type": "boolean"} + assert props["trace_id"] == {"type": "string"} + # data must contain the SmtpSendOutput properties (nullable union branch) + data_any = props["data"]["anyOf"] + output_branch = next(b for b in data_any if b.get("type") != "null") + assert "sent" in output_branch.get("properties", {}) + # error_category must inline the enum from runtime (no $ref to avoid $defs leakage) + ec = props["error_category"] + enum_values = ec["anyOf"][0]["enum"] + from node_wire_runtime.models import ErrorCategory + + assert set(enum_values) == {e.value for e in ErrorCategory} + assert "$ref" not in str(ec) + assert "$defs" not in schema + + +def test_manifest_strict_action_retains_additional_properties(): + """Actions not marked alias_tolerant must preserve additionalProperties:false.""" + auto_register() + factory = ConnectorFactory() + factory.load() + mcp_manifest = build_manifest(factory.list_for_protocol("mcp")) + + # files.list uses BaseDriveOperation(extra="forbid") and is not alias_tolerant + files_list = next( + e + for e in mcp_manifest + if e["connector_id"] == "google_drive" and e["action"] == "files.list" + ) + assert files_list["input_schema"].get("additionalProperties") is False + + +def test_manifest_alias_tolerant_actions_strip_additional_properties(): + """Actions marked alias_tolerant=True must have additionalProperties removed.""" + auto_register() + factory = ConnectorFactory() + factory.load() + mcp_manifest = build_manifest(factory.list_for_protocol("mcp")) + by_key = {(e["connector_id"], e["action"]): e for e in mcp_manifest} + + # files.upload is alias_tolerant via SdkActionSpec + assert "additionalProperties" not in by_key[("google_drive", "files.upload")]["input_schema"] + # smtp send_email is alias_tolerant via @sdk_action kwarg + assert "additionalProperties" not in by_key[("smtp", "send_email")]["input_schema"] + # fhir read_patient / search_patients / search_encounter are alias_tolerant + assert "additionalProperties" not in by_key[("fhir_cerner", "read_patient")]["input_schema"] + assert "additionalProperties" not in by_key[("fhir_cerner", "search_patients")]["input_schema"] + assert "additionalProperties" not in by_key[("fhir_cerner", "search_encounter")]["input_schema"] + assert "additionalProperties" not in by_key[("fhir_epic", "read_patient")]["input_schema"] + assert "additionalProperties" not in by_key[("fhir_epic", "search_patients")]["input_schema"] + assert "additionalProperties" not in by_key[("fhir_epic", "search_encounter")]["input_schema"] + + +def test_sdk_action_meta_alias_tolerant_propagates(): + """alias_tolerant must be correctly stored in _action_registry for all paths.""" + auto_register() + + # google_drive files.upload: alias_tolerant via SdkActionSpec → _make_spec_handler + gd_cls = _CONNECTOR_REGISTRY["google_drive"] + assert gd_cls._action_registry["files.upload"].alias_tolerant is True + assert gd_cls._action_registry["files.list"].alias_tolerant is False + + # smtp send_email: alias_tolerant via @sdk_action kwarg + smtp_cls = _CONNECTOR_REGISTRY["smtp"] + assert smtp_cls._action_registry["send_email"].alias_tolerant is True + + # fhir connectors + cerner_cls = _CONNECTOR_REGISTRY["fhir_cerner"] + assert cerner_cls._action_registry["read_patient"].alias_tolerant is True + assert cerner_cls._action_registry["search_patients"].alias_tolerant is True + assert cerner_cls._action_registry["search_encounter"].alias_tolerant is True + + epic_cls = _CONNECTOR_REGISTRY["fhir_epic"] + assert epic_cls._action_registry["search_encounter"].alias_tolerant is True + + +def test_manifest_error_category_enum_matches_runtime_error_category(): + """Emitted JSON Schema enum must stay in sync with ErrorCategory.""" + from node_wire_runtime.manifest import _error_category_json_schema + from node_wire_runtime.models import ErrorCategory + + schema = _error_category_json_schema() + assert set(schema["enum"]) == {e.value for e in ErrorCategory} + + +def test_sdk_action_meta_mcp_normalize_propagates(): + """mcp_normalize registered on actions must appear in _action_registry.""" + auto_register() + + gd_cls = _CONNECTOR_REGISTRY["google_drive"] + assert gd_cls._action_registry["files.upload"].mcp_normalize is not None + assert gd_cls._action_registry["files.list"].mcp_normalize is None + + smtp_cls = _CONNECTOR_REGISTRY["smtp"] + assert smtp_cls._action_registry["send_email"].mcp_normalize is not None + + cerner_cls = _CONNECTOR_REGISTRY["fhir_cerner"] + assert cerner_cls._action_registry["read_patient"].mcp_normalize is not None + assert cerner_cls._action_registry["search_patients"].mcp_normalize is not None + assert cerner_cls._action_registry["search_encounter"].mcp_normalize is not None + + epic_cls = _CONNECTOR_REGISTRY["fhir_epic"] + assert epic_cls._action_registry["search_encounter"].mcp_normalize is not None + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_tool_failure_payload_matches_output_schema_shape() -> None: + """Error ConnectorResponse (data=None) matches manifest output_schema (nullable data).""" + from bindings.mcp_server.server import McpServer + from node_wire_runtime.models import ConnectorResponse, ErrorCategory + + server = McpServer(connector_ids=["smtp"]) + smtp = server._factory.get_for_protocol("smtp", "mcp") + assert smtp is not None + + entry = next(e for e in server.list_tools() if e["name"] == "smtp.send_email") + schema = entry["output_schema"] + data_prop = schema["properties"]["data"] + assert {"type": "null"} in data_prop["anyOf"] + + async def fake_run(_raw_input, **_kwargs): + return ConnectorResponse( + success=False, + data=None, + error_code="VALIDATION_ERROR", + error_category=ErrorCategory.BUSINESS, + message="bad", + trace_id="trace-1", + details=[{"loc": ["x"], "msg": "y", "type": "value_error"}], + ) + + orig_run = smtp.run + try: + smtp.run = fake_run + out = await server.invoke_tool( + "smtp.send_email", + {"from_email": "a@b.com", "to": ["x@y.com"], "subject": "s", "body": "b"}, + ) + finally: + smtp.run = orig_run + + assert out["success"] is False + assert out["data"] is None + assert out["error_code"] == "VALIDATION_ERROR" + assert out["trace_id"] == "trace-1" + + +def test_normalize_mcp_tool_arguments_noop_when_action_has_no_normalizer(): + """Strict actions without mcp_normalize should pass args through unchanged.""" + from bindings.mcp_server.server import normalize_mcp_tool_arguments + + auto_register() + factory = ConnectorFactory() + factory.load() + connector = factory.get_for_protocol("google_drive", "mcp") + assert connector is not None + raw = {"action": "files.list", "page_size": 10} + out = normalize_mcp_tool_arguments(connector, "files.list", raw) + assert out == raw diff --git a/tests/test_call_action_policy.py b/tests/test_call_action_policy.py new file mode 100644 index 0000000..db59064 --- /dev/null +++ b/tests/test_call_action_policy.py @@ -0,0 +1,131 @@ +"""Regression: ``BaseConnector.call_action`` must honor scope policy via ``run``.""" + +from __future__ import annotations + +from typing import Literal + +import pytest +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.errors import ErrorMapper +from node_wire_runtime.models import ErrorCategory +from node_wire_runtime.policies.mcp_scope_policy import ScopePolicyHook +from node_wire_runtime.policy import PolicyDenied + + +class _NestedBizError(Exception): + """Test-only mapped exception for nested action failure semantics.""" + + +ErrorMapper.register(_NestedBizError, ErrorCategory.BUSINESS, code="NESTED_BIZ_TEST") + + +class _DelInput(BaseModel): + action: Literal["delegate"] = "delegate" + resource_id: str + + +class _ReadInput(BaseModel): + action: Literal["read_patient"] = "read_patient" + resource_id: str + + +class _Output(BaseModel): + ok: bool + + +class _CompositeConnector(BaseConnector): + # Must not reuse a production ``connector_id`` — ``__init_subclass__`` overwrites + # :data:`_CONNECTOR_REGISTRY` and would break other tests. + connector_id = "policy_test_composite" + output_model = _Output + + @nw_action("delegate") + async def delegate(self, params: _DelInput, *, trace_id: str) -> _Output: + return await self.call_action( + "read_patient", + {"action": "read_patient", "resource_id": params.resource_id}, + ) + + @nw_action("read_patient") + async def read_patient(self, params: _ReadInput, *, trace_id: str) -> _Output: + return _Output(ok=True) + + +class _FailNestedConnector(BaseConnector): + connector_id = "policy_test_fail_nested" + output_model = _Output + + @nw_action("delegate") + async def delegate(self, params: _DelInput, *, trace_id: str) -> _Output: + return await self.call_action( + "read_patient", + {"action": "read_patient", "resource_id": params.resource_id}, + ) + + @nw_action("read_patient") + async def read_patient(self, params: _ReadInput, *, trace_id: str) -> _Output: + raise _NestedBizError("nested failure") + + +@pytest.mark.asyncio +async def test_call_action_inherits_identity_for_nested_policy() -> None: + hook = ScopePolicyHook({"policy_test_composite.read_patient": "mcp:fhir.read_patient"}) + connector = _CompositeConnector(policy_hook=hook) + + resp = await connector.run( + {"action": "delegate", "resource_id": "x"}, + principal="alice", + tenant_id="t1", + scopes=("mcp:fhir.read_patient",), + ) + assert resp.success is True + assert resp.data is not None + assert resp.data["ok"] is True + + +@pytest.mark.asyncio +async def test_call_action_nested_policy_denied_raises() -> None: + hook = ScopePolicyHook({"policy_test_composite.read_patient": "mcp:fhir.read_patient"}) + connector = _CompositeConnector(policy_hook=hook) + + resp = await connector.run( + {"action": "delegate", "resource_id": "x"}, + principal="alice", + tenant_id="t1", + scopes=("mcp:other.scope",), + ) + assert resp.success is False + assert resp.error_code == "POLICY_DENIED" + + +def test_call_action_direct_raises_policy_denied_sync_wrap() -> None: + """PolicyDenied from nested run surfaces through async delegate body.""" + hook = ScopePolicyHook({"policy_test_composite.read_patient": "mcp:fhir.read_patient"}) + connector = _CompositeConnector(policy_hook=hook) + + async def _run() -> None: + await connector.call_action( + "read_patient", + {"action": "read_patient", "resource_id": "x"}, + principal="alice", + scopes=("mcp:other.scope",), + ) + + import asyncio + + with pytest.raises(PolicyDenied): + asyncio.run(_run()) + + +@pytest.mark.asyncio +async def test_call_action_preserves_nested_error_category_and_code() -> None: + connector = _FailNestedConnector(policy_hook=None) + resp = await connector.run({"action": "delegate", "resource_id": "x"}) + assert resp.success is False + assert resp.error_code == "NESTED_BIZ_TEST" + assert resp.error_category == ErrorCategory.BUSINESS + assert "nested failure" in (resp.message or "") + assert resp.details is not None + assert "nested_trace_id" in resp.details diff --git a/tests/test_connector_registry.py b/tests/test_connector_registry.py new file mode 100644 index 0000000..aa344a3 --- /dev/null +++ b/tests/test_connector_registry.py @@ -0,0 +1,98 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for entry-point connector registration and allowlist.""" + +from __future__ import annotations + +from importlib.metadata import EntryPoint +from unittest.mock import patch + +import pytest + +from node_wire_runtime import connector_registry + + +def test_auto_register_respects_nw_allowed_connectors(monkeypatch: pytest.MonkeyPatch) -> None: + """Only listed entry point names are imported when NW_ALLOWED_CONNECTORS is set.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "a") + + eps = [ + EntryPoint(name="a", value="node_wire_a.logic", group="node_wire.connectors"), + EntryPoint(name="b", value="node_wire_b.logic", group="node_wire.connectors"), + ] + + with ( + patch.object(connector_registry, "entry_points", return_value=eps), + patch.object(connector_registry.importlib, "import_module") as mock_imp, + ): + connector_registry.auto_register() + + imported = [c[0][0] for c in mock_imp.call_args_list] + assert "node_wire_a.logic" in imported + assert "node_wire_b.logic" not in imported + + +def test_auto_register_skips_bad_module_prefix() -> None: + fake_ep = EntryPoint( + name="evil", + value="third_party_evil.logic", + group="node_wire.connectors", + ) + assert connector_registry._should_skip_ep(fake_ep, {"evil"}, "node_wire_") is True + + +def test_allowed_connector_not_skipped_when_prefix_matches() -> None: + fake_ep = EntryPoint( + name="http_generic", + value="node_wire_http_generic.logic", + group="node_wire.connectors", + ) + assert connector_registry._should_skip_ep(fake_ep, {"http_generic"}, "node_wire_") is False + + +def test_logic_module_dotted_path_supports_colon_attr() -> None: + ep = EntryPoint( + name="x", + value="node_wire_x.logic:ConnectorClass", + group="node_wire.connectors", + ) + assert connector_registry._logic_module_dotted_path(ep) == "node_wire_x.logic" + + +def test_auto_register_fallback_import_when_no_entry_points( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Allowed connectors not found via entry points fallback to node_wire_.logic import.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "fallback_test") + + # entry_points() returns empty list + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module") as mock_imp, + ): + loaded = connector_registry.auto_register() + + # Should attempt to import node_wire_fallback_test.logic and .registration + imported = [c[0][0] for c in mock_imp.call_args_list] + assert "node_wire_fallback_test.logic" in imported + assert "node_wire_fallback_test.registration" in imported + # Both are mocked to succeed, so both should be in loaded + assert "node_wire_fallback_test.logic" in loaded + assert "node_wire_fallback_test.registration" in loaded + + +def test_auto_register_fallback_respects_custom_prefix(monkeypatch: pytest.MonkeyPatch) -> None: + """Fallback import uses NW_CONNECTOR_MODULE_PREFIX if set.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "fallback_test") + monkeypatch.setenv("NW_CONNECTOR_MODULE_PREFIX", "custom_") + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module") as mock_imp, + ): + connector_registry.auto_register() + + imported = [c[0][0] for c in mock_imp.call_args_list] + assert "custom_fallback_test.logic" in imported diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index a4e633e..b5da862 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -1,46 +1,61 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations -import asyncio +import pytest -from pydantic import BaseModel - -from connectors.http_generic.logic import HttpGenericConnector -from connectors.http_generic.schema import HttpRequestInput, HttpResponseOutput -from connectors.smtp.logic import SmtpConnector -from connectors.smtp.schema import SmtpSendInput, SmtpSendOutput -from connectors.stripe.logic import StripeChargeConnector -from connectors.stripe.schema import ChargeInput, ChargeOutput -from runtime import ConnectorResponse, ErrorCategory, SecretProvider -from connectors import auto_register +from node_wire_http_generic.logic import HttpGenericConnector +from node_wire_smtp.logic import SmtpConnector +from node_wire_stripe.logic import StripeConnector +from node_wire_runtime import BaseConnector, SecretProvider +from node_wire_runtime.connector_registry import auto_register class DummySecretProvider(SecretProvider): def __init__(self) -> None: - self._store = {"stripe_api_key": "sk_test_dummy", "smtp_user": "user", "smtp_pass": "pass"} + self._store = {"STRIPE_API_KEY": "sk_test_dummy", "smtp_user": "user", "smtp_pass": "pass"} def get_secret(self, key: str) -> str: return self._store[key] -def test_auto_register_runs_without_error(): +def test_auto_register_runs_without_error(monkeypatch): + monkeypatch.setenv( + "NW_ALLOWED_CONNECTORS", "fhir_cerner,fhir_epic,google_drive,smtp,stripe,http_generic" + ) imported = auto_register() + if not imported: + pytest.skip( + "importlib.metadata entry points for node_wire.connectors are empty; " + "use `pip install -e .` to run this assertion." + ) assert any("http_generic.registration" in name for name in imported) + assert any("google_drive.logic" in name for name in imported) def test_http_connector_instantiation_only(): - connector = HttpGenericConnector(HttpRequestInput, HttpResponseOutput) + connector = HttpGenericConnector() assert connector.connector_id == "http_generic" - assert connector.action == "request" + assert isinstance(connector, BaseConnector) def test_smtp_connector_instantiation_only(): - connector = SmtpConnector(SmtpSendInput, SmtpSendOutput, secret_provider=DummySecretProvider()) + connector = SmtpConnector(secret_provider=DummySecretProvider()) assert connector.connector_id == "smtp" - assert connector.action == "send_email" + assert isinstance(connector, BaseConnector) def test_stripe_connector_instantiation_only(): - connector = StripeChargeConnector(ChargeInput, ChargeOutput, secret_provider=DummySecretProvider()) + connector = StripeConnector(secret_provider=DummySecretProvider()) assert connector.connector_id == "stripe" - assert connector.action == "charge" + assert connector.action == "execute" + +def test_salesforce_connector_instantiation_only(): + store = {"salesforce_instance_url": "https://test.salesforce.com"} + provider = type("Mock", (), {"get_secret": lambda s, k: store[k]})() + connector = BaseConnector.get_registry()["salesforce"](secret_provider=provider) + assert connector.connector_id == "salesforce" + assert "create_lead" in connector._action_registry diff --git a/tests/test_connectors_io.py b/tests/test_connectors_io.py new file mode 100644 index 0000000..e2baa37 --- /dev/null +++ b/tests/test_connectors_io.py @@ -0,0 +1,247 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests for SMTP, HTTP generic, and Stripe connectors with mocked I/O.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import ANY, MagicMock, patch + +import httpx +import pytest +from pydantic import ValidationError + +from node_wire_http_generic.logic import HttpGenericConnector +from node_wire_http_generic.schema import HttpRequestInput +from node_wire_smtp.logic import SmtpConnector +from node_wire_smtp.schema import SmtpSendInput +from node_wire_stripe.logic import StripeConnector +from node_wire_runtime.secrets import SecretProvider + + +class _MapSecrets(SecretProvider): + def __init__(self, mapping: dict[str, str]) -> None: + self._m = mapping + + def get_secret(self, key: str) -> str: + return self._m[key] + + +def test_smtp_internal_execute_calls_aiosmtplib_send() -> None: + secrets = _MapSecrets({"SMTP_USERNAME": "u", "SMTP_PASSWORD": "p"}) + + async def fake_send(*args: object, **kwargs: object) -> tuple[int, str]: + return (250, "OK") + + async def _run() -> None: + with patch("node_wire_smtp.logic.aiosmtplib.send", new=fake_send): + c = SmtpConnector(secret_provider=secrets) + inp = SmtpSendInput( + host="localhost", + port=1025, + use_tls=False, + from_email="a@example.com", + to=["b@example.com"], + subject="s", + body="hi", + ) + out = await c.internal_execute(inp, trace_id="t-1") + assert out.sent is True + + asyncio.run(_run()) + + +def test_smtp_send_email_does_not_log_sender_address() -> None: + """L1 — from_email (PII) must never appear in any log record.""" + import logging + + secrets = _MapSecrets({"SMTP_USERNAME": "u", "SMTP_PASSWORD": "p"}) + + async def fake_send(*args: object, **kwargs: object) -> tuple[int, str]: + return (250, "OK") + + captured: list[logging.LogRecord] = [] + + class _Capture(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + captured.append(record) + + handler = _Capture(level=logging.DEBUG) + smtp_logger = logging.getLogger("connectors.smtp") + # Lower the logger's own level so INFO records are not silently dropped + # when the root logger is configured at WARNING (the pytest default). + _prev_level = smtp_logger.level + smtp_logger.setLevel(logging.DEBUG) + smtp_logger.addHandler(handler) + try: + + async def _run() -> None: + with patch("node_wire_smtp.logic.aiosmtplib.send", new=fake_send): + c = SmtpConnector(secret_provider=secrets) + inp = SmtpSendInput( + host="smtp.example.com", + port=587, + use_tls=True, + from_email="sender@private.example.com", + to=["recipient@other.example.com"], + subject="Test", + body="body", + ) + await c.internal_execute(inp, trace_id="t-pii") + + asyncio.run(_run()) + finally: + smtp_logger.removeHandler(handler) + smtp_logger.setLevel(_prev_level) + + assert len(captured) >= 2, "Expected at least prepare + sent log records" + for record in captured: + # The full sender address must never appear anywhere in the serialised record. + log_text = str(record.__dict__) + assert "sender@private.example.com" not in log_text, ( + f"Sender PII leaked into log record: {log_text!r}" + ) + # The domain-only hint MUST be present in at least the prepare record. + domains = [r.__dict__.get("sender_domain") for r in captured if "sender_domain" in r.__dict__] + assert all(d == "private.example.com" for d in domains), ( + f"Unexpected sender_domain values: {domains}" + ) + + +def test_http_generic_internal_execute() -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = httpx.Headers({"X-Test": "1"}) + mock_resp.text = "response-body" + + class _FakeAsyncClient: + async def __aenter__(self) -> "_FakeAsyncClient": + return self + + async def __aexit__(self, *args: object) -> None: + return None + + async def request(self, **kwargs: object) -> MagicMock: + return mock_resp + + async def _run() -> None: + with patch( + "node_wire_http_generic.logic.httpx.AsyncClient", return_value=_FakeAsyncClient() + ): + c = HttpGenericConnector() + inp = HttpRequestInput(url="http://example.com/path", method="GET") + out = await c.internal_execute(inp, trace_id="t-2") + assert out.status_code == 200 + assert out.body == "response-body" + + asyncio.run(_run()) + + +def test_http_request_input_normalizes_method() -> None: + parsed = HttpRequestInput(url="https://example.com/path", method=" post ") + assert parsed.method == "POST" + + +def test_http_request_input_rejects_unsupported_method() -> None: + with pytest.raises(ValidationError): + HttpRequestInput(url="https://example.com/path", method="TRACE") + + +@pytest.mark.parametrize( + "blocked_url", + [ + "http://localhost/health", + "http://LOCALHOST/health", + "http://127.0.0.1/internal", + "http://10.0.0.25/api", + "http://0.0.0.0/debug", + "http://169.254.169.254/latest/meta-data", + "http://[::1]/health", + "http://metadata.google.internal/computeMetadata/v1", + "http://metadata.google.internal./computeMetadata/v1", + ], +) +def test_http_request_input_rejects_internal_targets(blocked_url: str) -> None: + with pytest.raises(ValidationError): + HttpRequestInput(url=blocked_url, method="GET") + + +def test_http_request_input_allows_public_url() -> None: + parsed = HttpRequestInput(url="https://example.com/path?q=1", method="GET") + assert str(parsed.url) == "https://example.com/path?q=1" + + +def test_http_generic_logs_sanitized_url() -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = httpx.Headers({}) + mock_resp.text = "ok" + + class _FakeAsyncClient: + async def __aenter__(self) -> "_FakeAsyncClient": + return self + + async def __aexit__(self, *args: object) -> None: + return None + + async def request(self, **kwargs: object) -> MagicMock: + return mock_resp + + async def _run() -> None: + with ( + patch( + "node_wire_http_generic.logic.httpx.AsyncClient", return_value=_FakeAsyncClient() + ), + patch("node_wire_http_generic.logic.logger.info") as mocked_info, + ): + c = HttpGenericConnector() + inp = HttpRequestInput( + url="https://user:pass@example.com/path?token=secret&patient=123", + method="GET", + ) + await c.internal_execute(inp, trace_id="t-log") + for call in mocked_info.call_args_list: + extra = call.kwargs.get("extra") or {} + if "url" in extra: + assert extra["url"] == "https://example.com/path" + assert "secret" not in extra["url"] + assert "user:pass" not in extra["url"] + + asyncio.run(_run()) + + +def test_stripe_charge_via_run() -> None: + secrets = _MapSecrets({"stripe_api_key": "sk_test_dummy"}) + + with patch("node_wire_stripe.logic.stripe.Charge") as mock_charge: + mock_charge.create.return_value = MagicMock( + id="ch_123", receipt_url="https://pay.example/r", paid=True + ) + c = StripeConnector(secret_provider=secrets) + + async def _run() -> None: + resp = await c.run( + { + "action": "charge", + "amount": 1000, + "currency": "usd", + "source": "tok_visa", + } + ) + assert resp.success is True + assert resp.data is not None + assert resp.data.get("charge_id") == "ch_123" + + asyncio.run(_run()) + mock_charge.create.assert_called_once_with( + api_key="sk_test_dummy", + amount=1000, + currency="usd", + source="tok_visa", + customer=None, + description=None, + metadata=None, + idempotency_key=ANY, + ) diff --git a/tests/test_entrypoints.py b/tests/test_entrypoints.py new file mode 100644 index 0000000..95aa029 --- /dev/null +++ b/tests/test_entrypoints.py @@ -0,0 +1,109 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for MCP and REST/gRPC process entrypoints.""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + + +def test_mcp_entrypoint_main_calls_run_stdio() -> None: + with patch("bindings.mcp_server.server.McpServer") as MockServer: + from agents import mcp_entrypoint + + mcp_entrypoint.main() + MockServer.assert_called_once_with(server_name="node-wire") + MockServer.return_value.run.assert_called_once() + + +@pytest.mark.parametrize( + "module_path", + [ + "agents.fhir_cerner_mcp", + "agents.fhir_epic_mcp", + "agents.google_drive_mcp", + "agents.smtp_mcp", + ], +) +def test_per_connector_mcp_main_calls_run_stdio(module_path: str) -> None: + with patch("bindings.mcp_server.server.McpServer") as MockServer: + mod = __import__(module_path, fromlist=["main"]) + mod.main() + MockServer.return_value.run.assert_called_once() + + +def test_bindings_entrypoint_api_mode_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MODE", raising=False) + with ( + patch("bindings_entrypoint.init_observability") as mock_obs, + patch("bindings_entrypoint.uvicorn.run") as mock_uv, + ): + import bindings_entrypoint + + bindings_entrypoint.main() + mock_obs.assert_called_once_with(app_name="node-wire") + mock_uv.assert_called_once() + call_kw = mock_uv.call_args[1] + assert call_kw["host"] == "0.0.0.0" + assert call_kw["port"] == 8000 + + +def test_bindings_entrypoint_api_mode_explicit_port(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MODE", "API") + monkeypatch.setenv("PORT", "9000") + with ( + patch("bindings_entrypoint.init_observability"), + patch("bindings_entrypoint.uvicorn.run") as mock_uv, + ): + import bindings_entrypoint + + bindings_entrypoint.main() + assert mock_uv.call_args[1]["port"] == 9000 + + +def test_bindings_entrypoint_grpc_mode(monkeypatch: pytest.MonkeyPatch) -> None: + """GRPC path lazy-imports `serve`; stub the module so generated protos are not required.""" + monkeypatch.setenv("MODE", "GRPC") + mock_serve = MagicMock() + fake_grpc_server = MagicMock() + fake_grpc_server.serve = mock_serve + with ( + patch.dict(sys.modules, {"bindings.grpc_server.server": fake_grpc_server}), + patch("bindings_entrypoint.init_observability"), + ): + import bindings_entrypoint + + bindings_entrypoint.main() + mock_serve.assert_called_once_with(port=50051) + + +def test_bindings_entrypoint_mcp_mode(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MODE", "MCP") + mock_server = MagicMock() + mock_server.list_tools.return_value = [{"name": "a.b"}] + with ( + patch("bindings_entrypoint.init_observability"), + patch("bindings_entrypoint.McpServer", return_value=mock_server), + patch("time.sleep", side_effect=RuntimeError("stop_loop")), + ): + import bindings_entrypoint + + with pytest.raises(RuntimeError, match="stop_loop"): + bindings_entrypoint.main() + mock_server.list_tools.assert_called() + + +def test_bindings_entrypoint_unknown_mode_exits(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MODE", "NOT_A_MODE") + with ( + patch("bindings_entrypoint.init_observability"), + pytest.raises(SystemExit, match="Unknown MODE"), + ): + import bindings_entrypoint + + bindings_entrypoint.main() diff --git a/tests/test_factory_and_rest.py b/tests/test_factory_and_rest.py index 33c153e..37220ea 100644 --- a/tests/test_factory_and_rest.py +++ b/tests/test_factory_and_rest.py @@ -1,12 +1,22 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations +from unittest.mock import AsyncMock, MagicMock + +import jwt +import pytest from fastapi.testclient import TestClient from bindings.factory import ConnectorFactory -from bindings.rest_api.app import app +from bindings.rest_api.app import app, get_factory +from node_wire_runtime.models import ConnectorResponse, ErrorCategory -def test_factory_loads_config(): +def test_factory_loads_config(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(ConnectorFactory, "_instantiate", lambda self, cid: MagicMock()) factory = ConnectorFactory() factory.load() @@ -14,7 +24,7 @@ def test_factory_loads_config(): assert http_connector is not None stripe_rest = factory.get_for_protocol("stripe", "rest") - assert stripe_rest is None # stripe not exposed via REST per config + assert stripe_rest is not None # stripe exposed via REST def test_health_endpoint(): @@ -23,3 +33,293 @@ def test_health_endpoint(): assert resp.status_code == 200 assert resp.json() == {"status": "ok"} + +def test_agent_transport_defaults_to_stdio(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("NW_MCP_TRANSPORT", raising=False) + client = TestClient(app) + resp = client.get("/scenarios/agent-transport") + assert resp.status_code == 200 + assert resp.json() == {"transport": "stdio", "label": "stdio"} + + +def test_agent_transport_reports_streamable_http(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("NW_MCP_TRANSPORT", "streamable-http") + client = TestClient(app) + resp = client.get("/scenarios/agent-transport") + assert resp.status_code == 200 + assert resp.json() == { + "transport": "streamable-http", + "label": "Streamable HTTP", + } + + +def test_rest_post_without_auth_returns_401_when_key_required( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_REST_JWT_SECRET", raising=False) + monkeypatch.setenv("NW_REST_API_KEY", "unit-test-secret") + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={}, trace_id="t") + ) + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + r = client.post( + "/connectors/http_generic/request", json={"method": "GET", "url": "https://example.com"} + ) + finally: + app.dependency_overrides.clear() + + assert r.status_code == 401 + assert "Authentication" in r.json()["detail"] + + +def test_rest_post_with_bearer_succeeds_when_key_required(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_REST_API_KEY", "unit-test-secret") + monkeypatch.setenv("NW_RATE_LIMIT_DISABLED", "true") # Disable rate limiting for this test + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={"ok": True}, trace_id="t-rest") + ) + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + r = client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"Authorization": "Bearer unit-test-secret"}, + ) + finally: + app.dependency_overrides.clear() + + assert r.status_code == 200 + + +def test_rest_post_propagates_api_key_identity_to_connector_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_REST_JWT_SECRET", raising=False) + monkeypatch.delenv("NW_REST_API_KEY_SCOPES", raising=False) + monkeypatch.setenv("NW_REST_API_KEY", "unit-test-secret") + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={}, trace_id="t-p") + ) + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"Authorization": "Bearer unit-test-secret"}, + ) + finally: + app.dependency_overrides.clear() + + stub = mock_factory.get_for_protocol.return_value + kwargs = stub.run.await_args.kwargs + assert kwargs["principal"] == "api-key-user" + assert kwargs["tenant_id"] is None + assert kwargs["scopes"] == () + + +def test_rest_post_propagates_jwt_claims_to_connector_run(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_REST_API_KEY", raising=False) + secret = "rest-jwt-test-secret-at-least-32bytes!!" + monkeypatch.setenv("NW_REST_JWT_SECRET", secret) + + tok = jwt.encode( + {"sub": "alice", "tenant_id": "t-1", "scopes": ["mcp:test.scope"]}, + secret, + algorithm="HS256", + ) + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={}, trace_id="t-j") + ) + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"Authorization": f"Bearer {tok}"}, + ) + finally: + app.dependency_overrides.clear() + + # The connector needs to be called first to set up the mock + stub = mock_factory.get_for_protocol.return_value + assert stub.run is not None, "Connector mock was not called" + kwargs = stub.run.await_args.kwargs + assert kwargs["principal"] == "alice" + assert kwargs["tenant_id"] == "t-1" + assert kwargs["scopes"] == ("mcp:test.scope",) + + +def test_rest_not_configured_returns_503_when_no_key_and_not_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_REST_API_KEY", raising=False) + monkeypatch.delenv("NW_REST_JWT_SECRET", raising=False) + + client = TestClient(app) + r = client.post("/connectors/http_generic/request", json={}) + assert r.status_code == 503 + assert "not configured" in r.json()["detail"].lower() + + +def test_health_public_when_auth_required(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_REST_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_REST_API_KEY", "unit-test-secret") + + client = TestClient(app) + assert client.get("/health").status_code == 200 + + +def _stub_connector(response: ConnectorResponse) -> MagicMock: + c = MagicMock() + c.run = AsyncMock(return_value=response) + return c + + +def test_rest_post_connector_success() -> None: + """Dynamic POST forwards payload to connector.run and returns JSON with 200.""" + resp_body = ConnectorResponse(success=True, data={"ok": True}, trace_id="t-rest") + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector(resp_body) + + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + r = client.post( + "/connectors/http_generic/request", json={"method": "GET", "url": "https://example.com"} + ) + finally: + app.dependency_overrides.clear() + + assert r.status_code == 200 + body = r.json() + assert body["success"] is True + assert body["trace_id"] == "t-rest" + mock_factory.get_for_protocol.assert_called_with("http_generic", "rest", action="request") + stub = mock_factory.get_for_protocol.return_value + stub.run.assert_awaited_once() + call_payload = stub.run.await_args[0][0] + assert call_payload["action"] == "request" + assert call_payload["method"] == "GET" + + +def test_rest_post_connector_rejects_conflicting_action_in_body() -> None: + """Body action must match URL path segment (same as MCP tool name authority).""" + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector( + ConnectorResponse(success=True, data={}, trace_id="t") + ) + + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + r = client.post( + "/connectors/http_generic/request", + json={"action": "wrong_action", "method": "GET", "url": "https://example.com"}, + ) + finally: + app.dependency_overrides.clear() + + assert r.status_code == 400 + assert "does not match" in r.json()["detail"] + + +@pytest.mark.parametrize( + ("category", "expected_status"), + [ + (ErrorCategory.BUSINESS, 400), + (ErrorCategory.AUTH, 401), + (ErrorCategory.RETRYABLE, 503), + (ErrorCategory.FATAL, 500), + ], +) +def test_rest_post_connector_error_category_http_status( + category: ErrorCategory, expected_status: int +) -> None: + resp_body = ConnectorResponse( + success=False, + trace_id="t-err", + error_category=category, + error_code="E1", + message="nope", + ) + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector(resp_body) + + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + r = client.post( + "/connectors/http_generic/request", json={"method": "GET", "url": "https://example.com"} + ) + finally: + app.dependency_overrides.clear() + + assert r.status_code == expected_status + assert r.json()["success"] is False + assert r.json()["error_category"] == category.value + + +def test_rest_post_connector_not_available_returns_404() -> None: + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = None + + app.dependency_overrides[get_factory] = lambda: mock_factory + try: + client = TestClient(app) + r = client.post("/connectors/http_generic/request", json={}) + finally: + app.dependency_overrides.clear() + + assert r.status_code == 404 + assert r.json()["detail"] == "Connector not available for REST" + + +def test_http_status_for_category_direct() -> None: + from bindings.rest_api.app import _http_status_for_category + + assert _http_status_for_category(None) == 200 + assert _http_status_for_category(ErrorCategory.BUSINESS) == 400 + assert _http_status_for_category(ErrorCategory.AUTH) == 401 + assert _http_status_for_category(ErrorCategory.RETRYABLE) == 503 + assert _http_status_for_category(ErrorCategory.FATAL) == 500 + + +def test_factory_scope_policy_strict_mode_requires_deny_or_map( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("NW_MCP_SCOPE_POLICY_STRICT", "true") + monkeypatch.setenv("NW_MCP_SCOPE_POLICY_DEFAULT", "allow") + monkeypatch.setenv("NW_MCP_ACTION_SCOPE_MAP_JSON", "{}") + + with pytest.raises(ValueError) as exc_info: + ConnectorFactory() + assert "MCP scope policy is effectively disabled" in str(exc_info.value) + + +def test_factory_scope_policy_default_deny_without_map_enables_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_MCP_SCOPE_POLICY_STRICT", raising=False) + monkeypatch.setenv("NW_MCP_SCOPE_POLICY_DEFAULT", "deny") + monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) + + factory = ConnectorFactory() + assert factory._policy_hook is not None diff --git a/tests/test_fhir_cerner.py b/tests/test_fhir_cerner.py index a48eb72..9880066 100644 --- a/tests/test_fhir_cerner.py +++ b/tests/test_fhir_cerner.py @@ -1,25 +1,32 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest -from connectors.fhir_cerner.logic import FhirCernerConnector -from runtime import SecretProvider +from node_wire_fhir_cerner.logic import FhirCernerConnector +from node_wire_runtime import SecretProvider # --------------------------------------------------------------------------- # Shared helpers # --------------------------------------------------------------------------- + class MockSecretProvider(SecretProvider): def get_secret(self, key: str) -> str: return { "cerner_fhir_base_url": "https://fhir-myrecord.cerner.com/r4/tenant-id", - "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\\\nMEowIQ...dummy\\\\n-----END RSA PRIVATE KEY-----", + "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\nMEowIQ...dummy\\n-----END RSA PRIVATE KEY-----", "cerner_kid": "dummy-kid", "cerner_client_id": "dummy-client-id", "cerner_token_url": "https://authorization.cerner.com/tenants/tenant-id/protocols/oauth2/profiles/smart-v1/token", + "dummy_token_key": "dummy-access-token", }[key] @@ -31,43 +38,58 @@ def _token_mock() -> MagicMock: def _connector() -> FhirCernerConnector: - """Return a FhirCernerConnector with mock secrets.""" - return FhirCernerConnector(secret_provider=MockSecretProvider()) + """Return a FhirCernerConnector with a static mock token.""" + from node_wire_runtime.auth import StaticTokenAuthProvider + + sp = MockSecretProvider() + auth = StaticTokenAuthProvider( + secret_provider=sp, + secret_key="dummy_token_key", + ) + return FhirCernerConnector(secret_provider=sp, auth_provider=auth) + + +def _token_mock() -> MagicMock: + """Not used by StaticTokenAuthProvider, but kept for compatibility.""" + m = MagicMock() + m.status_code = 200 + m.json.return_value = {"access_token": "dummy-access-token"} + return m # --------------------------------------------------------------------------- -# Sanity: connector exposes all 5 actions +# Sanity: unified connector (single execute entrypoint) # --------------------------------------------------------------------------- -def test_fhir_cerner_connector_exposes_five_actions(): + +def test_fhir_cerner_connector_is_unified_execute(): c = _connector() - actions = {a.action for a in c.list_actions()} - assert actions == { - "read_patient", "search_patients", - "search_encounter", "create_document_reference", "search_document_reference", - } - for name in actions: - assert c.get_action(name) is not None + assert c.connector_id == "fhir_cerner" + assert c.action == "execute" # --------------------------------------------------------------------------- # read_patient — by ID # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_id(): - action = _connector().get_action("read_patient") - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(resource_id="12345678") + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + + params = FhirCernerPatientReadInput(action="read_patient", resource_id="12345678") patient_response = MagicMock() patient_response.status_code = 200 - patient_response.json.return_value = {"resourceType": "Patient", "id": "12345678", "name": [{"family": "Smith"}]} + patient_response.json.return_value = { + "resourceType": "Patient", + "id": "12345678", + "name": [{"family": "Smith"}], + } - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "12345678" assert result.resource["name"][0]["family"] == "Smith" @@ -77,23 +99,27 @@ async def test_fhir_cerner_read_patient_by_id(): # read_patient — by raw search_params dict (backward-compat) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_search(): - action = _connector().get_action("read_patient") - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(search_params={"family": "Smith", "given": "John"}) + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + + params = FhirCernerPatientReadInput( + action="read_patient", + search_params={"family": "Smith", "given": "John"}, + ) patient_response = MagicMock() patient_response.status_code = 200 patient_response.json.return_value = { - "resourceType": "Bundle", "total": 1, + "resourceType": "Bundle", + "total": 1, "entry": [{"resource": {"resourceType": "Patient", "id": "99887766"}}], } - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "99887766" @@ -102,28 +128,42 @@ async def test_fhir_cerner_read_patient_by_search(): # read_patient — by explicit given_name / family_name fields # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_explicit_name_fields(): - action = _connector().get_action("read_patient") - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(given_name=" Jane ", family_name="Doe", birthdate="1990-06-15") - + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + + params = FhirCernerPatientReadInput( + action="read_patient", + given_name=" Jane ", + family_name="Doe", + birthdate="1990-06-15", + ) + patient_response = MagicMock() patient_response.status_code = 200 patient_response.json.return_value = { - "resourceType": "Bundle", "total": 1, - "entry": [{"resource": {"resourceType": "Patient", "id": "55551234", "birthDate": "1990-06-15"}}], + "resourceType": "Bundle", + "total": 1, + "entry": [ + {"resource": {"resourceType": "Patient", "id": "55551234", "birthDate": "1990-06-15"}} + ], } - - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + + with ( + patch("node_wire_runtime.auth.oauth2.jwt.encode", return_value="dummy-jwt"), + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response + ) as mock_get, + ): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "55551234" call_kwargs = mock_get.call_args sent_params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) - assert sent_params.get("given") == "Jane" # whitespace stripped + assert sent_params.get("given") == "Jane" # whitespace stripped assert sent_params.get("family") == "Doe" assert sent_params.get("birthdate") == "1990-06-15" @@ -132,23 +172,30 @@ async def test_fhir_cerner_read_patient_by_explicit_name_fields(): # read_patient — by 'name' convenience field # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_name_field(): - action = _connector().get_action("read_patient") - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(name="Johnson") - + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + + params = FhirCernerPatientReadInput(action="read_patient", name="Johnson") + patient_response = MagicMock() patient_response.status_code = 200 patient_response.json.return_value = { - "resourceType": "Bundle", "total": 1, + "resourceType": "Bundle", + "total": 1, "entry": [{"resource": {"resourceType": "Patient", "id": "99990001"}}], } - - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + + with ( + patch("node_wire_runtime.auth.oauth2.jwt.encode", return_value="dummy-jwt"), + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response + ) as mock_get, + ): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "99990001" call_kwargs = mock_get.call_args @@ -160,27 +207,36 @@ async def test_fhir_cerner_read_patient_by_name_field(): # read_patient — no params raises ValueError # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_read_patient_no_params_raises(): - action = _connector().get_action("read_patient") - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput() - - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + + params = FhirCernerPatientReadInput(action="read_patient") + + with ( + patch("node_wire_runtime.auth.oauth2.jwt.encode", return_value="dummy-jwt"), + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + ): with pytest.raises(ValueError, match="Provide resource_id"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- # search_patients — multi-ID, all succeed # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_search_patients_multi_id(): - action = _connector().get_action("search_patients") - from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(resource_ids=["11111111", "22222222"]) + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientSearchInput + + params = FhirCernerPatientSearchInput( + action="search_patients", + resource_ids=["11111111", "22222222"], + ) def _patient_resp(pid: str) -> MagicMock: m = MagicMock() @@ -190,10 +246,8 @@ def _patient_resp(pid: str) -> MagicMock: responses = [_patient_resp("11111111"), _patient_resp("22222222")] - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): + result = await c.internal_execute(params, trace_id="test-trace") ids = {r["id"] for r in result.resources} assert ids == {"11111111", "22222222"} @@ -205,11 +259,16 @@ def _patient_resp(pid: str) -> MagicMock: # search_patients — multi-ID, partial failure # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_search_patients_partial_failure(): - action = _connector().get_action("search_patients") - from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(resource_ids=["99999999", "00000000"]) + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientSearchInput + + params = FhirCernerPatientSearchInput( + action="search_patients", + resource_ids=["99999999", "00000000"], + ) good_resp = MagicMock() good_resp.status_code = 200 @@ -219,10 +278,8 @@ async def test_fhir_cerner_search_patients_partial_failure(): bad_resp.status_code = 404 bad_resp.raise_for_status.side_effect = Exception("404 Not Found") - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): + result = await c.internal_execute(params, trace_id="test-trace") assert len(result.resources) == 1 assert result.resources[0]["id"] == "99999999" @@ -234,11 +291,13 @@ async def test_fhir_cerner_search_patients_partial_failure(): # search_patients — name-based search returning multiple Bundle entries # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_search_patients_by_name(): - action = _connector().get_action("search_patients") - from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(family_name="Smith") + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientSearchInput + + params = FhirCernerPatientSearchInput(action="search_patients", family_name="Smith") bundle_resp = MagicMock() bundle_resp.status_code = 200 @@ -246,15 +305,27 @@ async def test_fhir_cerner_search_patients_by_name(): "resourceType": "Bundle", "total": 2, "entry": [ - {"resource": {"resourceType": "Patient", "id": "11111111", "name": [{"family": "Smith", "given": ["Alice"]}]}}, - {"resource": {"resourceType": "Patient", "id": "22222222", "name": [{"family": "Smith", "given": ["Bob"]}]}}, + { + "resource": { + "resourceType": "Patient", + "id": "11111111", + "name": [{"family": "Smith", "given": ["Alice"]}], + } + }, + { + "resource": { + "resourceType": "Patient", + "id": "22222222", + "name": [{"family": "Smith", "given": ["Bob"]}], + } + }, ], } - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + with patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp + ) as mock_get: + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert len(result.resources) == 2 @@ -268,64 +339,85 @@ async def test_fhir_cerner_search_patients_by_name(): # search_patients — no params raises ValueError # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_search_patients_no_params_raises(): - action = _connector().get_action("search_patients") - from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput() - - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerPatientSearchInput + + params = FhirCernerPatientSearchInput(action="search_patients") + + with ( + patch("node_wire_runtime.auth.oauth2.jwt.encode", return_value="dummy-jwt"), + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + ): with pytest.raises(ValueError): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- # search_encounter # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_search_encounter(): - action = _connector().get_action("search_encounter") - from connectors.fhir_cerner.schema import FhirCernerEncounterSearchInput - params = FhirCernerEncounterSearchInput(search_params={"patient": "12345678", "status": "finished"}) + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerEncounterSearchInput + + params = FhirCernerEncounterSearchInput( + action="search_encounter", + search_params={"patient": "12345678", "status": "finished"}, + ) enc_response = MagicMock() enc_response.status_code = 200 enc_response.json.return_value = { - "resourceType": "Bundle", "total": 2, + "resourceType": "Bundle", + "total": 2, "entry": [ {"resource": {"resourceType": "Encounter", "id": "enc-1"}}, {"resource": {"resourceType": "Encounter", "id": "enc-2"}}, ], } - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert result.resources[0]["id"] == "enc-1" +@pytest.mark.asyncio +async def test_fhir_cerner_search_encounter_rejects_unscoped_query(): + """Status-only encounter search must not call the FHIR server (enterprise guard).""" + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerEncounterSearchInput + + params = FhirCernerEncounterSearchInput(action="search_encounter", status="finished") + with pytest.raises(ValueError, match="patient-scoped"): + await c.internal_execute(params, trace_id="test-trace") + + @pytest.mark.asyncio async def test_fhir_cerner_search_encounter_by_patient(): - action = _connector().get_action("search_encounter") - from connectors.fhir_cerner.schema import FhirCernerEncounterSearchInput - params = FhirCernerEncounterSearchInput(patient_id="12345678") + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerEncounterSearchInput + + params = FhirCernerEncounterSearchInput(action="search_encounter", patient_id="12345678") enc_response = MagicMock() enc_response.status_code = 200 enc_response.json.return_value = { - "resourceType": "Bundle", "total": 1, + "resourceType": "Bundle", + "total": 1, "entry": [{"resource": {"resourceType": "Encounter", "id": "enc-1"}}], } - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + with patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response + ) as mock_get: + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "enc-1" @@ -338,16 +430,32 @@ async def test_fhir_cerner_search_encounter_by_patient(): # create_document_reference # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_create_document_reference(): - action = _connector().get_action("create_document_reference") - from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput + params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", - type={"coding": [{"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"}]}, + doc_status="final", + type={ + "coding": [ + { + "system": "urn:oid:4.5.6", + "code": "18100", + "display": "Employer Group Scan", + "userSelected": True, + } + ], + "text": "Employer Group Scan", + }, subject="Patient/12724066", data="dGVzdA==", + attachment_title="Document", + author=[{"reference": "Practitioner/p1"}], context={ "encounter": [{"reference": "Encounter/enc-1"}], "period": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T01:00:00Z"}, @@ -356,45 +464,59 @@ async def test_fhir_cerner_create_document_reference(): create_response = MagicMock() create_response.status_code = 201 - create_response.headers = {"Location": "https://fhir-myrecord.cerner.com/r4/tenant-id/DocumentReference/doc-456/_history/1"} + create_response.headers = { + "Location": "https://fhir-myrecord.cerner.com/r4/tenant-id/DocumentReference/doc-456/_history/1" + } create_response.content = b"" create_response.text = "" - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: - mock_post.side_effect = [_token_mock(), create_response] - result = await action.internal_execute(params, trace_id="test-trace") + with patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=create_response + ) as mock_post: + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource_id == "doc-456" - _, kwargs = mock_post.call_args_list[1] + _, kwargs = mock_post.call_args assert kwargs["json"]["resourceType"] == "DocumentReference" assert kwargs["json"]["subject"] == {"reference": "Patient/12724066"} # Verify that charset was added to contentType - assert kwargs["json"]["content"][0]["attachment"]["contentType"] == "text/plain; charset=UTF-8" + assert kwargs["json"]["content"][0]["attachment"]["contentType"] == "text/plain;charset=utf-8" # --------------------------------------------------------------------------- # search_document_reference # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_search_document_reference(): - action = _connector().get_action("search_document_reference") - from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceSearchInput - params = FhirCernerDocumentReferenceSearchInput(search_params={"patient": "12345678"}) + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerDocumentReferenceSearchInput + + params = FhirCernerDocumentReferenceSearchInput( + action="search_document_reference", + search_params={"patient": "12345678"}, + ) search_response = MagicMock() search_response.status_code = 200 search_response.json.return_value = { - "resourceType": "Bundle", "total": 1, - "entry": [{"resource": {"resourceType": "DocumentReference", "id": "doc-789", "status": "current", - "type": {"coding": [{"system": "urn:oid:4.5.6", "code": "18100"}]}}}], + "resourceType": "Bundle", + "total": 1, + "entry": [ + { + "resource": { + "resourceType": "DocumentReference", + "id": "doc-789", + "status": "current", + "type": {"coding": [{"system": "urn:oid:4.5.6", "code": "18100"}]}, + } + } + ], } - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "doc-789" @@ -404,21 +526,129 @@ async def test_fhir_cerner_search_document_reference(): # Validation: LOINC system rejected for Cerner (context.period auto-inject) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_cerner_create_document_reference_validation(): """Verify that ValueError is raised when period is missing but encounter is present.""" - action = _connector().get_action("create_document_reference") - from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput + params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", + doc_status="final", type={"coding": [{"system": "http://loinc.org", "code": "11488-4"}]}, subject="Patient/12724066", data="dGVzdA==", + attachment_title="Doc", + author=[{"reference": "Practitioner/p1"}], context={"encounter": [{"reference": "Encounter/enc-1"}]}, ) - with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): - with pytest.raises(ValueError, match="Cerner requires the proprietary CodeSet 72 system"): - await action.internal_execute(params, trace_id="test-trace") + with pytest.raises(ValueError, match="Cerner requires the proprietary CodeSet 72 system"): + await c.internal_execute(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# Auth: malformed token URL (/hosts/) +# --------------------------------------------------------------------------- + + +class MalformedCernerTokenUrlProvider(SecretProvider): + """Token URL containing ``/hosts/`` triggers a clear configuration error.""" + + def __init__(self) -> None: + self._defaults = MockSecretProvider() + + def get_secret(self, key: str) -> str: + if key == "cerner_token_url": + return ( + "https://authorization.cerner.com/tenants/x/hosts/fhir-ehr-code.cerner.com/" + "protocols/oauth2/profiles/smart-v1/token" + ) + return self._defaults.get_secret(key) + + +@pytest.mark.asyncio +async def test_fhir_cerner_auth_rejects_malformed_token_url_with_hosts_segment() -> None: + from node_wire_runtime.auth import OAuth2AuthProvider + + sp = MalformedCernerTokenUrlProvider() + auth = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="cerner_token_url", + client_id_secret="cerner_client_id", + private_key_secret="cerner_private_key", + kid_secret="cerner_kid", + algorithm="RS384", + ) + c = FhirCernerConnector(secret_provider=sp, auth_provider=auth) + from node_wire_fhir_cerner.schema import FhirCernerPatientReadInput + + params = FhirCernerPatientReadInput(action="read_patient", resource_id="123") + with patch("node_wire_runtime.auth.oauth2.jwt.encode", return_value="dummy-assertion"): + with pytest.raises(ValueError, match="/hosts/"): + await c.internal_execute(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# create_document_reference: OperationOutcome on HTTP error +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_fhir_cerner_create_document_reference_operation_outcome_error() -> None: + c = _connector() + from node_wire_fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput + + params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", + identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], + status="current", + doc_status="final", + type={ + "coding": [ + { + "system": "urn:oid:4.5.6", + "code": "18100", + "display": "Employer Group Scan", + "userSelected": True, + } + ], + "text": "Employer Group Scan", + }, + subject="Patient/12724066", + data="dGVzdA==", + attachment_title="Document", + author=[{"reference": "Practitioner/p1"}], + context={ + "encounter": [{"reference": "Encounter/enc-1"}], + "period": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T01:00:00Z"}, + }, + ) + + post_req = httpx.Request("POST", "https://fhir.example/DocumentReference") + err_resp = httpx.Response( + 400, + request=post_req, + json={ + "resourceType": "OperationOutcome", + "issue": [ + {"severity": "error", "code": "invalid", "diagnostics": "Cerner rejected payload"} + ], + }, + ) + + async def post_side_effect(*args: object, **kwargs: object) -> httpx.Response | MagicMock: + post_side_effect.calls += 1 + if post_side_effect.calls == 1: + return _token_mock() + return err_resp + + post_side_effect.calls = 0 + + with patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=err_resp): + with pytest.raises(ValueError, match="Cerner Error:.*Cerner rejected payload"): + await c.internal_execute(params, trace_id="test-trace") diff --git a/tests/test_fhir_epic.py b/tests/test_fhir_epic.py index b076da5..5727401 100644 --- a/tests/test_fhir_epic.py +++ b/tests/test_fhir_epic.py @@ -1,17 +1,22 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest -from connectors.fhir_epic.logic import FhirEpicConnector -from runtime import SecretProvider +from node_wire_fhir_epic.logic import FhirEpicConnector +from node_wire_runtime import SecretProvider # --------------------------------------------------------------------------- # Shared helpers # --------------------------------------------------------------------------- + class MockSecretProvider(SecretProvider): def get_secret(self, key: str) -> str: return { @@ -20,6 +25,7 @@ def get_secret(self, key: str) -> str: "epic_kid": "dummy-kid", "epic_client_id": "dummy-client-id", "epic_token_url": "https://fhir.epic.com/token", + "dummy_token_key": "dummy-access-token", }[key] @@ -31,43 +37,58 @@ def _token_mock() -> MagicMock: def _connector() -> FhirEpicConnector: - """Return a FhirEpicConnector with mock secrets.""" - return FhirEpicConnector(secret_provider=MockSecretProvider()) + """Return a FhirEpicConnector with a static mock token.""" + from node_wire_runtime.auth import StaticTokenAuthProvider + + sp = MockSecretProvider() + auth = StaticTokenAuthProvider( + secret_provider=sp, + secret_key="dummy_token_key", + ) + return FhirEpicConnector(secret_provider=sp, auth_provider=auth) + + +def _token_mock() -> MagicMock: + """Not used by StaticTokenAuthProvider, but kept for compatibility if needed.""" + m = MagicMock() + m.status_code = 200 + m.json.return_value = {"access_token": "dummy-access-token"} + return m # --------------------------------------------------------------------------- -# Sanity: connector exposes all 5 actions +# Sanity: unified connector (single execute entrypoint) # --------------------------------------------------------------------------- -def test_fhir_epic_connector_exposes_five_actions(): + +def test_fhir_epic_connector_is_unified_execute(): c = _connector() - actions = {a.action for a in c.list_actions()} - assert actions == { - "read_patient", "search_patients", - "search_encounter", "create_document_reference", "search_document_reference", - } - for name in actions: - assert c.get_action(name) is not None + assert c.connector_id == "fhir_epic" + assert c.action == "execute" # --------------------------------------------------------------------------- # read_patient — by ID # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_id(): - action = _connector().get_action("read_patient") - from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(resource_id="eXYZ123") + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientReadInput + + params = FhirPatientReadInput(action="read_patient", resource_id="eXYZ123") patient_response = MagicMock() patient_response.status_code = 200 - patient_response.json.return_value = {"resourceType": "Patient", "id": "eXYZ123", "name": [{"family": "Smith"}]} + patient_response.json.return_value = { + "resourceType": "Patient", + "id": "eXYZ123", + "name": [{"family": "Smith"}], + } - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eXYZ123" assert result.resource["name"][0]["family"] == "Smith" @@ -77,23 +98,27 @@ async def test_fhir_epic_read_patient_by_id(): # read_patient — by raw search_params dict (backward-compat) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_search(): - action = _connector().get_action("read_patient") - from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(search_params={"family": "Smith", "given": "John"}) + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientReadInput + + params = FhirPatientReadInput( + action="read_patient", + search_params={"family": "Smith", "given": "John"}, + ) patient_response = MagicMock() patient_response.status_code = 200 patient_response.json.return_value = { - "resourceType": "Bundle", "total": 1, + "resourceType": "Bundle", + "total": 1, "entry": [{"resource": {"resourceType": "Patient", "id": "eABC"}}], } - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eABC" @@ -102,23 +127,36 @@ async def test_fhir_epic_read_patient_by_search(): # read_patient — by explicit given_name / family_name fields # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_explicit_name_fields(): - action = _connector().get_action("read_patient") - from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(given_name=" John ", family_name="Smith", birthdate="1980-01-01") - + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientReadInput + + params = FhirPatientReadInput( + action="read_patient", + given_name=" John ", + family_name="Smith", + birthdate="1980-01-01", + ) + patient_response = MagicMock() patient_response.status_code = 200 patient_response.json.return_value = { - "resourceType": "Bundle", "total": 1, - "entry": [{"resource": {"resourceType": "Patient", "id": "eDEF", "birthDate": "1980-01-01"}}], + "resourceType": "Bundle", + "total": 1, + "entry": [ + {"resource": {"resourceType": "Patient", "id": "eDEF", "birthDate": "1980-01-01"}} + ], } - - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + + with ( + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response + ) as mock_get, + ): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eDEF" # Verify the correct FHIR params were built (stripped whitespace) @@ -133,23 +171,29 @@ async def test_fhir_epic_read_patient_by_explicit_name_fields(): # read_patient — by 'name' convenience field # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_name_field(): - action = _connector().get_action("read_patient") - from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(name="Johnson") - + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientReadInput + + params = FhirPatientReadInput(action="read_patient", name="Johnson") + patient_response = MagicMock() patient_response.status_code = 200 patient_response.json.return_value = { - "resourceType": "Bundle", "total": 1, + "resourceType": "Bundle", + "total": 1, "entry": [{"resource": {"resourceType": "Patient", "id": "eGHI"}}], } - - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + + with ( + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response + ) as mock_get, + ): + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eGHI" call_kwargs = mock_get.call_args @@ -161,27 +205,30 @@ async def test_fhir_epic_read_patient_by_name_field(): # read_patient — no params raises ValueError # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_read_patient_no_params_raises(): - action = _connector().get_action("read_patient") - from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput() # nothing provided - - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientReadInput + + params = FhirPatientReadInput(action="read_patient") + + with patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Provide resource_id"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- # search_patients — multi-ID, all succeed # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_search_patients_multi_id(): - action = _connector().get_action("search_patients") - from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(resource_ids=["eABC", "eDEF"]) + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientSearchInput + + params = FhirPatientSearchInput(action="search_patients", resource_ids=["eABC", "eDEF"]) def _patient_resp(pid: str) -> MagicMock: m = MagicMock() @@ -191,10 +238,8 @@ def _patient_resp(pid: str) -> MagicMock: responses = [_patient_resp("eABC"), _patient_resp("eDEF")] - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): + result = await c.internal_execute(params, trace_id="test-trace") ids = {r["id"] for r in result.resources} assert ids == {"eABC", "eDEF"} @@ -206,11 +251,13 @@ def _patient_resp(pid: str) -> MagicMock: # search_patients — multi-ID, partial failure # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_search_patients_partial_failure(): - action = _connector().get_action("search_patients") - from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(resource_ids=["eGOOD", "eBAD"]) + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientSearchInput + + params = FhirPatientSearchInput(action="search_patients", resource_ids=["eGOOD", "eBAD"]) good_resp = MagicMock() good_resp.status_code = 200 @@ -220,10 +267,8 @@ async def test_fhir_epic_search_patients_partial_failure(): bad_resp.status_code = 404 bad_resp.raise_for_status.side_effect = Exception("404 Not Found") - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): + result = await c.internal_execute(params, trace_id="test-trace") assert len(result.resources) == 1 assert result.resources[0]["id"] == "eGOOD" @@ -235,11 +280,13 @@ async def test_fhir_epic_search_patients_partial_failure(): # search_patients — name-based search returning multiple entries # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_search_patients_by_name(): - action = _connector().get_action("search_patients") - from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(family_name="Smith") + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientSearchInput + + params = FhirPatientSearchInput(action="search_patients", family_name="Smith") bundle_resp = MagicMock() bundle_resp.status_code = 200 @@ -247,15 +294,30 @@ async def test_fhir_epic_search_patients_by_name(): "resourceType": "Bundle", "total": 2, "entry": [ - {"resource": {"resourceType": "Patient", "id": "e001", "name": [{"family": "Smith", "given": ["Alice"]}]}}, - {"resource": {"resourceType": "Patient", "id": "e002", "name": [{"family": "Smith", "given": ["Bob"]}]}}, + { + "resource": { + "resourceType": "Patient", + "id": "e001", + "name": [{"family": "Smith", "given": ["Alice"]}], + } + }, + { + "resource": { + "resourceType": "Patient", + "id": "e002", + "name": [{"family": "Smith", "given": ["Bob"]}], + } + }, ], } - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + with ( + patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), + patch( + "httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp + ) as mock_get, + ): + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert len(result.resources) == 2 @@ -270,42 +332,47 @@ async def test_fhir_epic_search_patients_by_name(): # search_patients — no params raises ValueError # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_search_patients_no_params_raises(): - action = _connector().get_action("search_patients") - from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput() - - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): + c = _connector() + from node_wire_fhir_epic.schema import FhirPatientSearchInput + + params = FhirPatientSearchInput(action="search_patients") + + with patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- # search_encounter # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_search_encounter(): - action = _connector().get_action("search_encounter") - from connectors.fhir_epic.schema import FhirEncounterSearchInput - params = FhirEncounterSearchInput(search_params={"patient": "eXYZ123", "status": "finished"}) + c = _connector() + from node_wire_fhir_epic.schema import FhirEncounterSearchInput + + params = FhirEncounterSearchInput( + action="search_encounter", + search_params={"patient": "eXYZ123", "status": "finished"}, + ) enc_response = MagicMock() enc_response.status_code = 200 enc_response.json.return_value = { - "resourceType": "Bundle", "total": 2, + "resourceType": "Bundle", + "total": 2, "entry": [ {"resource": {"resourceType": "Encounter", "id": "enc-1"}}, {"resource": {"resourceType": "Encounter", "id": "enc-2"}}, ], } - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert result.resources[0]["id"] == "enc-1" @@ -315,14 +382,21 @@ async def test_fhir_epic_search_encounter(): # create_document_reference # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_create_document_reference(): - action = _connector().get_action("create_document_reference") - from connectors.fhir_epic.schema import FhirDocumentReferenceCreateInput + c = _connector() + from node_wire_fhir_epic.schema import FhirDocumentReferenceCreateInput + params = FhirDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", - type={"coding": [{"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"}]}, + type={ + "coding": [ + {"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"} + ] + }, subject="Patient/ePD0eeFq.GMHG.aXttqP.Lw3", data="dGVzdA==", context={"related": [{"reference": "Group/eqv3buSV"}]}, @@ -330,17 +404,19 @@ async def test_fhir_epic_create_document_reference(): create_response = MagicMock() create_response.status_code = 201 - create_response.headers = {"Location": "https://fhir.epic.com/api/FHIR/R4/DocumentReference/doc-456/_history/1"} + create_response.headers = { + "Location": "https://fhir.epic.com/api/FHIR/R4/DocumentReference/doc-456/_history/1" + } create_response.content = b"" create_response.text = "" - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: - mock_post.side_effect = [_token_mock(), create_response] - result = await action.internal_execute(params, trace_id="test-trace") + with patch( + "httpx.AsyncClient.post", new_callable=AsyncMock, return_value=create_response + ) as mock_post: + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource_id == "doc-456" - _, kwargs = mock_post.call_args_list[1] + _, kwargs = mock_post.call_args assert kwargs["json"]["resourceType"] == "DocumentReference" assert kwargs["json"]["subject"] == {"reference": "Patient/ePD0eeFq.GMHG.aXttqP.Lw3"} @@ -349,24 +425,36 @@ async def test_fhir_epic_create_document_reference(): # search_document_reference # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_fhir_epic_search_document_reference(): - action = _connector().get_action("search_document_reference") - from connectors.fhir_epic.schema import FhirDocumentReferenceSearchInput - params = FhirDocumentReferenceSearchInput(search_params={"patient": "eXYZ123"}) + c = _connector() + from node_wire_fhir_epic.schema import FhirDocumentReferenceSearchInput + + params = FhirDocumentReferenceSearchInput( + action="search_document_reference", + search_params={"patient": "eXYZ123"}, + ) search_response = MagicMock() search_response.status_code = 200 search_response.json.return_value = { - "resourceType": "Bundle", "total": 1, - "entry": [{"resource": {"resourceType": "DocumentReference", "id": "doc-789", "status": "current", - "type": {"coding": [{"system": "urn:oid:4.5.6", "code": "18100"}]}}}], + "resourceType": "Bundle", + "total": 1, + "entry": [ + { + "resource": { + "resourceType": "DocumentReference", + "id": "doc-789", + "status": "current", + "type": {"coding": [{"system": "urn:oid:4.5.6", "code": "18100"}]}, + } + } + ], } - with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ - patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ - patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): - result = await action.internal_execute(params, trace_id="test-trace") + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "doc-789" diff --git a/tests/test_google_drive.py b/tests/test_google_drive.py index 286d7a2..699c89c 100644 --- a/tests/test_google_drive.py +++ b/tests/test_google_drive.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# from __future__ import annotations import asyncio @@ -7,15 +11,18 @@ import pytest from pydantic import ValidationError -from connectors.google_drive.exceptions import ( +from node_wire_google_drive.exceptions import ( GoogleDriveAuthError, GoogleDriveBusinessError, GoogleDriveFatalError, GoogleDriveRateLimitError, ) -from connectors.google_drive.logic import DEFAULT_LIST_FIELDS, GoogleDriveConnector -from connectors.google_drive.schema import GoogleDriveOperationInput, GoogleDriveOperationOutput -from runtime import SecretProvider +from node_wire_google_drive.logic import DEFAULT_LIST_FIELDS, GoogleDriveConnector +from node_wire_google_drive.schema import ( + FilesUploadOperation, + GoogleDriveOperationInput, +) +from node_wire_runtime import SecretProvider class MockSecretProvider(SecretProvider): @@ -34,11 +41,36 @@ def __init__(self, status: int, *, content: str = "", reason: str = "") -> None: def _connector() -> GoogleDriveConnector: - return GoogleDriveConnector( - input_model=GoogleDriveOperationInput, - output_model=GoogleDriveOperationOutput, - secret_provider=MockSecretProvider(), + return GoogleDriveConnector(secret_provider=MockSecretProvider()) + + +def test_files_upload_operation_requires_exactly_one_body_source() -> None: + FilesUploadOperation.model_validate( + { + "action": "files.upload", + "name": "a.txt", + "mime_type": "text/plain", + "content": "hello", + } ) + with pytest.raises(ValidationError): + FilesUploadOperation.model_validate( + { + "action": "files.upload", + "name": "a.txt", + "mime_type": "text/plain", + } + ) + with pytest.raises(ValidationError): + FilesUploadOperation.model_validate( + { + "action": "files.upload", + "name": "a.txt", + "mime_type": "text/plain", + "content": "a", + "content_base64": "Zg==", + } + ) def test_google_drive_internal_execute_files_list_happy_path(): @@ -50,7 +82,7 @@ def test_google_drive_internal_execute_files_list_happy_path(): list_call = files_api.list.return_value list_call.execute.return_value = {"files": [{"id": "f-1", "name": "Report"}]} - with patch.object(connector, "_build_client", return_value=drive): + with patch.object(connector, "get_client", return_value=drive): result = asyncio.run(connector.internal_execute(params, trace_id="test-trace")) assert result.raw == {"files": [{"id": "f-1", "name": "Report"}]} @@ -59,6 +91,7 @@ def test_google_drive_internal_execute_files_list_happy_path(): pageSize=5, q=None, fields=DEFAULT_LIST_FIELDS, + pageToken=None, supportsAllDrives=True, includeItemsFromAllDrives=True, ) diff --git a/tests/test_google_drive_action_spec.py b/tests/test_google_drive_action_spec.py new file mode 100644 index 0000000..e498159 --- /dev/null +++ b/tests/test_google_drive_action_spec.py @@ -0,0 +1,136 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for Google Drive action specs and SDK call mapping.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +from node_wire_google_drive.action_spec import GOOGLE_DRIVE_ACTION_SPECS +from node_wire_google_drive.logic import GoogleDriveConnector +from node_wire_google_drive.schema import GoogleDriveOperationInput +from node_wire_runtime import SecretProvider + + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "GOOGLE_DRIVE_SA_JSON": '{"type":"service_account","project_id":"dummy"}', + }[key] + + +def _connector() -> GoogleDriveConnector: + return GoogleDriveConnector(secret_provider=MockSecretProvider()) + + +def test_action_spec_registry_covers_all_nw_actions(): + """Every @nw_action on GoogleDriveConnector must have a spec entry.""" + metas = GoogleDriveConnector.nw_action_metas() + for action_name in metas: + assert action_name in GOOGLE_DRIVE_ACTION_SPECS, f"missing spec for {action_name}" + + +def test_files_create_maps_body_and_constants(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "files.create", + "name": "doc.txt", + "mime_type": "text/plain", + "parents": ["p1"], + } + ) + + drive = MagicMock() + files_api = drive.files.return_value + create_call = files_api.create.return_value + create_call.execute.return_value = {"id": "new-id", "name": "doc.txt"} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"id": "new-id", "name": "doc.txt"} + files_api.create.assert_called_once_with( + body={"name": "doc.txt", "mimeType": "text/plain", "parents": ["p1"]}, + fields="id, name, webViewLink", + supportsAllDrives=True, + ) + + +def test_files_delete_returns_synthetic_raw(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + {"action": "files.delete", "file_id": "fid-99"} + ) + + drive = MagicMock() + files_api = drive.files.return_value + upd = files_api.update.return_value + upd.execute.return_value = {"id": "fid-99", "trashed": True} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"file_id": "fid-99", "status": "deleted"} + files_api.update.assert_called_once_with( + fileId="fid-99", + body={"trashed": True}, + supportsAllDrives=True, + ) + + +def test_permissions_create_maps_body(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "permissions.create", + "file_id": "f1", + "role": "reader", + "type": "user", + "email_address": "a@b.com", + } + ) + + drive = MagicMock() + perms = drive.permissions.return_value + perms.create.return_value.execute.return_value = {"id": "perm-1"} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"id": "perm-1"} + perms.create.assert_called_once_with( + fileId="f1", + body={"role": "reader", "type": "user", "emailAddress": "a@b.com"}, + supportsAllDrives=True, + ) + + +def test_permissions_create_excludes_empty_optional_fields(): + """Empty-string email_address and domain must be excluded from the body (not sent as "").""" + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "permissions.create", + "file_id": "file-abc", + "role": "reader", + "type": "anyone", + "email_address": "", + "domain": "", + } + ) + + drive = MagicMock() + perms = drive.permissions.return_value + perms.create.return_value.execute.return_value = {"kind": "drive#permission"} + + with patch.object(connector, "get_client", return_value=drive): + asyncio.run(connector.internal_execute(params, trace_id="t-empty")) + + _, kwargs = perms.create.call_args + body = kwargs["body"] + assert "emailAddress" not in body + assert "domain" not in body diff --git a/tests/test_llm_providers.py b/tests/test_llm_providers.py new file mode 100644 index 0000000..5100332 --- /dev/null +++ b/tests/test_llm_providers.py @@ -0,0 +1,361 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""LLM factory and provider unit tests (SDKs mocked).""" + +from __future__ import annotations + +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from agents.llm_factory import LLMMessage, LLMProviderFactory, ToolCall + + +def test_llm_factory_create_gemini_and_anthropic() -> None: + with patch("agents.providers.gemini_provider.genai") as genai_mod: + genai_mod.configure = MagicMock() + genai_mod.GenerativeModel = MagicMock() + p = LLMProviderFactory.create("gemini", api_key="k", model="gemini-2.0-flash") + from agents.providers.gemini_provider import GeminiProvider + + assert isinstance(p, GeminiProvider) + + with patch("agents.providers.anthropic_provider.anthropic") as anth_mod: + anth_mod.Anthropic = MagicMock() + p2 = LLMProviderFactory.create("anthropic", api_key="k", model="claude-3-5-haiku-20241022") + from agents.providers.anthropic_provider import AnthropicProvider + + assert isinstance(p2, AnthropicProvider) + + +@pytest.mark.parametrize( + ("env_provider", "env_key", "env_val", "model_env", "model_val", "expected_cls_path"), + [ + ( + "groq", + "GROQ_API_KEY", + "gk", + "GROQ_MODEL", + "llama-x", + "agents.providers.groq_provider.GroqProvider", + ), + ( + "openai", + "OPENAI_API_KEY", + "ok", + "OPENAI_MODEL", + "gpt-x", + "agents.providers.openai_provider.OpenAIProvider", + ), + ( + "gemini", + "GEMINI_API_KEY", + "gk2", + "GEMINI_MODEL", + "gem-x", + "agents.providers.gemini_provider.GeminiProvider", + ), + ( + "anthropic", + "ANTHROPIC_API_KEY", + "ak", + "ANTHROPIC_MODEL", + "claude-x", + "agents.providers.anthropic_provider.AnthropicProvider", + ), + ], +) +def test_llm_factory_create_from_env( + monkeypatch: pytest.MonkeyPatch, + env_provider: str, + env_key: str, + env_val: str, + model_env: str, + model_val: str, + expected_cls_path: str, +) -> None: + monkeypatch.setenv("LLM_PROVIDER", env_provider) + monkeypatch.setenv(env_key, env_val) + monkeypatch.setenv(model_env, model_val) + if "gemini" in expected_cls_path: + with patch("agents.providers.gemini_provider.genai") as g: + g.configure = MagicMock() + g.GenerativeModel = MagicMock() + provider = LLMProviderFactory.create_from_env() + from agents.providers.gemini_provider import GeminiProvider + + assert isinstance(provider, GeminiProvider) + elif "anthropic" in expected_cls_path: + with patch("agents.providers.anthropic_provider.anthropic") as a: + a.Anthropic = MagicMock() + provider = LLMProviderFactory.create_from_env() + from agents.providers.anthropic_provider import AnthropicProvider + + assert isinstance(provider, AnthropicProvider) + elif "groq" in expected_cls_path: + with patch("agents.providers.groq_provider.Groq"): + provider = LLMProviderFactory.create_from_env() + from agents.providers.groq_provider import GroqProvider + + assert isinstance(provider, GroqProvider) + else: + with patch("agents.providers.openai_provider.OpenAI"): + provider = LLMProviderFactory.create_from_env() + from agents.providers.openai_provider import OpenAIProvider + + assert isinstance(provider, OpenAIProvider) + + +def _openai_style_response(content: str | None, tool_calls: list | None) -> MagicMock: + msg = MagicMock() + msg.content = content + msg.tool_calls = tool_calls or [] + choice = MagicMock() + choice.message = msg + resp = MagicMock() + resp.choices = [choice] + return resp + + +def test_groq_provider_chat_with_tools_and_bad_json_args() -> None: + tc_ok = MagicMock() + tc_ok.id = "c1" + tc_ok.function.name = "a.b" + tc_ok.function.arguments = '{"x": 1}' + tc_bad = MagicMock() + tc_bad.id = "c2" + tc_bad.function.name = "a.c" + tc_bad.function.arguments = "not-json{" + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _openai_style_response(None, [tc_ok, tc_bad]) + with patch("agents.providers.groq_provider.Groq", return_value=mock_client): + from agents.providers.groq_provider import GroqProvider + + p = GroqProvider(api_key="k", model="m") + msgs = [ + LLMMessage(role="user", content="hi"), + LLMMessage( + role="assistant", + content=None, + tool_calls=[ToolCall(id="t1", name="a.b", arguments={"q": 1})], + ), + LLMMessage(role="tool", content="{}", tool_call_id="t1", name="a.b"), + ] + tools = [{"name": "a.b", "description": "d", "input_schema": {"type": "object"}}] + out = p.chat_with_tools(msgs, tools) + assert len(out.tool_calls) == 2 + assert out.tool_calls[0].arguments == {"x": 1} + assert out.tool_calls[1].arguments == {} + + +def test_openai_provider_chat_with_tools() -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _openai_style_response("done", []) + with patch("agents.providers.openai_provider.OpenAI", return_value=mock_client): + from agents.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="k", model="m") + out = p.chat_with_tools([LLMMessage(role="user", content="hello")], []) + assert out.content == "done" + assert out.tool_calls == [] + + +def test_openai_provider_chat_with_tools_bad_json_args() -> None: + """Invalid JSON in tool arguments becomes empty dict (parity with Groq).""" + tc_ok = MagicMock() + tc_ok.id = "c1" + tc_ok.function.name = "a.b" + tc_ok.function.arguments = '{"x": 1}' + tc_bad = MagicMock() + tc_bad.id = "c2" + tc_bad.function.name = "a.c" + tc_bad.function.arguments = "not-json{" + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _openai_style_response(None, [tc_ok, tc_bad]) + with patch("agents.providers.openai_provider.OpenAI", return_value=mock_client): + from agents.providers.openai_provider import OpenAIProvider + + p = OpenAIProvider(api_key="k", model="m") + out = p.chat_with_tools( + [LLMMessage(role="user", content="hi")], + [{"name": "a.b", "description": "d", "input_schema": {"type": "object"}}], + ) + assert len(out.tool_calls) == 2 + assert out.tool_calls[0].arguments == {"x": 1} + assert out.tool_calls[1].arguments == {} + + +def test_anthropic_provider_chat_with_tools() -> None: + block_tu = MagicMock() + block_tu.type = "tool_use" + block_tu.id = "tu1" + block_tu.name = "fhir_cerner.read_patient" + block_tu.input = {"resource_id": "1"} + block_txt = MagicMock() + block_txt.type = "text" + block_txt.text = "ok" + resp = MagicMock() + resp.content = [block_tu] + mock_client = MagicMock() + mock_client.messages.create.return_value = resp + with patch("agents.providers.anthropic_provider.anthropic") as anth_mod: + anth_mod.Anthropic = MagicMock(return_value=mock_client) + from agents.providers.anthropic_provider import AnthropicProvider + + p = AnthropicProvider(api_key="k", model="m") + out = p.chat_with_tools( + [LLMMessage(role="user", content="x")], + [ + { + "name": "fhir_cerner.read_patient", + "description": "d", + "input_schema": {"type": "object"}, + } + ], + ) + assert len(out.tool_calls) == 1 + assert out.tool_calls[0].name == "fhir_cerner.read_patient" + assert out.tool_calls[0].arguments == {"resource_id": "1"} + + +def test_anthropic_provider_tool_use_non_dict_input_becomes_empty_args() -> None: + block_tu = MagicMock() + block_tu.type = "tool_use" + block_tu.id = "tu1" + block_tu.name = "a.b" + block_tu.input = "not-a-dict" + resp = MagicMock() + resp.content = [block_tu] + mock_client = MagicMock() + mock_client.messages.create.return_value = resp + with patch("agents.providers.anthropic_provider.anthropic") as anth_mod: + anth_mod.Anthropic = MagicMock(return_value=mock_client) + from agents.providers.anthropic_provider import AnthropicProvider + + p = AnthropicProvider(api_key="k", model="m") + out = p.chat_with_tools([LLMMessage(role="user", content="x")], []) + assert len(out.tool_calls) == 1 + assert out.tool_calls[0].arguments == {} + + +def test_anthropic_provider_mixed_text_and_tool_use() -> None: + block_txt = MagicMock() + block_txt.type = "text" + block_txt.text = "Planning" + block_tu = MagicMock() + block_tu.type = "tool_use" + block_tu.id = "tu1" + block_tu.name = "a.b" + block_tu.input = {"q": 1} + resp = MagicMock() + resp.content = [block_txt, block_tu] + mock_client = MagicMock() + mock_client.messages.create.return_value = resp + with patch("agents.providers.anthropic_provider.anthropic") as anth_mod: + anth_mod.Anthropic = MagicMock(return_value=mock_client) + from agents.providers.anthropic_provider import AnthropicProvider + + p = AnthropicProvider(api_key="k", model="m") + out = p.chat_with_tools([LLMMessage(role="user", content="go")], []) + assert out.content == "Planning" + assert len(out.tool_calls) == 1 + assert out.tool_calls[0].arguments == {"q": 1} + + +def test_mcp_schema_to_gemini_strips_unknown_keys() -> None: + from agents.providers.gemini_provider import _mcp_schema_to_gemini + + raw = { + "type": "object", + "title": "Root", + "properties": { + "a": {"type": "string", "x-extra": 1}, + }, + "additionalProperties": False, + } + cleaned = _mcp_schema_to_gemini(raw) + assert "title" not in cleaned + assert "additionalProperties" not in cleaned + assert "x-extra" not in cleaned["properties"]["a"] + assert cleaned["properties"]["a"] == {"type": "string"} + + +def test_gemini_provider_chat_with_tools() -> None: + """Inject stub ``google.generativeai.types`` so chat_with_tools can import it.""" + genai_types = types.ModuleType("google.generativeai.types") + genai_types.FunctionDeclaration = MagicMock + genai_types.Tool = MagicMock + sys.modules["google.generativeai.types"] = genai_types + + part_fc = MagicMock() + part_fc.function_call.name = "google_drive.files.upload" + part_fc.function_call.args = {"name": "f.txt", "mime_type": "text/plain"} + type(part_fc).text = property(lambda self: None) + mock_resp = MagicMock() + mock_resp.parts = [part_fc] + mock_chat = MagicMock() + mock_chat.send_message.return_value = mock_resp + mock_model = MagicMock() + mock_model.start_chat.return_value = mock_chat + try: + with patch("agents.providers.gemini_provider.genai") as genai_mod: + genai_mod.configure = MagicMock() + genai_mod.GenerativeModel.return_value = mock_model + genai_mod.protos = MagicMock() + genai_mod.protos.Part = MagicMock(return_value=MagicMock()) + genai_mod.protos.FunctionCall = MagicMock() + genai_mod.protos.FunctionResponse = MagicMock() + from agents.providers.gemini_provider import GeminiProvider + + p = GeminiProvider(api_key="k", model="gemini-2.0-flash") + out = p.chat_with_tools( + [LLMMessage(role="user", content="upload")], + [ + { + "name": "google_drive.files.upload", + "description": "d", + "input_schema": {"type": "object"}, + } + ], + ) + assert len(out.tool_calls) == 1 + assert out.tool_calls[0].name == "google_drive.files.upload" + finally: + sys.modules.pop("google.generativeai.types", None) + + +def test_gemini_provider_text_response_without_tool_calls() -> None: + genai_types = types.ModuleType("google.generativeai.types") + genai_types.FunctionDeclaration = MagicMock + genai_types.Tool = MagicMock + sys.modules["google.generativeai.types"] = genai_types + + part_txt = MagicMock() + part_txt.function_call.name = None + part_txt.text = "Hello from Gemini" + mock_resp = MagicMock() + mock_resp.parts = [part_txt] + mock_chat = MagicMock() + mock_chat.send_message.return_value = mock_resp + mock_model = MagicMock() + mock_model.start_chat.return_value = mock_chat + try: + with patch("agents.providers.gemini_provider.genai") as genai_mod: + genai_mod.configure = MagicMock() + genai_mod.GenerativeModel.return_value = mock_model + genai_mod.protos = MagicMock() + genai_mod.protos.Part = MagicMock(return_value=MagicMock()) + genai_mod.protos.FunctionCall = MagicMock() + genai_mod.protos.FunctionResponse = MagicMock() + from agents.providers.gemini_provider import GeminiProvider + + p = GeminiProvider(api_key="k", model="gemini-2.0-flash") + out = p.chat_with_tools([LLMMessage(role="user", content="hi")], []) + assert out.content == "Hello from Gemini" + assert out.tool_calls == [] + finally: + sys.modules.pop("google.generativeai.types", None) diff --git a/tests/test_mcp_auth.py b/tests/test_mcp_auth.py new file mode 100644 index 0000000..65a3f80 --- /dev/null +++ b/tests/test_mcp_auth.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager + +import jwt +import pytest +from fastapi.testclient import TestClient +from starlette.responses import JSONResponse + +from bindings.mcp_server.auth import ( + McpAuthInvalidError, + McpAuthRequiredError, + authenticate_mcp_request, +) +from bindings.mcp_server.server import McpServer + + +@pytest.fixture(autouse=True) +def _mcp_auth_clear_allowlist_from_host_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Pin allowlist + scope defaults: host ``.env`` or deny-default leaks empty API-key scopes and filters all tools.""" + monkeypatch.setenv( + "NW_ALLOWED_CONNECTORS", + "http_generic,smtp,stripe,google_drive,fhir_epic,fhir_cerner", + ) + monkeypatch.setenv("NW_MCP_SCOPE_POLICY_DEFAULT", "allow") + monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY_SCOPES", raising=False) + + +def test_mcp_auth_missing_token_returns_401(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthRequiredError) as exc_info: + authenticate_mcp_request() + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Authentication required" + + +def test_mcp_auth_invalid_token_returns_403(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthInvalidError) as exc_info: + authenticate_mcp_request(meta={"token": "wrong-secret"}) + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Invalid API key or token" + + +def test_mcp_auth_valid_token_allows_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + identity = authenticate_mcp_request(meta={"token": "unit-test-secret"}) + assert identity is not None + + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert any(t["name"] == "smtp.send_email" for t in tools) + + +@pytest.mark.asyncio +async def test_mcp_authz_denies_tool_without_scope(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.setenv( + "NW_MCP_ACTION_SCOPE_MAP_JSON", + '{"smtp.send_email":"mcp:smtp.send_email"}', + ) + + token = jwt.encode( + {"sub": "alice", "tenant_id": "tenant-a", "scopes": ["mcp:other.scope"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + assert identity is not None + + server = McpServer(connector_ids=["smtp"]) + resp = await server.invoke_tool( + "smtp.send_email", + { + "from_email": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "x", + "body": "y", + }, + identity=identity, + ) + + assert resp["success"] is False + assert resp["error_code"] == "POLICY_DENIED" + assert resp["message"] == "Missing required scope: mcp:smtp.send_email" + + +@pytest.mark.asyncio +async def test_mcp_execution_passes_principal_and_tenant( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) + + token = jwt.encode( + {"sub": "service-account", "tenant_id": "tenant-42", "scopes": ["*"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + assert identity is not None + + server = McpServer(connector_ids=["smtp"]) + smtp = server._factory.get_for_protocol("smtp", "mcp") + assert smtp is not None + + captured: dict[str, object] = {} + + async def fake_run(raw_input, *, principal=None, tenant_id=None, scopes=None): + captured["payload"] = dict(raw_input) + captured["principal"] = principal + captured["tenant_id"] = tenant_id + captured["scopes"] = tuple(scopes or ()) + from node_wire_runtime.models import ConnectorResponse + + return ConnectorResponse(success=True, data={"ok": True}, trace_id="trace-test") + + orig_run = smtp.run + try: + smtp.run = fake_run + await server.invoke_tool( + "smtp.send_email", + { + "from_email": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "x", + "body": "y", + }, + identity=identity, + ) + finally: + smtp.run = orig_run + + assert captured["principal"] == "service-account" + assert captured["tenant_id"] == "tenant-42" + assert captured["scopes"] == ("*",) + + +def test_mcp_api_key_scopes_filter_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.setenv( + "NW_MCP_ACTION_SCOPE_MAP_JSON", + '{"smtp.send_email":"mcp:smtp.send_email"}', + ) + monkeypatch.setenv("NW_MCP_API_KEY_SCOPES", "mcp:other.scope") + + identity = authenticate_mcp_request(meta={"token": "unit-test-secret"}) + assert identity is not None + assert identity.scopes == ("mcp:other.scope",) + + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert not any(t["name"] == "smtp.send_email" for t in tools) + + +def test_mcp_jwt_scopes_filter_tools_list(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.setenv( + "NW_MCP_ACTION_SCOPE_MAP_JSON", + '{"smtp.send_email":"mcp:smtp.send_email"}', + ) + + token = jwt.encode( + {"sub": "alice", "scopes": ["mcp:other.scope"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert not any(t["name"] == "smtp.send_email" for t in tools) + + +@pytest.mark.asyncio +async def test_mcp_default_deny_fallback_scope_invokes_tool( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) + monkeypatch.setenv("NW_MCP_SCOPE_POLICY_DEFAULT", "deny") + + token = jwt.encode( + {"sub": "bob", "scopes": ["mcp:smtp.send_email"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert any(t["name"] == "smtp.send_email" for t in tools) + + smtp = server._factory.get_for_protocol("smtp", "mcp") + assert smtp is not None + + async def fake_run(raw_input, *, principal=None, tenant_id=None, scopes=None): + from node_wire_runtime.models import ConnectorResponse + + assert scopes == ("mcp:smtp.send_email",) + return ConnectorResponse(success=True, data={"ok": True}, trace_id="trace-test") + + orig_run = smtp.run + try: + smtp.run = fake_run + resp = await server.invoke_tool( + "smtp.send_email", + { + "from_email": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "x", + "body": "y", + }, + identity=identity, + ) + finally: + smtp.run = orig_run + + assert resp["success"] is True + + +@pytest.mark.asyncio +async def test_mcp_default_deny_denies_without_fallback_scope( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_API_KEY", raising=False) + monkeypatch.setenv("NW_MCP_JWT_SECRET", "jwt-secret") + monkeypatch.delenv("NW_MCP_ACTION_SCOPE_MAP_JSON", raising=False) + monkeypatch.setenv("NW_MCP_SCOPE_POLICY_DEFAULT", "deny") + + token = jwt.encode( + {"sub": "bob", "scopes": ["mcp:wrong.scope"]}, + "jwt-secret", + algorithm="HS256", + ) + identity = authenticate_mcp_request(meta={"authorization": f"Bearer {token}"}) + + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert not any(t["name"] == "smtp.send_email" for t in tools) + + resp = await server.invoke_tool( + "smtp.send_email", + { + "from_email": "sender@example.com", + "to": ["recipient@example.com"], + "subject": "x", + "body": "y", + }, + identity=identity, + ) + assert resp["success"] is False + assert resp["error_code"] == "POLICY_DENIED" + + +def test_mcp_api_key_explicit_star_scope_lists_tool(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.setenv( + "NW_MCP_ACTION_SCOPE_MAP_JSON", + '{"smtp.send_email":"mcp:smtp.send_email"}', + ) + monkeypatch.setenv("NW_MCP_API_KEY_SCOPES", "*") + + identity = authenticate_mcp_request(meta={"token": "unit-test-secret"}) + server = McpServer(connector_ids=["smtp"]) + tools = server.list_tools(identity=identity) + assert any(t["name"] == "smtp.send_email" for t in tools) + + +class _FakeStreamableSessionManager: + @asynccontextmanager + async def run(self): + yield + + async def handle_request(self, scope, receive, send): + response = JSONResponse({"ok": True}) + await response(scope, receive, send) + + +def test_streamable_http_edge_auth_rejects_missing_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(connector_ids=["smtp"]) + app = server._build_streamable_http_app( + session_manager=_FakeStreamableSessionManager(), + path="/mcp", + ) + client = TestClient(app) + response = client.post("/mcp", json={"jsonrpc": "2.0", "id": "1", "method": "tools/list"}) + + assert response.status_code == 401 + assert response.json()["error_code"] == "MCP_AUTH_REQUIRED" + + +def test_streamable_http_edge_auth_rejects_invalid_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(connector_ids=["smtp"]) + app = server._build_streamable_http_app( + session_manager=_FakeStreamableSessionManager(), + path="/mcp", + ) + client = TestClient(app) + response = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": "1", "method": "tools/list"}, + headers={"Authorization": "Bearer wrong-secret"}, + ) + + assert response.status_code == 403 + assert response.json()["error_code"] == "MCP_AUTH_INVALID" + + +def test_streamable_http_edge_auth_accepts_valid_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(connector_ids=["smtp"]) + app = server._build_streamable_http_app( + session_manager=_FakeStreamableSessionManager(), + path="/mcp", + ) + client = TestClient(app) + response = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": "1", "method": "tools/list"}, + headers={"Authorization": "Bearer unit-test-secret"}, + ) + + assert response.status_code == 200 + assert response.json()["ok"] is True + + +@pytest.mark.asyncio +async def test_streamable_http_identity_context_is_used_by_mcp_server( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(connector_ids=["smtp"]) + identity = authenticate_mcp_request(meta={"token": "unit-test-secret"}) + assert identity is not None + + from bindings.mcp_server.server import _streamable_http_identity_ctx + + token = _streamable_http_identity_ctx.set(identity) + try: + resolved = server._ensure_identity(identity=None, meta=None) + finally: + _streamable_http_identity_ctx.reset(token) + + assert resolved is not None + assert resolved.principal == "api-key-user" + + +def test_mcp_auth_enforced_when_not_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthRequiredError): + authenticate_mcp_request() + + +def test_mcp_auth_enforced_when_disabled_false(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "false") + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with pytest.raises(McpAuthRequiredError): + authenticate_mcp_request() + + +def test_mcp_auth_skipped_when_disabled_true(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_MCP_AUTH_DISABLED", "true") + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + assert authenticate_mcp_request() is None + + +def test_mcp_auth_legacy_enabled_env_disables_with_warning( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_MCP_AUTH_ENABLED", "true") + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + with caplog.at_level("WARNING"): + assert authenticate_mcp_request() is None + + assert "NW_MCP_AUTH_ENABLED is deprecated" in caplog.text diff --git a/tests/test_mcp_manifest_conformance.py b/tests/test_mcp_manifest_conformance.py new file mode 100644 index 0000000..0933771 --- /dev/null +++ b/tests/test_mcp_manifest_conformance.py @@ -0,0 +1,39 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""MCP manifest contract: unified vs per-connector entrypoints expose consistent tool shapes.""" + +from __future__ import annotations + +from node_wire_runtime.manifest import MCP_MANIFEST_CONTRACT_VERSION + + +def test_manifest_contract_version_is_exported() -> None: + assert MCP_MANIFEST_CONTRACT_VERSION + assert int(MCP_MANIFEST_CONTRACT_VERSION) >= 2 + + +def test_per_connector_tool_names_are_subsets_of_unified() -> None: + from bindings.factory import ConnectorFactory + from node_wire_runtime.connector_registry import auto_register + + auto_register() + factory = ConnectorFactory() + factory.load() + full = {t["name"] for t in _tools_from_server()} + cerner = {t["name"] for t in _tools_from_server(connector_ids=["fhir_cerner"])} + assert cerner == {n for n in full if n.startswith("fhir_cerner.")} + drive = {t["name"] for t in _tools_from_server(connector_ids=["google_drive"])} + assert drive == {n for n in full if n.startswith("google_drive.")} + assert "google_drive.files.upload" in drive + + +def _tools_from_server(connector_ids: list[str] | None = None) -> list[dict]: + from bindings.mcp_server.server import McpServer + + if connector_ids is None: + server = McpServer() + else: + server = McpServer(connector_ids=connector_ids) + return server.list_tools() diff --git a/tests/test_mcp_transport.py b/tests/test_mcp_transport.py new file mode 100644 index 0000000..3f9a93b --- /dev/null +++ b/tests/test_mcp_transport.py @@ -0,0 +1,266 @@ +import pytest +import httpx +from unittest.mock import patch +from bindings.mcp_server.server import McpServer, _http_request_headers +from starlette.applications import Starlette +from starlette.routing import Route +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + + +class _ASGIApp: + def __init__(self, handler): + self.handler = handler + + async def __call__(self, scope, receive, send): + headers = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in scope.get("headers", []) + } + token = _http_request_headers.set(headers) + try: + await self.handler(scope, receive, send) + finally: + _http_request_headers.reset(token) + + +@pytest.fixture(autouse=True) +def allow_only_standard_connectors(monkeypatch): + monkeypatch.setenv( + "NW_ALLOWED_CONNECTORS", "fhir_cerner,fhir_epic,google_drive,smtp,stripe,http_generic" + ) + + +@pytest.mark.anyio +async def test_mcp_transport_stdio_calls_run_stdio(): + server = McpServer() + with patch.object(server, "run_stdio") as mock_run: + server.run(transport="stdio") + mock_run.assert_called_once() + + +@pytest.mark.anyio +async def test_mcp_transport_streamable_http_calls_run_streamable_http(): + server = McpServer() + with patch.object(server, "run_streamable_http") as mock_run: + server.run(transport="streamable-http") + mock_run.assert_called_once() + + +@pytest.mark.anyio +async def test_mcp_transport_invalid_value_fails_fast(): + server = McpServer() + with pytest.raises(ValueError, match="Unsupported MCP transport: invalid"): + server.run(transport="invalid") + + +@pytest.mark.anyio +async def test_mcp_http_server_starts_and_responds(): + server = McpServer(server_name="test-server") + low = server._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + + starlette_app = Starlette( + routes=[ + Route( + "/mcp", endpoint=_ASGIApp(session_manager.handle_request), methods=["GET", "POST"] + ) + ] + ) + + async with session_manager.run(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), base_url="http://testserver" + ) as client: + rpc_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + } + response = await client.post( + "/mcp", json=rpc_request, headers={"Accept": "application/json, text/event-stream"} + ) + assert response.status_code == 200 + data = response.json() + assert "jsonrpc" in data + assert "result" in data or "error" in data + if "result" in data: + assert data["result"]["protocolVersion"] == "2024-11-05" + + +@pytest.mark.anyio +async def test_mcp_http_tools_list_success(): + server = McpServer(server_name="test-server") + low = server._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + + starlette_app = Starlette( + routes=[ + Route( + "/mcp", endpoint=_ASGIApp(session_manager.handle_request), methods=["GET", "POST"] + ) + ] + ) + + common_headers = {"Accept": "application/json, text/event-stream"} + + async with session_manager.run(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), base_url="http://testserver" + ) as client: + # First initialize + init_resp = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + }, + headers=common_headers, + ) + assert init_resp.status_code == 200 + # Use correct header name Mcp-Session-Id + session_id = init_resp.headers.get("Mcp-Session-Id") + + # Then list tools + headers = common_headers.copy() + if session_id: + headers["Mcp-Session-Id"] = session_id + + list_resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}, + headers=headers, + ) + assert list_resp.status_code == 200 + data = list_resp.json() + assert "tools" in data["result"] + + +@pytest.mark.anyio +async def test_mcp_http_tools_list_accepts_authorization_header(monkeypatch): + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(server_name="test-server", connector_ids=["smtp"]) + low = server._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + + starlette_app = Starlette( + routes=[ + Route( + "/mcp", endpoint=_ASGIApp(session_manager.handle_request), methods=["GET", "POST"] + ) + ] + ) + + common_headers = { + "Accept": "application/json, text/event-stream", + "Authorization": "Bearer unit-test-secret", + } + + async with session_manager.run(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), base_url="http://testserver" + ) as client: + init_resp = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + }, + headers=common_headers, + ) + assert init_resp.status_code == 200 + session_id = init_resp.headers.get("Mcp-Session-Id") + + headers = common_headers.copy() + if session_id: + headers["Mcp-Session-Id"] = session_id + + list_resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}, + headers=headers, + ) + assert list_resp.status_code == 200 + data = list_resp.json() + assert "tools" in data["result"] + assert any(t["name"] == "smtp.send_email" for t in data["result"]["tools"]) + + +@pytest.mark.anyio +async def test_mcp_http_tools_list_accepts_x_api_key_header(monkeypatch): + monkeypatch.delenv("NW_MCP_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_MCP_AUTH_ENABLED", raising=False) + monkeypatch.setenv("NW_MCP_API_KEY", "unit-test-secret") + monkeypatch.delenv("NW_MCP_JWT_SECRET", raising=False) + + server = McpServer(server_name="test-server", connector_ids=["smtp"]) + low = server._setup_lowlevel_server() + session_manager = StreamableHTTPSessionManager(low, json_response=True) + + starlette_app = Starlette( + routes=[ + Route( + "/mcp", endpoint=_ASGIApp(session_manager.handle_request), methods=["GET", "POST"] + ) + ] + ) + + common_headers = { + "Accept": "application/json, text/event-stream", + "X-API-Key": "unit-test-secret", + } + + async with session_manager.run(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=starlette_app), base_url="http://testserver" + ) as client: + init_resp = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + }, + headers=common_headers, + ) + assert init_resp.status_code == 200 + session_id = init_resp.headers.get("Mcp-Session-Id") + + headers = common_headers.copy() + if session_id: + headers["Mcp-Session-Id"] = session_id + + list_resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}, + headers=headers, + ) + assert list_resp.status_code == 200 + data = list_resp.json() + assert "tools" in data["result"] + assert any(t["name"] == "smtp.send_email" for t in data["result"]["tools"]) diff --git a/tests/test_observability.py b/tests/test_observability.py new file mode 100644 index 0000000..bb7144b --- /dev/null +++ b/tests/test_observability.py @@ -0,0 +1,124 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for runtime.observability (OpenTelemetry bootstrap).""" + +from __future__ import annotations + +import logging +import sys +import types +from collections.abc import Iterator +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + +import node_wire_runtime.observability as obs + + +@contextmanager +def _ensure_traceloop_stub_modules() -> Iterator[None]: + """ + unittest.mock.patch('traceloop.sdk.Traceloop') imports the traceloop package. + When traceloop-sdk is not installed (e.g. global pytest), register minimal + stubs so patch can bind; remove them only if we added them. + """ + added: list[str] = [] + try: + if "traceloop" not in sys.modules: + traceloop_mod = types.ModuleType("traceloop") + sdk_mod = types.ModuleType("traceloop.sdk") + traceloop_mod.sdk = sdk_mod # type: ignore[attr-defined] + sdk_mod.Traceloop = type("Traceloop", (), {}) # placeholder for patch target + sys.modules["traceloop"] = traceloop_mod + sys.modules["traceloop.sdk"] = sdk_mod + added.extend(["traceloop", "traceloop.sdk"]) + elif "traceloop.sdk" not in sys.modules: + sdk_mod = types.ModuleType("traceloop.sdk") + sdk_mod.Traceloop = type("Traceloop", (), {}) + sys.modules["traceloop.sdk"] = sdk_mod + added.append("traceloop.sdk") + yield + finally: + for key in added: + sys.modules.pop(key, None) + + +@pytest.fixture(autouse=True) +def reset_observability_initialized() -> None: + obs._INITIALIZED = False + yield + obs._INITIALIZED = False + + +@contextmanager +def _observability_test_patches(): + """Patches OTEL setup so tests do not mutate global tracer or break logging.""" + with _ensure_traceloop_stub_modules(): + with ( + patch("opentelemetry.trace.set_tracer_provider"), + patch("node_wire_runtime.observability.OTLPSpanExporter") as span_exp, + patch("node_wire_runtime.observability.OTLPLogExporter") as log_exp, + patch("node_wire_runtime.observability.BatchSpanProcessor"), + patch("node_wire_runtime.observability.BatchLogRecordProcessor"), + patch("node_wire_runtime.observability.set_logger_provider"), + patch( + "node_wire_runtime.observability.LoggingHandler", + side_effect=lambda **kwargs: logging.NullHandler(), + ), + patch("traceloop.sdk.Traceloop") as mock_tl, + ): + mock_tl.init = MagicMock() + yield span_exp, log_exp, mock_tl + + +def test_init_observability_idempotent() -> None: + """Second call should not reconfigure exporters.""" + with _observability_test_patches() as (span_exp, log_exp, _mock_tl): + obs.init_observability("app-a") + obs.init_observability("app-b") + assert span_exp.call_count == 1 + assert log_exp.call_count == 1 + + +def test_init_observability_invalid_sampling_ratio_logs_warning( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setenv("AOT_TRACING_SAMPLING_RATIO", "not-a-number") + with _observability_test_patches(): + with caplog.at_level(logging.WARNING, logger="runtime.observability"): + obs.init_observability("app-warn") + assert any("Invalid AOT_TRACING_SAMPLING_RATIO" in r.message for r in caplog.records) + + +def test_init_observability_otel_headers_passed_to_exporters( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_HEADERS", "key=value,foo=bar") + with _observability_test_patches() as (span_exp, log_exp, _mock_tl): + obs.init_observability("app-h") + expected_headers = {"key": "value", "foo": "bar"} + assert span_exp.call_args.kwargs.get("headers") == expected_headers + assert log_exp.call_args.kwargs.get("headers") == expected_headers + + +def test_otel_context_filter_sets_empty_trace_when_no_span() -> None: + flt = obs._OtelContextFilter() + log = logging.getLogger("test_otel_filter") + log.addFilter(flt) + record = logging.LogRecord("x", logging.INFO, __file__, 1, "msg", (), None) + assert flt.filter(record) is True + assert record.otel_trace_id == "" + assert record.otel_span_id == "" + + +def test_init_observability_traceloop_failure_does_not_raise( + caplog: pytest.LogCaptureFixture, +) -> None: + with _observability_test_patches() as (_s, _l, mock_tl): + mock_tl.init = MagicMock(side_effect=RuntimeError("traceloop unavailable")) + with caplog.at_level(logging.WARNING, logger="runtime.observability"): + obs.init_observability("app-tl") + assert any("Failed to initialize Traceloop" in r.message for r in caplog.records) diff --git a/tests/test_payload_redaction.py b/tests/test_payload_redaction.py new file mode 100644 index 0000000..57f6065 --- /dev/null +++ b/tests/test_payload_redaction.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import base64 +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from node_wire_fhir_cerner.logic import FhirCernerConnector +from node_wire_fhir_epic.logic import FhirEpicConnector +from node_wire_runtime import SecretProvider +from node_wire_runtime.auth import StaticTokenAuthProvider + + +class EpicSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "epic_fhir_base_url": "https://fhir.epic.com/api/FHIR/R4", + "epic_private_key": "-----BEGIN RSA PRIVATE KEY-----\nMEowIQ...dummy\n-----END RSA PRIVATE KEY-----", + "epic_kid": "dummy-kid", + "epic_client_id": "dummy-client-id", + "epic_token_url": "https://fhir.epic.com/token", + "dummy_token_key": "dummy-access-token", + }[key] + + +class CernerSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "cerner_fhir_base_url": "https://fhir-myrecord.cerner.com/r4/tenant-id", + "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\nMEowIQ...dummy\\n-----END RSA PRIVATE KEY-----", + "cerner_kid": "dummy-kid", + "cerner_client_id": "dummy-client-id", + "cerner_token_url": "https://authorization.cerner.com/tenants/tenant-id/protocols/oauth2/profiles/smart-v1/token", + "dummy_token_key": "dummy-access-token", + }[key] + + +def _epic_connector_for_redaction() -> FhirEpicConnector: + sp = EpicSecretProvider() + auth = StaticTokenAuthProvider(secret_provider=sp, secret_key="dummy_token_key") + return FhirEpicConnector(secret_provider=sp, auth_provider=auth) + + +def _cerner_connector_for_redaction() -> FhirCernerConnector: + sp = CernerSecretProvider() + auth = StaticTokenAuthProvider(secret_provider=sp, secret_key="dummy_token_key") + return FhirCernerConnector(secret_provider=sp, auth_provider=auth) + + +def _serialize_calls(mocked_logger: MagicMock) -> str: + parts: list[str] = [] + for call in mocked_logger.call_args_list: + parts.append(repr(call.args)) + parts.append(repr(call.kwargs)) + return "\n".join(parts) + + +@pytest.mark.asyncio +async def test_fhir_epic_create_document_reference_logs_redacted_payload() -> None: + from node_wire_fhir_epic.schema import FhirDocumentReferenceCreateInput + + connector = _epic_connector_for_redaction() + payload_secret = "SENSITIVE_PAYLOAD_VALUE" + response_secret = "SENSITIVE_RESPONSE_VALUE" + data_b64 = base64.b64encode(payload_secret.encode()).decode("ascii") + + params = FhirDocumentReferenceCreateInput( + action="create_document_reference", + identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], + status="current", + type={ + "coding": [ + {"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"} + ] + }, + subject="Patient/ePD0eeFq.GMHG.aXttqP.Lw3", + data=data_b64, + context={"related": [{"reference": "Group/eqv3buSV"}]}, + ) + + post_req = httpx.Request("POST", "https://fhir.example/DocumentReference") + err_resp = httpx.Response(400, request=post_req, text=response_secret) + + async def post_side_effect(*args: object, **kwargs: object) -> httpx.Response: + # StaticTokenAuthProvider: no separate OAuth POST; only FHIR create hits AsyncClient.post. + return err_resp + + with ( + patch("httpx.AsyncClient.post", new_callable=AsyncMock, side_effect=post_side_effect), + patch("node_wire_fhir_epic.logic.logger.error") as mocked_error, + ): + with pytest.raises(ValueError, match="Epic Error: HTTP 400 from Epic FHIR endpoint"): + await connector.internal_execute(params, trace_id="test-trace") + + logged = _serialize_calls(mocked_error) + assert payload_secret not in logged + assert data_b64 not in logged + assert response_secret not in logged + assert "payload_summary" in logged + + +@pytest.mark.asyncio +async def test_fhir_cerner_create_document_reference_logs_redacted_payload() -> None: + from node_wire_fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput + + connector = _cerner_connector_for_redaction() + payload_secret = "CERNER_SECRET_PAYLOAD" + response_secret = "CERNER_SECRET_RESPONSE" + data_b64 = base64.b64encode(payload_secret.encode()).decode("ascii") + + params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", + identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], + status="current", + doc_status="final", + type={ + "coding": [ + { + "system": "urn:oid:4.5.6", + "code": "18100", + "display": "Employer Group Scan", + "userSelected": True, + } + ], + "text": "Employer Group Scan", + }, + subject="Patient/12724066", + data=data_b64, + attachment_title="Document", + author=[{"reference": "Practitioner/p1"}], + context={ + "encounter": [{"reference": "Encounter/enc-1"}], + "period": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T01:00:00Z"}, + }, + ) + + post_req = httpx.Request("POST", "https://fhir.example/DocumentReference") + err_resp = httpx.Response(400, request=post_req, text=response_secret) + + async def post_side_effect(*args: object, **kwargs: object) -> httpx.Response: + return err_resp + + with ( + patch("httpx.AsyncClient.post", new_callable=AsyncMock, side_effect=post_side_effect), + patch("node_wire_fhir_cerner.logic.logger.error") as mocked_error, + ): + with pytest.raises(ValueError, match="Cerner Error: HTTP 400 from Cerner FHIR endpoint"): + await connector.internal_execute(params, trace_id="test-trace") + + logged = _serialize_calls(mocked_error) + assert payload_secret not in logged + assert data_b64 not in logged + assert response_secret not in logged + assert "payload_summary" in logged diff --git a/tests/test_rest_app_import_env.py b/tests/test_rest_app_import_env.py new file mode 100644 index 0000000..066c852 --- /dev/null +++ b/tests/test_rest_app_import_env.py @@ -0,0 +1,32 @@ +"""Guardrails for REST app import during pytest collection. + +Requires ``conftest`` env + ``connectors_for_tests.yaml`` so enabled connectors match +``NW_ALLOWED_CONNECTORS`` (optional connectors like slack/salesforce stay disabled). +""" + +from __future__ import annotations + +from pathlib import Path + + +def test_pytest_env_disables_rest_dotenv() -> None: + import os + + assert os.environ.get("NW_REST_LOAD_DOTENV", "").lower() in ("false", "0", "no") + + +def test_pytest_uses_test_connector_config_fixture() -> None: + import os + + path = os.environ.get("NW_CONFIG_PATH", "") + assert path + assert path.endswith("connectors_for_tests.yaml") + assert Path(path).is_file() + + +def test_rest_app_module_imports_without_runtime_error() -> None: + """``bindings.rest_api.app`` builds routes at import time; must not raise.""" + import bindings.rest_api.app as rest_app + + assert rest_app.app is not None + assert len(rest_app.app.routes) > 0 diff --git a/tests/test_rest_rate_limit_enforcement.py b/tests/test_rest_rate_limit_enforcement.py new file mode 100644 index 0000000..a8e4a3f --- /dev/null +++ b/tests/test_rest_rate_limit_enforcement.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +from fastapi.testclient import TestClient + +from bindings.rest_api import app as rest_app_module +from bindings.rest_api.app import app, get_factory +from node_wire_runtime.models import ConnectorResponse + + +def _stub_connector() -> MagicMock: + connector = MagicMock() + connector.run = AsyncMock( + return_value=ConnectorResponse(success=True, data={"ok": True}, trace_id="t-limit") + ) + return connector + + +def _make_client(monkeypatch) -> tuple[TestClient, MagicMock]: + monkeypatch.setenv("NW_REST_RATE_LIMIT_ENABLED", "true") + monkeypatch.setenv("NW_REST_RATE_LIMIT_MAX_REQUESTS", "2") + monkeypatch.setenv("NW_REST_RATE_LIMIT_WINDOW_SECONDS", "60") + monkeypatch.setattr(rest_app_module, "_rate_limiter", None) + monkeypatch.setattr(rest_app_module, "_rate_limiter_cfg", None) + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector() + app.dependency_overrides[get_factory] = lambda: mock_factory + return TestClient(app), mock_factory + + +def test_rest_rate_limit_allows_under_threshold(monkeypatch) -> None: + client, _ = _make_client(monkeypatch) + try: + first = client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-a"}, + ) + second = client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-a"}, + ) + finally: + app.dependency_overrides.clear() + assert first.status_code == 200 + assert second.status_code == 200 + + +def test_rest_rate_limit_returns_429_and_retry_after(monkeypatch) -> None: + client, _ = _make_client(monkeypatch) + try: + client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-a"}, + ) + client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-a"}, + ) + third = client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-a"}, + ) + finally: + app.dependency_overrides.clear() + + assert third.status_code == 429 + assert third.json()["detail"] == "Rate limit exceeded" + retry_after = third.headers.get("Retry-After") + assert retry_after is not None + assert int(retry_after) >= 1 + + +def test_rest_rate_limit_isolated_by_identity(monkeypatch) -> None: + monkeypatch.setenv("NW_REST_RATE_LIMIT_ENABLED", "true") + monkeypatch.setenv("NW_REST_RATE_LIMIT_MAX_REQUESTS", "1") + monkeypatch.setenv("NW_REST_RATE_LIMIT_WINDOW_SECONDS", "60") + monkeypatch.setattr(rest_app_module, "_rate_limiter", None) + monkeypatch.setattr(rest_app_module, "_rate_limiter_cfg", None) + + mock_factory = MagicMock() + mock_factory.get_for_protocol.return_value = _stub_connector() + app.dependency_overrides[get_factory] = lambda: mock_factory + + try: + client = TestClient(app) + first = client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-a"}, + ) + second = client.post( + "/connectors/http_generic/request", + json={"method": "GET", "url": "https://example.com"}, + headers={"X-API-Key": "tenant-b"}, + ) + finally: + app.dependency_overrides.clear() + + assert first.status_code == 200 + assert second.status_code == 200 diff --git a/tests/test_runtime_resilience.py b/tests/test_runtime_resilience.py new file mode 100644 index 0000000..ae2794e --- /dev/null +++ b/tests/test_runtime_resilience.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import asyncio + +import pytest +from pydantic import BaseModel +from pybreaker import CircuitBreaker, CircuitBreakerError + +from node_wire_runtime import BaseConnector, ErrorCategory, ErrorMapper, nw_action +from node_wire_runtime.resilience import _resolve_breaker, with_resilience + + +class RetryableTestError(Exception): + pass + + +class FatalTestError(Exception): + pass + + +class RetryInput(BaseModel): + action: str = "retry" + value: int = 1 + + +class RetryOutput(BaseModel): + attempts: int + + +class FlakyConnector(BaseConnector): + connector_id = "test_flaky_resilience" + output_model = RetryOutput + + def __init__(self) -> None: + super().__init__() + self.calls_by_tenant: dict[str, int] = {} + self.failures_by_tenant: dict[str, int] = {} + + @nw_action("retry") + async def retry(self, params: RetryInput, *, trace_id: str) -> RetryOutput: + tenant_key = trace_id.split(":", maxsplit=1)[0] + self.calls_by_tenant[tenant_key] = self.calls_by_tenant.get(tenant_key, 0) + 1 + failures = self.failures_by_tenant.get(tenant_key, 0) + if failures > 0: + self.failures_by_tenant[tenant_key] = failures - 1 + raise RetryableTestError(f"retryable failure for {tenant_key}") + return RetryOutput(attempts=self.calls_by_tenant[tenant_key]) + + async def run_for_tenant(self, tenant_id: str) -> object: + original_internal_execute = self.internal_execute + + async def _tagged_internal_execute(params: object, *, trace_id: str) -> object: + return await original_internal_execute(params, trace_id=f"{tenant_id}:{trace_id}") + + self.internal_execute = _tagged_internal_execute # type: ignore[method-assign] + try: + return await self.run({"action": "retry", "value": 1}, tenant_id=tenant_id) + finally: + self.internal_execute = original_internal_execute # type: ignore[method-assign] + + +@pytest.fixture(autouse=True) +def reset_error_mapper_registry() -> None: + original = dict(ErrorMapper._registry) + try: + ErrorMapper._registry.clear() + ErrorMapper.register(RetryableTestError, ErrorCategory.RETRYABLE, code="RETRYABLE_TEST") + yield + finally: + ErrorMapper._registry.clear() + ErrorMapper._registry.update(original) + + +def test_with_resilience_retries_retryable_errors_until_success() -> None: + connector = FlakyConnector() + connector.failures_by_tenant["tenant-a"] = 2 + + response = asyncio.run(connector.run_for_tenant("tenant-a")) + + assert response.success is True + assert response.data == {"attempts": 3} + assert connector.calls_by_tenant["tenant-a"] == 3 + + +def test_tenant_breaker_state_is_isolated_across_shared_connector_instance() -> None: + connector = FlakyConnector() + connector._breaker_for_tenant("tenant-a").open() + + first = asyncio.run(connector.run_for_tenant("tenant-a")) + other_tenant = asyncio.run(connector.run_for_tenant("tenant-b")) + + assert first.success is False + assert first.error_code == "CircuitBreakerError" + assert first.error_category == ErrorCategory.FATAL + assert other_tenant.success is True + assert other_tenant.data == {"attempts": 1} + + +def test_breaker_cache_uses_distinct_keys_per_tenant() -> None: + connector = FlakyConnector() + + default_breaker = connector._breaker_for_tenant(None) + tenant_a_breaker = connector._breaker_for_tenant("tenant-a") + tenant_b_breaker = connector._breaker_for_tenant("tenant-b") + + assert default_breaker is connector._breaker_for_tenant(None) + assert tenant_a_breaker is connector._breaker_for_tenant("tenant-a") + assert tenant_a_breaker is not tenant_b_breaker + assert default_breaker is not tenant_a_breaker + + +def test_open_breaker_rejects_calls_immediately() -> None: + connector = FlakyConnector() + breaker = connector._breaker_for_tenant("tenant-a") + breaker.open() + + response = asyncio.run(connector.run_for_tenant("tenant-a")) + + assert response.success is False + assert response.error_code == "CircuitBreakerError" + + +def test_circuit_breaker_error_defaults_to_fatal_mapping() -> None: + mapped = ErrorMapper.resolve(CircuitBreakerError("open")) + + assert mapped.code == "CircuitBreakerError" + assert mapped.category == ErrorCategory.FATAL + + +def test_resolve_breaker_returns_same_instance_for_circuit_breaker() -> None: + cb = CircuitBreaker() + assert _resolve_breaker(cb) is cb + + +def test_resolve_breaker_invokes_factory_callable() -> None: + created = CircuitBreaker() + + def factory() -> CircuitBreaker: + return created + + assert _resolve_breaker(factory) is created + + +def test_resolve_breaker_resolved_object_has_state() -> None: + cb = CircuitBreaker() + resolved = _resolve_breaker(cb) + assert hasattr(resolved, "state") + assert resolved.state.name in ("closed", "open", "half-open") + + +def test_with_resilience_accepts_concrete_circuit_breaker_instance() -> None: + breaker = CircuitBreaker() + + @with_resilience(breaker) + async def succeed(*, trace_id: str = "t") -> str: + return "ok" + + assert asyncio.run(succeed(trace_id="x")) == "ok" + + +def test_fatal_errors_do_not_retry() -> None: + class FatalConnector(BaseConnector): + connector_id = "test_fatal_resilience" + output_model = RetryOutput + + @nw_action("retry") + async def retry(self, params: RetryInput, *, trace_id: str) -> RetryOutput: + raise FatalTestError("fatal") + + connector = FatalConnector() + response = asyncio.run(connector.run({"action": "retry", "value": 1}, tenant_id="tenant-a")) + + assert response.success is False + assert response.error_code == "FatalTestError" + assert response.error_category == ErrorCategory.FATAL diff --git a/tests/test_salesforce.py b/tests/test_salesforce.py new file mode 100644 index 0000000..7cefb4d --- /dev/null +++ b/tests/test_salesforce.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from pydantic import ValidationError + +from node_wire_runtime import SecretProvider +from node_wire_salesforce.logic import SalesforceConnector, SalesforceTransientError +from node_wire_salesforce.schema import ( + CreateLeadInput, + ReadLeadInput, + UpdateLeadInput, + DeleteLeadInput, + CreateContactInput, + ReadContactInput, + UpdateContactInput, + DeleteContactInput, +) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "salesforce_instance_url": "https://test.salesforce.com", + }[key] + + +def _connector() -> SalesforceConnector: + """Return a SalesforceConnector with mock secrets.""" + conn = SalesforceConnector(secret_provider=MockSecretProvider()) + # Mock auth headers + conn.get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer mock_token"}) + return conn + + +# --------------------------------------------------------------------------- +# Create Contact +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_create_contact_happy_path(): + connector = _connector() + params = CreateContactInput(LastName="Doe", FirstName="John", Email="john@example.com") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 201 + mock_response.content = b'{"id": "003123456789012", "success": true}' + mock_response.json.return_value = {"id": "003123456789012", "success": True} + mock_response.text = '{"id": "003123456789012", "success": true}' + + with patch("httpx.AsyncClient.request", return_value=mock_response): + result = await connector.create_contact(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "003123456789012" + assert result.data["id"] == "003123456789012" + + +@pytest.mark.asyncio +async def test_salesforce_create_contact_validation_error(): + # Invalid AccountId (too short) + with pytest.raises(ValidationError) as excinfo: + CreateContactInput(LastName="Doe", AccountId="short") + assert "Invalid Salesforce AccountId format" in str(excinfo.value) + + +# --------------------------------------------------------------------------- +# Update Contact (204 No Content) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_update_contact_204_path(): + connector = _connector() + params = UpdateContactInput(record_id="003123456789012", fields={"FirstName": "Jane"}) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + mock_response.text = "" + + with patch("httpx.AsyncClient.request", return_value=mock_response): + result = await connector.update_contact(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "003123456789012" + assert result.data == {} + + +# --------------------------------------------------------------------------- +# Error Handling (Raises Exception) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_error_raises_exception(): + connector = _connector() + params = ReadContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message="Bad Request", request=MagicMock(), response=mock_response + ) + + with patch("httpx.AsyncClient.request", return_value=mock_response): + with pytest.raises(httpx.HTTPStatusError): + await connector.read_contact(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# Transient Error (Raises SalesforceTransientError) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_transient_error_raises(): + connector = _connector() + params = ReadContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 503 + mock_response.text = "Service Unavailable" + + with patch("httpx.AsyncClient.request", return_value=mock_response): + with pytest.raises(SalesforceTransientError): + await connector.read_contact(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# End-to-End internal_execute logic (checks mapping) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_internal_execute_mapping(): + connector = _connector() + # Mocking internal_execute because BaseConnector handles the exception wrapping + + params = ReadContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 503 + mock_response.text = "Transient Error" + + with patch("httpx.AsyncClient.request", return_value=mock_response): + # We call internal_execute directly to bypass BaseConnector.run's retry logic for now + # but check that it raises the expected transient error + with pytest.raises(SalesforceTransientError): + await connector.internal_execute(params, trace_id="test-trace") + + +# --------------------------------------------------------------------------- +# Delete Contact +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_delete_contact_happy_path(): + connector = _connector() + params = DeleteContactInput(record_id="003123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.delete_contact(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "003123456789012" + mock_request.assert_called_once() + assert mock_request.call_args[0][0] == "DELETE" + + +# --------------------------------------------------------------------------- +# Lead Operations +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_salesforce_create_lead_happy_path(): + connector = _connector() + params = CreateLeadInput(LastName="Smith", Company="Acme Corp", Email="smith@acme.com") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 201 + mock_response.content = b'{"id": "00Q123456789012", "success": true}' + mock_response.json.return_value = {"id": "00Q123456789012", "success": True} + mock_response.text = '{"id": "00Q123456789012", "success": true}' + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.create_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert "LastName" in mock_request.call_args[1]["json"] + assert mock_request.call_args[1]["json"]["LastName"] == "Smith" + + +@pytest.mark.asyncio +async def test_salesforce_read_lead_happy_path(): + connector = _connector() + params = ReadLeadInput(record_id="00Q123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = b'{"Id": "00Q123456789012", "LastName": "Smith"}' + mock_response.json.return_value = {"Id": "00Q123456789012", "LastName": "Smith"} + + with patch("httpx.AsyncClient.request", return_value=mock_response): + result = await connector.read_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert result.data["LastName"] == "Smith" + + +@pytest.mark.asyncio +async def test_salesforce_update_lead_happy_path(): + connector = _connector() + params = UpdateLeadInput(record_id="00Q123456789012", fields={"Company": "New Acme"}) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.update_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert mock_request.call_args[0][0] == "PATCH" + assert mock_request.call_args[1]["json"]["Company"] == "New Acme" + + +@pytest.mark.asyncio +async def test_salesforce_delete_lead_happy_path(): + connector = _connector() + params = DeleteLeadInput(record_id="00Q123456789012") + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 204 + mock_response.content = b"" + + with patch("httpx.AsyncClient.request", return_value=mock_response) as mock_request: + result = await connector.delete_lead(params, trace_id="test-trace") + + assert result.success is True + assert result.resource_id == "00Q123456789012" + assert mock_request.call_args[0][0] == "DELETE" diff --git a/tests/test_scope_policy_transport.py b/tests/test_scope_policy_transport.py new file mode 100644 index 0000000..7656648 --- /dev/null +++ b/tests/test_scope_policy_transport.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import asyncio +from typing import Literal + +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.policies.mcp_scope_policy import ( + DEFAULT_SCOPE_MODE_DENY, + ScopePolicyHook, +) + + +class _Input(BaseModel): + action: Literal["read_patient"] = "read_patient" + resource_id: str + + +class _Output(BaseModel): + ok: bool + + +class _PolicyTestConnector(BaseConnector): + # Do not shadow production ``fhir_epic`` in the global registry. + connector_id = "policy_transport_test" + output_model = _Output + + @nw_action("read_patient") + async def read_patient(self, params: _Input, *, trace_id: str) -> _Output: + return _Output(ok=True) + + +def _connector_with_scope_map() -> _PolicyTestConnector: + return _PolicyTestConnector( + policy_hook=ScopePolicyHook({"policy_transport_test.read_patient": "mcp:fhir.read_patient"}) + ) + + +def test_scope_policy_bypasses_when_identity_missing_like_grpc() -> None: + connector = _connector_with_scope_map() + response = asyncio.run(connector.run({"action": "read_patient", "resource_id": "x"})) + + assert response.success is True + assert response.error_code is None + + +def test_scope_policy_denies_when_identity_present_without_required_scope() -> None: + connector = _connector_with_scope_map() + response = asyncio.run( + connector.run( + {"action": "read_patient", "resource_id": "x"}, + principal="alice", + tenant_id="tenant-1", + scopes=("mcp:other.scope",), + ) + ) + + assert response.success is False + assert response.error_code == "POLICY_DENIED" + assert response.message == "Missing required scope: mcp:fhir.read_patient" + + +def test_scope_policy_default_deny_uses_conventional_scope() -> None: + hook = ScopePolicyHook({}, default_mode=DEFAULT_SCOPE_MODE_DENY) + connector = _PolicyTestConnector(policy_hook=hook) + response = asyncio.run( + connector.run( + {"action": "read_patient", "resource_id": "x"}, + principal="alice", + tenant_id="tenant-1", + scopes=("mcp:policy_transport_test.read_patient",), + ) + ) + assert response.success is True + + +def test_scope_policy_default_deny_without_fallback_scope() -> None: + hook = ScopePolicyHook({}, default_mode=DEFAULT_SCOPE_MODE_DENY) + connector = _PolicyTestConnector(policy_hook=hook) + response = asyncio.run( + connector.run( + {"action": "read_patient", "resource_id": "x"}, + principal="alice", + tenant_id="tenant-1", + scopes=("mcp:wrong",), + ) + ) + assert response.success is False + assert "Missing required scope" in (response.message or "") diff --git a/tests/test_secrets_env.py b/tests/test_secrets_env.py new file mode 100644 index 0000000..46ed087 --- /dev/null +++ b/tests/test_secrets_env.py @@ -0,0 +1,72 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for EnvSecretProvider and factory secret wiring.""" + +from __future__ import annotations + +import sys +import types + +import pytest + +from bindings.factory import _build_secret_provider +from node_wire_runtime.secrets.base import EnvSecretProvider, SecretNotFoundError +from node_wire_runtime.secrets.chained import ChainedSecretProvider + + +def test_env_secret_provider_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_ENV_SECRET_LEGACY_EMPTY", raising=False) + monkeypatch.delenv("MISSING_TEST_KEY_X", raising=False) + p = EnvSecretProvider(legacy_empty_on_missing=False) + with pytest.raises(SecretNotFoundError): + p.get_secret("MISSING_TEST_KEY_X") + + +def test_env_secret_provider_legacy_empty(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MISSING_TEST_KEY_X", raising=False) + p = EnvSecretProvider(legacy_empty_on_missing=True) + assert p.get_secret("MISSING_TEST_KEY_X") == "" + + +def test_build_secret_provider_default_env() -> None: + p = _build_secret_provider() + assert isinstance(p, EnvSecretProvider) + + +def test_build_secret_provider_aws_env_chain(monkeypatch: pytest.MonkeyPatch) -> None: + """Chained AWS+env without importing real boto3 (fake ``secrets.aws`` module).""" + monkeypatch.setenv("NW_SECRET_BACKEND", "aws_env") + monkeypatch.setenv("NW_AWS_SECRETS_MANAGER_SECRET_ID", "test-secret") + monkeypatch.setenv("AWS_REGION", "us-west-2") + monkeypatch.setenv("CHAIN_TEST_KEY", "from-env") + + class FakeAws: + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + def get_secret(self, key: str) -> str: + raise SecretNotFoundError(key) + + fake_mod = types.ModuleType("node_wire_runtime.secrets.aws") + fake_mod.AwsSecretsManagerProvider = FakeAws # type: ignore[attr-defined] + old = sys.modules.get("node_wire_runtime.secrets.aws") + sys.modules["node_wire_runtime.secrets.aws"] = fake_mod + try: + out = _build_secret_provider() + finally: + if old is not None: + sys.modules["node_wire_runtime.secrets.aws"] = old + else: + sys.modules.pop("node_wire_runtime.secrets.aws", None) + + assert isinstance(out, ChainedSecretProvider) + assert out.get_secret("CHAIN_TEST_KEY") == "from-env" + + +def test_aws_env_requires_secret_id(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_SECRET_BACKEND", "aws_env") + monkeypatch.delenv("NW_AWS_SECRETS_MANAGER_SECRET_ID", raising=False) + with pytest.raises(ValueError, match="NW_AWS_SECRETS_MANAGER_SECRET_ID"): + _build_secret_provider() diff --git a/tests/test_slack_connector.py b/tests/test_slack_connector.py new file mode 100644 index 0000000..ee738e6 --- /dev/null +++ b/tests/test_slack_connector.py @@ -0,0 +1,485 @@ +""" +Unit tests for the Node-Wire Slack connector. + +All tests are fully offline — httpx calls inside logic.py are patched with +unittest.mock so no real Slack API is contacted. + +Coverage +-------- +- Connector instantiation and connector_id +- post_message happy path +- post_message with Block Kit (string + list) +- post_message with invalid blocks JSON +- post_message: SlackAuthError maps to ErrorCategory.AUTH +- post_message: SlackRateLimitError maps to ErrorCategory.RETRYABLE +- send_direct_message happy path +- upload_file base64 happy path (all 3 upload steps mocked) +- upload_file filepath happy path +- upload_file: missing content source raises SlackUploadError +- upload_file: invalid base64 raises SlackUploadError +- upload_file: file exceeds size limit +- Token is NEVER present in log output (security boundary) +""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from node_wire_runtime import BaseConnector, ErrorCategory +from node_wire_runtime.secrets import SecretProvider + +from node_wire_slack.exceptions import ( + SlackAuthError, + SlackMessageError, + SlackPermissionError, + SlackRateLimitError, +) +from node_wire_slack.logic import ( + SlackConnector, + _complete_upload, + _resolve_blocks, +) +import node_wire_slack.registration # noqa: F401 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FAKE_TOKEN = "xoxb-test-token-0000" +_CHANNEL = "C0TEST123" +_USER_ID = "U0TEST456" + + +class FakeSecretProvider(SecretProvider): + """Returns the fake token for any key.""" + + def get_secret(self, key: str) -> str: + return _FAKE_TOKEN + + +def _make_connector() -> SlackConnector: + return SlackConnector(secret_provider=FakeSecretProvider()) + + +def _slack_ok_response(**extra: Any) -> dict[str, Any]: + return {"ok": True, "ts": "1234567890.123456", "channel": _CHANNEL, **extra} + + +# --------------------------------------------------------------------------- +# 1. Instantiation +# --------------------------------------------------------------------------- + + +def test_slack_connector_instantiation() -> None: + connector = _make_connector() + assert connector.connector_id == "slack" + assert isinstance(connector, BaseConnector) + + +def test_slack_connector_has_three_actions() -> None: + metas = SlackConnector.sdk_action_metas() + assert set(metas.keys()) == {"post_message", "send_direct_message", "upload_file"} + + +# --------------------------------------------------------------------------- +# 2. post_message — happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_post_message_success() -> None: + connector = _make_connector() + response_data = _slack_ok_response() + + with patch("node_wire_slack.logic._post_json", new=AsyncMock(return_value=response_data)): + result = await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hello World"} + ) + + assert result.success is True + assert result.data["ok"] is True + assert result.data["channel"] == _CHANNEL + + +# --------------------------------------------------------------------------- +# 3. post_message with Block Kit +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_post_message_with_blocks_list() -> None: + """Blocks provided as a list are forwarded directly.""" + connector = _make_connector() + blocks = [{"type": "section", "text": {"type": "mrkdwn", "text": "hello"}}] + captured: dict[str, Any] = {} + + async def fake_post_json(url: str, token: str, body: dict) -> dict: + captured.update(body) + return _slack_ok_response() + + with patch("node_wire_slack.logic._post_json", new=fake_post_json): + result = await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hi", "blocks": blocks} + ) + + assert result.success is True + assert captured.get("blocks") == blocks + + +@pytest.mark.asyncio +async def test_post_message_with_blocks_json_string() -> None: + """Blocks provided as a JSON string are parsed before being sent.""" + connector = _make_connector() + blocks = [{"type": "section", "text": {"type": "mrkdwn", "text": "hello"}}] + import json + + blocks_str = json.dumps(blocks) + captured: dict[str, Any] = {} + + async def fake_post_json(url: str, token: str, body: dict) -> dict: + captured.update(body) + return _slack_ok_response() + + with patch("node_wire_slack.logic._post_json", new=fake_post_json): + await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hi", "blocks": blocks_str} + ) + + assert captured.get("blocks") == blocks + + +@pytest.mark.asyncio +async def test_post_message_invalid_blocks_json_returns_error() -> None: + """Invalid blocks JSON must map to a BUSINESS error response, not an unhandled crash.""" + connector = _make_connector() + + result = await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hi", "blocks": "{not-json"} + ) + + assert result.success is False + assert result.error_category == ErrorCategory.BUSINESS + + +# --------------------------------------------------------------------------- +# 4. post_message — auth error +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_post_message_auth_error_maps_to_auth_category() -> None: + connector = _make_connector() + + with patch( + "node_wire_slack.logic._post_json", + new=AsyncMock(side_effect=SlackAuthError("token_revoked")), + ): + result = await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hi"} + ) + + assert result.success is False + assert result.error_category == ErrorCategory.AUTH + assert result.error_code == "SLACK_AUTH_ERROR" + + +# --------------------------------------------------------------------------- +# 5. post_message — permission error +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_post_message_permission_error_maps_to_auth_category() -> None: + connector = _make_connector() + + with patch( + "node_wire_slack.logic._post_json", + new=AsyncMock(side_effect=SlackPermissionError("missing_scope")), + ): + result = await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hi"} + ) + + assert result.success is False + assert result.error_category == ErrorCategory.AUTH + assert result.error_code == "SLACK_PERMISSION_ERROR" + + +# --------------------------------------------------------------------------- +# 6. post_message — rate limit +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_post_message_rate_limit_maps_to_retryable() -> None: + connector = _make_connector() + + with patch( + "node_wire_slack.logic._post_json", + new=AsyncMock(side_effect=SlackRateLimitError("ratelimited")), + ): + result = await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "Hi"} + ) + + assert result.success is False + assert result.error_category == ErrorCategory.RETRYABLE + assert result.error_code == "SLACK_RATE_LIMIT" + + +# --------------------------------------------------------------------------- +# 7. send_direct_message — happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_direct_message_success() -> None: + connector = _make_connector() + response_data = {**_slack_ok_response(), "channel": _USER_ID} + + with patch("node_wire_slack.logic._post_json", new=AsyncMock(return_value=response_data)): + result = await connector.run( + {"action": "send_direct_message", "channel": _USER_ID, "message": "Hey!"} + ) + + assert result.success is True + assert result.data["ok"] is True + + +# --------------------------------------------------------------------------- +# 8. upload_file — base64 happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_upload_file_base64_success() -> None: + connector = _make_connector() + content = b"Hello, file!" + b64 = base64.b64encode(content).decode() + file_id = "F0TESTFILE" + + complete_response = {"ok": True, "files": [{"id": file_id}]} + + with ( + patch( + "node_wire_slack.logic._get_upload_url", + new=AsyncMock(return_value=("https://upload.slack.com/test", file_id)), + ), + patch("node_wire_slack.logic._upload_bytes", new=AsyncMock(return_value=None)), + patch( + "node_wire_slack.logic._complete_upload", new=AsyncMock(return_value=complete_response) + ), + ): + result = await connector.run( + { + "action": "upload_file", + "channel": _CHANNEL, + "filename": "test.txt", + "content_base64": b64, + } + ) + + assert result.success is True + assert result.data["file_id"] == file_id + + +# --------------------------------------------------------------------------- +# 9. upload_file — missing content source +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_upload_file_missing_content_returns_business_error() -> None: + connector = _make_connector() + + result = await connector.run( + {"action": "upload_file", "channel": _CHANNEL, "filename": "empty.txt"} + ) + + assert result.success is False + assert result.error_category == ErrorCategory.BUSINESS + assert result.error_code == "SLACK_UPLOAD_ERROR" + + +# --------------------------------------------------------------------------- +# 10. upload_file — invalid base64 +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_upload_file_invalid_base64_returns_business_error() -> None: + connector = _make_connector() + + result = await connector.run( + { + "action": "upload_file", + "channel": _CHANNEL, + "content_base64": "!!!not-valid-base64!!!", + } + ) + + assert result.success is False + assert result.error_category == ErrorCategory.BUSINESS + assert result.error_code == "SLACK_UPLOAD_ERROR" + + +# --------------------------------------------------------------------------- +# 11. upload_file — file too large +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_upload_file_too_large_returns_business_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + connector = _make_connector() + monkeypatch.setenv("NW_SLACK_UPLOAD_LIMIT_MB", "1") + + # 2 MB of content + content = b"x" * (2 * 1024 * 1024) + b64 = base64.b64encode(content).decode() + + result = await connector.run( + {"action": "upload_file", "channel": _CHANNEL, "content_base64": b64} + ) + + assert result.success is False + assert result.error_category == ErrorCategory.BUSINESS + assert result.error_code == "SLACK_UPLOAD_ERROR" + + +# --------------------------------------------------------------------------- +# 12. Security: token never in logs +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_token_never_appears_in_logs(caplog: pytest.LogCaptureFixture) -> None: + """The Slack bot token must NEVER appear in any log record.""" + connector = _make_connector() + + with caplog.at_level(logging.DEBUG, logger="connectors.slack"): + with patch( + "node_wire_slack.logic._post_json", + new=AsyncMock(return_value=_slack_ok_response()), + ): + await connector.run( + {"action": "post_message", "channel": _CHANNEL, "message": "secure"} + ) + + for record in caplog.records: + assert _FAKE_TOKEN not in record.getMessage(), ( + f"Token leaked in log record: {record.getMessage()!r}" + ) + assert _FAKE_TOKEN not in str(record.__dict__), ( + f"Token leaked in log record attrs: {record.__dict__!r}" + ) + + +# --------------------------------------------------------------------------- +# 13. _resolve_blocks helper +# --------------------------------------------------------------------------- + + +def test_resolve_blocks_none_returns_none() -> None: + assert _resolve_blocks(None) is None + + +def test_resolve_blocks_list_passthrough() -> None: + blocks = [{"type": "section"}] + assert _resolve_blocks(blocks) == blocks + + +def test_resolve_blocks_valid_json_string() -> None: + import json + + blocks = [{"type": "section"}] + assert _resolve_blocks(json.dumps(blocks)) == blocks + + +def test_resolve_blocks_invalid_json_raises() -> None: + with pytest.raises(SlackMessageError, match="Invalid blocks JSON"): + _resolve_blocks("{bad json") + + +def test_resolve_blocks_non_array_json_raises() -> None: + import json + + with pytest.raises(SlackMessageError, match="must be a JSON array"): + _resolve_blocks(json.dumps({"type": "section"})) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("channel_id", ["", "#general", "U0TEST456"]) +async def test_complete_upload_omits_invalid_channel_id(channel_id: str) -> None: + captured: dict[str, Any] = {} + + class FakeResponse: + status_code = 200 + + def json(self) -> dict[str, Any]: + return {"ok": True, "files": [{"id": "F0TESTFILE"}]} + + class FakeAsyncClient: + def __init__(self, timeout: float) -> None: + self.timeout = timeout + + async def __aenter__(self) -> "FakeAsyncClient": + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + return None + + async def post( + self, + url: str, + headers: dict[str, str] | None = None, + data: dict[str, Any] | None = None, + ) -> FakeResponse: + captured["url"] = url + captured["headers"] = headers or {} + captured["data"] = data or {} + return FakeResponse() + + with patch("node_wire_slack.logic.httpx.AsyncClient", new=FakeAsyncClient): + data = await _complete_upload( + _FAKE_TOKEN, + "F0TESTFILE", + "test.txt", + channel_id=channel_id, + initial_comment="hello", + ) + + assert data["ok"] is True + assert "channel_id" not in captured["data"] + assert captured["data"]["initial_comment"] == "hello" + + +@pytest.mark.asyncio +async def test_upload_file_invalid_resolved_channel_returns_business_error() -> None: + connector = _make_connector() + b64 = base64.b64encode(b"Hello, file!").decode() + + with ( + patch("node_wire_slack.logic._resolve_channel_id", new=AsyncMock(return_value="#general")), + patch("node_wire_slack.logic._get_upload_url", new=AsyncMock()) as get_upload_url, + ): + result = await connector.run( + { + "action": "upload_file", + "channel": "#general", + "filename": "test.txt", + "content_base64": b64, + } + ) + + assert result.success is False + assert result.error_category == ErrorCategory.BUSINESS + assert result.error_code == "SLACK_UPLOAD_ERROR" + assert "Could not resolve '#general' to a valid Slack channel ID" in result.message + get_upload_url.assert_not_awaited() diff --git a/tests/test_stripe.py b/tests/test_stripe.py new file mode 100644 index 0000000..73fcd6a --- /dev/null +++ b/tests/test_stripe.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from node_wire_runtime import SecretProvider +from node_wire_stripe.logic import StripeConnector +from node_wire_stripe.schema import ( + CancelSubscriptionInput, + ChargeInput, + CreatePaymentIntentInput, + CreateSubscriptionInput, + IssueRefundInput, +) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "stripe_api_key": "sk_test_mock", + }[key] + + +def _connector() -> StripeConnector: + """Return a StripeConnector with mock secrets.""" + return StripeConnector(secret_provider=MockSecretProvider()) + + +# --------------------------------------------------------------------------- +# Charge +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stripe_charge_happy_path(): + connector = _connector() + params = ChargeInput(amount=1000, currency="usd", source="tok_visa") + + mock_charge = MagicMock(id="ch_123", receipt_url="http://stripe.com/receipt", paid=True) + + with patch("stripe.Charge.create", return_value=mock_charge) as mock_create: + result = await connector.charge(params, trace_id="test-trace") + + assert result.charge_id == "ch_123" + assert result.receipt_url == "http://stripe.com/receipt" + assert result.status == "succeeded" + mock_create.assert_called_once_with( + api_key="sk_test_mock", + amount=1000, + currency="usd", + source="tok_visa", + customer=None, + description=None, + metadata=None, + idempotency_key="test-trace", + ) + + +# --------------------------------------------------------------------------- +# Create Payment Intent +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stripe_create_payment_intent_happy_path(): + connector = _connector() + params = CreatePaymentIntentInput(amount=2000, currency="eur", confirm=True) + + mock_pi = MagicMock(id="pi_123", client_secret="secret_abc", status="requires_payment_method") + + with patch("stripe.PaymentIntent.create", return_value=mock_pi) as mock_create: + result = await connector.create_payment_intent(params, trace_id="test-trace") + + assert result.payment_intent_id == "pi_123" + assert result.client_secret == "secret_abc" + assert result.status == "requires_payment_method" + mock_create.assert_called_once_with( + api_key="sk_test_mock", + amount=2000, + currency="eur", + customer=None, + payment_method=None, + confirm=True, + description=None, + metadata=None, + idempotency_key="test-trace", + ) + + +# --------------------------------------------------------------------------- +# Create Subscription +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stripe_create_subscription_with_card_token(): + connector = _connector() + params = CreateSubscriptionInput( + customer_id="cus_123", price_id="price_abc", card_token="tok_visa" + ) + + mock_pm = MagicMock(id="pm_123") + mock_sub = MagicMock( + id="sub_123", status="active", pending_setup_intent=None, latest_invoice=None + ) + + with ( + patch("stripe.PaymentMethod.create", return_value=mock_pm) as mock_pm_create, + patch("stripe.PaymentMethod.attach") as mock_pm_attach, + patch("stripe.Subscription.create", return_value=mock_sub) as mock_sub_create, + ): + result = await connector.create_subscription(params, trace_id="test-trace") + + assert result.subscription_id == "sub_123" + assert result.status == "active" + + mock_pm_create.assert_called_once() + mock_pm_attach.assert_called_once_with("pm_123", api_key="sk_test_mock", customer="cus_123") + mock_sub_create.assert_called_once() + assert mock_sub_create.call_args.kwargs["default_payment_method"] == "pm_123" + assert mock_sub_create.call_args.kwargs["idempotency_key"] == "test-trace" + + +# --------------------------------------------------------------------------- +# Cancel Subscription +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stripe_cancel_subscription_immediate(): + connector = _connector() + params = CancelSubscriptionInput(subscription_id="sub_123", cancel_at_period_end=False) + + mock_sub = MagicMock(id="sub_123", status="canceled") + + with patch("stripe.Subscription.cancel", return_value=mock_sub) as mock_cancel: + result = await connector.cancel_subscription(params, trace_id="test-trace") + + assert result.subscription_id == "sub_123" + assert result.status == "canceled" + mock_cancel.assert_called_once_with( + "sub_123", api_key="sk_test_mock", idempotency_key="test-trace" + ) + + +# --------------------------------------------------------------------------- +# Issue Refund +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stripe_issue_refund_happy_path(): + connector = _connector() + params = IssueRefundInput(payment_intent_id="pi_123", amount=500) + + mock_refund = MagicMock(id="re_123", status="succeeded") + + with patch("stripe.Refund.create", return_value=mock_refund) as mock_refund_create: + result = await connector.issue_refund(params, trace_id="test-trace") + + assert result.refund_id == "re_123" + assert result.status == "succeeded" + mock_refund_create.assert_called_once_with( + api_key="sk_test_mock", + charge=None, + payment_intent="pi_123", + amount=500, + reason=None, + metadata=None, + idempotency_key="test-trace", + ) + + +# --------------------------------------------------------------------------- +# Schema Validation +# --------------------------------------------------------------------------- + + +def test_stripe_schema_validation_bounds(): + """Verify that amount and currency bounds are enforced.""" + # Valid + ChargeInput(amount=1, currency="usd", source="tok_visa") + + # Invalid amount (too small) + with pytest.raises(ValidationError): + ChargeInput(amount=0, currency="usd", source="tok_visa") + + # Invalid currency (wrong length/format) + with pytest.raises(ValidationError): + ChargeInput(amount=100, currency="us", source="tok_visa") + + with pytest.raises(ValidationError): + ChargeInput(amount=100, currency="USDT", source="tok_visa") + + +# --------------------------------------------------------------------------- +# Error Mapping +# --------------------------------------------------------------------------- + + +def test_stripe_error_mapping(): + """Verify that Stripe exceptions are correctly mapped to ErrorCategory.""" + import stripe + + connector = _connector() + from node_wire_runtime.models import ErrorCategory + + # Check specific mappings from StripeConnector.error_map + assert connector.error_map[stripe.error.CardError] == ( + ErrorCategory.BUSINESS, + "STRIPE_CARD_ERROR", + ) + assert connector.error_map[stripe.error.RateLimitError] == ( + ErrorCategory.RETRYABLE, + "STRIPE_RATE_LIMIT", + ) + assert connector.error_map[stripe.error.AuthenticationError] == ( + ErrorCategory.AUTH, + "STRIPE_AUTH_ERROR", + ) diff --git a/tests/test_toolhive_agent.py b/tests/test_toolhive_agent.py index 5f92366..7321146 100644 --- a/tests/test_toolhive_agent.py +++ b/tests/test_toolhive_agent.py @@ -1,3 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# """ Tests for the ToolHive Agent and LLM Factory ============================================= @@ -6,12 +10,12 @@ Run: pytest tests/test_toolhive_agent.py -v """ + from __future__ import annotations -import asyncio import uuid from typing import Any, Dict, List -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -22,49 +26,78 @@ LLMResponse, ToolCall, ) -from agents.toolhive import AgentRunResult, ToolHiveAgent, ToolHiveMcpClient +from agents.toolhive import ( + ToolHiveAgent, + ToolHiveMcpClient, + _is_tool_failure, + resolve_max_tool_failures, + truncate_tool_result_for_llm, +) + + +def test_truncate_tool_result_for_llm_respects_limit(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "20") + long = "x" * 100 + out = truncate_tool_result_for_llm(long) + assert len(out) > 20 + assert out.startswith("x" * 20) + assert "truncated" in out + + +def test_truncate_tool_result_for_llm_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "0") + long = "y" * 5000 + assert truncate_tool_result_for_llm(long) == long + + +def test_is_tool_failure_detects_validation_and_error_prefix() -> None: + assert _is_tool_failure("Input validation error: bad") + assert _is_tool_failure("ERROR: connection refused") + assert _is_tool_failure('{"success": false, "message": "x"}') + assert not _is_tool_failure("") + assert not _is_tool_failure('{"success": true, "data": {}}') + + +def test_resolve_max_tool_failures_env_and_override(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TOOLHIVE_MAX_TOOL_FAILURES", raising=False) + assert resolve_max_tool_failures(None) == 2 + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_FAILURES", "5") + assert resolve_max_tool_failures(None) == 5 + assert resolve_max_tool_failures(3) == 3 # --------------------------------------------------------------------------- -# Fixtures +# Fixtures (manifest-driven — must match production tools/list) # --------------------------------------------------------------------------- -SAMPLE_TOOLS = [ - { - "name": "fhir_cerner_read_patient", - "description": "Fetch a patient from Cerner FHIR", - "input_schema": { - "type": "object", - "properties": {"patient_id": {"type": "string"}}, - "required": ["patient_id"], - }, - }, - { - "name": "google_drive_upload_file", - "description": "Upload a file to Google Drive", - "input_schema": { - "type": "object", - "properties": { - "file_name": {"type": "string"}, - "content": {"type": "string"}, - }, - "required": ["file_name", "content"], - }, - }, - { - "name": "smtp_send_email", - "description": "Send an email via SMTP", - "input_schema": { - "type": "object", - "properties": { - "to_email": {"type": "string"}, - "subject": {"type": "string"}, - "body": {"type": "string"}, - }, - "required": ["to_email", "subject", "body"], - }, - }, -] + +def _mcp_tools_subset_from_manifest() -> List[Dict[str, Any]]: + """Same input_schema as McpServer.list_tools for a stable agent-test subset.""" + from bindings.factory import ConnectorFactory + from node_wire_runtime.connector_registry import auto_register + from node_wire_runtime.manifest import build_manifest + + auto_register() + factory = ConnectorFactory() + factory.load() + manifest = build_manifest(factory.list_for_protocol("mcp")) + want = {"fhir_cerner.read_patient", "google_drive.files.upload", "smtp.send_email"} + out: List[Dict[str, Any]] = [] + for entry in manifest: + name = f"{entry['connector_id']}.{entry['action']}" + if name in want: + out.append( + { + "name": name, + "description": f"{entry['connector_id']} {entry['action']}", + "input_schema": entry["input_schema"], + } + ) + assert {t["name"] for t in out} == want + return sorted(out, key=lambda t: t["name"]) + + +SAMPLE_TOOLS = _mcp_tools_subset_from_manifest() def _tool_call(name: str, args: Dict[str, Any]) -> ToolCall: @@ -78,7 +111,9 @@ def __init__(self, responses: List[LLMResponse]) -> None: self._responses = list(responses) self._call_count = 0 - def chat_with_tools(self, messages: List[LLMMessage], tools: List[Dict[str, Any]]) -> LLMResponse: + def chat_with_tools( + self, messages: List[LLMMessage], tools: List[Dict[str, Any]] + ) -> LLMResponse: idx = min(self._call_count, len(self._responses) - 1) resp = self._responses[idx] self._call_count += 1 @@ -89,42 +124,38 @@ def chat_with_tools(self, messages: List[LLMMessage], tools: List[Dict[str, Any] # LLM Factory tests # --------------------------------------------------------------------------- + def test_llm_factory_groq_created() -> None: """LLMProviderFactory.create('groq') should return a GroqProvider instance.""" - from agents.llm_factory import LLMProviderFactory - import agents.providers.groq_provider - print(f"\nDEBUG: gp file: {agents.providers.groq_provider.__file__}") - print(f"DEBUG: gp dir: {dir(agents.providers.groq_provider)}") + with patch("agents.providers.groq_provider.Groq"): provider = LLMProviderFactory.create("groq", api_key="test-key", model="llama3-8b-8192") from agents.providers.groq_provider import GroqProvider + assert isinstance(provider, GroqProvider) def test_llm_factory_openai_created() -> None: """LLMProviderFactory.create('openai') should return an OpenAIProvider instance.""" - from agents.llm_factory import LLMProviderFactory - import agents.providers.openai_provider with patch("agents.providers.openai_provider.OpenAI"): provider = LLMProviderFactory.create("openai", api_key="test-key", model="gpt-4o-mini") from agents.providers.openai_provider import OpenAIProvider + assert isinstance(provider, OpenAIProvider) def test_llm_factory_unknown_raises() -> None: """LLMProviderFactory.create with an unknown provider should raise ValueError.""" - from agents.llm_factory import LLMProviderFactory with pytest.raises(ValueError, match="Unknown LLM provider"): LLMProviderFactory.create("foobar") def test_llm_factory_case_insensitive() -> None: """Provider names should be case-insensitive.""" - from agents.llm_factory import LLMProviderFactory - import agents.providers.groq_provider with patch("agents.providers.groq_provider.Groq"): provider = LLMProviderFactory.create("GROQ", api_key="k", model="m") from agents.providers.groq_provider import GroqProvider + assert isinstance(provider, GroqProvider) @@ -132,6 +163,7 @@ def test_llm_factory_case_insensitive() -> None: # ToolHive Agent tests # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_agent_runs_three_tool_sequence() -> None: """ @@ -142,23 +174,43 @@ async def test_agent_runs_three_tool_sequence() -> None: # Step 1: Call FHIR LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "12724066"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], stop_reason="tool_calls", ), # Step 2: Call Drive LLMResponse( content=None, - tool_calls=[_tool_call("google_drive_upload_file", {"file_name": "summary.txt", "content": "Patient: John"})], + tool_calls=[ + _tool_call( + "google_drive.files.upload", + { + "name": "summary.txt", + "mime_type": "text/plain", + "content": "Patient: John", + }, + ) + ], stop_reason="tool_calls", ), # Step 3: Send email LLMResponse( content=None, - tool_calls=[_tool_call("smtp_send_email", {"to_email": "doc@example.com", "subject": "Summary", "body": "Patient: John"})], + tool_calls=[ + _tool_call( + "smtp.send_email", + { + "to": ["doc@example.com"], + "subject": "Summary", + "body": "Patient: John", + }, + ) + ], stop_reason="tool_calls", ), # Final answer - LLMResponse(content="All 3 steps completed successfully.", tool_calls=[], stop_reason="stop"), + LLMResponse( + content="All 3 steps completed successfully.", tool_calls=[], stop_reason="stop" + ), ] provider = _MockLLMProvider(responses) @@ -173,21 +225,75 @@ async def test_agent_runs_three_tool_sequence() -> None: assert result.success is True assert result.final_answer == "All 3 steps completed successfully." assert len(result.steps) == 3 - assert result.steps[0].tool_called == "fhir_cerner_read_patient" - assert result.steps[1].tool_called == "google_drive_upload_file" - assert result.steps[2].tool_called == "smtp_send_email" + assert result.steps[0].tool_called == "fhir_cerner.read_patient" + assert result.steps[1].tool_called == "google_drive.files.upload" + assert result.steps[2].tool_called == "smtp.send_email" # Verify MCP was called exactly 3 times assert mock_mcp.call_tool.await_count == 3 +@pytest.mark.asyncio +async def test_agent_run_events_emits_done_message_with_trace_id() -> None: + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], + stop_reason="tool_calls", + ), + LLMResponse(content="All done.", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = '{"status": "ok"}' + + agent = ToolHiveAgent(mcp_client=mock_mcp, llm_provider=provider, max_steps=5) + events = [event async for event in agent.run_events("Fetch patient 12724066")] + + assert events[0]["type"] == "meta" + assert any(event["type"] == "step" for event in events) + assert any(event["type"] == "final_chunk" for event in events) + assert events[-1]["type"] == "done" + assert events[-1]["success"] is True + assert events[-1]["trace_id"] == events[0]["trace_id"] + assert events[-1]["message"] == f"Streaming completed. trace_id={events[0]['trace_id']}" + + +@pytest.mark.asyncio +async def test_agent_id_first_turn_calls_read_patient_with_resource_id() -> None: + """Document ID-first flow: Cerner read uses canonical resource_id (not search_patients).""" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], + stop_reason="tool_calls", + ), + LLMResponse(content="Patient retrieved.", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = '{"success": true}' + + agent = ToolHiveAgent(mcp_client=mock_mcp, llm_provider=provider, max_steps=10) + result = await agent.run("Patient ID 12724066 — fetch from Cerner") + + assert result.success is True + mock_mcp.call_tool.assert_awaited_once() + call = mock_mcp.call_tool.await_args + assert call[0][0] == "fhir_cerner.read_patient" + assert call[0][1]["resource_id"] == "12724066" + + @pytest.mark.asyncio async def test_agent_respects_max_steps() -> None: """Agent should stop and return an error if max_steps is reached.""" # LLM always returns a tool call — never finishes infinite_response = LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "x"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "x"})], stop_reason="tool_calls", ) provider = _MockLLMProvider([infinite_response]) @@ -211,10 +317,12 @@ async def test_agent_handles_tool_error_gracefully() -> None: responses = [ LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "bad"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "bad"})], stop_reason="tool_calls", ), - LLMResponse(content="Unable to fetch patient — error recorded.", tool_calls=[], stop_reason="stop"), + LLMResponse( + content="Unable to fetch patient — error recorded.", tool_calls=[], stop_reason="stop" + ), ] provider = _MockLLMProvider(responses) @@ -244,96 +352,149 @@ async def test_agent_fails_when_mcp_unreachable() -> None: assert "Failed to list MCP tools" in (result.error or "") +@pytest.mark.asyncio +async def test_agent_stops_after_repeated_tool_failures() -> None: + """After max_tool_failures for the same tool, stop without further LLM steps.""" + fail_msg = "Input validation error: bad args" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {"name": "a.txt"})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {"name": "a.txt"})], + stop_reason="tool_calls", + ), + LLMResponse(content="should not run", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = fail_msg + + agent = ToolHiveAgent( + mcp_client=mock_mcp, + llm_provider=provider, + max_steps=10, + max_tool_failures=2, + ) + result = await agent.run("Upload to Drive") + + assert result.success is False + assert len(result.steps) == 2 + assert "google_drive.files.upload" in (result.error or "") + assert "failed 2 times" in (result.final_answer or result.error or "").lower() + assert mock_mcp.call_tool.await_count == 2 + assert provider._call_count == 2 + + +@pytest.mark.asyncio +async def test_agent_success_then_two_failures_same_tool_aborts() -> None: + """Failures only increment on failed tool results; abort after second failure.""" + ok = '{"success": true, "data": {}}' + fail_msg = "Input validation error: x" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.side_effect = [ok, fail_msg, fail_msg] + + agent = ToolHiveAgent( + mcp_client=mock_mcp, + llm_provider=provider, + max_steps=10, + max_tool_failures=2, + ) + result = await agent.run("x") + + assert result.success is False + assert len(result.steps) == 3 + assert mock_mcp.call_tool.await_count == 3 + + # --------------------------------------------------------------------------- # MCP entrypoint smoke test # --------------------------------------------------------------------------- -def test_mcp_entrypoint_registers_four_tools() -> None: - """The FastMCP server should expose exactly 4 tools.""" - # We patch all external deps before importing the module to avoid side effects - with ( - patch("bindings.factory.ConnectorFactory") as mock_factory_cls, - patch("connectors.auto_register"), - patch("mcp.server.fastmcp.FastMCP", autospec=False) as mock_fastmcp_cls, - ): - mock_factory = MagicMock() - mock_factory._connectors = { - "fhir_cerner": MagicMock(), - "fhir_epic": MagicMock(), - "google_drive": MagicMock(), - "smtp": MagicMock(), - } - mock_factory_cls.return_value = mock_factory - - mock_mcp_instance = MagicMock() - registered_tools: List[str] = [] - - def fake_tool(*args: Any, **kwargs: Any): - name = kwargs.get("name") or (args[0] if args else "unknown") - registered_tools.append(name) - return lambda fn: fn # decorator passthrough - - mock_mcp_instance.tool = fake_tool - mock_fastmcp_cls.return_value = mock_mcp_instance - - # Import inside the test to ensure it picks up the mocks - from agents.mcp_entrypoint import _make_server - _make_server() - - assert len(registered_tools) == 4 - assert "fhir_cerner_read_patient" in registered_tools - assert "fhir_epic_read_patient" in registered_tools - assert "google_drive_upload_file" in registered_tools - assert "smtp_send_email" in registered_tools + +def test_mcp_entrypoint_exposes_manifest_tools() -> None: + """Unified MCP server lists all connectors enabled for MCP in config.""" + from bindings.mcp_server.server import McpServer + + server = McpServer(server_name="node-wire") + names = {t["name"] for t in server.list_tools()} + assert "fhir_cerner.read_patient" in names + assert "fhir_epic.read_patient" in names + assert "google_drive.files.upload" in names + assert "smtp.send_email" in names + assert "stripe.charge" in names + assert "http_generic.request" in names + # Broader surface than the old 8 FastMCP tools + assert len(names) >= 18 # --------------------------------------------------------------------------- -# Individual MCP server smoke tests +# Individual MCP entrypoint modules (thin wrappers) # --------------------------------------------------------------------------- -def _make_server_smoke(module_path: str, expected_tool: str) -> None: - """Helper: verify a per-connector _make_server() registers exactly one tool.""" - with ( - patch("bindings.factory.ConnectorFactory") as mock_factory_cls, - patch("connectors.auto_register"), - patch("mcp.server.fastmcp.FastMCP", autospec=False) as mock_fastmcp_cls, - ): - mock_factory = MagicMock() - mock_factory._connectors = {} - mock_factory_cls.return_value = mock_factory - - mock_mcp_instance = MagicMock() - registered_tools: List[str] = [] - - def fake_tool(*args: Any, **kwargs: Any): - name = kwargs.get("name") or (args[0] if args else "unknown") - registered_tools.append(name) - return lambda fn: fn - - mock_mcp_instance.tool = fake_tool - mock_fastmcp_cls.return_value = mock_mcp_instance - - import importlib - mod = importlib.import_module(module_path) - mod._make_server() - - assert registered_tools == [expected_tool], ( - f"{module_path}: expected [{expected_tool}], got {registered_tools}" - ) +def test_fhir_cerner_mcp_main_callable() -> None: + from agents.fhir_cerner_mcp import main + + assert callable(main) + + +def test_fhir_epic_mcp_main_callable() -> None: + from agents.fhir_epic_mcp import main + + assert callable(main) + + +def test_google_drive_mcp_main_callable() -> None: + from agents.google_drive_mcp import main + + assert callable(main) + + +def test_smtp_mcp_main_callable() -> None: + from agents.smtp_mcp import main + + assert callable(main) -def test_fhir_cerner_mcp_registers_one_tool() -> None: - """fhir_cerner_mcp._make_server() should expose exactly fhir_cerner_read_patient.""" - _make_server_smoke("agents.fhir_cerner_mcp", "fhir_cerner_read_patient") +def test_mcp_server_matches_per_connector_entrypoints() -> None: + """Per-connector scripts use connector_ids filter; tool prefixes must match.""" + from bindings.mcp_server.server import McpServer -def test_fhir_epic_mcp_registers_one_tool() -> None: - """fhir_epic_mcp._make_server() should expose exactly fhir_epic_read_patient.""" - _make_server_smoke("agents.fhir_epic_mcp", "fhir_epic_read_patient") + full = {t["name"] for t in McpServer().list_tools()} + cerner = {t["name"] for t in McpServer(connector_ids=["fhir_cerner"]).list_tools()} + assert cerner == {n for n in full if n.startswith("fhir_cerner.")} -def test_google_drive_mcp_registers_one_tool() -> None: - """google_drive_mcp._make_server() should expose exactly google_drive_upload_file.""" - _make_server_smoke("agents.google_drive_mcp", "google_drive_upload_file") + epic = {t["name"] for t in McpServer(connector_ids=["fhir_epic"]).list_tools()} + assert epic == {n for n in full if n.startswith("fhir_epic.")} + drive = {t["name"] for t in McpServer(connector_ids=["google_drive"]).list_tools()} + assert drive == {n for n in full if n.startswith("google_drive.")} + assert "google_drive.files.upload" in drive + smtp = {t["name"] for t in McpServer(connector_ids=["smtp"]).list_tools()} + assert smtp == {"smtp.send_email"} diff --git a/tests/test_toolhive_client.py b/tests/test_toolhive_client.py new file mode 100644 index 0000000..fbb9c79 --- /dev/null +++ b/tests/test_toolhive_client.py @@ -0,0 +1,87 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for ToolHiveMcpClient HTTP transport and toolhive helper edge cases.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import patch + +import httpx +import pytest + +from agents.toolhive import ( + ToolHiveMcpClient, + resolve_max_tool_failures, + truncate_tool_result_for_llm, +) + + +def test_truncate_tool_result_non_numeric_env_uses_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "not-int") + long = "z" * 5000 + assert truncate_tool_result_for_llm(long) == long + + +def test_resolve_max_tool_failures_non_numeric_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_FAILURES", "bad") + assert resolve_max_tool_failures(None) == 2 + + +def test_toolhive_mcp_client_initialize_list_tools_call_tool() -> None: + """Exercise _initialize, tools/list, and tools/call over MockTransport.""" + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content.decode()) + method = body.get("method") + req_id = body.get("id") + if method == "initialize": + return httpx.Response( + 200, + json={"jsonrpc": "2.0", "id": req_id, "result": {"protocolVersion": "2024-11-05"}}, + headers={"Mcp-Session-Id": "sess-abc"}, + ) + if method == "notifications/initialized": + return httpx.Response(200, json={}) + if method == "tools/list": + return httpx.Response( + 200, + json={ + "jsonrpc": "2.0", + "id": req_id, + "result": {"tools": [{"name": "smtp.send_email", "description": "d"}]}, + }, + ) + if method == "tools/call": + return httpx.Response( + 200, + json={ + "jsonrpc": "2.0", + "id": req_id, + "result": {"content": [{"type": "text", "text": "sent"}]}, + }, + ) + return httpx.Response(404, json={"error": "unknown"}) + + transport = httpx.MockTransport(handler) + _RealAsyncClient = httpx.AsyncClient + + def make_client(**kwargs: object) -> httpx.AsyncClient: + return _RealAsyncClient(transport=transport, timeout=float(kwargs.get("timeout", 60.0))) + + async def _run() -> None: + with patch("httpx.AsyncClient", side_effect=make_client): + client = ToolHiveMcpClient("http://127.0.0.1:9/mcp") + tools = await client.list_tools() + assert len(tools) == 1 + assert tools[0]["name"] == "smtp.send_email" + text = await client.call_tool( + "smtp.send_email", + {"to": ["a@b.com"], "subject": "s", "body": "b"}, + ) + assert text == "sent" + + asyncio.run(_run())