From 7941922c207f1ce658fe46039d29c3e41d0240b5 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sat, 4 Jul 2026 10:58:43 -0700 Subject: [PATCH 1/9] Added coverage gate for PR changes. --- .github/workflows/pytest.yml | 40 +++++++++++++++++++++++++++++++++++- pyproject.toml | 4 ++++ uv.lock | 17 +++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d0606b9..2fc14d6 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -63,6 +63,44 @@ jobs: if-no-files-found: ignore + patch-coverage: + name: Patch coverage (changed lines ≥ 80%) + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + # Full history so diff-cover can diff against the base branch. + fetch-depth: 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: "3.11" + + - name: Install uv + uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 + with: + enable-cache: true + cache-dependency-glob: | + pyproject.toml + uv.lock + + - name: Install dependencies + run: uv sync --frozen --all-extras --dev + + - name: Run pytest with coverage + run: uv run pytest --cov --cov-report=xml + + - name: Enforce 80% coverage on changed lines + run: > + uv run diff-cover coverage.xml + --compare-branch origin/${{ github.base_ref }} + --fail-under 80 + --show-uncovered + playground-integration: name: Playground integration tests runs-on: ubuntu-latest @@ -98,7 +136,7 @@ jobs: GOOGLE_DRIVE_SA_JSON: ${{ secrets.GOOGLE_DRIVE_SA_JSON }} GOOGLE_DRIVE_FOLDER_ID: ${{ secrets.GOOGLE_DRIVE_FOLDER_ID }} GDRIVE_TEST_RECIPIENT_EMAIL: ${{ secrets.GDRIVE_TEST_RECIPIENT_EMAIL }} - + #stripe STRIPE_API_KEY: ${{ secrets.STRIPE_API_KEY }} STRIPE_TEST_CUSTOMER_ID: ${{ secrets.STRIPE_TEST_CUSTOMER_ID }} diff --git a/pyproject.toml b/pyproject.toml index d04a89e..4d320bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ dev = [ "bandit[toml]>=1.7.9", "pre-commit>=4.0.0", "pytest-playwright>=0.4.0", + "diff-cover>=9.0.0", "cyclonedx-bom==4.6.1", "reuse>=5.0.0", "licenseheaders>=0.8.8", @@ -142,6 +143,9 @@ fail_under = 80 omit = [ "*/__pycache__/*", "*/node_wire.egg-info/*", + # Generated gRPC stubs — also excluded from Ruff, Mypy, and CodeQL. + "*_pb2.py", + "*_pb2_grpc.py", ] [tool.ruff] diff --git a/uv.lock b/uv.lock index a493fad..471c3b5 100644 --- a/uv.lock +++ b/uv.lock @@ -754,6 +754,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, ] +[[package]] +name = "diff-cover" +version = "10.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "chardet" }, + { name = "jinja2" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/21/057e816125c162662d2a2cc2ebcd72dd333e78e51678298d07dd3146011a/diff_cover-10.3.0.tar.gz", hash = "sha256:474dbc63e815fbb7567d7b7ca5b104123e96129f25426ebdbc9a1bdbb935b2c6", size = 106546, upload-time = "2026-05-30T14:17:14.32Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/0a/a96e57a7a3fca419cd5ceff0d13dee2166520fa67103fb82624ad64700fb/diff_cover-10.3.0-py3-none-any.whl", hash = "sha256:2e47d5ab3868d1e92131c11f364f3f4a8583c97123d3bbc6b6cc8ce0a4cc2202", size = 58989, upload-time = "2026-05-30T14:17:12.858Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -2189,6 +2204,7 @@ agents = [ dev = [ { name = "bandit" }, { name = "cyclonedx-bom" }, + { name = "diff-cover" }, { name = "licenseheaders" }, { name = "mypy" }, { name = "pip-audit" }, @@ -2237,6 +2253,7 @@ provides-extras = ["agents"] dev = [ { name = "bandit", extras = ["toml"], specifier = ">=1.7.9" }, { name = "cyclonedx-bom", specifier = "==4.6.1" }, + { name = "diff-cover", specifier = ">=9.0.0" }, { name = "licenseheaders", specifier = ">=0.8.8" }, { name = "mypy", specifier = ">=1.9.0" }, { name = "pip-audit", specifier = ">=2.7.0" }, From 8574f9a42572540591ffd9f3890f1851b4521b99 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sat, 4 Jul 2026 11:25:55 -0700 Subject: [PATCH 2/9] Added more tests to push coverage to 85%. Added gate for PRs to have 80% coverage --- tests/test_grpc_bindings.py | 343 ++++++++++++++++++ tests/test_oauth2_extra.py | 497 ++++++++++++++++++++++++++ tests/test_runtime_misc.py | 605 ++++++++++++++++++++++++++++++++ tests/test_secrets_providers.py | 478 +++++++++++++++++++++++++ 4 files changed, 1923 insertions(+) create mode 100644 tests/test_grpc_bindings.py create mode 100644 tests/test_oauth2_extra.py create mode 100644 tests/test_runtime_misc.py create mode 100644 tests/test_secrets_providers.py diff --git a/tests/test_grpc_bindings.py b/tests/test_grpc_bindings.py new file mode 100644 index 0000000..33fb1d4 --- /dev/null +++ b/tests/test_grpc_bindings.py @@ -0,0 +1,343 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for gRPC server servicer and auth interceptor.""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import grpc +import pytest + +from bindings.grpc_server import connector_pb2 +from bindings.grpc_server.auth import ( + GrpcAuthInterceptor, + _extract_token, + _truthy, + get_grpc_caller_identity, + verify_grpc_token_and_identity, +) +from bindings.grpc_server.server import ConnectorServiceServicer +from node_wire_runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.rate_limit import RateLimitExceeded + + +# --------------------------------------------------------------------------- +# _truthy helper +# --------------------------------------------------------------------------- + + +def test_truthy_none_is_false() -> None: + assert _truthy(None) is False + + +def test_truthy_true_string() -> None: + for val in ("true", "TRUE", "1", "yes", "on", " True "): + assert _truthy(val) is True, val + + +def test_truthy_false_string() -> None: + for val in ("false", "0", "no", "off", ""): + assert _truthy(val) is False, val + + +# --------------------------------------------------------------------------- +# _extract_token helper +# --------------------------------------------------------------------------- + + +def test_extract_token_from_authorization_bearer() -> None: + meta = (("authorization", "Bearer mytoken123"),) + assert _extract_token(meta) == "mytoken123" + + +def test_extract_token_from_authorization_raw() -> None: + meta = (("authorization", "rawtoken"),) + assert _extract_token(meta) == "rawtoken" + + +def test_extract_token_from_x_api_key() -> None: + meta = (("x-api-key", " apikey "),) + assert _extract_token(meta) == "apikey" + + +def test_extract_token_returns_none_when_absent() -> None: + assert _extract_token(()) is None + assert _extract_token((("content-type", "application/json"),)) is None + + +# --------------------------------------------------------------------------- +# GrpcAuthInterceptor.intercept_service +# --------------------------------------------------------------------------- + + +def _make_call_details(metadata: tuple = ()) -> grpc.HandlerCallDetails: + details = MagicMock(spec=grpc.HandlerCallDetails) + details.invocation_metadata = metadata + return details + + +def test_interceptor_auth_disabled_passes_through(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_GRPC_AUTH_DISABLED", "true") + sentinel = object() + cont = MagicMock(return_value=sentinel) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details()) + assert result is sentinel + cont.assert_called_once() + + +def test_interceptor_returns_abort_when_not_configured(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_GRPC_API_KEY", raising=False) + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + cont = MagicMock() + handler = GrpcAuthInterceptor().intercept_service(cont, _make_call_details()) + # The returned handler should abort when called + context = MagicMock(spec=grpc.ServicerContext) + handler.unary_unary(None, context) + context.abort.assert_called_once() + code, msg = context.abort.call_args[0] + assert code == grpc.StatusCode.UNAVAILABLE + assert "NW_GRPC_API_KEY" in msg + + +def test_interceptor_returns_unauthenticated_when_no_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "secret") + + cont = MagicMock() + handler = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=())) + context = MagicMock(spec=grpc.ServicerContext) + handler.unary_unary(None, context) + context.abort.assert_called_once() + code, _ = context.abort.call_args[0] + assert code == grpc.StatusCode.UNAUTHENTICATED + + +def test_interceptor_returns_unauthenticated_for_bad_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "correct-key") + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + cont = MagicMock() + meta = (("x-api-key", "wrong-key"),) + handler = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + context = MagicMock(spec=grpc.ServicerContext) + handler.unary_unary(None, context) + context.abort.assert_called_once() + code, _ = context.abort.call_args[0] + assert code == grpc.StatusCode.UNAUTHENTICATED + + +def test_interceptor_wraps_handler_on_valid_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "valid-key") + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + inner_handler = grpc.unary_unary_rpc_method_handler(lambda req, ctx: "ok") + cont = MagicMock(return_value=inner_handler) + meta = (("x-api-key", "valid-key"),) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + assert result is not None + assert result.unary_unary is not None + assert result.unary_unary(None, None) == "ok" + + +def test_interceptor_returns_none_when_continuation_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "valid-key") + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + cont = MagicMock(return_value=None) + meta = (("x-api-key", "valid-key"),) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + assert result is None + + +def test_interceptor_jwt_auth(monkeypatch: pytest.MonkeyPatch) -> None: + from tests.jwt_test_helpers import mint_test_jwt + + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_GRPC_API_KEY", raising=False) + monkeypatch.setenv("NW_GRPC_JWT_SECRET", "test-grpc-secret") + + token = mint_test_jwt({"sub": "svc"}, "test-grpc-secret") + inner_handler = grpc.unary_unary_rpc_method_handler(lambda req, ctx: "ok") + cont = MagicMock(return_value=inner_handler) + meta = (("authorization", f"Bearer {token}"),) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + assert result is not None + assert result.unary_unary(None, None) == "ok" + + +# --------------------------------------------------------------------------- +# ConnectorServiceServicer._invoke_async +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def servicer() -> ConnectorServiceServicer: + return ConnectorServiceServicer() + + +async def test_invoke_rate_limit_exceeded(servicer: ConnectorServiceServicer) -> None: + async def _raise(*_a: Any, **_kw: Any) -> None: + raise RateLimitExceeded("too fast") + + with patch("bindings.grpc_server.server.global_rate_limiter") as mock_rl: + mock_rl.acquire = _raise + req = connector_pb2.InvokeRequest(connector_id="any", action="act") + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "RATE_LIMIT_EXCEEDED" + assert resp.error_category == ErrorCategory.RETRYABLE.value + + +async def test_invoke_unknown_connector_returns_not_available( + servicer: ConnectorServiceServicer, +) -> None: + with patch.object(servicer._factory, "get_for_protocol", return_value=None): + req = connector_pb2.InvokeRequest(connector_id="no_such", action="act") + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "CONNECTOR_NOT_AVAILABLE" + assert resp.error_category == ErrorCategory.BUSINESS.value + assert "no_such" in resp.message + + +async def test_invoke_invalid_json_payload(servicer: ConnectorServiceServicer) -> None: + fake_connector = MagicMock() + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest( + connector_id="x", action="y", payload_json="not-valid-json{" + ) + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "INVALID_JSON" + assert resp.error_category == ErrorCategory.BUSINESS.value + + +async def test_invoke_success_path(servicer: ConnectorServiceServicer) -> None: + fake_connector = MagicMock() + fake_connector.run = AsyncMock( + return_value=ConnectorResponse( + success=True, + data={"result": "hello"}, + trace_id="trace-001", + ) + ) + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest( + connector_id="x", action="greet", payload_json='{"field": "val"}' + ) + resp = await servicer._invoke_async(req) + + assert resp.success is True + assert resp.trace_id == "trace-001" + assert "result" in resp.data_json + + +async def test_invoke_action_injected_into_payload(servicer: ConnectorServiceServicer) -> None: + captured_payload: list[Any] = [] + + async def mock_run(payload: Any, **_: Any) -> ConnectorResponse: + captured_payload.append(payload) + return ConnectorResponse(success=True, trace_id="t1") + + fake_connector = MagicMock() + fake_connector.run = mock_run + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest( + connector_id="x", + action="do_thing", + payload_json='{"some": "data"}', + ) + await servicer._invoke_async(req) + + assert captured_payload[0]["action"] == "do_thing" + + +async def test_invoke_identity_propagated(servicer: ConnectorServiceServicer) -> None: + from node_wire_runtime.caller_identity import build_caller_identity + + identity = build_caller_identity({"sub": "grpc-svc"}, auth_type="grpc_api_key") + captured: list[Any] = [] + + async def mock_run(payload: Any, **kwargs: Any) -> ConnectorResponse: + captured.append(kwargs) + return ConnectorResponse(success=True, trace_id="t2") + + fake_connector = MagicMock() + fake_connector.run = mock_run + with ( + patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector), + patch("bindings.grpc_server.server.get_grpc_caller_identity", return_value=identity), + ): + req = connector_pb2.InvokeRequest(connector_id="x", action="act", payload_json="{}") + await servicer._invoke_async(req) + + assert captured[0]["principal"] == identity.principal + assert captured[0]["tenant_id"] == identity.tenant_id + + +async def test_invoke_no_identity_passes_none(servicer: ConnectorServiceServicer) -> None: + captured: list[Any] = [] + + async def mock_run(payload: Any, **kwargs: Any) -> ConnectorResponse: + captured.append(kwargs) + return ConnectorResponse(success=True, trace_id="t3") + + fake_connector = MagicMock() + fake_connector.run = mock_run + with ( + patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector), + patch("bindings.grpc_server.server.get_grpc_caller_identity", return_value=None), + ): + req = connector_pb2.InvokeRequest(connector_id="x", action="act", payload_json="{}") + await servicer._invoke_async(req) + + assert captured[0]["principal"] is None + assert captured[0]["tenant_id"] is None + + +async def test_invoke_empty_payload_json(servicer: ConnectorServiceServicer) -> None: + fake_connector = MagicMock() + fake_connector.run = AsyncMock( + return_value=ConnectorResponse(success=True, trace_id="t4") + ) + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest(connector_id="x", action="act") + resp = await servicer._invoke_async(req) + assert resp.success is True + + +async def test_invoke_error_response_maps_error_category( + servicer: ConnectorServiceServicer, +) -> None: + fake_connector = MagicMock() + fake_connector.run = AsyncMock( + return_value=ConnectorResponse( + success=False, + error_code="UPSTREAM_TIMEOUT", + error_category=ErrorCategory.RETRYABLE, + message="upstream timed out", + trace_id="t5", + ) + ) + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest(connector_id="x", action="act", payload_json="{}") + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "UPSTREAM_TIMEOUT" + assert resp.error_category == ErrorCategory.RETRYABLE.value diff --git a/tests/test_oauth2_extra.py b/tests/test_oauth2_extra.py new file mode 100644 index 0000000..4bc3bd9 --- /dev/null +++ b/tests/test_oauth2_extra.py @@ -0,0 +1,497 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Additional coverage for OAuth2AuthProvider and ServiceAccountAuthProvider.""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from node_wire_runtime.auth import OAuth2AuthProvider, ServiceAccountAuthProvider +from node_wire_runtime.secrets import SecretProvider +from node_wire_runtime.secrets.base import SecretNotFoundError + + +class _DictSecretProvider(SecretProvider): + def __init__(self, data: dict[str, str]) -> None: + self._data = data + + def get_secret(self, key: str) -> str: + if key not in self._data: + raise SecretNotFoundError(key) + return self._data[key] + + +def _oauth2_client_secret_provider(**extra: str) -> _DictSecretProvider: + data = { + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "super-secret", + **extra, + } + return _DictSecretProvider(data) + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — invalid grant method +# --------------------------------------------------------------------------- + + +def test_oauth2_invalid_grant_method_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + with pytest.raises(ValueError, match="Unsupported grant_method"): + OAuth2AuthProvider( + secret_provider=sp, + grant_method="magic", + token_url_secret="token_url", + client_id_secret="client_id", + ) + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — _resolve_scopes via scopes_secret +# --------------------------------------------------------------------------- + + +def test_resolve_scopes_from_secret() -> None: + sp = _DictSecretProvider({ + "token_url": "x", + "client_id": "y", + "scope_val": "openid profile", + }) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", # not used in resolve_scopes + scopes_secret="scope_val", + ) + assert provider._resolve_scopes() == "openid profile" + + +def test_resolve_scopes_falls_back_to_static_when_secret_missing() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + scopes_secret="missing_key", + scopes=["read", "write"], + ) + assert provider._resolve_scopes() == "read write" + + +def test_resolve_scopes_returns_none_when_none_configured() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + ) + assert provider._resolve_scopes() is None + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — client_secret_post grant +# --------------------------------------------------------------------------- + + +async def test_oauth2_client_secret_post_success() -> None: + sp = _oauth2_client_secret_provider() + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_secret", + scopes=["read"], + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "tok-cs", "expires_in": 3600} + + with patch("node_wire_runtime.auth.oauth2.httpx.AsyncClient") as mock_client_cls: + mock_client = MagicMock() + mock_client.__aenter__ = MagicMock(return_value=mock_client) + mock_client.__aexit__ = MagicMock(return_value=False) + mock_client.post = MagicMock(return_value=mock_response) + # Make post awaitable + import asyncio + mock_client.post = MagicMock(side_effect=lambda *a, **kw: mock_response) + + async def _fake_enter(_: Any) -> Any: + return mock_client + + async def _fake_exit(*_: Any) -> bool: + return False + + mock_client.__aenter__ = _fake_enter + mock_client.__aexit__ = _fake_exit + mock_client_cls.return_value = mock_client + + # Use respx-style: patch the _post_token static method directly + with patch.object( + OAuth2AuthProvider, + "_post_token", + return_value={"access_token": "tok-cs", "expires_in": 3600}, + ): + headers = await provider.get_headers() + + assert headers["Authorization"] == "Bearer tok-cs" + + +async def test_client_secret_post_missing_secret_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + # client_secret_secret deliberately absent + ) + with pytest.raises(ValueError, match="client_secret_secret"): + await provider._fetch_client_secret_post() + + +async def test_oauth2_client_secret_post_with_scope() -> None: + sp = _oauth2_client_secret_provider() + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_secret", + scopes=["openid"], + ) + captured_data: list[dict] = [] + + async def fake_post_token(url: str, data: dict) -> dict: + captured_data.append(data) + return {"access_token": "tok2", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post_token): + await provider.get_headers() + + assert captured_data[0]["scope"] == "openid" + assert captured_data[0]["grant_type"] == "client_credentials" + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — refresh_token grant +# --------------------------------------------------------------------------- + + +async def test_oauth2_refresh_token_success() -> None: + sp = _DictSecretProvider({ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "sec", + "refresh_token": "rt-123", + }) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="refresh_token", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_secret", + refresh_token_secret="refresh_token", + scopes=["profile"], + ) + captured: list[dict] = [] + + async def fake_post(url: str, data: dict) -> dict: + captured.append(data) + return {"access_token": "new-tok", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post): + headers = await provider.get_headers() + + assert headers["Authorization"] == "Bearer new-tok" + assert captured[0]["grant_type"] == "refresh_token" + assert captured[0]["refresh_token"] == "rt-123" + assert captured[0]["client_secret"] == "sec" + assert captured[0]["scope"] == "profile" + + +async def test_oauth2_refresh_token_without_client_secret() -> None: + sp = _DictSecretProvider({ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "refresh_token": "rt-456", + }) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="refresh_token", + token_url_secret="token_url", + client_id_secret="client_id", + refresh_token_secret="refresh_token", + ) + captured: list[dict] = [] + + async def fake_post(url: str, data: dict) -> dict: + captured.append(data) + return {"access_token": "no-secret-tok", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post): + await provider.get_headers() + + assert "client_secret" not in captured[0] + + +async def test_oauth2_refresh_token_missing_secret_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="refresh_token", + token_url_secret="token_url", + client_id_secret="client_id", + # refresh_token_secret absent + ) + with pytest.raises(ValueError, match="refresh_token_secret"): + await provider._fetch_refresh_token() + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — private_key_jwt grant +# --------------------------------------------------------------------------- + + +async def test_oauth2_private_key_jwt_invalid_key_raises() -> None: + sp = _DictSecretProvider({ + "token_url": "https://idp.example.com/token", + "client_id": "client", + "private_key": "not-a-valid-pem-key", + "kid": "key-1", + }) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + private_key_secret="private_key", + kid_secret="kid", + algorithm="RS384", + ) + with pytest.raises(ValueError, match="private_key_jwt"): + await provider._fetch_private_key_jwt() + + +async def test_oauth2_private_key_jwt_missing_secrets_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + # private_key_secret and kid_secret absent + ) + with pytest.raises(ValueError, match="private_key_secret.*kid_secret"): + await provider._fetch_private_key_jwt() + + +async def test_oauth2_private_key_jwt_success_with_rsa_key() -> None: + """Test private_key_jwt success path with a real RSA key.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + sp = _DictSecretProvider({ + "token_url": "https://idp.example.com/token", + "client_id": "client-id", + "private_key": private_pem, + "kid": "rsa-key-1", + }) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + private_key_secret="private_key", + kid_secret="kid", + algorithm="RS384", + scopes=["system/*.read"], + ) + captured: list[dict] = [] + + async def fake_post(url: str, data: dict) -> dict: + captured.append(data) + return {"access_token": "pkey-tok", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post): + headers = await provider.get_headers() + + assert headers["Authorization"] == "Bearer pkey-tok" + assert captured[0]["grant_type"] == "client_credentials" + assert "client_assertion" in captured[0] + assert captured[0]["scope"] == "system/*.read" + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — _post_token non-200 error +# --------------------------------------------------------------------------- + + +async def test_post_token_non_200_raises() -> None: + transport = httpx.MockTransport( + handler=lambda req: httpx.Response(401, text="Unauthorized") + ) + + with patch( + "node_wire_runtime.auth.oauth2.httpx.AsyncClient", + return_value=httpx.AsyncClient(transport=transport), + ): + with pytest.raises(ValueError, match="HTTP 401"): + await OAuth2AuthProvider._post_token( + "https://fake.example.com/token", + {"grant_type": "client_credentials"}, + ) + + +# --------------------------------------------------------------------------- +# ServiceAccountAuthProvider +# --------------------------------------------------------------------------- + + +async def test_sa_get_headers_returns_empty() -> None: + sp = _DictSecretProvider({"sa_json": '{"type": "service_account"}'}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + headers = await provider.get_headers() + assert headers == {} + + +async def test_sa_get_client_credentials_caches() -> None: + fake_creds = MagicMock() + sp = _DictSecretProvider({"sa_json": json.dumps({"type": "service_account"})}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with patch.object(provider, "_build_credentials", return_value=fake_creds) as mock_build: + creds1 = await provider.get_client_credentials() + creds2 = await provider.get_client_credentials() + + assert creds1 is fake_creds + assert creds2 is fake_creds + mock_build.assert_called_once() + + +async def test_sa_refresh_clears_credentials() -> None: + fake_creds = MagicMock() + sp = _DictSecretProvider({"sa_json": "{}"}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + provider._credentials = fake_creds + + await provider.refresh() + assert provider._credentials is None + + +def test_sa_build_credentials_from_json_string() -> None: + import json as _json + + fake_creds = MagicMock() + sa_info = { + "type": "service_account", + "project_id": "proj", + "private_key_id": "key1", + "private_key": "-----BEGIN RSA PRIVATE KEY-----\nfake\n-----END RSA PRIVATE KEY-----\n", + "client_email": "svc@proj.iam.gserviceaccount.com", + "client_id": "12345", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + } + sp = _DictSecretProvider({"sa_json": _json.dumps(sa_info)}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + return_value=fake_creds, + ): + creds = provider._build_credentials() + assert creds is fake_creds + + +def test_sa_build_credentials_from_file_path(tmp_path: Any) -> None: + import json as _json + + sa_info = { + "type": "service_account", + "project_id": "proj", + "private_key_id": "key1", + "private_key": "fake-key", + "client_email": "svc@proj.iam.gserviceaccount.com", + "client_id": "12345", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + } + sa_file = tmp_path / "sa.json" + sa_file.write_text(_json.dumps(sa_info)) + + fake_creds = MagicMock() + sp = _DictSecretProvider({"sa_json": str(sa_file)}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with patch( + "google.oauth2.service_account.Credentials.from_service_account_file", + return_value=fake_creds, + ): + creds = provider._build_credentials() + assert creds is fake_creds + + +def test_sa_build_credentials_file_not_found(tmp_path: Any) -> None: + sp = _DictSecretProvider({"sa_json": "/nonexistent/path/sa.json"}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with pytest.raises(ValueError, match="not found at path"): + provider._build_credentials() + + +def test_sa_build_credentials_invalid_json_info() -> None: + import json as _json + + bad_info = {"type": "service_account"} # missing required fields + sp = _DictSecretProvider({"sa_json": _json.dumps(bad_info)}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with pytest.raises((ValueError, Exception)): + provider._build_credentials() + + +def test_sa_build_credentials_google_auth_not_installed() -> None: + import sys + + sp = _DictSecretProvider({"sa_json": '{"type": "service_account"}'}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + original = sys.modules.get("google.oauth2") + original_sa = sys.modules.get("google.oauth2.service_account") + try: + sys.modules["google.oauth2"] = None # type: ignore[assignment] + sys.modules["google.oauth2.service_account"] = None # type: ignore[assignment] + with pytest.raises(ImportError, match="google-auth"): + provider._build_credentials() + finally: + if original is None: + sys.modules.pop("google.oauth2", None) + else: + sys.modules["google.oauth2"] = original + if original_sa is None: + sys.modules.pop("google.oauth2.service_account", None) + else: + sys.modules["google.oauth2.service_account"] = original_sa diff --git a/tests/test_runtime_misc.py b/tests/test_runtime_misc.py new file mode 100644 index 0000000..accac11 --- /dev/null +++ b/tests/test_runtime_misc.py @@ -0,0 +1,605 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for rate_limit, streaming, log_sanitization gaps, observability +wrappers, and connector_registry error paths.""" + +from __future__ import annotations + +import logging +import sys +from importlib.metadata import EntryPoint +from unittest.mock import MagicMock, patch + +import pytest + +from node_wire_runtime import connector_registry +from node_wire_runtime.log_sanitization import ( + REDACTED, + SanitizingLogFilter, + _redact_sensitive_string_arg, + install_sanitizing_log_filter, + is_sensitive_key, + sanitize_log_record, + sanitize_mapping, + sanitize_value, +) +from node_wire_runtime.rate_limit import RateLimitExceeded, TokenBucket +from node_wire_runtime.streaming import ( + BufferedStreamIterator, + resolve_stream_buffer_ms, + stream_completion_log, +) + + +# --------------------------------------------------------------------------- +# TokenBucket — acquire paths +# --------------------------------------------------------------------------- + + +async def test_token_bucket_acquire_success() -> None: + bucket = TokenBucket(capacity=10, refill_rate=1) + await bucket.acquire(1) + assert bucket.tokens == 9 + + +async def test_token_bucket_acquire_multiple() -> None: + bucket = TokenBucket(capacity=10, refill_rate=1) + await bucket.acquire(5) + assert bucket.tokens == 5 + + +async def test_token_bucket_acquire_raises_when_exhausted() -> None: + bucket = TokenBucket(capacity=1, refill_rate=0) + await bucket.acquire(1) + with pytest.raises(RateLimitExceeded): + await bucket.acquire(1) + + +async def test_token_bucket_refills_over_time() -> None: + import time + + bucket = TokenBucket(capacity=10, refill_rate=1000) + await bucket.acquire(10) + assert bucket.tokens == 0 + # Artificially push last_refill back to simulate elapsed time + bucket.last_refill -= 1.0 # 1 second elapsed → 1000 tokens added, capped at 10 + await bucket.acquire(5) + assert bucket.tokens >= 4 # at least 5 refilled, minus 5 acquired + + +async def test_token_bucket_exceed_raises_with_message() -> None: + bucket = TokenBucket(capacity=0, refill_rate=0) + with pytest.raises(RateLimitExceeded, match="rate limit exceeded"): + await bucket.acquire(1) + + +# --------------------------------------------------------------------------- +# resolve_stream_buffer_ms +# --------------------------------------------------------------------------- + + +def test_resolve_stream_buffer_ms_override_used() -> None: + assert resolve_stream_buffer_ms(200) == 200 + + +def test_resolve_stream_buffer_ms_override_clamped_to_zero() -> None: + assert resolve_stream_buffer_ms(-50) == 0 + + +def test_resolve_stream_buffer_ms_override_clamped_to_max() -> None: + assert resolve_stream_buffer_ms(99999) == 30000 + + +def test_resolve_stream_buffer_ms_env_var(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_STREAM_BUFFER_MS", "500") + assert resolve_stream_buffer_ms() == 500 + + +def test_resolve_stream_buffer_ms_env_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_STREAM_BUFFER_MS", "notanumber") + assert resolve_stream_buffer_ms() == 0 + + +# --------------------------------------------------------------------------- +# stream_completion_log +# --------------------------------------------------------------------------- + + +def test_stream_completion_log_success(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.INFO, logger="runtime.streaming"): + stream_completion_log("tid-1", True, connector_id="smtp", action="send") + assert any("completed" in r.message.lower() for r in caplog.records) + + +def test_stream_completion_log_failure(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="runtime.streaming"): + stream_completion_log("tid-2", False, connector_id="smtp", action="send") + assert any("failed" in r.message.lower() for r in caplog.records) + + +# --------------------------------------------------------------------------- +# BufferedStreamIterator +# --------------------------------------------------------------------------- + + +async def _collect_async(gen) -> list: + items = [] + async for item in gen: + items.append(item) + return items + + +async def test_buffered_stream_no_buffer() -> None: + async def source(): + for i in range(3): + yield {"i": i} + + items = await _collect_async( + BufferedStreamIterator(source(), buffer_ms=0, trace_id="t1") + ) + assert items == [{"i": 0}, {"i": 1}, {"i": 2}] + + +async def test_buffered_stream_with_buffer() -> None: + async def source(): + for i in range(4): + yield {"i": i} + + items = await _collect_async( + BufferedStreamIterator(source(), buffer_ms=1000, trace_id="t2") + ) + assert items == [{"i": 0}, {"i": 1}, {"i": 2}, {"i": 3}] + + +async def test_buffered_stream_logs_on_failure() -> None: + async def bad_source(): + yield {"i": 0} + raise RuntimeError("upstream error") + + with pytest.raises(RuntimeError, match="upstream error"): + await _collect_async(BufferedStreamIterator(bad_source(), buffer_ms=0, trace_id="t3")) + + +async def test_buffered_stream_flush_mid_buffer(monkeypatch: pytest.MonkeyPatch) -> None: + """Force a mid-stream flush by rolling back last_flush time.""" + import time + + call_count = 0 + real_monotonic = time.monotonic + + def patched_monotonic() -> float: + nonlocal call_count + call_count += 1 + # First call (setup) returns normal; subsequent calls return far future + if call_count <= 2: + return real_monotonic() + return real_monotonic() + 10.0 + + async def source(): + for i in range(3): + yield {"i": i} + + with patch("node_wire_runtime.streaming.time.monotonic", side_effect=patched_monotonic): + items = await _collect_async( + BufferedStreamIterator(source(), buffer_ms=100, trace_id="t4") + ) + assert len(items) == 3 + + +# --------------------------------------------------------------------------- +# log_sanitization gaps +# --------------------------------------------------------------------------- + + +def test_is_sensitive_key_substring_match() -> None: + assert is_sensitive_key("patient_id") is True + assert is_sensitive_key("my_ssn") is True + assert is_sensitive_key("email_address") is True + + +def test_is_sensitive_key_always_redact() -> None: + assert is_sensitive_key("search_params") is True + assert is_sensitive_key("body") is True + assert is_sensitive_key("payload") is True + + +def test_is_sensitive_key_safe_key() -> None: + assert is_sensitive_key("connector_id") is False + assert is_sensitive_key("action") is False + + +def test_sanitize_value_nested_dict() -> None: + val = {"safe_key": {"nested": "value"}, "patient_id": "123"} + result = sanitize_value("outer", val) + assert result["patient_id"] == REDACTED + assert result["safe_key"] == {"nested": "value"} + + +def test_sanitize_value_list() -> None: + # List elements are recursively processed with the same key + result = sanitize_value("items", [{"secret": "s"}, {"ok": "v"}]) + assert isinstance(result, list) + # "secret" key inside the nested dict gets redacted + assert result[0]["secret"] == REDACTED + assert result[1]["ok"] == "v" + + +def test_sanitize_value_tuple() -> None: + result = sanitize_value("data", ("safe", "value")) + assert isinstance(result, tuple) + assert result == ("safe", "value") + + +def test_sanitize_value_str_body_key_redacted() -> None: + result = sanitize_value("body", "any string content") + assert result == REDACTED + + +def test_sanitize_value_str_non_body_key_passthrough() -> None: + result = sanitize_value("action", "send_email") + assert result == "send_email" + + +def test_sanitize_mapping_redacts_sensitive() -> None: + result = sanitize_mapping({"patient": "Smith", "connector_id": "smtp"}) + assert result["patient"] == REDACTED + assert result["connector_id"] == "smtp" + + +def test_redact_sensitive_string_arg_long_string() -> None: + long_str = "x" * 101 + assert _redact_sensitive_string_arg(long_str) == REDACTED + + +def test_redact_sensitive_string_arg_phi_marker() -> None: + assert _redact_sensitive_string_arg("contains_phi_marker_here") == REDACTED + + +def test_redact_sensitive_string_arg_safe() -> None: + assert _redact_sensitive_string_arg("connector_id=smtp") == "connector_id=smtp" + + +def test_sanitize_log_record_dict_args() -> None: + record = logging.LogRecord( + name="test", level=logging.INFO, pathname=__file__, lineno=1, + msg="%(patient)s %(action)s", args={"patient": "Smith", "action": "send"}, + exc_info=None, + ) + sanitize_log_record(record) + assert record.args["patient"] == REDACTED # type: ignore[index] + assert record.args["action"] == "send" # type: ignore[index] + + +def test_sanitize_log_record_tuple_args() -> None: + record = logging.LogRecord( + name="test", level=logging.INFO, pathname=__file__, lineno=1, + msg="%s %s", args=("phi_marker_data", "safe"), + exc_info=None, + ) + sanitize_log_record(record) + assert record.args[0] == REDACTED # type: ignore[index] + assert record.args[1] == "safe" + + +def test_sanitize_log_record_no_args() -> None: + record = logging.LogRecord( + name="test", level=logging.INFO, pathname=__file__, lineno=1, + msg="plain message", args=(), exc_info=None, + ) + sanitize_log_record(record) # should not raise + + +def test_install_sanitizing_filter_finds_existing_without_flag() -> None: + """When a SanitizingLogFilter is already present but the flag is False, + it should detect it and not add a duplicate.""" + import node_wire_runtime.log_sanitization as ls + + original_flag = ls._SANITIZING_FILTER_INSTALLED + original_filters = list(logging.getLogger().filters) + try: + ls._SANITIZING_FILTER_INSTALLED = False + root = logging.getLogger() + # Add filter directly without going through install_sanitizing_log_filter + existing = SanitizingLogFilter() + root.addFilter(existing) + install_sanitizing_log_filter() + count = sum(1 for f in root.filters if isinstance(f, SanitizingLogFilter)) + assert count >= 1 + assert ls._SANITIZING_FILTER_INSTALLED is True + finally: + ls._SANITIZING_FILTER_INSTALLED = original_flag + # Restore original filters + for f in list(logging.getLogger().filters): + logging.getLogger().removeFilter(f) + for f in original_filters: + logging.getLogger().addFilter(f) + + +# --------------------------------------------------------------------------- +# observability: SanitizingSpanExporter / SanitizingLogExporter +# --------------------------------------------------------------------------- + + +def test_sanitizing_span_exporter_redacts_sensitive_attributes() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock() + delegate.export.return_value = None + exporter = SanitizingSpanExporter(delegate) + + fake_span = MagicMock() + fake_span._attributes = {"patient_id": "12345", "connector_id": "smtp"} + exporter.export([fake_span]) + + delegate.export.assert_called_once() + assert fake_span._attributes["patient_id"] == REDACTED + assert fake_span._attributes["connector_id"] == "smtp" + + +def test_sanitizing_span_exporter_shutdown() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock() + exporter = SanitizingSpanExporter(delegate) + exporter.shutdown() + delegate.shutdown.assert_called_once() + + +def test_sanitizing_span_exporter_force_flush() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock() + delegate.force_flush.return_value = True + exporter = SanitizingSpanExporter(delegate) + result = exporter.force_flush(1000) + assert result is True + delegate.force_flush.assert_called_once_with(1000) + + +def test_sanitizing_span_exporter_force_flush_no_delegate_method() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock(spec=[]) # no force_flush attribute + exporter = SanitizingSpanExporter(delegate) + result = exporter.force_flush() + assert result is True + + +def test_sanitizing_log_exporter_redacts_attributes() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock() + exporter = SanitizingLogExporter(delegate) + + fake_record = MagicMock() + fake_record.attributes = {"email": "user@example.com", "action": "send"} + exporter.export([fake_record]) + + delegate.export.assert_called_once() + assert fake_record.attributes["email"] == REDACTED + assert fake_record.attributes["action"] == "send" + + +def test_sanitizing_log_exporter_shutdown() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock() + exporter = SanitizingLogExporter(delegate) + exporter.shutdown() + delegate.shutdown.assert_called_once() + + +def test_sanitizing_log_exporter_force_flush() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock() + delegate.force_flush.return_value = True + exporter = SanitizingLogExporter(delegate) + result = exporter.force_flush(500) + assert result is True + + +def test_sanitizing_log_exporter_force_flush_no_delegate() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock(spec=[]) + exporter = SanitizingLogExporter(delegate) + result = exporter.force_flush() + assert result is True + + +def test_otel_context_filter_with_valid_span() -> None: + from opentelemetry.sdk.trace import TracerProvider + + from node_wire_runtime.observability import _OtelContextFilter + + flt = _OtelContextFilter() + provider = TracerProvider() + tracer = provider.get_tracer("test") + + record = logging.LogRecord("test", logging.INFO, __file__, 1, "msg", (), None) + with tracer.start_as_current_span("test-span"): + result = flt.filter(record) + + assert result is True + assert len(record.otel_trace_id) == 32 + assert len(record.otel_span_id) == 16 + + +# --------------------------------------------------------------------------- +# connector_registry — error paths +# --------------------------------------------------------------------------- + + +def test_parse_allowed_names_empty_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_ALLOWED_CONNECTORS", raising=False) + result = connector_registry._parse_allowed_names() + assert result == set() + + +def test_parse_allowed_names_empty_string(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", " ") + result = connector_registry._parse_allowed_names() + assert result == set() + + +def test_registration_module_missing_is_silently_skipped( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ModuleNotFoundError with name == reg_name is silently ignored.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn") + ep = EntryPoint( + name="myconn", value="node_wire_myconn.logic", group="node_wire.connectors" + ) + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_myconn.logic": + return MagicMock() + err = ModuleNotFoundError(f"No module named '{name}'") + err.name = name # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[ep]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + ): + loaded = connector_registry.auto_register() + + assert "node_wire_myconn.logic" in loaded + assert "node_wire_myconn.registration" not in loaded + + +def test_registration_dep_error_is_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + """ModuleNotFoundError for a dep inside registration module is re-raised.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn2") + ep = EntryPoint( + name="myconn2", value="node_wire_myconn2.logic", group="node_wire.connectors" + ) + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_myconn2.logic": + return MagicMock() + err = ModuleNotFoundError("No module named 'missing_dep'") + err.name = "missing_dep" # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[ep]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ModuleNotFoundError), + ): + connector_registry.auto_register() + + +def test_registration_unexpected_exception_is_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn3") + ep = EntryPoint( + name="myconn3", value="node_wire_myconn3.logic", group="node_wire.connectors" + ) + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_myconn3.logic": + return MagicMock() + raise RuntimeError("unexpected error in registration module") + + with ( + patch.object(connector_registry, "entry_points", return_value=[ep]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(RuntimeError, match="unexpected error"), + ): + connector_registry.auto_register() + + +def test_fallback_logic_missing_dep_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + """Fallback logic module raises ModuleNotFoundError for an internal dep → re-raise.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "badconn") + + def fake_import(name: str) -> MagicMock: + err = ModuleNotFoundError("No module named 'some_dep'") + err.name = "some_dep" # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ModuleNotFoundError), + ): + connector_registry.auto_register() + + +def test_fallback_logic_not_found_is_skipped(monkeypatch: pytest.MonkeyPatch) -> None: + """Fallback logic module genuinely absent → skip, no error.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "absent_connector") + + def fake_import(name: str) -> MagicMock: + err = ModuleNotFoundError(f"No module named '{name}'") + err.name = name # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + ): + loaded = connector_registry.auto_register() + + assert loaded == [] + + +def test_fallback_registration_missing_silently_skipped(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "noregconn") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_noregconn.logic": + return MagicMock() + err = ModuleNotFoundError(f"No module named '{name}'") + err.name = name # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + ): + loaded = connector_registry.auto_register() + + assert "node_wire_noregconn.logic" in loaded + assert "node_wire_noregconn.registration" not in loaded + + +def test_fallback_registration_dep_error_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "regerrconn") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_regerrconn.logic": + return MagicMock() + if name == "node_wire_regerrconn.registration": + err = ModuleNotFoundError("No module named 'dep_x'") + err.name = "dep_x" # type: ignore[attr-defined] + raise err + raise ImportError(f"unexpected: {name}") + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ModuleNotFoundError), + ): + connector_registry.auto_register() + + +def test_fallback_registration_unexpected_exception_reraised( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "excconn") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_excconn.logic": + return MagicMock() + raise ValueError("unexpected registration failure") + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ValueError, match="unexpected registration failure"), + ): + connector_registry.auto_register() diff --git a/tests/test_secrets_providers.py b/tests/test_secrets_providers.py new file mode 100644 index 0000000..ca75fad --- /dev/null +++ b/tests/test_secrets_providers.py @@ -0,0 +1,478 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for the cloud secret providers (AWS, Azure, GCP, Vault). + +The cloud SDKs are not installed in the test environment, so each fixture +injects stub modules into ``sys.modules`` before importing the provider +module, and removes the provider module afterwards so other tests see a +clean import state. +""" + +from __future__ import annotations + +import importlib +import json +import sys +import types +from types import SimpleNamespace + +import pytest + +from node_wire_runtime.secrets.base import ( + EnvSecretProvider, + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) +from node_wire_runtime.secrets.chained import ChainedSecretProvider + + +# --------------------------------------------------------------------------- +# AWS +# --------------------------------------------------------------------------- + + +@pytest.fixture +def aws(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.aws against stub boto3/botocore.""" + state = SimpleNamespace( + secret_string="{}", + error=None, + service=None, + region=None, + requested_secret_id=None, + ) + + class BotoCoreError(Exception): + pass + + class ClientError(Exception): + def __init__(self, error_response: dict, operation_name: str) -> None: + super().__init__(error_response, operation_name) + self.response = error_response + self.operation_name = operation_name + + exc_mod = types.ModuleType("botocore.exceptions") + exc_mod.BotoCoreError = BotoCoreError + exc_mod.ClientError = ClientError + botocore_mod = types.ModuleType("botocore") + botocore_mod.exceptions = exc_mod + + class FakeClient: + def get_secret_value(self, SecretId: str) -> dict: + state.requested_secret_id = SecretId + if state.error is not None: + raise state.error + return {"SecretString": state.secret_string} + + boto3_mod = types.ModuleType("boto3") + + def client(service: str, region_name: str | None = None) -> FakeClient: + state.service = service + state.region = region_name + return FakeClient() + + boto3_mod.client = client + + monkeypatch.setitem(sys.modules, "boto3", boto3_mod) + monkeypatch.setitem(sys.modules, "botocore", botocore_mod) + monkeypatch.setitem(sys.modules, "botocore.exceptions", exc_mod) + sys.modules.pop("node_wire_runtime.secrets.aws", None) + mod = importlib.import_module("node_wire_runtime.secrets.aws") + yield mod, state, exc_mod + sys.modules.pop("node_wire_runtime.secrets.aws", None) + + +def test_aws_success_and_missing_key(aws) -> None: + mod, state, _exc = aws + state.secret_string = json.dumps({"epic_client_id": "abc123"}) + + provider = mod.AwsSecretsManagerProvider("my-bundle", region="eu-west-1") + + assert state.service == "secretsmanager" + assert state.region == "eu-west-1" + assert state.requested_secret_id == "my-bundle" + assert provider.get_secret("epic_client_id") == "abc123" + with pytest.raises(SecretNotFoundError): + provider.get_secret("nope") + + +def test_aws_resource_not_found_maps_to_secret_not_found(aws) -> None: + mod, state, exc = aws + state.error = exc.ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "GetSecretValue" + ) + with pytest.raises(SecretNotFoundError): + mod.AwsSecretsManagerProvider("missing-bundle") + + +def test_aws_other_client_error_maps_to_provider_error(aws) -> None: + mod, state, exc = aws + state.error = exc.ClientError( + {"Error": {"Code": "AccessDeniedException"}}, "GetSecretValue" + ) + with pytest.raises(SecretProviderError, match="AccessDeniedException"): + mod.AwsSecretsManagerProvider("forbidden-bundle") + + +def test_aws_botocore_error_maps_to_provider_error(aws) -> None: + mod, state, exc = aws + state.error = exc.BotoCoreError("connection refused") + with pytest.raises(SecretProviderError, match="AWS connection error"): + mod.AwsSecretsManagerProvider("unreachable-bundle") + + +def test_aws_import_error_has_install_hint(monkeypatch: pytest.MonkeyPatch) -> None: + """With boto3 absent the module import fails with an actionable hint.""" + monkeypatch.delitem(sys.modules, "boto3", raising=False) + sys.modules.pop("node_wire_runtime.secrets.aws", None) + try: + with pytest.raises(ImportError, match=r"node-wire-runtime\[aws\]"): + importlib.import_module("node_wire_runtime.secrets.aws") + finally: + sys.modules.pop("node_wire_runtime.secrets.aws", None) + + +# --------------------------------------------------------------------------- +# Azure +# --------------------------------------------------------------------------- + + +@pytest.fixture +def azure(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.azure against stub azure-* SDKs.""" + state = SimpleNamespace( + secrets={}, # azure-name -> value + error=None, + credential_error=None, + vault_url=None, + requested_names=[], + ) + + class ResourceNotFoundError(Exception): + pass + + class HttpResponseError(Exception): + pass + + class DefaultAzureCredential: + def __init__(self) -> None: + if state.credential_error is not None: + raise state.credential_error + + class SecretClient: + def __init__(self, vault_url: str, credential: object) -> None: + state.vault_url = vault_url + + def get_secret(self, name: str) -> SimpleNamespace: + state.requested_names.append(name) + if state.error is not None: + raise state.error + if name not in state.secrets: + raise ResourceNotFoundError(name) + return SimpleNamespace(value=state.secrets[name]) + + identity_mod = types.ModuleType("azure.identity") + identity_mod.DefaultAzureCredential = DefaultAzureCredential + kv_secrets_mod = types.ModuleType("azure.keyvault.secrets") + kv_secrets_mod.SecretClient = SecretClient + core_exc_mod = types.ModuleType("azure.core.exceptions") + core_exc_mod.ResourceNotFoundError = ResourceNotFoundError + core_exc_mod.HttpResponseError = HttpResponseError + + azure_mod = types.ModuleType("azure") + kv_mod = types.ModuleType("azure.keyvault") + core_mod = types.ModuleType("azure.core") + + monkeypatch.setitem(sys.modules, "azure", azure_mod) + monkeypatch.setitem(sys.modules, "azure.identity", identity_mod) + monkeypatch.setitem(sys.modules, "azure.keyvault", kv_mod) + monkeypatch.setitem(sys.modules, "azure.keyvault.secrets", kv_secrets_mod) + monkeypatch.setitem(sys.modules, "azure.core", core_mod) + monkeypatch.setitem(sys.modules, "azure.core.exceptions", core_exc_mod) + sys.modules.pop("node_wire_runtime.secrets.azure", None) + mod = importlib.import_module("node_wire_runtime.secrets.azure") + yield mod, state, core_exc_mod + sys.modules.pop("node_wire_runtime.secrets.azure", None) + + +def test_azure_success_maps_underscores_to_hyphens(azure) -> None: + mod, state, _exc = azure + state.secrets["epic-client-id"] = "azure-value" + + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + + assert state.vault_url == "https://kv.example.vault.azure.net" + assert provider.get_secret("epic_client_id") == "azure-value" + assert state.requested_names == ["epic-client-id"] + + +def test_azure_missing_secret_maps_to_secret_not_found(azure) -> None: + mod, _state, _exc = azure + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + with pytest.raises(SecretNotFoundError): + provider.get_secret("absent_key") + + +def test_azure_none_value_maps_to_secret_not_found(azure) -> None: + mod, state, _exc = azure + state.secrets["empty-key"] = None + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + with pytest.raises(SecretNotFoundError): + provider.get_secret("empty_key") + + +def test_azure_http_error_maps_to_provider_error(azure) -> None: + mod, state, exc = azure + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + state.error = exc.HttpResponseError("503 upstream unavailable") + with pytest.raises(SecretProviderError, match="Azure Key Vault HTTP error"): + provider.get_secret("any_key") + + +def test_azure_init_failure_maps_to_provider_error(azure) -> None: + mod, state, _exc = azure + state.credential_error = RuntimeError("no credential chain available") + with pytest.raises(SecretProviderError, match="Failed to initialise"): + mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + + +# --------------------------------------------------------------------------- +# GCP +# --------------------------------------------------------------------------- + + +@pytest.fixture +def gcp(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.gcp against stub google SDK modules.""" + state = SimpleNamespace(payload=b"{}", error=None, requested_name=None) + + class GoogleAPICallError(Exception): + pass + + class NotFound(GoogleAPICallError): + pass + + class SecretManagerServiceClient: + def access_secret_version(self, request: dict) -> SimpleNamespace: + state.requested_name = request["name"] + if state.error is not None: + raise state.error + return SimpleNamespace(payload=SimpleNamespace(data=state.payload)) + + sm_mod = types.ModuleType("google.cloud.secretmanager") + sm_mod.SecretManagerServiceClient = SecretManagerServiceClient + api_exc_mod = types.ModuleType("google.api_core.exceptions") + api_exc_mod.NotFound = NotFound + api_exc_mod.GoogleAPICallError = GoogleAPICallError + api_core_mod = types.ModuleType("google.api_core") + api_core_mod.exceptions = api_exc_mod + + # "google" and "google.cloud" may exist as real namespace packages; only + # provide them if absent so we never clobber installed google packages. + if "google" not in sys.modules: + monkeypatch.setitem(sys.modules, "google", types.ModuleType("google")) + if "google.cloud" not in sys.modules: + monkeypatch.setitem(sys.modules, "google.cloud", types.ModuleType("google.cloud")) + monkeypatch.setitem(sys.modules, "google.cloud.secretmanager", sm_mod) + monkeypatch.setitem(sys.modules, "google.api_core", api_core_mod) + monkeypatch.setitem(sys.modules, "google.api_core.exceptions", api_exc_mod) + sys.modules.pop("node_wire_runtime.secrets.gcp", None) + mod = importlib.import_module("node_wire_runtime.secrets.gcp") + yield mod, state, api_exc_mod + sys.modules.pop("node_wire_runtime.secrets.gcp", None) + + +def test_gcp_success_and_missing_key(gcp) -> None: + mod, state, _exc = gcp + state.payload = json.dumps({"db_password": "gcp-value"}).encode("utf-8") + + provider = mod.GcpSecretManagerProvider("proj-1", "bundle", version="7") + + assert state.requested_name == "projects/proj-1/secrets/bundle/versions/7" + assert provider.get_secret("db_password") == "gcp-value" + with pytest.raises(SecretNotFoundError): + provider.get_secret("nope") + + +def test_gcp_default_version_is_latest(gcp) -> None: + mod, state, _exc = gcp + mod.GcpSecretManagerProvider("proj-1", "bundle") + assert state.requested_name == "projects/proj-1/secrets/bundle/versions/latest" + + +def test_gcp_not_found_maps_to_secret_not_found(gcp) -> None: + mod, state, exc = gcp + state.error = exc.NotFound("no such secret") + with pytest.raises(SecretNotFoundError): + mod.GcpSecretManagerProvider("proj-1", "missing") + + +def test_gcp_api_error_maps_to_provider_error(gcp) -> None: + mod, state, exc = gcp + state.error = exc.GoogleAPICallError("permission denied") + with pytest.raises(SecretProviderError, match="GCP Secret Manager error"): + mod.GcpSecretManagerProvider("proj-1", "forbidden") + + +# --------------------------------------------------------------------------- +# HashiCorp Vault +# --------------------------------------------------------------------------- + + +@pytest.fixture +def vault(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.vault against a stub hvac module.""" + state = SimpleNamespace( + data={}, + error=None, + authenticated=True, + url=None, + token=None, + path=None, + mount_point=None, + ) + + class VaultError(Exception): + pass + + class InvalidPath(VaultError): + pass + + class _KvV2: + def read_secret_version(self, path: str, mount_point: str = "secret") -> dict: + state.path = path + state.mount_point = mount_point + if state.error is not None: + raise state.error + return {"data": {"data": state.data}} + + class Client: + def __init__(self, url: str, token: str | None = None) -> None: + state.url = url + state.token = token + self.secrets = SimpleNamespace(kv=SimpleNamespace(v2=_KvV2())) + + def is_authenticated(self) -> bool: + return state.authenticated + + hvac_mod = types.ModuleType("hvac") + hvac_mod.Client = Client + exc_mod = types.ModuleType("hvac.exceptions") + exc_mod.VaultError = VaultError + exc_mod.InvalidPath = InvalidPath + hvac_mod.exceptions = exc_mod + + monkeypatch.setitem(sys.modules, "hvac", hvac_mod) + monkeypatch.setitem(sys.modules, "hvac.exceptions", exc_mod) + sys.modules.pop("node_wire_runtime.secrets.vault", None) + mod = importlib.import_module("node_wire_runtime.secrets.vault") + yield mod, state, exc_mod + sys.modules.pop("node_wire_runtime.secrets.vault", None) + + +def test_vault_success_and_missing_key(vault) -> None: + mod, state, _exc = vault + state.data = {"api_token": "vault-value"} + + provider = mod.HashiCorpVaultProvider( + "apps/node-wire", url="https://vault.internal:8200", token="t-1", mount_point="kv" + ) + + assert state.url == "https://vault.internal:8200" + assert state.token == "t-1" + assert state.path == "apps/node-wire" + assert state.mount_point == "kv" + assert provider.get_secret("api_token") == "vault-value" + with pytest.raises(SecretNotFoundError): + provider.get_secret("nope") + + +def test_vault_unauthenticated_maps_to_provider_error(vault) -> None: + mod, state, _exc = vault + state.authenticated = False + with pytest.raises(SecretProviderError, match="not authenticated"): + mod.HashiCorpVaultProvider("apps/node-wire") + + +def test_vault_invalid_path_maps_to_secret_not_found(vault) -> None: + mod, state, exc = vault + state.error = exc.InvalidPath("no secret at path") + with pytest.raises(SecretNotFoundError): + mod.HashiCorpVaultProvider("apps/missing") + + +def test_vault_error_maps_to_provider_error(vault) -> None: + mod, state, exc = vault + state.error = exc.VaultError("sealed") + with pytest.raises(SecretProviderError, match="Vault error"): + mod.HashiCorpVaultProvider("apps/broken") + + +# --------------------------------------------------------------------------- +# base.py gaps +# --------------------------------------------------------------------------- + + +def test_abstract_get_secret_body_raises_not_implemented() -> None: + class PassThrough(SecretProvider): + def get_secret(self, key: str) -> str: + return super().get_secret(key) # type: ignore[safe-super] + + with pytest.raises(NotImplementedError): + PassThrough().get_secret("anything") + + +def test_env_provider_uppercase_fallback_strips_quotes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("nw_test_upper_key", raising=False) + monkeypatch.setenv("NW_TEST_UPPER_KEY", " 'quoted-value' ") + p = EnvSecretProvider(legacy_empty_on_missing=False) + assert p.get_secret("nw_test_upper_key") == "quoted-value" + + +# --------------------------------------------------------------------------- +# chained.py gaps +# --------------------------------------------------------------------------- + + +class _StaticProvider(SecretProvider): + def __init__(self, data: dict[str, str]) -> None: + self._data = data + + def get_secret(self, key: str) -> str: + try: + return self._data[key] + except KeyError: + raise SecretNotFoundError(key) + + +class _BrokenProvider(SecretProvider): + def get_secret(self, key: str) -> str: + raise SecretProviderError("IAM is on fire") + + +def test_chained_requires_at_least_one_provider() -> None: + with pytest.raises(ValueError, match="at least one provider"): + ChainedSecretProvider() + + +def test_chained_falls_through_on_not_found() -> None: + chain = ChainedSecretProvider(_StaticProvider({}), _StaticProvider({"k": "v2"})) + assert chain.get_secret("k") == "v2" + + +def test_chained_propagates_provider_error_immediately() -> None: + fallback = _StaticProvider({"k": "should-never-be-reached"}) + chain = ChainedSecretProvider(_BrokenProvider(), fallback) + with pytest.raises(SecretProviderError, match="IAM is on fire"): + chain.get_secret("k") + + +def test_chained_raises_not_found_when_all_providers_miss() -> None: + chain = ChainedSecretProvider(_StaticProvider({}), _StaticProvider({})) + with pytest.raises(SecretNotFoundError, match="not found in any of 2 provider"): + chain.get_secret("missing") From 065b4f9f089f0f88dbdc6a389ef50fa85927496d Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sat, 4 Jul 2026 11:39:30 -0700 Subject: [PATCH 3/9] Fixed health checks --- pyproject.toml | 1 + src/bindings/grpc_server/server.py | 7 ++++ src/bindings/rest_api/app.py | 14 +++++++ tests/test_factory_and_rest.py | 29 +++++++++++++++ tests/test_grpc_health.py | 59 ++++++++++++++++++++++++++++++ uv.lock | 15 ++++++++ 6 files changed, 125 insertions(+) create mode 100644 tests/test_grpc_health.py diff --git a/pyproject.toml b/pyproject.toml index 4d320bb..81db754 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "opentelemetry-instrumentation-fastapi>=0.45b0", "opentelemetry-exporter-otlp>=1.24.0", "traceloop-sdk>=0.53.0", + "grpcio-health-checking>=1.62.0", ] [project.urls] diff --git a/src/bindings/grpc_server/server.py b/src/bindings/grpc_server/server.py index 52cd335..5bb0c8b 100644 --- a/src/bindings/grpc_server/server.py +++ b/src/bindings/grpc_server/server.py @@ -11,6 +11,8 @@ from typing import Any import grpc +from grpc_health.v1 import health as grpc_health +from grpc_health.v1 import health_pb2, health_pb2_grpc from bindings.factory import ConnectorFactory from node_wire_runtime.connector_registry import auto_register @@ -117,6 +119,11 @@ def serve(port: int = 50051) -> None: host = resolve_grpc_host() configure_grpc_server_port(server, port=port, host=host, cert_path=cert_path, key_path=key_path) + health_servicer = grpc_health.HealthServicer() + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + health_servicer.set("", health_pb2.HealthCheckResponse.SERVING) + health_servicer.set("aot.connectors.ConnectorService", health_pb2.HealthCheckResponse.SERVING) + server.start() server.wait_for_termination() diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index 6fee523..15da3ab 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -128,6 +128,20 @@ async def health() -> Dict[str, str]: return {"status": "ok"} +@app.get("/ready", tags=["system"]) +async def ready() -> Dict[str, str]: + try: + factory = get_factory() + if not factory.list_for_protocol("rest") and not factory.list_for_protocol("grpc"): + raise HTTPException(status_code=503, detail="no connectors loaded") + except HTTPException: + raise + except Exception as exc: + logger.warning("Readiness check failed: %s", exc) + raise HTTPException(status_code=503, detail="factory not ready") + return {"status": "ready"} + + def _http_status_for_category(category: ErrorCategory | None) -> int: if category is None: return 200 diff --git a/tests/test_factory_and_rest.py b/tests/test_factory_and_rest.py index 060eeae..8c83c8e 100644 --- a/tests/test_factory_and_rest.py +++ b/tests/test_factory_and_rest.py @@ -34,6 +34,35 @@ def test_health_endpoint(): assert resp.json() == {"status": "ok"} +def test_ready_endpoint_when_factory_has_connectors(monkeypatch: pytest.MonkeyPatch): + mock_factory = MagicMock() + mock_factory.list_for_protocol.return_value = [MagicMock()] + monkeypatch.setattr("bindings.rest_api.app._factory", mock_factory) + client = TestClient(app) + resp = client.get("/ready") + assert resp.status_code == 200 + assert resp.json() == {"status": "ready"} + + +def test_ready_endpoint_returns_503_when_no_connectors(monkeypatch: pytest.MonkeyPatch): + mock_factory = MagicMock() + mock_factory.list_for_protocol.return_value = [] + monkeypatch.setattr("bindings.rest_api.app._factory", mock_factory) + client = TestClient(app) + resp = client.get("/ready") + assert resp.status_code == 503 + + +def test_ready_endpoint_returns_503_on_factory_error(monkeypatch: pytest.MonkeyPatch): + def _boom(): + raise RuntimeError("factory load failed") + + monkeypatch.setattr("bindings.rest_api.app.get_factory", _boom) + client = TestClient(app) + resp = client.get("/ready") + assert resp.status_code == 503 + + def test_agent_transport_defaults_to_stdio(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("NW_MCP_TRANSPORT", raising=False) client = TestClient(app) diff --git a/tests/test_grpc_health.py b/tests/test_grpc_health.py new file mode 100644 index 0000000..693f425 --- /dev/null +++ b/tests/test_grpc_health.py @@ -0,0 +1,59 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for gRPC health service registration in serve().""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from grpc_health.v1 import health_pb2 + + +def test_grpc_health_servicer_is_registered(): + """serve() must register a HealthServicer and set SERVING for both service names.""" + added_servicers = {} + + def fake_add_health(servicer, server): + added_servicers["health"] = servicer + + def fake_add_connector(servicer, server): + added_servicers["connector"] = servicer + + with ( + patch("bindings.grpc_server.server.grpc.server") as mock_grpc_server, + patch( + "bindings.grpc_server.server.connector_pb2_grpc.add_ConnectorServiceServicer_to_server", + fake_add_connector, + ), + patch( + "bindings.grpc_server.server.health_pb2_grpc.add_HealthServicer_to_server", + fake_add_health, + ), + patch("bindings.grpc_server.server.grpc_health.HealthServicer") as mock_health_cls, + patch("bindings.grpc_server.server.configure_grpc_server_port"), + patch("bindings.grpc_server.server._async_runner"), + patch("bindings.grpc_server.server.ConnectorServiceServicer"), + ): + fake_server = MagicMock() + fake_server.wait_for_termination.side_effect = KeyboardInterrupt + mock_grpc_server.return_value = fake_server + + mock_health_instance = MagicMock() + mock_health_cls.return_value = mock_health_instance + + from bindings.grpc_server.server import serve + + try: + serve(port=0) + except KeyboardInterrupt: + pass + + assert "health" in added_servicers, "HealthServicer was not added to the gRPC server" + + set_calls = {call.args[0]: call.args[1] for call in mock_health_instance.set.call_args_list} + assert "" in set_calls, "Overall health (empty string) not set" + assert "aot.connectors.ConnectorService" in set_calls, "ConnectorService health not set" + assert set_calls[""] == health_pb2.HealthCheckResponse.SERVING + assert set_calls["aot.connectors.ConnectorService"] == health_pb2.HealthCheckResponse.SERVING diff --git a/uv.lock b/uv.lock index 471c3b5..e0a5599 100644 --- a/uv.lock +++ b/uv.lock @@ -1211,6 +1211,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/e8/127dc2b246096ad50ef7c8d9b7b31d757787aeb796368bcdd4454e4204c4/grpcio-1.81.0-cp314-cp314-win_amd64.whl", hash = "sha256:b93cee313cae4e113fbb3a0ce1ea5633db6f63cfde2b2dc1d817429026b2a50b", size = 5070848, upload-time = "2026-06-01T05:56:19.735Z" }, ] +[[package]] +name = "grpcio-health-checking" +version = "1.71.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/86/20994347ef36b7626fb74539f13128100dd8b7eaac67efc063264e6cdc80/grpcio_health_checking-1.71.2.tar.gz", hash = "sha256:1c21ece88c641932f432b573ef504b20603bdf030ad4e1ec35dd7fdb4ea02637", size = 16770, upload-time = "2025-06-28T04:24:08.768Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/74/7bc6ab96bf1083cab2684f9c3ae434caa638de3d5c5574e8435e2c146598/grpcio_health_checking-1.71.2-py3-none-any.whl", hash = "sha256:f91db41410d6bd18a7828c5b6ac2bebd77a63483263cbe42bf3c0c9b86cece33", size = 18918, upload-time = "2025-06-28T04:23:56.923Z" }, +] + [[package]] name = "grpcio-status" version = "1.71.2" @@ -2174,6 +2187,7 @@ dependencies = [ { name = "google-api-python-client" }, { name = "google-auth" }, { name = "grpcio" }, + { name = "grpcio-health-checking" }, { name = "grpcio-tools" }, { name = "httpx", extra = ["http2"] }, { name = "opentelemetry-api" }, @@ -2229,6 +2243,7 @@ requires-dist = [ { name = "google-generativeai", marker = "extra == 'agents'", specifier = ">=0.7.0" }, { name = "groq", marker = "extra == 'agents'", specifier = ">=0.9.0" }, { name = "grpcio", specifier = ">=1.62.0" }, + { name = "grpcio-health-checking", specifier = ">=1.62.0" }, { name = "grpcio-tools", specifier = ">=1.62.0" }, { name = "httpx", extras = ["http2"], specifier = ">=0.27.0,<0.28.0" }, { name = "mcp", marker = "extra == 'agents'", specifier = ">=1.6.0" }, From e0351d6f32a319c0886a9e43660f18518b5004b4 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sat, 4 Jul 2026 12:28:40 -0700 Subject: [PATCH 4/9] Added OpenTelemetry metrics and audit trail --- src/node_wire_runtime/base_connector.py | 58 +++++++++- src/node_wire_runtime/observability.py | 59 +++++++++- tests/test_audit_trail.py | 148 ++++++++++++++++++++++++ tests/test_metrics.py | 119 +++++++++++++++++++ tests/test_observability.py | 36 +++++- 5 files changed, 407 insertions(+), 13 deletions(-) create mode 100644 tests/test_audit_trail.py create mode 100644 tests/test_metrics.py diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py index eff46b9..159577d 100644 --- a/src/node_wire_runtime/base_connector.py +++ b/src/node_wire_runtime/base_connector.py @@ -7,6 +7,7 @@ import contextvars import inspect import logging +import time import uuid from abc import ABC from collections import defaultdict @@ -26,7 +27,7 @@ List, ) -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.trace import Tracer from pybreaker import CircuitBreaker from pydantic import BaseModel, Field, RootModel, ValidationError @@ -41,6 +42,17 @@ logger = logging.getLogger("runtime.base_connector") tracer: Tracer = trace.get_tracer("runtime") +_meter = metrics.get_meter("runtime") +_invocation_counter = _meter.create_counter( + "connector.invocations", + unit="1", + description="Total number of connector invocations", +) +_invocation_duration = _meter.create_histogram( + "connector.duration_ms", + unit="ms", + description="Connector invocation wall-clock time in milliseconds", +) ErrorMapper.register(PolicyDenied, ErrorCategory.AUTH, code="POLICY_DENIED") @@ -441,6 +453,7 @@ async def run( - Maps exceptions into the standard error taxonomy """ trace_id = str(uuid.uuid4()) + _start = time.monotonic() with tracer.start_as_current_span( "connector.run", @@ -458,9 +471,15 @@ async def run( "trace_id": trace_id, "connector_id": self.connector_id, "action": self.action, + "principal": principal, + "tenant_id": tenant_id, + "scopes": list(scopes) if scopes else [], + "audit": True, + "audit_event": "invocation_start", }, ) + _response: Optional[ConnectorResponse] = None token = _caller_execution_ctx.set((principal, tenant_id, scopes)) try: try: @@ -474,12 +493,14 @@ async def run( "action": self.action, "error_type": type(exc).__name__, "error_message": str(exc), + "audit": True, + "audit_event": "invocation_validation_failure", }, ) details = [ {"loc": e["loc"], "msg": e["msg"], "type": e["type"]} for e in exc.errors() ] - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code="VALIDATION_ERROR", error_category=ErrorCategory.BUSINESS, @@ -487,6 +508,7 @@ async def run( trace_id=trace_id, details=details, ) + return _response # Policy hook if self._policy_hook is not None: @@ -518,13 +540,14 @@ async def run( }, ) mapped = ErrorMapper.resolve(exc) - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code=mapped.code, error_category=mapped.category, message=str(exc), trace_id=trace_id, ) + return _response execute_with_resilience = with_resilience(self._breaker_for_tenant(tenant_id)) @@ -540,14 +563,18 @@ async def _do_execute(*, trace_id: str) -> Any: "trace_id": trace_id, "connector_id": self.connector_id, "action": self.action, + "duration_ms": round((time.monotonic() - _start) * 1000, 2), + "audit": True, + "audit_event": "invocation_success", }, ) - return ConnectorResponse( + _response = ConnectorResponse( success=True, data=output_model.model_dump(), trace_id=trace_id, ) + return _response except NestedConnectorActionError as exc: nested = exc.response logger.warning( @@ -559,7 +586,7 @@ async def _do_execute(*, trace_id: str) -> Any: "nested_trace_id": nested.trace_id, }, ) - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code=nested.error_code, error_category=nested.error_category, @@ -567,6 +594,7 @@ async def _do_execute(*, trace_id: str) -> Any: trace_id=trace_id, details=_merge_nested_failure_details(nested), ) + return _response except Exception as exc: # noqa: BLE001 mapped = ErrorMapper.resolve(exc) logger.error( @@ -579,17 +607,35 @@ async def _do_execute(*, trace_id: str) -> Any: "error_category": mapped.category.value, "error_type": type(exc).__name__, "error_message": str(exc), + "duration_ms": round((time.monotonic() - _start) * 1000, 2), + "audit": True, + "audit_event": "invocation_failure", }, ) - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code=mapped.code, error_category=mapped.category, message=str(exc), trace_id=trace_id, ) + return _response finally: _caller_execution_ctx.reset(token) + if _response is not None: + _duration_ms = (time.monotonic() - _start) * 1000 + _metric_attrs = { + "connector.id": self.connector_id, + "connector.action": self.action, + "success": _response.success, + "error_category": ( + _response.error_category.value + if _response.error_category + else "none" + ), + } + _invocation_counter.add(1, attributes=_metric_attrs) + _invocation_duration.record(_duration_ms, attributes=_metric_attrs) @classmethod def get_registry(cls) -> Dict[str, Type[BaseConnector]]: diff --git a/src/node_wire_runtime/observability.py b/src/node_wire_runtime/observability.py index 2f70440..0884fcb 100644 --- a/src/node_wire_runtime/observability.py +++ b/src/node_wire_runtime/observability.py @@ -9,12 +9,15 @@ from typing import Optional, cast from opentelemetry._logs import set_logger_provider -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import MetricExporter, PeriodicExportingMetricReader from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased @@ -85,6 +88,36 @@ def force_flush(self, timeout_millis: int = 30000): return True +class SanitizingMetricExporter(MetricExporter): + def __init__(self, delegate: MetricExporter): + self._delegate = delegate + + @property + def _preferred_temporality(self): # type: ignore[override] + return self._delegate._preferred_temporality + + @property + def _preferred_aggregation(self): # type: ignore[override] + return self._delegate._preferred_aggregation + + def export(self, metrics_data, timeout_millis=10_000, **kwargs): + for rm in metrics_data.resource_metrics: + for sm in rm.scope_metrics: + for metric in sm.metrics: + if hasattr(metric.data, "data_points"): + for dp in metric.data.data_points: + attrs = getattr(dp, "attributes", None) + if attrs and isinstance(attrs, dict): + _sanitize_otlp_attributes(attrs) + return self._delegate.export(metrics_data, timeout_millis=timeout_millis, **kwargs) + + def shutdown(self, timeout_millis=30_000, **kwargs): + return self._delegate.shutdown(timeout_millis=timeout_millis, **kwargs) + + def force_flush(self, timeout_millis: int = 10_000): + return self._delegate.force_flush(timeout_millis) + + def init_observability(app_name: str = "node_wire") -> None: """ Initialize OpenTelemetry + OpenLLMetry/Traceloop for the process. @@ -153,6 +186,30 @@ def init_observability(app_name: str = "node_wire") -> None: root_logger.addFilter(_OtelContextFilter()) root_logger.addHandler(LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider)) + # Metrics: export to the local OTLP collector alongside traces and logs. + metric_interval_str: str = os.getenv("AOT_METRIC_EXPORT_INTERVAL_MS", "60000") + try: + metric_interval_ms = int(metric_interval_str) + except ValueError: + logger.warning( + "Invalid AOT_METRIC_EXPORT_INTERVAL_MS %r, falling back to 60000", metric_interval_str + ) + metric_interval_ms = 60000 + + metric_exporter = SanitizingMetricExporter( + OTLPMetricExporter( + headers=dict(header.split("=", 1) for header in otlp_headers.split(",")) + if otlp_headers + else None, + ) + ) + metric_reader = PeriodicExportingMetricReader( + metric_exporter, + export_interval_millis=metric_interval_ms, + ) + meter_provider = MeterProvider(resource=resource, readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + # Initialize Traceloop/OpenLLMetry in metadata-only mode. Advanced AI features # (prompt logging, workflows, tools) are intentionally deferred. # Skip silently when no API key is configured — Traceloop is optional. diff --git a/tests/test_audit_trail.py b/tests/test_audit_trail.py new file mode 100644 index 0000000..9717b9e --- /dev/null +++ b/tests/test_audit_trail.py @@ -0,0 +1,148 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for per-invocation audit log events emitted by BaseConnector.run().""" + +from __future__ import annotations + +import logging +from typing import Literal + +import pytest +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.policy import PolicyContext, PolicyDenied, PolicyHook + + +class _In(BaseModel): + action: Literal["go"] = "go" + + +class _Out(BaseModel): + ok: bool + + +class _AuditConnector(BaseConnector): + connector_id = "audit_test" + output_model = _Out + + @nw_action("go") + async def go(self, params: _In, *, trace_id: str) -> _Out: + return _Out(ok=True) + + +class _BoomConnector(BaseConnector): + connector_id = "audit_boom" + output_model = _Out + + @nw_action("go") + async def go(self, params: _In, *, trace_id: str) -> _Out: + raise RuntimeError("something broke") + + +class _DenyAll(PolicyHook): + def check(self, ctx: PolicyContext) -> None: + raise PolicyDenied("not allowed") + + +@pytest.mark.asyncio +async def test_audit_invocation_start_emitted(caplog: pytest.LogCaptureFixture) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.INFO, logger="runtime.base_connector"): + await connector.run( + {"action": "go"}, + principal="alice", + tenant_id="t1", + scopes=("scope:read",), + ) + start_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_start" + ] + assert len(start_records) == 1 + rec = start_records[0] + assert rec.__dict__.get("audit") is True + assert rec.__dict__.get("connector_id") == "audit_test" + assert rec.__dict__.get("principal") == "alice" + assert rec.__dict__.get("tenant_id") == "t1" + assert "scope:read" in rec.__dict__.get("scopes", []) + + +@pytest.mark.asyncio +async def test_audit_invocation_success_emitted(caplog: pytest.LogCaptureFixture) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.INFO, logger="runtime.base_connector"): + resp = await connector.run({"action": "go"}) + assert resp.success is True + success_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_success" + ] + assert len(success_records) == 1 + rec = success_records[0] + assert rec.__dict__.get("audit") is True + assert rec.__dict__.get("connector_id") == "audit_test" + assert "duration_ms" in rec.__dict__ + + +@pytest.mark.asyncio +async def test_audit_invocation_failure_emitted(caplog: pytest.LogCaptureFixture) -> None: + connector = _BoomConnector() + with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): + resp = await connector.run({"action": "go"}) + assert resp.success is False + failure_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_failure" + ] + assert len(failure_records) == 1 + rec = failure_records[0] + assert rec.__dict__.get("audit") is True + assert rec.__dict__.get("connector_id") == "audit_boom" + assert "duration_ms" in rec.__dict__ + + +@pytest.mark.asyncio +async def test_audit_invocation_validation_failure_emitted( + caplog: pytest.LogCaptureFixture, +) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): + resp = await connector.run({"action": "go", "unexpected_extra_field_that_breaks": True}) + # Validation may succeed (extra fields tolerated by default) or fail depending on model config. + # Use a clearly wrong payload instead. + with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): + resp = await connector.run({"action": "nonexistent_action"}) + assert resp.success is False + validation_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_validation_failure" + ] + assert len(validation_records) >= 1 + rec = validation_records[0] + assert rec.__dict__.get("audit") is True + + +@pytest.mark.asyncio +async def test_audit_policy_denial_unchanged(caplog: pytest.LogCaptureFixture) -> None: + connector = _AuditConnector(policy_hook=_DenyAll()) + with caplog.at_level(logging.WARNING, logger="runtime.base_connector"): + resp = await connector.run({"action": "go"}) + assert resp.success is False + denial_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "policy_denial" + ] + assert len(denial_records) == 1 + assert denial_records[0].__dict__.get("audit") is True + + +@pytest.mark.asyncio +async def test_audit_start_includes_empty_scopes_when_none( + caplog: pytest.LogCaptureFixture, +) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.INFO, logger="runtime.base_connector"): + await connector.run({"action": "go"}) + start_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_start" + ] + assert len(start_records) == 1 + assert start_records[0].__dict__.get("scopes") == [] diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..d93816c --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,119 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for OpenTelemetry metric recording in BaseConnector.run().""" + +from __future__ import annotations + +from typing import Literal +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, nw_action +import node_wire_runtime.base_connector as bc_module + + +class _MIn(BaseModel): + action: Literal["run"] = "run" + + +class _MOut(BaseModel): + done: bool + + +class _MetricConnector(BaseConnector): + connector_id = "metric_test" + output_model = _MOut + + @nw_action("run") + async def run_action(self, params: _MIn, *, trace_id: str) -> _MOut: + return _MOut(done=True) + + +class _MetricFailConnector(BaseConnector): + connector_id = "metric_fail" + output_model = _MOut + + @nw_action("run") + async def run_action(self, params: _MIn, *, trace_id: str) -> _MOut: + raise ValueError("intentional failure") + + +@pytest.mark.asyncio +async def test_invocation_counter_incremented_on_success() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + resp = await connector.run({"action": "run"}) + assert resp.success is True + mock_counter.add.assert_called_once() + call_attrs = mock_counter.add.call_args[1]["attributes"] + assert call_attrs["connector.id"] == "metric_test" + assert call_attrs["success"] is True + assert call_attrs["error_category"] == "none" + + +@pytest.mark.asyncio +async def test_invocation_histogram_recorded_on_success() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + await connector.run({"action": "run"}) + mock_histogram.record.assert_called_once() + duration_val = mock_histogram.record.call_args[0][0] + assert duration_val >= 0 + + +@pytest.mark.asyncio +async def test_invocation_counter_incremented_on_failure() -> None: + connector = _MetricFailConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + resp = await connector.run({"action": "run"}) + assert resp.success is False + mock_counter.add.assert_called_once() + call_attrs = mock_counter.add.call_args[1]["attributes"] + assert call_attrs["success"] is False + assert call_attrs["error_category"] != "none" + + +@pytest.mark.asyncio +async def test_invocation_counter_incremented_on_validation_failure() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + resp = await connector.run({"action": "nonexistent"}) + assert resp.success is False + mock_counter.add.assert_called_once() + call_attrs = mock_counter.add.call_args[1]["attributes"] + assert call_attrs["success"] is False + + +@pytest.mark.asyncio +async def test_metric_attributes_include_connector_action() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + with patch.object(bc_module, "_invocation_counter", mock_counter): + with patch.object(bc_module, "_invocation_duration", MagicMock()): + await connector.run({"action": "run"}) + attrs = mock_counter.add.call_args[1]["attributes"] + assert attrs["connector.action"] == "execute" diff --git a/tests/test_observability.py b/tests/test_observability.py index f74feea..1019f0b 100644 --- a/tests/test_observability.py +++ b/tests/test_observability.py @@ -60,10 +60,14 @@ def _observability_test_patches(): with _ensure_traceloop_stub_modules(): with ( patch("opentelemetry.trace.set_tracer_provider"), + patch("opentelemetry.metrics.set_meter_provider"), patch("node_wire_runtime.observability.OTLPSpanExporter") as span_exp, patch("node_wire_runtime.observability.OTLPLogExporter") as log_exp, + patch("node_wire_runtime.observability.OTLPMetricExporter") as metric_exp, patch("node_wire_runtime.observability.BatchSpanProcessor"), patch("node_wire_runtime.observability.BatchLogRecordProcessor"), + patch("node_wire_runtime.observability.PeriodicExportingMetricReader"), + patch("node_wire_runtime.observability.MeterProvider"), patch("node_wire_runtime.observability.set_logger_provider"), patch( "node_wire_runtime.observability.LoggingHandler", @@ -72,23 +76,24 @@ def _observability_test_patches(): patch("traceloop.sdk.Traceloop") as mock_tl, ): mock_tl.init = MagicMock() - yield span_exp, log_exp, mock_tl + yield span_exp, log_exp, mock_tl, metric_exp def test_init_observability_idempotent() -> None: """Second call should not reconfigure exporters.""" - with _observability_test_patches() as (span_exp, log_exp, _mock_tl): + with _observability_test_patches() as (span_exp, log_exp, _mock_tl, metric_exp): obs.init_observability("app-a") obs.init_observability("app-b") assert span_exp.call_count == 1 assert log_exp.call_count == 1 + assert metric_exp.call_count == 1 def test_init_observability_invalid_sampling_ratio_logs_warning( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: monkeypatch.setenv("AOT_TRACING_SAMPLING_RATIO", "not-a-number") - with _observability_test_patches(): + with _observability_test_patches() as _: with caplog.at_level(logging.WARNING, logger="runtime.observability"): obs.init_observability("app-warn") assert any("Invalid AOT_TRACING_SAMPLING_RATIO" in r.message for r in caplog.records) @@ -98,11 +103,12 @@ def test_init_observability_otel_headers_passed_to_exporters( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_HEADERS", "key=value,foo=bar") - with _observability_test_patches() as (span_exp, log_exp, _mock_tl): + with _observability_test_patches() as (span_exp, log_exp, _mock_tl, metric_exp): obs.init_observability("app-h") expected_headers = {"key": "value", "foo": "bar"} assert span_exp.call_args.kwargs.get("headers") == expected_headers assert log_exp.call_args.kwargs.get("headers") == expected_headers + assert metric_exp.call_args.kwargs.get("headers") == expected_headers def test_otel_context_filter_sets_empty_trace_when_no_span() -> None: @@ -116,7 +122,7 @@ def test_otel_context_filter_sets_empty_trace_when_no_span() -> None: def test_init_observability_installs_sanitizing_log_filter() -> None: - with _observability_test_patches(): + with _observability_test_patches() as _: obs.init_observability("app-filter") root = logging.getLogger() assert any(isinstance(flt, SanitizingLogFilter) for flt in root.filters) @@ -127,8 +133,26 @@ def test_init_observability_traceloop_failure_does_not_raise( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("TRACELOOP_API_KEY", "test-key") - with _observability_test_patches() as (_s, _l, mock_tl): + with _observability_test_patches() as (_s, _l, mock_tl, _m): mock_tl.init = MagicMock(side_effect=RuntimeError("traceloop unavailable")) with caplog.at_level(logging.WARNING, logger="runtime.observability"): obs.init_observability("app-tl") assert any("Failed to initialize Traceloop" in r.message for r in caplog.records) + + +def test_init_observability_metric_exporter_initialized() -> None: + """MeterProvider and OTLPMetricExporter should be set up on first call.""" + with _observability_test_patches() as (_span, _log, _tl, metric_exp): + obs.init_observability("app-metrics") + assert metric_exp.call_count == 1 + + +def test_init_observability_invalid_metric_interval_logs_warning( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + monkeypatch.setenv("AOT_METRIC_EXPORT_INTERVAL_MS", "not-a-number") + with _observability_test_patches() as _: + with caplog.at_level(logging.WARNING, logger="runtime.observability"): + obs.init_observability("app-metric-warn") + assert any("Invalid AOT_METRIC_EXPORT_INTERVAL_MS" in r.message for r in caplog.records) From 3f7745e957dffc56e702d98ce8322d96c899e2e1 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sun, 5 Jul 2026 03:11:31 -0700 Subject: [PATCH 5/9] Added metrics --- src/node_wire_runtime/base_connector.py | 10 ++-- src/node_wire_runtime/rate_limit.py | 10 ++++ src/node_wire_runtime/resilience.py | 25 ++++++++ tests/test_metrics.py | 77 ++++++++++++++++++++++++- 4 files changed, 117 insertions(+), 5 deletions(-) diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py index 159577d..172d4cc 100644 --- a/src/node_wire_runtime/base_connector.py +++ b/src/node_wire_runtime/base_connector.py @@ -549,7 +549,11 @@ async def run( ) return _response - execute_with_resilience = with_resilience(self._breaker_for_tenant(tenant_id)) + execute_with_resilience = with_resilience( + self._breaker_for_tenant(tenant_id), + connector_id=self.connector_id, + action=self.action, + ) @execute_with_resilience async def _do_execute(*, trace_id: str) -> Any: @@ -629,9 +633,7 @@ async def _do_execute(*, trace_id: str) -> Any: "connector.action": self.action, "success": _response.success, "error_category": ( - _response.error_category.value - if _response.error_category - else "none" + _response.error_category.value if _response.error_category else "none" ), } _invocation_counter.add(1, attributes=_metric_attrs) diff --git a/src/node_wire_runtime/rate_limit.py b/src/node_wire_runtime/rate_limit.py index d338fa8..ffb6a05 100644 --- a/src/node_wire_runtime/rate_limit.py +++ b/src/node_wire_runtime/rate_limit.py @@ -15,6 +15,15 @@ import os import time +from opentelemetry import metrics + +_meter = metrics.get_meter("runtime") +_rate_limit_rejections = _meter.create_counter( + "connector.rate_limit_rejections", + unit="1", + description="Requests rejected by the global token-bucket rate limiter", +) + class RateLimitExceeded(Exception): """Raised when the rate limit has been exceeded.""" @@ -50,6 +59,7 @@ async def acquire(self, amount: int = 1) -> None: if self.tokens >= amount: self.tokens -= amount else: + _rate_limit_rejections.add(1) raise RateLimitExceeded("Global rate limit exceeded. Please try again later.") diff --git a/src/node_wire_runtime/resilience.py b/src/node_wire_runtime/resilience.py index ce4053b..ff59438 100644 --- a/src/node_wire_runtime/resilience.py +++ b/src/node_wire_runtime/resilience.py @@ -8,6 +8,7 @@ from functools import wraps from typing import Any, Awaitable, Callable, Coroutine, TypeVar +from opentelemetry import metrics from pybreaker import CircuitBreaker, CircuitBreakerError from tenacity import ( AsyncRetrying, @@ -21,6 +22,18 @@ logger = logging.getLogger("runtime.resilience") +_meter = metrics.get_meter("runtime") +_retry_counter = _meter.create_counter( + "connector.retries", + unit="1", + description="Retryable errors encountered during execution (triggers a retry unless attempts are exhausted)", +) +_circuit_breaker_rejections = _meter.create_counter( + "connector.circuit_breaker_rejections", + unit="1", + description="Calls rejected because the tenant circuit breaker was open", +) + T = TypeVar("T") @@ -98,11 +111,17 @@ def with_resilience( max_attempts: int = 3, base_wait: float = 0.5, max_wait: float = 5.0, + connector_id: str = "", + action: str = "", ) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Coroutine[Any, Any, T]]]: """ Decorator that applies retry (Tenacity) and circuit breaking (PyBreaker) around an async function that may raise exceptions. + + ``connector_id`` and ``action`` are low-cardinality labels attached to the + retry and circuit-breaker metrics; they carry no request-specific identity. """ + _metric_attrs = {"connector.id": connector_id, "connector.action": action} def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Coroutine[Any, Any, T]]: @wraps(fn) @@ -121,6 +140,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> T: return await _run_through_breaker(breaker, fn, args, kwargs, trace_id) except Exception as exc: # noqa: BLE001 mapped = ErrorMapper.resolve(exc) + if isinstance(exc, CircuitBreakerError): + _circuit_breaker_rejections.add(1, attributes=_metric_attrs) if mapped.category is not ErrorCategory.RETRYABLE: # Non-retryable: log, then escape the retry loop entirely. logger.error( @@ -135,6 +156,10 @@ async def wrapper(*args: Any, **kwargs: Any) -> T: ) raise _AbortRetry(exc) + _retry_counter.add( + 1, + attributes={**_metric_attrs, "error_code": mapped.code}, + ) logger.warning( "Retryable error during execution; will retry", extra={ diff --git a/tests/test_metrics.py b/tests/test_metrics.py index d93816c..077763f 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -12,8 +12,13 @@ import pytest from pydantic import BaseModel -from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime import BaseConnector, ErrorCategory, ErrorMapper, nw_action import node_wire_runtime.base_connector as bc_module +import node_wire_runtime.rate_limit as rl_module +import node_wire_runtime.resilience as res_module +from node_wire_runtime.rate_limit import RateLimitExceeded, TokenBucket +from node_wire_runtime.resilience import with_resilience +from pybreaker import CircuitBreaker class _MIn(BaseModel): @@ -117,3 +122,73 @@ async def test_metric_attributes_include_connector_action() -> None: await connector.run({"action": "run"}) attrs = mock_counter.add.call_args[1]["attributes"] assert attrs["connector.action"] == "execute" + + +class _RetryableMetricError(Exception): + pass + + +@pytest.fixture +def _register_retryable() -> None: + ErrorMapper.register(_RetryableMetricError, ErrorCategory.RETRYABLE, code="RETRYABLE_METRIC") + try: + yield + finally: + ErrorMapper._registry.pop(_RetryableMetricError, None) + + +@pytest.mark.asyncio +async def test_retry_counter_incremented_on_retryable_error(_register_retryable: None) -> None: + attempts = {"n": 0} + + @with_resilience(CircuitBreaker(), connector_id="retry_cx", action="do") + async def flaky(*, trace_id: str = "t") -> str: + attempts["n"] += 1 + if attempts["n"] < 2: + raise _RetryableMetricError("boom") + return "ok" + + mock_counter = MagicMock() + with patch.object(res_module, "_retry_counter", mock_counter): + result = await flaky(trace_id="t1") + + assert result == "ok" + # One retryable failure occurred before the successful attempt. + mock_counter.add.assert_called_once() + attrs = mock_counter.add.call_args[1]["attributes"] + assert attrs["connector.id"] == "retry_cx" + assert attrs["connector.action"] == "do" + assert attrs["error_code"] == "RETRYABLE_METRIC" + + +@pytest.mark.asyncio +async def test_circuit_breaker_rejection_counter_incremented() -> None: + breaker = CircuitBreaker() + breaker.open() + + @with_resilience(breaker, connector_id="cb_cx", action="do") + async def never_runs(*, trace_id: str = "t") -> str: + return "unreachable" + + mock_counter = MagicMock() + with patch.object(res_module, "_circuit_breaker_rejections", mock_counter): + with pytest.raises(Exception): + await never_runs(trace_id="t1") + + mock_counter.add.assert_called_once() + attrs = mock_counter.add.call_args[1]["attributes"] + assert attrs["connector.id"] == "cb_cx" + assert attrs["connector.action"] == "do" + + +@pytest.mark.asyncio +async def test_rate_limit_rejection_counter_incremented() -> None: + bucket = TokenBucket(capacity=1, refill_rate=0) + await bucket.acquire() # consumes the only token + + mock_counter = MagicMock() + with patch.object(rl_module, "_rate_limit_rejections", mock_counter): + with pytest.raises(RateLimitExceeded): + await bucket.acquire() + + mock_counter.add.assert_called_once() From bb532c94b01320744b107f2fe1de025c9bd57452 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sun, 5 Jul 2026 03:15:13 -0700 Subject: [PATCH 6/9] Fixed mypy errors --- src/node_wire_runtime/base_connector.py | 3 ++- src/node_wire_runtime/observability.py | 4 ++-- tests/test_observability.py | 6 +++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py index 172d4cc..e217c40 100644 --- a/src/node_wire_runtime/base_connector.py +++ b/src/node_wire_runtime/base_connector.py @@ -29,6 +29,7 @@ from opentelemetry import metrics, trace from opentelemetry.trace import Tracer +from opentelemetry.util.types import AttributeValue from pybreaker import CircuitBreaker from pydantic import BaseModel, Field, RootModel, ValidationError @@ -628,7 +629,7 @@ async def _do_execute(*, trace_id: str) -> Any: _caller_execution_ctx.reset(token) if _response is not None: _duration_ms = (time.monotonic() - _start) * 1000 - _metric_attrs = { + _metric_attrs: Dict[str, AttributeValue] = { "connector.id": self.connector_id, "connector.action": self.action, "success": _response.success, diff --git a/src/node_wire_runtime/observability.py b/src/node_wire_runtime/observability.py index 0884fcb..a1a0f0d 100644 --- a/src/node_wire_runtime/observability.py +++ b/src/node_wire_runtime/observability.py @@ -114,7 +114,7 @@ def export(self, metrics_data, timeout_millis=10_000, **kwargs): def shutdown(self, timeout_millis=30_000, **kwargs): return self._delegate.shutdown(timeout_millis=timeout_millis, **kwargs) - def force_flush(self, timeout_millis: int = 10_000): + def force_flush(self, timeout_millis: float = 10_000): return self._delegate.force_flush(timeout_millis) @@ -207,7 +207,7 @@ def init_observability(app_name: str = "node_wire") -> None: metric_exporter, export_interval_millis=metric_interval_ms, ) - meter_provider = MeterProvider(resource=resource, readers=[metric_reader]) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) # Initialize Traceloop/OpenLLMetry in metadata-only mode. Advanced AI features diff --git a/tests/test_observability.py b/tests/test_observability.py index 1019f0b..e7a6bd4 100644 --- a/tests/test_observability.py +++ b/tests/test_observability.py @@ -143,8 +143,12 @@ def test_init_observability_traceloop_failure_does_not_raise( def test_init_observability_metric_exporter_initialized() -> None: """MeterProvider and OTLPMetricExporter should be set up on first call.""" with _observability_test_patches() as (_span, _log, _tl, metric_exp): - obs.init_observability("app-metrics") + with patch("node_wire_runtime.observability.MeterProvider") as meter_provider: + obs.init_observability("app-metrics") assert metric_exp.call_count == 1 + # The reader must be passed via the SDK's real kwarg name (`metric_readers`); + # `readers=` raises TypeError at runtime and would silently drop metrics export. + assert "metric_readers" in meter_provider.call_args.kwargs def test_init_observability_invalid_metric_interval_logs_warning( From ac82ee5d0bcf9042ee2751062589c15586875270 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sun, 5 Jul 2026 03:26:04 -0700 Subject: [PATCH 7/9] Removed unused imports --- tests/test_grpc_bindings.py | 3 --- tests/test_oauth2_extra.py | 2 -- tests/test_runtime_misc.py | 3 --- 3 files changed, 8 deletions(-) diff --git a/tests/test_grpc_bindings.py b/tests/test_grpc_bindings.py index 33fb1d4..a1a85b2 100644 --- a/tests/test_grpc_bindings.py +++ b/tests/test_grpc_bindings.py @@ -6,7 +6,6 @@ from __future__ import annotations -import os from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -18,8 +17,6 @@ GrpcAuthInterceptor, _extract_token, _truthy, - get_grpc_caller_identity, - verify_grpc_token_and_identity, ) from bindings.grpc_server.server import ConnectorServiceServicer from node_wire_runtime import ConnectorResponse, ErrorCategory diff --git a/tests/test_oauth2_extra.py b/tests/test_oauth2_extra.py index 4bc3bd9..3b5847c 100644 --- a/tests/test_oauth2_extra.py +++ b/tests/test_oauth2_extra.py @@ -127,8 +127,6 @@ async def test_oauth2_client_secret_post_success() -> None: mock_client.__aenter__ = MagicMock(return_value=mock_client) mock_client.__aexit__ = MagicMock(return_value=False) mock_client.post = MagicMock(return_value=mock_response) - # Make post awaitable - import asyncio mock_client.post = MagicMock(side_effect=lambda *a, **kw: mock_response) async def _fake_enter(_: Any) -> Any: diff --git a/tests/test_runtime_misc.py b/tests/test_runtime_misc.py index accac11..822be4a 100644 --- a/tests/test_runtime_misc.py +++ b/tests/test_runtime_misc.py @@ -8,7 +8,6 @@ from __future__ import annotations import logging -import sys from importlib.metadata import EntryPoint from unittest.mock import MagicMock, patch @@ -58,8 +57,6 @@ async def test_token_bucket_acquire_raises_when_exhausted() -> None: async def test_token_bucket_refills_over_time() -> None: - import time - bucket = TokenBucket(capacity=10, refill_rate=1000) await bucket.acquire(10) assert bucket.tokens == 0 From 1605e022ed88a7cc7e3b694110481c6c5337bf8d Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sun, 5 Jul 2026 03:29:38 -0700 Subject: [PATCH 8/9] Fixed lint issues --- tests/test_audit_trail.py | 8 ++-- tests/test_grpc_bindings.py | 4 +- tests/test_grpc_health.py | 4 +- tests/test_oauth2_extra.py | 70 ++++++++++++++++++--------------- tests/test_runtime_misc.py | 49 ++++++++++++----------- tests/test_secrets_providers.py | 4 +- 6 files changed, 73 insertions(+), 66 deletions(-) diff --git a/tests/test_audit_trail.py b/tests/test_audit_trail.py index 9717b9e..1cee2c6 100644 --- a/tests/test_audit_trail.py +++ b/tests/test_audit_trail.py @@ -114,7 +114,9 @@ async def test_audit_invocation_validation_failure_emitted( resp = await connector.run({"action": "nonexistent_action"}) assert resp.success is False validation_records = [ - r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_validation_failure" + r + for r in caplog.records + if r.__dict__.get("audit_event") == "invocation_validation_failure" ] assert len(validation_records) >= 1 rec = validation_records[0] @@ -127,9 +129,7 @@ async def test_audit_policy_denial_unchanged(caplog: pytest.LogCaptureFixture) - with caplog.at_level(logging.WARNING, logger="runtime.base_connector"): resp = await connector.run({"action": "go"}) assert resp.success is False - denial_records = [ - r for r in caplog.records if r.__dict__.get("audit_event") == "policy_denial" - ] + denial_records = [r for r in caplog.records if r.__dict__.get("audit_event") == "policy_denial"] assert len(denial_records) == 1 assert denial_records[0].__dict__.get("audit") is True diff --git a/tests/test_grpc_bindings.py b/tests/test_grpc_bindings.py index a1a85b2..d239ea2 100644 --- a/tests/test_grpc_bindings.py +++ b/tests/test_grpc_bindings.py @@ -309,9 +309,7 @@ async def mock_run(payload: Any, **kwargs: Any) -> ConnectorResponse: async def test_invoke_empty_payload_json(servicer: ConnectorServiceServicer) -> None: fake_connector = MagicMock() - fake_connector.run = AsyncMock( - return_value=ConnectorResponse(success=True, trace_id="t4") - ) + fake_connector.run = AsyncMock(return_value=ConnectorResponse(success=True, trace_id="t4")) with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): req = connector_pb2.InvokeRequest(connector_id="x", action="act") resp = await servicer._invoke_async(req) diff --git a/tests/test_grpc_health.py b/tests/test_grpc_health.py index 693f425..b282c4e 100644 --- a/tests/test_grpc_health.py +++ b/tests/test_grpc_health.py @@ -56,4 +56,6 @@ def fake_add_connector(servicer, server): assert "" in set_calls, "Overall health (empty string) not set" assert "aot.connectors.ConnectorService" in set_calls, "ConnectorService health not set" assert set_calls[""] == health_pb2.HealthCheckResponse.SERVING - assert set_calls["aot.connectors.ConnectorService"] == health_pb2.HealthCheckResponse.SERVING + assert ( + set_calls["aot.connectors.ConnectorService"] == health_pb2.HealthCheckResponse.SERVING + ) diff --git a/tests/test_oauth2_extra.py b/tests/test_oauth2_extra.py index 3b5847c..6cdf500 100644 --- a/tests/test_oauth2_extra.py +++ b/tests/test_oauth2_extra.py @@ -60,11 +60,13 @@ def test_oauth2_invalid_grant_method_raises() -> None: def test_resolve_scopes_from_secret() -> None: - sp = _DictSecretProvider({ - "token_url": "x", - "client_id": "y", - "scope_val": "openid profile", - }) + sp = _DictSecretProvider( + { + "token_url": "x", + "client_id": "y", + "scope_val": "openid profile", + } + ) provider = OAuth2AuthProvider( secret_provider=sp, grant_method="client_secret_post", @@ -192,12 +194,14 @@ async def fake_post_token(url: str, data: dict) -> dict: async def test_oauth2_refresh_token_success() -> None: - sp = _DictSecretProvider({ - "token_url": "https://idp.example.com/token", - "client_id": "my-client", - "client_secret": "sec", - "refresh_token": "rt-123", - }) + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "sec", + "refresh_token": "rt-123", + } + ) provider = OAuth2AuthProvider( secret_provider=sp, grant_method="refresh_token", @@ -224,11 +228,13 @@ async def fake_post(url: str, data: dict) -> dict: async def test_oauth2_refresh_token_without_client_secret() -> None: - sp = _DictSecretProvider({ - "token_url": "https://idp.example.com/token", - "client_id": "my-client", - "refresh_token": "rt-456", - }) + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "refresh_token": "rt-456", + } + ) provider = OAuth2AuthProvider( secret_provider=sp, grant_method="refresh_token", @@ -267,12 +273,14 @@ async def test_oauth2_refresh_token_missing_secret_raises() -> None: async def test_oauth2_private_key_jwt_invalid_key_raises() -> None: - sp = _DictSecretProvider({ - "token_url": "https://idp.example.com/token", - "client_id": "client", - "private_key": "not-a-valid-pem-key", - "kid": "key-1", - }) + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "client", + "private_key": "not-a-valid-pem-key", + "kid": "key-1", + } + ) provider = OAuth2AuthProvider( secret_provider=sp, grant_method="private_key_jwt", @@ -311,12 +319,14 @@ async def test_oauth2_private_key_jwt_success_with_rsa_key() -> None: encryption_algorithm=serialization.NoEncryption(), ).decode() - sp = _DictSecretProvider({ - "token_url": "https://idp.example.com/token", - "client_id": "client-id", - "private_key": private_pem, - "kid": "rsa-key-1", - }) + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "client-id", + "private_key": private_pem, + "kid": "rsa-key-1", + } + ) provider = OAuth2AuthProvider( secret_provider=sp, grant_method="private_key_jwt", @@ -348,9 +358,7 @@ async def fake_post(url: str, data: dict) -> dict: async def test_post_token_non_200_raises() -> None: - transport = httpx.MockTransport( - handler=lambda req: httpx.Response(401, text="Unauthorized") - ) + transport = httpx.MockTransport(handler=lambda req: httpx.Response(401, text="Unauthorized")) with patch( "node_wire_runtime.auth.oauth2.httpx.AsyncClient", diff --git a/tests/test_runtime_misc.py b/tests/test_runtime_misc.py index 822be4a..0fbdbb7 100644 --- a/tests/test_runtime_misc.py +++ b/tests/test_runtime_misc.py @@ -133,9 +133,7 @@ async def source(): for i in range(3): yield {"i": i} - items = await _collect_async( - BufferedStreamIterator(source(), buffer_ms=0, trace_id="t1") - ) + items = await _collect_async(BufferedStreamIterator(source(), buffer_ms=0, trace_id="t1")) assert items == [{"i": 0}, {"i": 1}, {"i": 2}] @@ -144,9 +142,7 @@ async def source(): for i in range(4): yield {"i": i} - items = await _collect_async( - BufferedStreamIterator(source(), buffer_ms=1000, trace_id="t2") - ) + items = await _collect_async(BufferedStreamIterator(source(), buffer_ms=1000, trace_id="t2")) assert items == [{"i": 0}, {"i": 1}, {"i": 2}, {"i": 3}] @@ -179,9 +175,7 @@ async def source(): yield {"i": i} with patch("node_wire_runtime.streaming.time.monotonic", side_effect=patched_monotonic): - items = await _collect_async( - BufferedStreamIterator(source(), buffer_ms=100, trace_id="t4") - ) + items = await _collect_async(BufferedStreamIterator(source(), buffer_ms=100, trace_id="t4")) assert len(items) == 3 @@ -260,8 +254,12 @@ def test_redact_sensitive_string_arg_safe() -> None: def test_sanitize_log_record_dict_args() -> None: record = logging.LogRecord( - name="test", level=logging.INFO, pathname=__file__, lineno=1, - msg="%(patient)s %(action)s", args={"patient": "Smith", "action": "send"}, + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="%(patient)s %(action)s", + args={"patient": "Smith", "action": "send"}, exc_info=None, ) sanitize_log_record(record) @@ -271,8 +269,12 @@ def test_sanitize_log_record_dict_args() -> None: def test_sanitize_log_record_tuple_args() -> None: record = logging.LogRecord( - name="test", level=logging.INFO, pathname=__file__, lineno=1, - msg="%s %s", args=("phi_marker_data", "safe"), + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="%s %s", + args=("phi_marker_data", "safe"), exc_info=None, ) sanitize_log_record(record) @@ -282,8 +284,13 @@ def test_sanitize_log_record_tuple_args() -> None: def test_sanitize_log_record_no_args() -> None: record = logging.LogRecord( - name="test", level=logging.INFO, pathname=__file__, lineno=1, - msg="plain message", args=(), exc_info=None, + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="plain message", + args=(), + exc_info=None, ) sanitize_log_record(record) # should not raise @@ -447,9 +454,7 @@ def test_registration_module_missing_is_silently_skipped( ) -> None: """ModuleNotFoundError with name == reg_name is silently ignored.""" monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn") - ep = EntryPoint( - name="myconn", value="node_wire_myconn.logic", group="node_wire.connectors" - ) + ep = EntryPoint(name="myconn", value="node_wire_myconn.logic", group="node_wire.connectors") def fake_import(name: str) -> MagicMock: if name == "node_wire_myconn.logic": @@ -471,9 +476,7 @@ def fake_import(name: str) -> MagicMock: def test_registration_dep_error_is_reraised(monkeypatch: pytest.MonkeyPatch) -> None: """ModuleNotFoundError for a dep inside registration module is re-raised.""" monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn2") - ep = EntryPoint( - name="myconn2", value="node_wire_myconn2.logic", group="node_wire.connectors" - ) + ep = EntryPoint(name="myconn2", value="node_wire_myconn2.logic", group="node_wire.connectors") def fake_import(name: str) -> MagicMock: if name == "node_wire_myconn2.logic": @@ -492,9 +495,7 @@ def fake_import(name: str) -> MagicMock: def test_registration_unexpected_exception_is_reraised(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn3") - ep = EntryPoint( - name="myconn3", value="node_wire_myconn3.logic", group="node_wire.connectors" - ) + ep = EntryPoint(name="myconn3", value="node_wire_myconn3.logic", group="node_wire.connectors") def fake_import(name: str) -> MagicMock: if name == "node_wire_myconn3.logic": diff --git a/tests/test_secrets_providers.py b/tests/test_secrets_providers.py index ca75fad..2521faa 100644 --- a/tests/test_secrets_providers.py +++ b/tests/test_secrets_providers.py @@ -110,9 +110,7 @@ def test_aws_resource_not_found_maps_to_secret_not_found(aws) -> None: def test_aws_other_client_error_maps_to_provider_error(aws) -> None: mod, state, exc = aws - state.error = exc.ClientError( - {"Error": {"Code": "AccessDeniedException"}}, "GetSecretValue" - ) + state.error = exc.ClientError({"Error": {"Code": "AccessDeniedException"}}, "GetSecretValue") with pytest.raises(SecretProviderError, match="AccessDeniedException"): mod.AwsSecretsManagerProvider("forbidden-bundle") From 40c26ab4c2e4de2dbff9ce56d77859c7421f0653 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Sun, 5 Jul 2026 03:39:45 -0700 Subject: [PATCH 9/9] Fix codeql comments --- tests/test_audit_trail.py | 2 +- tests/test_grpc_health.py | 2 +- tests/test_metrics.py | 10 ++++------ tests/test_runtime_misc.py | 4 ++-- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/test_audit_trail.py b/tests/test_audit_trail.py index 1cee2c6..2557ca9 100644 --- a/tests/test_audit_trail.py +++ b/tests/test_audit_trail.py @@ -107,7 +107,7 @@ async def test_audit_invocation_validation_failure_emitted( ) -> None: connector = _AuditConnector() with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): - resp = await connector.run({"action": "go", "unexpected_extra_field_that_breaks": True}) + await connector.run({"action": "go", "unexpected_extra_field_that_breaks": True}) # Validation may succeed (extra fields tolerated by default) or fail depending on model config. # Use a clearly wrong payload instead. with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): diff --git a/tests/test_grpc_health.py b/tests/test_grpc_health.py index b282c4e..09191d4 100644 --- a/tests/test_grpc_health.py +++ b/tests/test_grpc_health.py @@ -48,7 +48,7 @@ def fake_add_connector(servicer, server): try: serve(port=0) except KeyboardInterrupt: - pass + pass # serve() raises KeyboardInterrupt on shutdown; suppress to let assertions run assert "health" in added_servicers, "HealthServicer was not added to the gRPC server" diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 077763f..8fa68ca 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -16,8 +16,6 @@ import node_wire_runtime.base_connector as bc_module import node_wire_runtime.rate_limit as rl_module import node_wire_runtime.resilience as res_module -from node_wire_runtime.rate_limit import RateLimitExceeded, TokenBucket -from node_wire_runtime.resilience import with_resilience from pybreaker import CircuitBreaker @@ -141,7 +139,7 @@ def _register_retryable() -> None: async def test_retry_counter_incremented_on_retryable_error(_register_retryable: None) -> None: attempts = {"n": 0} - @with_resilience(CircuitBreaker(), connector_id="retry_cx", action="do") + @res_module.with_resilience(CircuitBreaker(), connector_id="retry_cx", action="do") async def flaky(*, trace_id: str = "t") -> str: attempts["n"] += 1 if attempts["n"] < 2: @@ -166,7 +164,7 @@ async def test_circuit_breaker_rejection_counter_incremented() -> None: breaker = CircuitBreaker() breaker.open() - @with_resilience(breaker, connector_id="cb_cx", action="do") + @res_module.with_resilience(breaker, connector_id="cb_cx", action="do") async def never_runs(*, trace_id: str = "t") -> str: return "unreachable" @@ -183,12 +181,12 @@ async def never_runs(*, trace_id: str = "t") -> str: @pytest.mark.asyncio async def test_rate_limit_rejection_counter_incremented() -> None: - bucket = TokenBucket(capacity=1, refill_rate=0) + bucket = rl_module.TokenBucket(capacity=1, refill_rate=0) await bucket.acquire() # consumes the only token mock_counter = MagicMock() with patch.object(rl_module, "_rate_limit_rejections", mock_counter): - with pytest.raises(RateLimitExceeded): + with pytest.raises(rl_module.RateLimitExceeded): await bucket.acquire() mock_counter.add.assert_called_once() diff --git a/tests/test_runtime_misc.py b/tests/test_runtime_misc.py index 0fbdbb7..295dd33 100644 --- a/tests/test_runtime_misc.py +++ b/tests/test_runtime_misc.py @@ -13,6 +13,8 @@ import pytest +import node_wire_runtime.log_sanitization as ls + from node_wire_runtime import connector_registry from node_wire_runtime.log_sanitization import ( REDACTED, @@ -298,8 +300,6 @@ def test_sanitize_log_record_no_args() -> None: def test_install_sanitizing_filter_finds_existing_without_flag() -> None: """When a SanitizingLogFilter is already present but the flag is False, it should detect it and not add a duplicate.""" - import node_wire_runtime.log_sanitization as ls - original_flag = ls._SANITIZING_FILTER_INSTALLED original_filters = list(logging.getLogger().filters) try: