diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d0606b9..2fc14d6 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -63,6 +63,44 @@ jobs: if-no-files-found: ignore + patch-coverage: + name: Patch coverage (changed lines ≥ 80%) + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + # Full history so diff-cover can diff against the base branch. + fetch-depth: 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: "3.11" + + - name: Install uv + uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 + with: + enable-cache: true + cache-dependency-glob: | + pyproject.toml + uv.lock + + - name: Install dependencies + run: uv sync --frozen --all-extras --dev + + - name: Run pytest with coverage + run: uv run pytest --cov --cov-report=xml + + - name: Enforce 80% coverage on changed lines + run: > + uv run diff-cover coverage.xml + --compare-branch origin/${{ github.base_ref }} + --fail-under 80 + --show-uncovered + playground-integration: name: Playground integration tests runs-on: ubuntu-latest @@ -98,7 +136,7 @@ jobs: GOOGLE_DRIVE_SA_JSON: ${{ secrets.GOOGLE_DRIVE_SA_JSON }} GOOGLE_DRIVE_FOLDER_ID: ${{ secrets.GOOGLE_DRIVE_FOLDER_ID }} GDRIVE_TEST_RECIPIENT_EMAIL: ${{ secrets.GDRIVE_TEST_RECIPIENT_EMAIL }} - + #stripe STRIPE_API_KEY: ${{ secrets.STRIPE_API_KEY }} STRIPE_TEST_CUSTOMER_ID: ${{ secrets.STRIPE_TEST_CUSTOMER_ID }} diff --git a/pyproject.toml b/pyproject.toml index d04a89e..81db754 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "opentelemetry-instrumentation-fastapi>=0.45b0", "opentelemetry-exporter-otlp>=1.24.0", "traceloop-sdk>=0.53.0", + "grpcio-health-checking>=1.62.0", ] [project.urls] @@ -107,6 +108,7 @@ dev = [ "bandit[toml]>=1.7.9", "pre-commit>=4.0.0", "pytest-playwright>=0.4.0", + "diff-cover>=9.0.0", "cyclonedx-bom==4.6.1", "reuse>=5.0.0", "licenseheaders>=0.8.8", @@ -142,6 +144,9 @@ fail_under = 80 omit = [ "*/__pycache__/*", "*/node_wire.egg-info/*", + # Generated gRPC stubs — also excluded from Ruff, Mypy, and CodeQL. + "*_pb2.py", + "*_pb2_grpc.py", ] [tool.ruff] diff --git a/src/bindings/grpc_server/server.py b/src/bindings/grpc_server/server.py index 52cd335..5bb0c8b 100644 --- a/src/bindings/grpc_server/server.py +++ b/src/bindings/grpc_server/server.py @@ -11,6 +11,8 @@ from typing import Any import grpc +from grpc_health.v1 import health as grpc_health +from grpc_health.v1 import health_pb2, health_pb2_grpc from bindings.factory import ConnectorFactory from node_wire_runtime.connector_registry import auto_register @@ -117,6 +119,11 @@ def serve(port: int = 50051) -> None: host = resolve_grpc_host() configure_grpc_server_port(server, port=port, host=host, cert_path=cert_path, key_path=key_path) + health_servicer = grpc_health.HealthServicer() + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + health_servicer.set("", health_pb2.HealthCheckResponse.SERVING) + health_servicer.set("aot.connectors.ConnectorService", health_pb2.HealthCheckResponse.SERVING) + server.start() server.wait_for_termination() diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index 6fee523..15da3ab 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -128,6 +128,20 @@ async def health() -> Dict[str, str]: return {"status": "ok"} +@app.get("/ready", tags=["system"]) +async def ready() -> Dict[str, str]: + try: + factory = get_factory() + if not factory.list_for_protocol("rest") and not factory.list_for_protocol("grpc"): + raise HTTPException(status_code=503, detail="no connectors loaded") + except HTTPException: + raise + except Exception as exc: + logger.warning("Readiness check failed: %s", exc) + raise HTTPException(status_code=503, detail="factory not ready") + return {"status": "ready"} + + def _http_status_for_category(category: ErrorCategory | None) -> int: if category is None: return 200 diff --git a/src/node_wire_runtime/base_connector.py b/src/node_wire_runtime/base_connector.py index eff46b9..e217c40 100644 --- a/src/node_wire_runtime/base_connector.py +++ b/src/node_wire_runtime/base_connector.py @@ -7,6 +7,7 @@ import contextvars import inspect import logging +import time import uuid from abc import ABC from collections import defaultdict @@ -26,8 +27,9 @@ List, ) -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.trace import Tracer +from opentelemetry.util.types import AttributeValue from pybreaker import CircuitBreaker from pydantic import BaseModel, Field, RootModel, ValidationError @@ -41,6 +43,17 @@ logger = logging.getLogger("runtime.base_connector") tracer: Tracer = trace.get_tracer("runtime") +_meter = metrics.get_meter("runtime") +_invocation_counter = _meter.create_counter( + "connector.invocations", + unit="1", + description="Total number of connector invocations", +) +_invocation_duration = _meter.create_histogram( + "connector.duration_ms", + unit="ms", + description="Connector invocation wall-clock time in milliseconds", +) ErrorMapper.register(PolicyDenied, ErrorCategory.AUTH, code="POLICY_DENIED") @@ -441,6 +454,7 @@ async def run( - Maps exceptions into the standard error taxonomy """ trace_id = str(uuid.uuid4()) + _start = time.monotonic() with tracer.start_as_current_span( "connector.run", @@ -458,9 +472,15 @@ async def run( "trace_id": trace_id, "connector_id": self.connector_id, "action": self.action, + "principal": principal, + "tenant_id": tenant_id, + "scopes": list(scopes) if scopes else [], + "audit": True, + "audit_event": "invocation_start", }, ) + _response: Optional[ConnectorResponse] = None token = _caller_execution_ctx.set((principal, tenant_id, scopes)) try: try: @@ -474,12 +494,14 @@ async def run( "action": self.action, "error_type": type(exc).__name__, "error_message": str(exc), + "audit": True, + "audit_event": "invocation_validation_failure", }, ) details = [ {"loc": e["loc"], "msg": e["msg"], "type": e["type"]} for e in exc.errors() ] - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code="VALIDATION_ERROR", error_category=ErrorCategory.BUSINESS, @@ -487,6 +509,7 @@ async def run( trace_id=trace_id, details=details, ) + return _response # Policy hook if self._policy_hook is not None: @@ -518,15 +541,20 @@ async def run( }, ) mapped = ErrorMapper.resolve(exc) - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code=mapped.code, error_category=mapped.category, message=str(exc), trace_id=trace_id, ) + return _response - execute_with_resilience = with_resilience(self._breaker_for_tenant(tenant_id)) + execute_with_resilience = with_resilience( + self._breaker_for_tenant(tenant_id), + connector_id=self.connector_id, + action=self.action, + ) @execute_with_resilience async def _do_execute(*, trace_id: str) -> Any: @@ -540,14 +568,18 @@ async def _do_execute(*, trace_id: str) -> Any: "trace_id": trace_id, "connector_id": self.connector_id, "action": self.action, + "duration_ms": round((time.monotonic() - _start) * 1000, 2), + "audit": True, + "audit_event": "invocation_success", }, ) - return ConnectorResponse( + _response = ConnectorResponse( success=True, data=output_model.model_dump(), trace_id=trace_id, ) + return _response except NestedConnectorActionError as exc: nested = exc.response logger.warning( @@ -559,7 +591,7 @@ async def _do_execute(*, trace_id: str) -> Any: "nested_trace_id": nested.trace_id, }, ) - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code=nested.error_code, error_category=nested.error_category, @@ -567,6 +599,7 @@ async def _do_execute(*, trace_id: str) -> Any: trace_id=trace_id, details=_merge_nested_failure_details(nested), ) + return _response except Exception as exc: # noqa: BLE001 mapped = ErrorMapper.resolve(exc) logger.error( @@ -579,17 +612,33 @@ async def _do_execute(*, trace_id: str) -> Any: "error_category": mapped.category.value, "error_type": type(exc).__name__, "error_message": str(exc), + "duration_ms": round((time.monotonic() - _start) * 1000, 2), + "audit": True, + "audit_event": "invocation_failure", }, ) - return ConnectorResponse( + _response = ConnectorResponse( success=False, error_code=mapped.code, error_category=mapped.category, message=str(exc), trace_id=trace_id, ) + return _response finally: _caller_execution_ctx.reset(token) + if _response is not None: + _duration_ms = (time.monotonic() - _start) * 1000 + _metric_attrs: Dict[str, AttributeValue] = { + "connector.id": self.connector_id, + "connector.action": self.action, + "success": _response.success, + "error_category": ( + _response.error_category.value if _response.error_category else "none" + ), + } + _invocation_counter.add(1, attributes=_metric_attrs) + _invocation_duration.record(_duration_ms, attributes=_metric_attrs) @classmethod def get_registry(cls) -> Dict[str, Type[BaseConnector]]: diff --git a/src/node_wire_runtime/observability.py b/src/node_wire_runtime/observability.py index 2f70440..a1a0f0d 100644 --- a/src/node_wire_runtime/observability.py +++ b/src/node_wire_runtime/observability.py @@ -9,12 +9,15 @@ from typing import Optional, cast from opentelemetry._logs import set_logger_provider -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import MetricExporter, PeriodicExportingMetricReader from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased @@ -85,6 +88,36 @@ def force_flush(self, timeout_millis: int = 30000): return True +class SanitizingMetricExporter(MetricExporter): + def __init__(self, delegate: MetricExporter): + self._delegate = delegate + + @property + def _preferred_temporality(self): # type: ignore[override] + return self._delegate._preferred_temporality + + @property + def _preferred_aggregation(self): # type: ignore[override] + return self._delegate._preferred_aggregation + + def export(self, metrics_data, timeout_millis=10_000, **kwargs): + for rm in metrics_data.resource_metrics: + for sm in rm.scope_metrics: + for metric in sm.metrics: + if hasattr(metric.data, "data_points"): + for dp in metric.data.data_points: + attrs = getattr(dp, "attributes", None) + if attrs and isinstance(attrs, dict): + _sanitize_otlp_attributes(attrs) + return self._delegate.export(metrics_data, timeout_millis=timeout_millis, **kwargs) + + def shutdown(self, timeout_millis=30_000, **kwargs): + return self._delegate.shutdown(timeout_millis=timeout_millis, **kwargs) + + def force_flush(self, timeout_millis: float = 10_000): + return self._delegate.force_flush(timeout_millis) + + def init_observability(app_name: str = "node_wire") -> None: """ Initialize OpenTelemetry + OpenLLMetry/Traceloop for the process. @@ -153,6 +186,30 @@ def init_observability(app_name: str = "node_wire") -> None: root_logger.addFilter(_OtelContextFilter()) root_logger.addHandler(LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider)) + # Metrics: export to the local OTLP collector alongside traces and logs. + metric_interval_str: str = os.getenv("AOT_METRIC_EXPORT_INTERVAL_MS", "60000") + try: + metric_interval_ms = int(metric_interval_str) + except ValueError: + logger.warning( + "Invalid AOT_METRIC_EXPORT_INTERVAL_MS %r, falling back to 60000", metric_interval_str + ) + metric_interval_ms = 60000 + + metric_exporter = SanitizingMetricExporter( + OTLPMetricExporter( + headers=dict(header.split("=", 1) for header in otlp_headers.split(",")) + if otlp_headers + else None, + ) + ) + metric_reader = PeriodicExportingMetricReader( + metric_exporter, + export_interval_millis=metric_interval_ms, + ) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + # Initialize Traceloop/OpenLLMetry in metadata-only mode. Advanced AI features # (prompt logging, workflows, tools) are intentionally deferred. # Skip silently when no API key is configured — Traceloop is optional. diff --git a/src/node_wire_runtime/rate_limit.py b/src/node_wire_runtime/rate_limit.py index d338fa8..ffb6a05 100644 --- a/src/node_wire_runtime/rate_limit.py +++ b/src/node_wire_runtime/rate_limit.py @@ -15,6 +15,15 @@ import os import time +from opentelemetry import metrics + +_meter = metrics.get_meter("runtime") +_rate_limit_rejections = _meter.create_counter( + "connector.rate_limit_rejections", + unit="1", + description="Requests rejected by the global token-bucket rate limiter", +) + class RateLimitExceeded(Exception): """Raised when the rate limit has been exceeded.""" @@ -50,6 +59,7 @@ async def acquire(self, amount: int = 1) -> None: if self.tokens >= amount: self.tokens -= amount else: + _rate_limit_rejections.add(1) raise RateLimitExceeded("Global rate limit exceeded. Please try again later.") diff --git a/src/node_wire_runtime/resilience.py b/src/node_wire_runtime/resilience.py index ce4053b..ff59438 100644 --- a/src/node_wire_runtime/resilience.py +++ b/src/node_wire_runtime/resilience.py @@ -8,6 +8,7 @@ from functools import wraps from typing import Any, Awaitable, Callable, Coroutine, TypeVar +from opentelemetry import metrics from pybreaker import CircuitBreaker, CircuitBreakerError from tenacity import ( AsyncRetrying, @@ -21,6 +22,18 @@ logger = logging.getLogger("runtime.resilience") +_meter = metrics.get_meter("runtime") +_retry_counter = _meter.create_counter( + "connector.retries", + unit="1", + description="Retryable errors encountered during execution (triggers a retry unless attempts are exhausted)", +) +_circuit_breaker_rejections = _meter.create_counter( + "connector.circuit_breaker_rejections", + unit="1", + description="Calls rejected because the tenant circuit breaker was open", +) + T = TypeVar("T") @@ -98,11 +111,17 @@ def with_resilience( max_attempts: int = 3, base_wait: float = 0.5, max_wait: float = 5.0, + connector_id: str = "", + action: str = "", ) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Coroutine[Any, Any, T]]]: """ Decorator that applies retry (Tenacity) and circuit breaking (PyBreaker) around an async function that may raise exceptions. + + ``connector_id`` and ``action`` are low-cardinality labels attached to the + retry and circuit-breaker metrics; they carry no request-specific identity. """ + _metric_attrs = {"connector.id": connector_id, "connector.action": action} def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Coroutine[Any, Any, T]]: @wraps(fn) @@ -121,6 +140,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> T: return await _run_through_breaker(breaker, fn, args, kwargs, trace_id) except Exception as exc: # noqa: BLE001 mapped = ErrorMapper.resolve(exc) + if isinstance(exc, CircuitBreakerError): + _circuit_breaker_rejections.add(1, attributes=_metric_attrs) if mapped.category is not ErrorCategory.RETRYABLE: # Non-retryable: log, then escape the retry loop entirely. logger.error( @@ -135,6 +156,10 @@ async def wrapper(*args: Any, **kwargs: Any) -> T: ) raise _AbortRetry(exc) + _retry_counter.add( + 1, + attributes={**_metric_attrs, "error_code": mapped.code}, + ) logger.warning( "Retryable error during execution; will retry", extra={ diff --git a/tests/test_audit_trail.py b/tests/test_audit_trail.py new file mode 100644 index 0000000..2557ca9 --- /dev/null +++ b/tests/test_audit_trail.py @@ -0,0 +1,148 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for per-invocation audit log events emitted by BaseConnector.run().""" + +from __future__ import annotations + +import logging +from typing import Literal + +import pytest +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, nw_action +from node_wire_runtime.policy import PolicyContext, PolicyDenied, PolicyHook + + +class _In(BaseModel): + action: Literal["go"] = "go" + + +class _Out(BaseModel): + ok: bool + + +class _AuditConnector(BaseConnector): + connector_id = "audit_test" + output_model = _Out + + @nw_action("go") + async def go(self, params: _In, *, trace_id: str) -> _Out: + return _Out(ok=True) + + +class _BoomConnector(BaseConnector): + connector_id = "audit_boom" + output_model = _Out + + @nw_action("go") + async def go(self, params: _In, *, trace_id: str) -> _Out: + raise RuntimeError("something broke") + + +class _DenyAll(PolicyHook): + def check(self, ctx: PolicyContext) -> None: + raise PolicyDenied("not allowed") + + +@pytest.mark.asyncio +async def test_audit_invocation_start_emitted(caplog: pytest.LogCaptureFixture) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.INFO, logger="runtime.base_connector"): + await connector.run( + {"action": "go"}, + principal="alice", + tenant_id="t1", + scopes=("scope:read",), + ) + start_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_start" + ] + assert len(start_records) == 1 + rec = start_records[0] + assert rec.__dict__.get("audit") is True + assert rec.__dict__.get("connector_id") == "audit_test" + assert rec.__dict__.get("principal") == "alice" + assert rec.__dict__.get("tenant_id") == "t1" + assert "scope:read" in rec.__dict__.get("scopes", []) + + +@pytest.mark.asyncio +async def test_audit_invocation_success_emitted(caplog: pytest.LogCaptureFixture) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.INFO, logger="runtime.base_connector"): + resp = await connector.run({"action": "go"}) + assert resp.success is True + success_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_success" + ] + assert len(success_records) == 1 + rec = success_records[0] + assert rec.__dict__.get("audit") is True + assert rec.__dict__.get("connector_id") == "audit_test" + assert "duration_ms" in rec.__dict__ + + +@pytest.mark.asyncio +async def test_audit_invocation_failure_emitted(caplog: pytest.LogCaptureFixture) -> None: + connector = _BoomConnector() + with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): + resp = await connector.run({"action": "go"}) + assert resp.success is False + failure_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_failure" + ] + assert len(failure_records) == 1 + rec = failure_records[0] + assert rec.__dict__.get("audit") is True + assert rec.__dict__.get("connector_id") == "audit_boom" + assert "duration_ms" in rec.__dict__ + + +@pytest.mark.asyncio +async def test_audit_invocation_validation_failure_emitted( + caplog: pytest.LogCaptureFixture, +) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): + await connector.run({"action": "go", "unexpected_extra_field_that_breaks": True}) + # Validation may succeed (extra fields tolerated by default) or fail depending on model config. + # Use a clearly wrong payload instead. + with caplog.at_level(logging.ERROR, logger="runtime.base_connector"): + resp = await connector.run({"action": "nonexistent_action"}) + assert resp.success is False + validation_records = [ + r + for r in caplog.records + if r.__dict__.get("audit_event") == "invocation_validation_failure" + ] + assert len(validation_records) >= 1 + rec = validation_records[0] + assert rec.__dict__.get("audit") is True + + +@pytest.mark.asyncio +async def test_audit_policy_denial_unchanged(caplog: pytest.LogCaptureFixture) -> None: + connector = _AuditConnector(policy_hook=_DenyAll()) + with caplog.at_level(logging.WARNING, logger="runtime.base_connector"): + resp = await connector.run({"action": "go"}) + assert resp.success is False + denial_records = [r for r in caplog.records if r.__dict__.get("audit_event") == "policy_denial"] + assert len(denial_records) == 1 + assert denial_records[0].__dict__.get("audit") is True + + +@pytest.mark.asyncio +async def test_audit_start_includes_empty_scopes_when_none( + caplog: pytest.LogCaptureFixture, +) -> None: + connector = _AuditConnector() + with caplog.at_level(logging.INFO, logger="runtime.base_connector"): + await connector.run({"action": "go"}) + start_records = [ + r for r in caplog.records if r.__dict__.get("audit_event") == "invocation_start" + ] + assert len(start_records) == 1 + assert start_records[0].__dict__.get("scopes") == [] diff --git a/tests/test_factory_and_rest.py b/tests/test_factory_and_rest.py index 060eeae..8c83c8e 100644 --- a/tests/test_factory_and_rest.py +++ b/tests/test_factory_and_rest.py @@ -34,6 +34,35 @@ def test_health_endpoint(): assert resp.json() == {"status": "ok"} +def test_ready_endpoint_when_factory_has_connectors(monkeypatch: pytest.MonkeyPatch): + mock_factory = MagicMock() + mock_factory.list_for_protocol.return_value = [MagicMock()] + monkeypatch.setattr("bindings.rest_api.app._factory", mock_factory) + client = TestClient(app) + resp = client.get("/ready") + assert resp.status_code == 200 + assert resp.json() == {"status": "ready"} + + +def test_ready_endpoint_returns_503_when_no_connectors(monkeypatch: pytest.MonkeyPatch): + mock_factory = MagicMock() + mock_factory.list_for_protocol.return_value = [] + monkeypatch.setattr("bindings.rest_api.app._factory", mock_factory) + client = TestClient(app) + resp = client.get("/ready") + assert resp.status_code == 503 + + +def test_ready_endpoint_returns_503_on_factory_error(monkeypatch: pytest.MonkeyPatch): + def _boom(): + raise RuntimeError("factory load failed") + + monkeypatch.setattr("bindings.rest_api.app.get_factory", _boom) + client = TestClient(app) + resp = client.get("/ready") + assert resp.status_code == 503 + + def test_agent_transport_defaults_to_stdio(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("NW_MCP_TRANSPORT", raising=False) client = TestClient(app) diff --git a/tests/test_grpc_bindings.py b/tests/test_grpc_bindings.py new file mode 100644 index 0000000..d239ea2 --- /dev/null +++ b/tests/test_grpc_bindings.py @@ -0,0 +1,338 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for gRPC server servicer and auth interceptor.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import grpc +import pytest + +from bindings.grpc_server import connector_pb2 +from bindings.grpc_server.auth import ( + GrpcAuthInterceptor, + _extract_token, + _truthy, +) +from bindings.grpc_server.server import ConnectorServiceServicer +from node_wire_runtime import ConnectorResponse, ErrorCategory +from node_wire_runtime.rate_limit import RateLimitExceeded + + +# --------------------------------------------------------------------------- +# _truthy helper +# --------------------------------------------------------------------------- + + +def test_truthy_none_is_false() -> None: + assert _truthy(None) is False + + +def test_truthy_true_string() -> None: + for val in ("true", "TRUE", "1", "yes", "on", " True "): + assert _truthy(val) is True, val + + +def test_truthy_false_string() -> None: + for val in ("false", "0", "no", "off", ""): + assert _truthy(val) is False, val + + +# --------------------------------------------------------------------------- +# _extract_token helper +# --------------------------------------------------------------------------- + + +def test_extract_token_from_authorization_bearer() -> None: + meta = (("authorization", "Bearer mytoken123"),) + assert _extract_token(meta) == "mytoken123" + + +def test_extract_token_from_authorization_raw() -> None: + meta = (("authorization", "rawtoken"),) + assert _extract_token(meta) == "rawtoken" + + +def test_extract_token_from_x_api_key() -> None: + meta = (("x-api-key", " apikey "),) + assert _extract_token(meta) == "apikey" + + +def test_extract_token_returns_none_when_absent() -> None: + assert _extract_token(()) is None + assert _extract_token((("content-type", "application/json"),)) is None + + +# --------------------------------------------------------------------------- +# GrpcAuthInterceptor.intercept_service +# --------------------------------------------------------------------------- + + +def _make_call_details(metadata: tuple = ()) -> grpc.HandlerCallDetails: + details = MagicMock(spec=grpc.HandlerCallDetails) + details.invocation_metadata = metadata + return details + + +def test_interceptor_auth_disabled_passes_through(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_GRPC_AUTH_DISABLED", "true") + sentinel = object() + cont = MagicMock(return_value=sentinel) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details()) + assert result is sentinel + cont.assert_called_once() + + +def test_interceptor_returns_abort_when_not_configured(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_GRPC_API_KEY", raising=False) + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + cont = MagicMock() + handler = GrpcAuthInterceptor().intercept_service(cont, _make_call_details()) + # The returned handler should abort when called + context = MagicMock(spec=grpc.ServicerContext) + handler.unary_unary(None, context) + context.abort.assert_called_once() + code, msg = context.abort.call_args[0] + assert code == grpc.StatusCode.UNAVAILABLE + assert "NW_GRPC_API_KEY" in msg + + +def test_interceptor_returns_unauthenticated_when_no_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "secret") + + cont = MagicMock() + handler = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=())) + context = MagicMock(spec=grpc.ServicerContext) + handler.unary_unary(None, context) + context.abort.assert_called_once() + code, _ = context.abort.call_args[0] + assert code == grpc.StatusCode.UNAUTHENTICATED + + +def test_interceptor_returns_unauthenticated_for_bad_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "correct-key") + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + cont = MagicMock() + meta = (("x-api-key", "wrong-key"),) + handler = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + context = MagicMock(spec=grpc.ServicerContext) + handler.unary_unary(None, context) + context.abort.assert_called_once() + code, _ = context.abort.call_args[0] + assert code == grpc.StatusCode.UNAUTHENTICATED + + +def test_interceptor_wraps_handler_on_valid_token(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "valid-key") + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + inner_handler = grpc.unary_unary_rpc_method_handler(lambda req, ctx: "ok") + cont = MagicMock(return_value=inner_handler) + meta = (("x-api-key", "valid-key"),) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + assert result is not None + assert result.unary_unary is not None + assert result.unary_unary(None, None) == "ok" + + +def test_interceptor_returns_none_when_continuation_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.setenv("NW_GRPC_API_KEY", "valid-key") + monkeypatch.delenv("NW_GRPC_JWT_SECRET", raising=False) + + cont = MagicMock(return_value=None) + meta = (("x-api-key", "valid-key"),) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + assert result is None + + +def test_interceptor_jwt_auth(monkeypatch: pytest.MonkeyPatch) -> None: + from tests.jwt_test_helpers import mint_test_jwt + + monkeypatch.delenv("NW_GRPC_AUTH_DISABLED", raising=False) + monkeypatch.delenv("NW_GRPC_API_KEY", raising=False) + monkeypatch.setenv("NW_GRPC_JWT_SECRET", "test-grpc-secret") + + token = mint_test_jwt({"sub": "svc"}, "test-grpc-secret") + inner_handler = grpc.unary_unary_rpc_method_handler(lambda req, ctx: "ok") + cont = MagicMock(return_value=inner_handler) + meta = (("authorization", f"Bearer {token}"),) + result = GrpcAuthInterceptor().intercept_service(cont, _make_call_details(metadata=meta)) + assert result is not None + assert result.unary_unary(None, None) == "ok" + + +# --------------------------------------------------------------------------- +# ConnectorServiceServicer._invoke_async +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def servicer() -> ConnectorServiceServicer: + return ConnectorServiceServicer() + + +async def test_invoke_rate_limit_exceeded(servicer: ConnectorServiceServicer) -> None: + async def _raise(*_a: Any, **_kw: Any) -> None: + raise RateLimitExceeded("too fast") + + with patch("bindings.grpc_server.server.global_rate_limiter") as mock_rl: + mock_rl.acquire = _raise + req = connector_pb2.InvokeRequest(connector_id="any", action="act") + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "RATE_LIMIT_EXCEEDED" + assert resp.error_category == ErrorCategory.RETRYABLE.value + + +async def test_invoke_unknown_connector_returns_not_available( + servicer: ConnectorServiceServicer, +) -> None: + with patch.object(servicer._factory, "get_for_protocol", return_value=None): + req = connector_pb2.InvokeRequest(connector_id="no_such", action="act") + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "CONNECTOR_NOT_AVAILABLE" + assert resp.error_category == ErrorCategory.BUSINESS.value + assert "no_such" in resp.message + + +async def test_invoke_invalid_json_payload(servicer: ConnectorServiceServicer) -> None: + fake_connector = MagicMock() + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest( + connector_id="x", action="y", payload_json="not-valid-json{" + ) + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "INVALID_JSON" + assert resp.error_category == ErrorCategory.BUSINESS.value + + +async def test_invoke_success_path(servicer: ConnectorServiceServicer) -> None: + fake_connector = MagicMock() + fake_connector.run = AsyncMock( + return_value=ConnectorResponse( + success=True, + data={"result": "hello"}, + trace_id="trace-001", + ) + ) + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest( + connector_id="x", action="greet", payload_json='{"field": "val"}' + ) + resp = await servicer._invoke_async(req) + + assert resp.success is True + assert resp.trace_id == "trace-001" + assert "result" in resp.data_json + + +async def test_invoke_action_injected_into_payload(servicer: ConnectorServiceServicer) -> None: + captured_payload: list[Any] = [] + + async def mock_run(payload: Any, **_: Any) -> ConnectorResponse: + captured_payload.append(payload) + return ConnectorResponse(success=True, trace_id="t1") + + fake_connector = MagicMock() + fake_connector.run = mock_run + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest( + connector_id="x", + action="do_thing", + payload_json='{"some": "data"}', + ) + await servicer._invoke_async(req) + + assert captured_payload[0]["action"] == "do_thing" + + +async def test_invoke_identity_propagated(servicer: ConnectorServiceServicer) -> None: + from node_wire_runtime.caller_identity import build_caller_identity + + identity = build_caller_identity({"sub": "grpc-svc"}, auth_type="grpc_api_key") + captured: list[Any] = [] + + async def mock_run(payload: Any, **kwargs: Any) -> ConnectorResponse: + captured.append(kwargs) + return ConnectorResponse(success=True, trace_id="t2") + + fake_connector = MagicMock() + fake_connector.run = mock_run + with ( + patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector), + patch("bindings.grpc_server.server.get_grpc_caller_identity", return_value=identity), + ): + req = connector_pb2.InvokeRequest(connector_id="x", action="act", payload_json="{}") + await servicer._invoke_async(req) + + assert captured[0]["principal"] == identity.principal + assert captured[0]["tenant_id"] == identity.tenant_id + + +async def test_invoke_no_identity_passes_none(servicer: ConnectorServiceServicer) -> None: + captured: list[Any] = [] + + async def mock_run(payload: Any, **kwargs: Any) -> ConnectorResponse: + captured.append(kwargs) + return ConnectorResponse(success=True, trace_id="t3") + + fake_connector = MagicMock() + fake_connector.run = mock_run + with ( + patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector), + patch("bindings.grpc_server.server.get_grpc_caller_identity", return_value=None), + ): + req = connector_pb2.InvokeRequest(connector_id="x", action="act", payload_json="{}") + await servicer._invoke_async(req) + + assert captured[0]["principal"] is None + assert captured[0]["tenant_id"] is None + + +async def test_invoke_empty_payload_json(servicer: ConnectorServiceServicer) -> None: + fake_connector = MagicMock() + fake_connector.run = AsyncMock(return_value=ConnectorResponse(success=True, trace_id="t4")) + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest(connector_id="x", action="act") + resp = await servicer._invoke_async(req) + assert resp.success is True + + +async def test_invoke_error_response_maps_error_category( + servicer: ConnectorServiceServicer, +) -> None: + fake_connector = MagicMock() + fake_connector.run = AsyncMock( + return_value=ConnectorResponse( + success=False, + error_code="UPSTREAM_TIMEOUT", + error_category=ErrorCategory.RETRYABLE, + message="upstream timed out", + trace_id="t5", + ) + ) + with patch.object(servicer._factory, "get_for_protocol", return_value=fake_connector): + req = connector_pb2.InvokeRequest(connector_id="x", action="act", payload_json="{}") + resp = await servicer._invoke_async(req) + + assert resp.success is False + assert resp.error_code == "UPSTREAM_TIMEOUT" + assert resp.error_category == ErrorCategory.RETRYABLE.value diff --git a/tests/test_grpc_health.py b/tests/test_grpc_health.py new file mode 100644 index 0000000..09191d4 --- /dev/null +++ b/tests/test_grpc_health.py @@ -0,0 +1,61 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for gRPC health service registration in serve().""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from grpc_health.v1 import health_pb2 + + +def test_grpc_health_servicer_is_registered(): + """serve() must register a HealthServicer and set SERVING for both service names.""" + added_servicers = {} + + def fake_add_health(servicer, server): + added_servicers["health"] = servicer + + def fake_add_connector(servicer, server): + added_servicers["connector"] = servicer + + with ( + patch("bindings.grpc_server.server.grpc.server") as mock_grpc_server, + patch( + "bindings.grpc_server.server.connector_pb2_grpc.add_ConnectorServiceServicer_to_server", + fake_add_connector, + ), + patch( + "bindings.grpc_server.server.health_pb2_grpc.add_HealthServicer_to_server", + fake_add_health, + ), + patch("bindings.grpc_server.server.grpc_health.HealthServicer") as mock_health_cls, + patch("bindings.grpc_server.server.configure_grpc_server_port"), + patch("bindings.grpc_server.server._async_runner"), + patch("bindings.grpc_server.server.ConnectorServiceServicer"), + ): + fake_server = MagicMock() + fake_server.wait_for_termination.side_effect = KeyboardInterrupt + mock_grpc_server.return_value = fake_server + + mock_health_instance = MagicMock() + mock_health_cls.return_value = mock_health_instance + + from bindings.grpc_server.server import serve + + try: + serve(port=0) + except KeyboardInterrupt: + pass # serve() raises KeyboardInterrupt on shutdown; suppress to let assertions run + + assert "health" in added_servicers, "HealthServicer was not added to the gRPC server" + + set_calls = {call.args[0]: call.args[1] for call in mock_health_instance.set.call_args_list} + assert "" in set_calls, "Overall health (empty string) not set" + assert "aot.connectors.ConnectorService" in set_calls, "ConnectorService health not set" + assert set_calls[""] == health_pb2.HealthCheckResponse.SERVING + assert ( + set_calls["aot.connectors.ConnectorService"] == health_pb2.HealthCheckResponse.SERVING + ) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..8fa68ca --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,192 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for OpenTelemetry metric recording in BaseConnector.run().""" + +from __future__ import annotations + +from typing import Literal +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import BaseModel + +from node_wire_runtime import BaseConnector, ErrorCategory, ErrorMapper, nw_action +import node_wire_runtime.base_connector as bc_module +import node_wire_runtime.rate_limit as rl_module +import node_wire_runtime.resilience as res_module +from pybreaker import CircuitBreaker + + +class _MIn(BaseModel): + action: Literal["run"] = "run" + + +class _MOut(BaseModel): + done: bool + + +class _MetricConnector(BaseConnector): + connector_id = "metric_test" + output_model = _MOut + + @nw_action("run") + async def run_action(self, params: _MIn, *, trace_id: str) -> _MOut: + return _MOut(done=True) + + +class _MetricFailConnector(BaseConnector): + connector_id = "metric_fail" + output_model = _MOut + + @nw_action("run") + async def run_action(self, params: _MIn, *, trace_id: str) -> _MOut: + raise ValueError("intentional failure") + + +@pytest.mark.asyncio +async def test_invocation_counter_incremented_on_success() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + resp = await connector.run({"action": "run"}) + assert resp.success is True + mock_counter.add.assert_called_once() + call_attrs = mock_counter.add.call_args[1]["attributes"] + assert call_attrs["connector.id"] == "metric_test" + assert call_attrs["success"] is True + assert call_attrs["error_category"] == "none" + + +@pytest.mark.asyncio +async def test_invocation_histogram_recorded_on_success() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + await connector.run({"action": "run"}) + mock_histogram.record.assert_called_once() + duration_val = mock_histogram.record.call_args[0][0] + assert duration_val >= 0 + + +@pytest.mark.asyncio +async def test_invocation_counter_incremented_on_failure() -> None: + connector = _MetricFailConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + resp = await connector.run({"action": "run"}) + assert resp.success is False + mock_counter.add.assert_called_once() + call_attrs = mock_counter.add.call_args[1]["attributes"] + assert call_attrs["success"] is False + assert call_attrs["error_category"] != "none" + + +@pytest.mark.asyncio +async def test_invocation_counter_incremented_on_validation_failure() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + mock_histogram = MagicMock() + with ( + patch.object(bc_module, "_invocation_counter", mock_counter), + patch.object(bc_module, "_invocation_duration", mock_histogram), + ): + resp = await connector.run({"action": "nonexistent"}) + assert resp.success is False + mock_counter.add.assert_called_once() + call_attrs = mock_counter.add.call_args[1]["attributes"] + assert call_attrs["success"] is False + + +@pytest.mark.asyncio +async def test_metric_attributes_include_connector_action() -> None: + connector = _MetricConnector() + mock_counter = MagicMock() + with patch.object(bc_module, "_invocation_counter", mock_counter): + with patch.object(bc_module, "_invocation_duration", MagicMock()): + await connector.run({"action": "run"}) + attrs = mock_counter.add.call_args[1]["attributes"] + assert attrs["connector.action"] == "execute" + + +class _RetryableMetricError(Exception): + pass + + +@pytest.fixture +def _register_retryable() -> None: + ErrorMapper.register(_RetryableMetricError, ErrorCategory.RETRYABLE, code="RETRYABLE_METRIC") + try: + yield + finally: + ErrorMapper._registry.pop(_RetryableMetricError, None) + + +@pytest.mark.asyncio +async def test_retry_counter_incremented_on_retryable_error(_register_retryable: None) -> None: + attempts = {"n": 0} + + @res_module.with_resilience(CircuitBreaker(), connector_id="retry_cx", action="do") + async def flaky(*, trace_id: str = "t") -> str: + attempts["n"] += 1 + if attempts["n"] < 2: + raise _RetryableMetricError("boom") + return "ok" + + mock_counter = MagicMock() + with patch.object(res_module, "_retry_counter", mock_counter): + result = await flaky(trace_id="t1") + + assert result == "ok" + # One retryable failure occurred before the successful attempt. + mock_counter.add.assert_called_once() + attrs = mock_counter.add.call_args[1]["attributes"] + assert attrs["connector.id"] == "retry_cx" + assert attrs["connector.action"] == "do" + assert attrs["error_code"] == "RETRYABLE_METRIC" + + +@pytest.mark.asyncio +async def test_circuit_breaker_rejection_counter_incremented() -> None: + breaker = CircuitBreaker() + breaker.open() + + @res_module.with_resilience(breaker, connector_id="cb_cx", action="do") + async def never_runs(*, trace_id: str = "t") -> str: + return "unreachable" + + mock_counter = MagicMock() + with patch.object(res_module, "_circuit_breaker_rejections", mock_counter): + with pytest.raises(Exception): + await never_runs(trace_id="t1") + + mock_counter.add.assert_called_once() + attrs = mock_counter.add.call_args[1]["attributes"] + assert attrs["connector.id"] == "cb_cx" + assert attrs["connector.action"] == "do" + + +@pytest.mark.asyncio +async def test_rate_limit_rejection_counter_incremented() -> None: + bucket = rl_module.TokenBucket(capacity=1, refill_rate=0) + await bucket.acquire() # consumes the only token + + mock_counter = MagicMock() + with patch.object(rl_module, "_rate_limit_rejections", mock_counter): + with pytest.raises(rl_module.RateLimitExceeded): + await bucket.acquire() + + mock_counter.add.assert_called_once() diff --git a/tests/test_oauth2_extra.py b/tests/test_oauth2_extra.py new file mode 100644 index 0000000..6cdf500 --- /dev/null +++ b/tests/test_oauth2_extra.py @@ -0,0 +1,503 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Additional coverage for OAuth2AuthProvider and ServiceAccountAuthProvider.""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from node_wire_runtime.auth import OAuth2AuthProvider, ServiceAccountAuthProvider +from node_wire_runtime.secrets import SecretProvider +from node_wire_runtime.secrets.base import SecretNotFoundError + + +class _DictSecretProvider(SecretProvider): + def __init__(self, data: dict[str, str]) -> None: + self._data = data + + def get_secret(self, key: str) -> str: + if key not in self._data: + raise SecretNotFoundError(key) + return self._data[key] + + +def _oauth2_client_secret_provider(**extra: str) -> _DictSecretProvider: + data = { + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "super-secret", + **extra, + } + return _DictSecretProvider(data) + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — invalid grant method +# --------------------------------------------------------------------------- + + +def test_oauth2_invalid_grant_method_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + with pytest.raises(ValueError, match="Unsupported grant_method"): + OAuth2AuthProvider( + secret_provider=sp, + grant_method="magic", + token_url_secret="token_url", + client_id_secret="client_id", + ) + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — _resolve_scopes via scopes_secret +# --------------------------------------------------------------------------- + + +def test_resolve_scopes_from_secret() -> None: + sp = _DictSecretProvider( + { + "token_url": "x", + "client_id": "y", + "scope_val": "openid profile", + } + ) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", # not used in resolve_scopes + scopes_secret="scope_val", + ) + assert provider._resolve_scopes() == "openid profile" + + +def test_resolve_scopes_falls_back_to_static_when_secret_missing() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + scopes_secret="missing_key", + scopes=["read", "write"], + ) + assert provider._resolve_scopes() == "read write" + + +def test_resolve_scopes_returns_none_when_none_configured() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_id", + ) + assert provider._resolve_scopes() is None + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — client_secret_post grant +# --------------------------------------------------------------------------- + + +async def test_oauth2_client_secret_post_success() -> None: + sp = _oauth2_client_secret_provider() + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_secret", + scopes=["read"], + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "tok-cs", "expires_in": 3600} + + with patch("node_wire_runtime.auth.oauth2.httpx.AsyncClient") as mock_client_cls: + mock_client = MagicMock() + mock_client.__aenter__ = MagicMock(return_value=mock_client) + mock_client.__aexit__ = MagicMock(return_value=False) + mock_client.post = MagicMock(return_value=mock_response) + mock_client.post = MagicMock(side_effect=lambda *a, **kw: mock_response) + + async def _fake_enter(_: Any) -> Any: + return mock_client + + async def _fake_exit(*_: Any) -> bool: + return False + + mock_client.__aenter__ = _fake_enter + mock_client.__aexit__ = _fake_exit + mock_client_cls.return_value = mock_client + + # Use respx-style: patch the _post_token static method directly + with patch.object( + OAuth2AuthProvider, + "_post_token", + return_value={"access_token": "tok-cs", "expires_in": 3600}, + ): + headers = await provider.get_headers() + + assert headers["Authorization"] == "Bearer tok-cs" + + +async def test_client_secret_post_missing_secret_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + # client_secret_secret deliberately absent + ) + with pytest.raises(ValueError, match="client_secret_secret"): + await provider._fetch_client_secret_post() + + +async def test_oauth2_client_secret_post_with_scope() -> None: + sp = _oauth2_client_secret_provider() + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="client_secret_post", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_secret", + scopes=["openid"], + ) + captured_data: list[dict] = [] + + async def fake_post_token(url: str, data: dict) -> dict: + captured_data.append(data) + return {"access_token": "tok2", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post_token): + await provider.get_headers() + + assert captured_data[0]["scope"] == "openid" + assert captured_data[0]["grant_type"] == "client_credentials" + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — refresh_token grant +# --------------------------------------------------------------------------- + + +async def test_oauth2_refresh_token_success() -> None: + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "sec", + "refresh_token": "rt-123", + } + ) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="refresh_token", + token_url_secret="token_url", + client_id_secret="client_id", + client_secret_secret="client_secret", + refresh_token_secret="refresh_token", + scopes=["profile"], + ) + captured: list[dict] = [] + + async def fake_post(url: str, data: dict) -> dict: + captured.append(data) + return {"access_token": "new-tok", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post): + headers = await provider.get_headers() + + assert headers["Authorization"] == "Bearer new-tok" + assert captured[0]["grant_type"] == "refresh_token" + assert captured[0]["refresh_token"] == "rt-123" + assert captured[0]["client_secret"] == "sec" + assert captured[0]["scope"] == "profile" + + +async def test_oauth2_refresh_token_without_client_secret() -> None: + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "refresh_token": "rt-456", + } + ) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="refresh_token", + token_url_secret="token_url", + client_id_secret="client_id", + refresh_token_secret="refresh_token", + ) + captured: list[dict] = [] + + async def fake_post(url: str, data: dict) -> dict: + captured.append(data) + return {"access_token": "no-secret-tok", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post): + await provider.get_headers() + + assert "client_secret" not in captured[0] + + +async def test_oauth2_refresh_token_missing_secret_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="refresh_token", + token_url_secret="token_url", + client_id_secret="client_id", + # refresh_token_secret absent + ) + with pytest.raises(ValueError, match="refresh_token_secret"): + await provider._fetch_refresh_token() + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — private_key_jwt grant +# --------------------------------------------------------------------------- + + +async def test_oauth2_private_key_jwt_invalid_key_raises() -> None: + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "client", + "private_key": "not-a-valid-pem-key", + "kid": "key-1", + } + ) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + private_key_secret="private_key", + kid_secret="kid", + algorithm="RS384", + ) + with pytest.raises(ValueError, match="private_key_jwt"): + await provider._fetch_private_key_jwt() + + +async def test_oauth2_private_key_jwt_missing_secrets_raises() -> None: + sp = _DictSecretProvider({"token_url": "x", "client_id": "y"}) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + # private_key_secret and kid_secret absent + ) + with pytest.raises(ValueError, match="private_key_secret.*kid_secret"): + await provider._fetch_private_key_jwt() + + +async def test_oauth2_private_key_jwt_success_with_rsa_key() -> None: + """Test private_key_jwt success path with a real RSA key.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + sp = _DictSecretProvider( + { + "token_url": "https://idp.example.com/token", + "client_id": "client-id", + "private_key": private_pem, + "kid": "rsa-key-1", + } + ) + provider = OAuth2AuthProvider( + secret_provider=sp, + grant_method="private_key_jwt", + token_url_secret="token_url", + client_id_secret="client_id", + private_key_secret="private_key", + kid_secret="kid", + algorithm="RS384", + scopes=["system/*.read"], + ) + captured: list[dict] = [] + + async def fake_post(url: str, data: dict) -> dict: + captured.append(data) + return {"access_token": "pkey-tok", "expires_in": 3600} + + with patch.object(OAuth2AuthProvider, "_post_token", side_effect=fake_post): + headers = await provider.get_headers() + + assert headers["Authorization"] == "Bearer pkey-tok" + assert captured[0]["grant_type"] == "client_credentials" + assert "client_assertion" in captured[0] + assert captured[0]["scope"] == "system/*.read" + + +# --------------------------------------------------------------------------- +# OAuth2AuthProvider — _post_token non-200 error +# --------------------------------------------------------------------------- + + +async def test_post_token_non_200_raises() -> None: + transport = httpx.MockTransport(handler=lambda req: httpx.Response(401, text="Unauthorized")) + + with patch( + "node_wire_runtime.auth.oauth2.httpx.AsyncClient", + return_value=httpx.AsyncClient(transport=transport), + ): + with pytest.raises(ValueError, match="HTTP 401"): + await OAuth2AuthProvider._post_token( + "https://fake.example.com/token", + {"grant_type": "client_credentials"}, + ) + + +# --------------------------------------------------------------------------- +# ServiceAccountAuthProvider +# --------------------------------------------------------------------------- + + +async def test_sa_get_headers_returns_empty() -> None: + sp = _DictSecretProvider({"sa_json": '{"type": "service_account"}'}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + headers = await provider.get_headers() + assert headers == {} + + +async def test_sa_get_client_credentials_caches() -> None: + fake_creds = MagicMock() + sp = _DictSecretProvider({"sa_json": json.dumps({"type": "service_account"})}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with patch.object(provider, "_build_credentials", return_value=fake_creds) as mock_build: + creds1 = await provider.get_client_credentials() + creds2 = await provider.get_client_credentials() + + assert creds1 is fake_creds + assert creds2 is fake_creds + mock_build.assert_called_once() + + +async def test_sa_refresh_clears_credentials() -> None: + fake_creds = MagicMock() + sp = _DictSecretProvider({"sa_json": "{}"}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + provider._credentials = fake_creds + + await provider.refresh() + assert provider._credentials is None + + +def test_sa_build_credentials_from_json_string() -> None: + import json as _json + + fake_creds = MagicMock() + sa_info = { + "type": "service_account", + "project_id": "proj", + "private_key_id": "key1", + "private_key": "-----BEGIN RSA PRIVATE KEY-----\nfake\n-----END RSA PRIVATE KEY-----\n", + "client_email": "svc@proj.iam.gserviceaccount.com", + "client_id": "12345", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + } + sp = _DictSecretProvider({"sa_json": _json.dumps(sa_info)}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + return_value=fake_creds, + ): + creds = provider._build_credentials() + assert creds is fake_creds + + +def test_sa_build_credentials_from_file_path(tmp_path: Any) -> None: + import json as _json + + sa_info = { + "type": "service_account", + "project_id": "proj", + "private_key_id": "key1", + "private_key": "fake-key", + "client_email": "svc@proj.iam.gserviceaccount.com", + "client_id": "12345", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + } + sa_file = tmp_path / "sa.json" + sa_file.write_text(_json.dumps(sa_info)) + + fake_creds = MagicMock() + sp = _DictSecretProvider({"sa_json": str(sa_file)}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with patch( + "google.oauth2.service_account.Credentials.from_service_account_file", + return_value=fake_creds, + ): + creds = provider._build_credentials() + assert creds is fake_creds + + +def test_sa_build_credentials_file_not_found(tmp_path: Any) -> None: + sp = _DictSecretProvider({"sa_json": "/nonexistent/path/sa.json"}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with pytest.raises(ValueError, match="not found at path"): + provider._build_credentials() + + +def test_sa_build_credentials_invalid_json_info() -> None: + import json as _json + + bad_info = {"type": "service_account"} # missing required fields + sp = _DictSecretProvider({"sa_json": _json.dumps(bad_info)}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + with pytest.raises((ValueError, Exception)): + provider._build_credentials() + + +def test_sa_build_credentials_google_auth_not_installed() -> None: + import sys + + sp = _DictSecretProvider({"sa_json": '{"type": "service_account"}'}) + provider = ServiceAccountAuthProvider(secret_provider=sp, sa_json_secret="sa_json") + + original = sys.modules.get("google.oauth2") + original_sa = sys.modules.get("google.oauth2.service_account") + try: + sys.modules["google.oauth2"] = None # type: ignore[assignment] + sys.modules["google.oauth2.service_account"] = None # type: ignore[assignment] + with pytest.raises(ImportError, match="google-auth"): + provider._build_credentials() + finally: + if original is None: + sys.modules.pop("google.oauth2", None) + else: + sys.modules["google.oauth2"] = original + if original_sa is None: + sys.modules.pop("google.oauth2.service_account", None) + else: + sys.modules["google.oauth2.service_account"] = original_sa diff --git a/tests/test_observability.py b/tests/test_observability.py index f74feea..e7a6bd4 100644 --- a/tests/test_observability.py +++ b/tests/test_observability.py @@ -60,10 +60,14 @@ def _observability_test_patches(): with _ensure_traceloop_stub_modules(): with ( patch("opentelemetry.trace.set_tracer_provider"), + patch("opentelemetry.metrics.set_meter_provider"), patch("node_wire_runtime.observability.OTLPSpanExporter") as span_exp, patch("node_wire_runtime.observability.OTLPLogExporter") as log_exp, + patch("node_wire_runtime.observability.OTLPMetricExporter") as metric_exp, patch("node_wire_runtime.observability.BatchSpanProcessor"), patch("node_wire_runtime.observability.BatchLogRecordProcessor"), + patch("node_wire_runtime.observability.PeriodicExportingMetricReader"), + patch("node_wire_runtime.observability.MeterProvider"), patch("node_wire_runtime.observability.set_logger_provider"), patch( "node_wire_runtime.observability.LoggingHandler", @@ -72,23 +76,24 @@ def _observability_test_patches(): patch("traceloop.sdk.Traceloop") as mock_tl, ): mock_tl.init = MagicMock() - yield span_exp, log_exp, mock_tl + yield span_exp, log_exp, mock_tl, metric_exp def test_init_observability_idempotent() -> None: """Second call should not reconfigure exporters.""" - with _observability_test_patches() as (span_exp, log_exp, _mock_tl): + with _observability_test_patches() as (span_exp, log_exp, _mock_tl, metric_exp): obs.init_observability("app-a") obs.init_observability("app-b") assert span_exp.call_count == 1 assert log_exp.call_count == 1 + assert metric_exp.call_count == 1 def test_init_observability_invalid_sampling_ratio_logs_warning( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: monkeypatch.setenv("AOT_TRACING_SAMPLING_RATIO", "not-a-number") - with _observability_test_patches(): + with _observability_test_patches() as _: with caplog.at_level(logging.WARNING, logger="runtime.observability"): obs.init_observability("app-warn") assert any("Invalid AOT_TRACING_SAMPLING_RATIO" in r.message for r in caplog.records) @@ -98,11 +103,12 @@ def test_init_observability_otel_headers_passed_to_exporters( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_HEADERS", "key=value,foo=bar") - with _observability_test_patches() as (span_exp, log_exp, _mock_tl): + with _observability_test_patches() as (span_exp, log_exp, _mock_tl, metric_exp): obs.init_observability("app-h") expected_headers = {"key": "value", "foo": "bar"} assert span_exp.call_args.kwargs.get("headers") == expected_headers assert log_exp.call_args.kwargs.get("headers") == expected_headers + assert metric_exp.call_args.kwargs.get("headers") == expected_headers def test_otel_context_filter_sets_empty_trace_when_no_span() -> None: @@ -116,7 +122,7 @@ def test_otel_context_filter_sets_empty_trace_when_no_span() -> None: def test_init_observability_installs_sanitizing_log_filter() -> None: - with _observability_test_patches(): + with _observability_test_patches() as _: obs.init_observability("app-filter") root = logging.getLogger() assert any(isinstance(flt, SanitizingLogFilter) for flt in root.filters) @@ -127,8 +133,30 @@ def test_init_observability_traceloop_failure_does_not_raise( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("TRACELOOP_API_KEY", "test-key") - with _observability_test_patches() as (_s, _l, mock_tl): + with _observability_test_patches() as (_s, _l, mock_tl, _m): mock_tl.init = MagicMock(side_effect=RuntimeError("traceloop unavailable")) with caplog.at_level(logging.WARNING, logger="runtime.observability"): obs.init_observability("app-tl") assert any("Failed to initialize Traceloop" in r.message for r in caplog.records) + + +def test_init_observability_metric_exporter_initialized() -> None: + """MeterProvider and OTLPMetricExporter should be set up on first call.""" + with _observability_test_patches() as (_span, _log, _tl, metric_exp): + with patch("node_wire_runtime.observability.MeterProvider") as meter_provider: + obs.init_observability("app-metrics") + assert metric_exp.call_count == 1 + # The reader must be passed via the SDK's real kwarg name (`metric_readers`); + # `readers=` raises TypeError at runtime and would silently drop metrics export. + assert "metric_readers" in meter_provider.call_args.kwargs + + +def test_init_observability_invalid_metric_interval_logs_warning( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + monkeypatch.setenv("AOT_METRIC_EXPORT_INTERVAL_MS", "not-a-number") + with _observability_test_patches() as _: + with caplog.at_level(logging.WARNING, logger="runtime.observability"): + obs.init_observability("app-metric-warn") + assert any("Invalid AOT_METRIC_EXPORT_INTERVAL_MS" in r.message for r in caplog.records) diff --git a/tests/test_runtime_misc.py b/tests/test_runtime_misc.py new file mode 100644 index 0000000..295dd33 --- /dev/null +++ b/tests/test_runtime_misc.py @@ -0,0 +1,603 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for rate_limit, streaming, log_sanitization gaps, observability +wrappers, and connector_registry error paths.""" + +from __future__ import annotations + +import logging +from importlib.metadata import EntryPoint +from unittest.mock import MagicMock, patch + +import pytest + +import node_wire_runtime.log_sanitization as ls + +from node_wire_runtime import connector_registry +from node_wire_runtime.log_sanitization import ( + REDACTED, + SanitizingLogFilter, + _redact_sensitive_string_arg, + install_sanitizing_log_filter, + is_sensitive_key, + sanitize_log_record, + sanitize_mapping, + sanitize_value, +) +from node_wire_runtime.rate_limit import RateLimitExceeded, TokenBucket +from node_wire_runtime.streaming import ( + BufferedStreamIterator, + resolve_stream_buffer_ms, + stream_completion_log, +) + + +# --------------------------------------------------------------------------- +# TokenBucket — acquire paths +# --------------------------------------------------------------------------- + + +async def test_token_bucket_acquire_success() -> None: + bucket = TokenBucket(capacity=10, refill_rate=1) + await bucket.acquire(1) + assert bucket.tokens == 9 + + +async def test_token_bucket_acquire_multiple() -> None: + bucket = TokenBucket(capacity=10, refill_rate=1) + await bucket.acquire(5) + assert bucket.tokens == 5 + + +async def test_token_bucket_acquire_raises_when_exhausted() -> None: + bucket = TokenBucket(capacity=1, refill_rate=0) + await bucket.acquire(1) + with pytest.raises(RateLimitExceeded): + await bucket.acquire(1) + + +async def test_token_bucket_refills_over_time() -> None: + bucket = TokenBucket(capacity=10, refill_rate=1000) + await bucket.acquire(10) + assert bucket.tokens == 0 + # Artificially push last_refill back to simulate elapsed time + bucket.last_refill -= 1.0 # 1 second elapsed → 1000 tokens added, capped at 10 + await bucket.acquire(5) + assert bucket.tokens >= 4 # at least 5 refilled, minus 5 acquired + + +async def test_token_bucket_exceed_raises_with_message() -> None: + bucket = TokenBucket(capacity=0, refill_rate=0) + with pytest.raises(RateLimitExceeded, match="rate limit exceeded"): + await bucket.acquire(1) + + +# --------------------------------------------------------------------------- +# resolve_stream_buffer_ms +# --------------------------------------------------------------------------- + + +def test_resolve_stream_buffer_ms_override_used() -> None: + assert resolve_stream_buffer_ms(200) == 200 + + +def test_resolve_stream_buffer_ms_override_clamped_to_zero() -> None: + assert resolve_stream_buffer_ms(-50) == 0 + + +def test_resolve_stream_buffer_ms_override_clamped_to_max() -> None: + assert resolve_stream_buffer_ms(99999) == 30000 + + +def test_resolve_stream_buffer_ms_env_var(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_STREAM_BUFFER_MS", "500") + assert resolve_stream_buffer_ms() == 500 + + +def test_resolve_stream_buffer_ms_env_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_STREAM_BUFFER_MS", "notanumber") + assert resolve_stream_buffer_ms() == 0 + + +# --------------------------------------------------------------------------- +# stream_completion_log +# --------------------------------------------------------------------------- + + +def test_stream_completion_log_success(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.INFO, logger="runtime.streaming"): + stream_completion_log("tid-1", True, connector_id="smtp", action="send") + assert any("completed" in r.message.lower() for r in caplog.records) + + +def test_stream_completion_log_failure(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="runtime.streaming"): + stream_completion_log("tid-2", False, connector_id="smtp", action="send") + assert any("failed" in r.message.lower() for r in caplog.records) + + +# --------------------------------------------------------------------------- +# BufferedStreamIterator +# --------------------------------------------------------------------------- + + +async def _collect_async(gen) -> list: + items = [] + async for item in gen: + items.append(item) + return items + + +async def test_buffered_stream_no_buffer() -> None: + async def source(): + for i in range(3): + yield {"i": i} + + items = await _collect_async(BufferedStreamIterator(source(), buffer_ms=0, trace_id="t1")) + assert items == [{"i": 0}, {"i": 1}, {"i": 2}] + + +async def test_buffered_stream_with_buffer() -> None: + async def source(): + for i in range(4): + yield {"i": i} + + items = await _collect_async(BufferedStreamIterator(source(), buffer_ms=1000, trace_id="t2")) + assert items == [{"i": 0}, {"i": 1}, {"i": 2}, {"i": 3}] + + +async def test_buffered_stream_logs_on_failure() -> None: + async def bad_source(): + yield {"i": 0} + raise RuntimeError("upstream error") + + with pytest.raises(RuntimeError, match="upstream error"): + await _collect_async(BufferedStreamIterator(bad_source(), buffer_ms=0, trace_id="t3")) + + +async def test_buffered_stream_flush_mid_buffer(monkeypatch: pytest.MonkeyPatch) -> None: + """Force a mid-stream flush by rolling back last_flush time.""" + import time + + call_count = 0 + real_monotonic = time.monotonic + + def patched_monotonic() -> float: + nonlocal call_count + call_count += 1 + # First call (setup) returns normal; subsequent calls return far future + if call_count <= 2: + return real_monotonic() + return real_monotonic() + 10.0 + + async def source(): + for i in range(3): + yield {"i": i} + + with patch("node_wire_runtime.streaming.time.monotonic", side_effect=patched_monotonic): + items = await _collect_async(BufferedStreamIterator(source(), buffer_ms=100, trace_id="t4")) + assert len(items) == 3 + + +# --------------------------------------------------------------------------- +# log_sanitization gaps +# --------------------------------------------------------------------------- + + +def test_is_sensitive_key_substring_match() -> None: + assert is_sensitive_key("patient_id") is True + assert is_sensitive_key("my_ssn") is True + assert is_sensitive_key("email_address") is True + + +def test_is_sensitive_key_always_redact() -> None: + assert is_sensitive_key("search_params") is True + assert is_sensitive_key("body") is True + assert is_sensitive_key("payload") is True + + +def test_is_sensitive_key_safe_key() -> None: + assert is_sensitive_key("connector_id") is False + assert is_sensitive_key("action") is False + + +def test_sanitize_value_nested_dict() -> None: + val = {"safe_key": {"nested": "value"}, "patient_id": "123"} + result = sanitize_value("outer", val) + assert result["patient_id"] == REDACTED + assert result["safe_key"] == {"nested": "value"} + + +def test_sanitize_value_list() -> None: + # List elements are recursively processed with the same key + result = sanitize_value("items", [{"secret": "s"}, {"ok": "v"}]) + assert isinstance(result, list) + # "secret" key inside the nested dict gets redacted + assert result[0]["secret"] == REDACTED + assert result[1]["ok"] == "v" + + +def test_sanitize_value_tuple() -> None: + result = sanitize_value("data", ("safe", "value")) + assert isinstance(result, tuple) + assert result == ("safe", "value") + + +def test_sanitize_value_str_body_key_redacted() -> None: + result = sanitize_value("body", "any string content") + assert result == REDACTED + + +def test_sanitize_value_str_non_body_key_passthrough() -> None: + result = sanitize_value("action", "send_email") + assert result == "send_email" + + +def test_sanitize_mapping_redacts_sensitive() -> None: + result = sanitize_mapping({"patient": "Smith", "connector_id": "smtp"}) + assert result["patient"] == REDACTED + assert result["connector_id"] == "smtp" + + +def test_redact_sensitive_string_arg_long_string() -> None: + long_str = "x" * 101 + assert _redact_sensitive_string_arg(long_str) == REDACTED + + +def test_redact_sensitive_string_arg_phi_marker() -> None: + assert _redact_sensitive_string_arg("contains_phi_marker_here") == REDACTED + + +def test_redact_sensitive_string_arg_safe() -> None: + assert _redact_sensitive_string_arg("connector_id=smtp") == "connector_id=smtp" + + +def test_sanitize_log_record_dict_args() -> None: + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="%(patient)s %(action)s", + args={"patient": "Smith", "action": "send"}, + exc_info=None, + ) + sanitize_log_record(record) + assert record.args["patient"] == REDACTED # type: ignore[index] + assert record.args["action"] == "send" # type: ignore[index] + + +def test_sanitize_log_record_tuple_args() -> None: + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="%s %s", + args=("phi_marker_data", "safe"), + exc_info=None, + ) + sanitize_log_record(record) + assert record.args[0] == REDACTED # type: ignore[index] + assert record.args[1] == "safe" + + +def test_sanitize_log_record_no_args() -> None: + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="plain message", + args=(), + exc_info=None, + ) + sanitize_log_record(record) # should not raise + + +def test_install_sanitizing_filter_finds_existing_without_flag() -> None: + """When a SanitizingLogFilter is already present but the flag is False, + it should detect it and not add a duplicate.""" + original_flag = ls._SANITIZING_FILTER_INSTALLED + original_filters = list(logging.getLogger().filters) + try: + ls._SANITIZING_FILTER_INSTALLED = False + root = logging.getLogger() + # Add filter directly without going through install_sanitizing_log_filter + existing = SanitizingLogFilter() + root.addFilter(existing) + install_sanitizing_log_filter() + count = sum(1 for f in root.filters if isinstance(f, SanitizingLogFilter)) + assert count >= 1 + assert ls._SANITIZING_FILTER_INSTALLED is True + finally: + ls._SANITIZING_FILTER_INSTALLED = original_flag + # Restore original filters + for f in list(logging.getLogger().filters): + logging.getLogger().removeFilter(f) + for f in original_filters: + logging.getLogger().addFilter(f) + + +# --------------------------------------------------------------------------- +# observability: SanitizingSpanExporter / SanitizingLogExporter +# --------------------------------------------------------------------------- + + +def test_sanitizing_span_exporter_redacts_sensitive_attributes() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock() + delegate.export.return_value = None + exporter = SanitizingSpanExporter(delegate) + + fake_span = MagicMock() + fake_span._attributes = {"patient_id": "12345", "connector_id": "smtp"} + exporter.export([fake_span]) + + delegate.export.assert_called_once() + assert fake_span._attributes["patient_id"] == REDACTED + assert fake_span._attributes["connector_id"] == "smtp" + + +def test_sanitizing_span_exporter_shutdown() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock() + exporter = SanitizingSpanExporter(delegate) + exporter.shutdown() + delegate.shutdown.assert_called_once() + + +def test_sanitizing_span_exporter_force_flush() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock() + delegate.force_flush.return_value = True + exporter = SanitizingSpanExporter(delegate) + result = exporter.force_flush(1000) + assert result is True + delegate.force_flush.assert_called_once_with(1000) + + +def test_sanitizing_span_exporter_force_flush_no_delegate_method() -> None: + from node_wire_runtime.observability import SanitizingSpanExporter + + delegate = MagicMock(spec=[]) # no force_flush attribute + exporter = SanitizingSpanExporter(delegate) + result = exporter.force_flush() + assert result is True + + +def test_sanitizing_log_exporter_redacts_attributes() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock() + exporter = SanitizingLogExporter(delegate) + + fake_record = MagicMock() + fake_record.attributes = {"email": "user@example.com", "action": "send"} + exporter.export([fake_record]) + + delegate.export.assert_called_once() + assert fake_record.attributes["email"] == REDACTED + assert fake_record.attributes["action"] == "send" + + +def test_sanitizing_log_exporter_shutdown() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock() + exporter = SanitizingLogExporter(delegate) + exporter.shutdown() + delegate.shutdown.assert_called_once() + + +def test_sanitizing_log_exporter_force_flush() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock() + delegate.force_flush.return_value = True + exporter = SanitizingLogExporter(delegate) + result = exporter.force_flush(500) + assert result is True + + +def test_sanitizing_log_exporter_force_flush_no_delegate() -> None: + from node_wire_runtime.observability import SanitizingLogExporter + + delegate = MagicMock(spec=[]) + exporter = SanitizingLogExporter(delegate) + result = exporter.force_flush() + assert result is True + + +def test_otel_context_filter_with_valid_span() -> None: + from opentelemetry.sdk.trace import TracerProvider + + from node_wire_runtime.observability import _OtelContextFilter + + flt = _OtelContextFilter() + provider = TracerProvider() + tracer = provider.get_tracer("test") + + record = logging.LogRecord("test", logging.INFO, __file__, 1, "msg", (), None) + with tracer.start_as_current_span("test-span"): + result = flt.filter(record) + + assert result is True + assert len(record.otel_trace_id) == 32 + assert len(record.otel_span_id) == 16 + + +# --------------------------------------------------------------------------- +# connector_registry — error paths +# --------------------------------------------------------------------------- + + +def test_parse_allowed_names_empty_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NW_ALLOWED_CONNECTORS", raising=False) + result = connector_registry._parse_allowed_names() + assert result == set() + + +def test_parse_allowed_names_empty_string(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", " ") + result = connector_registry._parse_allowed_names() + assert result == set() + + +def test_registration_module_missing_is_silently_skipped( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ModuleNotFoundError with name == reg_name is silently ignored.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn") + ep = EntryPoint(name="myconn", value="node_wire_myconn.logic", group="node_wire.connectors") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_myconn.logic": + return MagicMock() + err = ModuleNotFoundError(f"No module named '{name}'") + err.name = name # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[ep]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + ): + loaded = connector_registry.auto_register() + + assert "node_wire_myconn.logic" in loaded + assert "node_wire_myconn.registration" not in loaded + + +def test_registration_dep_error_is_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + """ModuleNotFoundError for a dep inside registration module is re-raised.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn2") + ep = EntryPoint(name="myconn2", value="node_wire_myconn2.logic", group="node_wire.connectors") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_myconn2.logic": + return MagicMock() + err = ModuleNotFoundError("No module named 'missing_dep'") + err.name = "missing_dep" # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[ep]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ModuleNotFoundError), + ): + connector_registry.auto_register() + + +def test_registration_unexpected_exception_is_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "myconn3") + ep = EntryPoint(name="myconn3", value="node_wire_myconn3.logic", group="node_wire.connectors") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_myconn3.logic": + return MagicMock() + raise RuntimeError("unexpected error in registration module") + + with ( + patch.object(connector_registry, "entry_points", return_value=[ep]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(RuntimeError, match="unexpected error"), + ): + connector_registry.auto_register() + + +def test_fallback_logic_missing_dep_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + """Fallback logic module raises ModuleNotFoundError for an internal dep → re-raise.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "badconn") + + def fake_import(name: str) -> MagicMock: + err = ModuleNotFoundError("No module named 'some_dep'") + err.name = "some_dep" # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ModuleNotFoundError), + ): + connector_registry.auto_register() + + +def test_fallback_logic_not_found_is_skipped(monkeypatch: pytest.MonkeyPatch) -> None: + """Fallback logic module genuinely absent → skip, no error.""" + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "absent_connector") + + def fake_import(name: str) -> MagicMock: + err = ModuleNotFoundError(f"No module named '{name}'") + err.name = name # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + ): + loaded = connector_registry.auto_register() + + assert loaded == [] + + +def test_fallback_registration_missing_silently_skipped(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "noregconn") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_noregconn.logic": + return MagicMock() + err = ModuleNotFoundError(f"No module named '{name}'") + err.name = name # type: ignore[attr-defined] + raise err + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + ): + loaded = connector_registry.auto_register() + + assert "node_wire_noregconn.logic" in loaded + assert "node_wire_noregconn.registration" not in loaded + + +def test_fallback_registration_dep_error_reraised(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "regerrconn") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_regerrconn.logic": + return MagicMock() + if name == "node_wire_regerrconn.registration": + err = ModuleNotFoundError("No module named 'dep_x'") + err.name = "dep_x" # type: ignore[attr-defined] + raise err + raise ImportError(f"unexpected: {name}") + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ModuleNotFoundError), + ): + connector_registry.auto_register() + + +def test_fallback_registration_unexpected_exception_reraised( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("NW_ALLOWED_CONNECTORS", "excconn") + + def fake_import(name: str) -> MagicMock: + if name == "node_wire_excconn.logic": + return MagicMock() + raise ValueError("unexpected registration failure") + + with ( + patch.object(connector_registry, "entry_points", return_value=[]), + patch.object(connector_registry.importlib, "import_module", side_effect=fake_import), + pytest.raises(ValueError, match="unexpected registration failure"), + ): + connector_registry.auto_register() diff --git a/tests/test_secrets_providers.py b/tests/test_secrets_providers.py new file mode 100644 index 0000000..2521faa --- /dev/null +++ b/tests/test_secrets_providers.py @@ -0,0 +1,476 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for the cloud secret providers (AWS, Azure, GCP, Vault). + +The cloud SDKs are not installed in the test environment, so each fixture +injects stub modules into ``sys.modules`` before importing the provider +module, and removes the provider module afterwards so other tests see a +clean import state. +""" + +from __future__ import annotations + +import importlib +import json +import sys +import types +from types import SimpleNamespace + +import pytest + +from node_wire_runtime.secrets.base import ( + EnvSecretProvider, + SecretNotFoundError, + SecretProvider, + SecretProviderError, +) +from node_wire_runtime.secrets.chained import ChainedSecretProvider + + +# --------------------------------------------------------------------------- +# AWS +# --------------------------------------------------------------------------- + + +@pytest.fixture +def aws(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.aws against stub boto3/botocore.""" + state = SimpleNamespace( + secret_string="{}", + error=None, + service=None, + region=None, + requested_secret_id=None, + ) + + class BotoCoreError(Exception): + pass + + class ClientError(Exception): + def __init__(self, error_response: dict, operation_name: str) -> None: + super().__init__(error_response, operation_name) + self.response = error_response + self.operation_name = operation_name + + exc_mod = types.ModuleType("botocore.exceptions") + exc_mod.BotoCoreError = BotoCoreError + exc_mod.ClientError = ClientError + botocore_mod = types.ModuleType("botocore") + botocore_mod.exceptions = exc_mod + + class FakeClient: + def get_secret_value(self, SecretId: str) -> dict: + state.requested_secret_id = SecretId + if state.error is not None: + raise state.error + return {"SecretString": state.secret_string} + + boto3_mod = types.ModuleType("boto3") + + def client(service: str, region_name: str | None = None) -> FakeClient: + state.service = service + state.region = region_name + return FakeClient() + + boto3_mod.client = client + + monkeypatch.setitem(sys.modules, "boto3", boto3_mod) + monkeypatch.setitem(sys.modules, "botocore", botocore_mod) + monkeypatch.setitem(sys.modules, "botocore.exceptions", exc_mod) + sys.modules.pop("node_wire_runtime.secrets.aws", None) + mod = importlib.import_module("node_wire_runtime.secrets.aws") + yield mod, state, exc_mod + sys.modules.pop("node_wire_runtime.secrets.aws", None) + + +def test_aws_success_and_missing_key(aws) -> None: + mod, state, _exc = aws + state.secret_string = json.dumps({"epic_client_id": "abc123"}) + + provider = mod.AwsSecretsManagerProvider("my-bundle", region="eu-west-1") + + assert state.service == "secretsmanager" + assert state.region == "eu-west-1" + assert state.requested_secret_id == "my-bundle" + assert provider.get_secret("epic_client_id") == "abc123" + with pytest.raises(SecretNotFoundError): + provider.get_secret("nope") + + +def test_aws_resource_not_found_maps_to_secret_not_found(aws) -> None: + mod, state, exc = aws + state.error = exc.ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "GetSecretValue" + ) + with pytest.raises(SecretNotFoundError): + mod.AwsSecretsManagerProvider("missing-bundle") + + +def test_aws_other_client_error_maps_to_provider_error(aws) -> None: + mod, state, exc = aws + state.error = exc.ClientError({"Error": {"Code": "AccessDeniedException"}}, "GetSecretValue") + with pytest.raises(SecretProviderError, match="AccessDeniedException"): + mod.AwsSecretsManagerProvider("forbidden-bundle") + + +def test_aws_botocore_error_maps_to_provider_error(aws) -> None: + mod, state, exc = aws + state.error = exc.BotoCoreError("connection refused") + with pytest.raises(SecretProviderError, match="AWS connection error"): + mod.AwsSecretsManagerProvider("unreachable-bundle") + + +def test_aws_import_error_has_install_hint(monkeypatch: pytest.MonkeyPatch) -> None: + """With boto3 absent the module import fails with an actionable hint.""" + monkeypatch.delitem(sys.modules, "boto3", raising=False) + sys.modules.pop("node_wire_runtime.secrets.aws", None) + try: + with pytest.raises(ImportError, match=r"node-wire-runtime\[aws\]"): + importlib.import_module("node_wire_runtime.secrets.aws") + finally: + sys.modules.pop("node_wire_runtime.secrets.aws", None) + + +# --------------------------------------------------------------------------- +# Azure +# --------------------------------------------------------------------------- + + +@pytest.fixture +def azure(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.azure against stub azure-* SDKs.""" + state = SimpleNamespace( + secrets={}, # azure-name -> value + error=None, + credential_error=None, + vault_url=None, + requested_names=[], + ) + + class ResourceNotFoundError(Exception): + pass + + class HttpResponseError(Exception): + pass + + class DefaultAzureCredential: + def __init__(self) -> None: + if state.credential_error is not None: + raise state.credential_error + + class SecretClient: + def __init__(self, vault_url: str, credential: object) -> None: + state.vault_url = vault_url + + def get_secret(self, name: str) -> SimpleNamespace: + state.requested_names.append(name) + if state.error is not None: + raise state.error + if name not in state.secrets: + raise ResourceNotFoundError(name) + return SimpleNamespace(value=state.secrets[name]) + + identity_mod = types.ModuleType("azure.identity") + identity_mod.DefaultAzureCredential = DefaultAzureCredential + kv_secrets_mod = types.ModuleType("azure.keyvault.secrets") + kv_secrets_mod.SecretClient = SecretClient + core_exc_mod = types.ModuleType("azure.core.exceptions") + core_exc_mod.ResourceNotFoundError = ResourceNotFoundError + core_exc_mod.HttpResponseError = HttpResponseError + + azure_mod = types.ModuleType("azure") + kv_mod = types.ModuleType("azure.keyvault") + core_mod = types.ModuleType("azure.core") + + monkeypatch.setitem(sys.modules, "azure", azure_mod) + monkeypatch.setitem(sys.modules, "azure.identity", identity_mod) + monkeypatch.setitem(sys.modules, "azure.keyvault", kv_mod) + monkeypatch.setitem(sys.modules, "azure.keyvault.secrets", kv_secrets_mod) + monkeypatch.setitem(sys.modules, "azure.core", core_mod) + monkeypatch.setitem(sys.modules, "azure.core.exceptions", core_exc_mod) + sys.modules.pop("node_wire_runtime.secrets.azure", None) + mod = importlib.import_module("node_wire_runtime.secrets.azure") + yield mod, state, core_exc_mod + sys.modules.pop("node_wire_runtime.secrets.azure", None) + + +def test_azure_success_maps_underscores_to_hyphens(azure) -> None: + mod, state, _exc = azure + state.secrets["epic-client-id"] = "azure-value" + + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + + assert state.vault_url == "https://kv.example.vault.azure.net" + assert provider.get_secret("epic_client_id") == "azure-value" + assert state.requested_names == ["epic-client-id"] + + +def test_azure_missing_secret_maps_to_secret_not_found(azure) -> None: + mod, _state, _exc = azure + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + with pytest.raises(SecretNotFoundError): + provider.get_secret("absent_key") + + +def test_azure_none_value_maps_to_secret_not_found(azure) -> None: + mod, state, _exc = azure + state.secrets["empty-key"] = None + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + with pytest.raises(SecretNotFoundError): + provider.get_secret("empty_key") + + +def test_azure_http_error_maps_to_provider_error(azure) -> None: + mod, state, exc = azure + provider = mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + state.error = exc.HttpResponseError("503 upstream unavailable") + with pytest.raises(SecretProviderError, match="Azure Key Vault HTTP error"): + provider.get_secret("any_key") + + +def test_azure_init_failure_maps_to_provider_error(azure) -> None: + mod, state, _exc = azure + state.credential_error = RuntimeError("no credential chain available") + with pytest.raises(SecretProviderError, match="Failed to initialise"): + mod.AzureKeyVaultProvider("https://kv.example.vault.azure.net") + + +# --------------------------------------------------------------------------- +# GCP +# --------------------------------------------------------------------------- + + +@pytest.fixture +def gcp(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.gcp against stub google SDK modules.""" + state = SimpleNamespace(payload=b"{}", error=None, requested_name=None) + + class GoogleAPICallError(Exception): + pass + + class NotFound(GoogleAPICallError): + pass + + class SecretManagerServiceClient: + def access_secret_version(self, request: dict) -> SimpleNamespace: + state.requested_name = request["name"] + if state.error is not None: + raise state.error + return SimpleNamespace(payload=SimpleNamespace(data=state.payload)) + + sm_mod = types.ModuleType("google.cloud.secretmanager") + sm_mod.SecretManagerServiceClient = SecretManagerServiceClient + api_exc_mod = types.ModuleType("google.api_core.exceptions") + api_exc_mod.NotFound = NotFound + api_exc_mod.GoogleAPICallError = GoogleAPICallError + api_core_mod = types.ModuleType("google.api_core") + api_core_mod.exceptions = api_exc_mod + + # "google" and "google.cloud" may exist as real namespace packages; only + # provide them if absent so we never clobber installed google packages. + if "google" not in sys.modules: + monkeypatch.setitem(sys.modules, "google", types.ModuleType("google")) + if "google.cloud" not in sys.modules: + monkeypatch.setitem(sys.modules, "google.cloud", types.ModuleType("google.cloud")) + monkeypatch.setitem(sys.modules, "google.cloud.secretmanager", sm_mod) + monkeypatch.setitem(sys.modules, "google.api_core", api_core_mod) + monkeypatch.setitem(sys.modules, "google.api_core.exceptions", api_exc_mod) + sys.modules.pop("node_wire_runtime.secrets.gcp", None) + mod = importlib.import_module("node_wire_runtime.secrets.gcp") + yield mod, state, api_exc_mod + sys.modules.pop("node_wire_runtime.secrets.gcp", None) + + +def test_gcp_success_and_missing_key(gcp) -> None: + mod, state, _exc = gcp + state.payload = json.dumps({"db_password": "gcp-value"}).encode("utf-8") + + provider = mod.GcpSecretManagerProvider("proj-1", "bundle", version="7") + + assert state.requested_name == "projects/proj-1/secrets/bundle/versions/7" + assert provider.get_secret("db_password") == "gcp-value" + with pytest.raises(SecretNotFoundError): + provider.get_secret("nope") + + +def test_gcp_default_version_is_latest(gcp) -> None: + mod, state, _exc = gcp + mod.GcpSecretManagerProvider("proj-1", "bundle") + assert state.requested_name == "projects/proj-1/secrets/bundle/versions/latest" + + +def test_gcp_not_found_maps_to_secret_not_found(gcp) -> None: + mod, state, exc = gcp + state.error = exc.NotFound("no such secret") + with pytest.raises(SecretNotFoundError): + mod.GcpSecretManagerProvider("proj-1", "missing") + + +def test_gcp_api_error_maps_to_provider_error(gcp) -> None: + mod, state, exc = gcp + state.error = exc.GoogleAPICallError("permission denied") + with pytest.raises(SecretProviderError, match="GCP Secret Manager error"): + mod.GcpSecretManagerProvider("proj-1", "forbidden") + + +# --------------------------------------------------------------------------- +# HashiCorp Vault +# --------------------------------------------------------------------------- + + +@pytest.fixture +def vault(monkeypatch: pytest.MonkeyPatch): + """Import node_wire_runtime.secrets.vault against a stub hvac module.""" + state = SimpleNamespace( + data={}, + error=None, + authenticated=True, + url=None, + token=None, + path=None, + mount_point=None, + ) + + class VaultError(Exception): + pass + + class InvalidPath(VaultError): + pass + + class _KvV2: + def read_secret_version(self, path: str, mount_point: str = "secret") -> dict: + state.path = path + state.mount_point = mount_point + if state.error is not None: + raise state.error + return {"data": {"data": state.data}} + + class Client: + def __init__(self, url: str, token: str | None = None) -> None: + state.url = url + state.token = token + self.secrets = SimpleNamespace(kv=SimpleNamespace(v2=_KvV2())) + + def is_authenticated(self) -> bool: + return state.authenticated + + hvac_mod = types.ModuleType("hvac") + hvac_mod.Client = Client + exc_mod = types.ModuleType("hvac.exceptions") + exc_mod.VaultError = VaultError + exc_mod.InvalidPath = InvalidPath + hvac_mod.exceptions = exc_mod + + monkeypatch.setitem(sys.modules, "hvac", hvac_mod) + monkeypatch.setitem(sys.modules, "hvac.exceptions", exc_mod) + sys.modules.pop("node_wire_runtime.secrets.vault", None) + mod = importlib.import_module("node_wire_runtime.secrets.vault") + yield mod, state, exc_mod + sys.modules.pop("node_wire_runtime.secrets.vault", None) + + +def test_vault_success_and_missing_key(vault) -> None: + mod, state, _exc = vault + state.data = {"api_token": "vault-value"} + + provider = mod.HashiCorpVaultProvider( + "apps/node-wire", url="https://vault.internal:8200", token="t-1", mount_point="kv" + ) + + assert state.url == "https://vault.internal:8200" + assert state.token == "t-1" + assert state.path == "apps/node-wire" + assert state.mount_point == "kv" + assert provider.get_secret("api_token") == "vault-value" + with pytest.raises(SecretNotFoundError): + provider.get_secret("nope") + + +def test_vault_unauthenticated_maps_to_provider_error(vault) -> None: + mod, state, _exc = vault + state.authenticated = False + with pytest.raises(SecretProviderError, match="not authenticated"): + mod.HashiCorpVaultProvider("apps/node-wire") + + +def test_vault_invalid_path_maps_to_secret_not_found(vault) -> None: + mod, state, exc = vault + state.error = exc.InvalidPath("no secret at path") + with pytest.raises(SecretNotFoundError): + mod.HashiCorpVaultProvider("apps/missing") + + +def test_vault_error_maps_to_provider_error(vault) -> None: + mod, state, exc = vault + state.error = exc.VaultError("sealed") + with pytest.raises(SecretProviderError, match="Vault error"): + mod.HashiCorpVaultProvider("apps/broken") + + +# --------------------------------------------------------------------------- +# base.py gaps +# --------------------------------------------------------------------------- + + +def test_abstract_get_secret_body_raises_not_implemented() -> None: + class PassThrough(SecretProvider): + def get_secret(self, key: str) -> str: + return super().get_secret(key) # type: ignore[safe-super] + + with pytest.raises(NotImplementedError): + PassThrough().get_secret("anything") + + +def test_env_provider_uppercase_fallback_strips_quotes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("nw_test_upper_key", raising=False) + monkeypatch.setenv("NW_TEST_UPPER_KEY", " 'quoted-value' ") + p = EnvSecretProvider(legacy_empty_on_missing=False) + assert p.get_secret("nw_test_upper_key") == "quoted-value" + + +# --------------------------------------------------------------------------- +# chained.py gaps +# --------------------------------------------------------------------------- + + +class _StaticProvider(SecretProvider): + def __init__(self, data: dict[str, str]) -> None: + self._data = data + + def get_secret(self, key: str) -> str: + try: + return self._data[key] + except KeyError: + raise SecretNotFoundError(key) + + +class _BrokenProvider(SecretProvider): + def get_secret(self, key: str) -> str: + raise SecretProviderError("IAM is on fire") + + +def test_chained_requires_at_least_one_provider() -> None: + with pytest.raises(ValueError, match="at least one provider"): + ChainedSecretProvider() + + +def test_chained_falls_through_on_not_found() -> None: + chain = ChainedSecretProvider(_StaticProvider({}), _StaticProvider({"k": "v2"})) + assert chain.get_secret("k") == "v2" + + +def test_chained_propagates_provider_error_immediately() -> None: + fallback = _StaticProvider({"k": "should-never-be-reached"}) + chain = ChainedSecretProvider(_BrokenProvider(), fallback) + with pytest.raises(SecretProviderError, match="IAM is on fire"): + chain.get_secret("k") + + +def test_chained_raises_not_found_when_all_providers_miss() -> None: + chain = ChainedSecretProvider(_StaticProvider({}), _StaticProvider({})) + with pytest.raises(SecretNotFoundError, match="not found in any of 2 provider"): + chain.get_secret("missing") diff --git a/uv.lock b/uv.lock index a493fad..e0a5599 100644 --- a/uv.lock +++ b/uv.lock @@ -754,6 +754,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, ] +[[package]] +name = "diff-cover" +version = "10.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "chardet" }, + { name = "jinja2" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/21/057e816125c162662d2a2cc2ebcd72dd333e78e51678298d07dd3146011a/diff_cover-10.3.0.tar.gz", hash = "sha256:474dbc63e815fbb7567d7b7ca5b104123e96129f25426ebdbc9a1bdbb935b2c6", size = 106546, upload-time = "2026-05-30T14:17:14.32Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/0a/a96e57a7a3fca419cd5ceff0d13dee2166520fa67103fb82624ad64700fb/diff_cover-10.3.0-py3-none-any.whl", hash = "sha256:2e47d5ab3868d1e92131c11f364f3f4a8583c97123d3bbc6b6cc8ce0a4cc2202", size = 58989, upload-time = "2026-05-30T14:17:12.858Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -1196,6 +1211,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/e8/127dc2b246096ad50ef7c8d9b7b31d757787aeb796368bcdd4454e4204c4/grpcio-1.81.0-cp314-cp314-win_amd64.whl", hash = "sha256:b93cee313cae4e113fbb3a0ce1ea5633db6f63cfde2b2dc1d817429026b2a50b", size = 5070848, upload-time = "2026-06-01T05:56:19.735Z" }, ] +[[package]] +name = "grpcio-health-checking" +version = "1.71.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/86/20994347ef36b7626fb74539f13128100dd8b7eaac67efc063264e6cdc80/grpcio_health_checking-1.71.2.tar.gz", hash = "sha256:1c21ece88c641932f432b573ef504b20603bdf030ad4e1ec35dd7fdb4ea02637", size = 16770, upload-time = "2025-06-28T04:24:08.768Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/74/7bc6ab96bf1083cab2684f9c3ae434caa638de3d5c5574e8435e2c146598/grpcio_health_checking-1.71.2-py3-none-any.whl", hash = "sha256:f91db41410d6bd18a7828c5b6ac2bebd77a63483263cbe42bf3c0c9b86cece33", size = 18918, upload-time = "2025-06-28T04:23:56.923Z" }, +] + [[package]] name = "grpcio-status" version = "1.71.2" @@ -2159,6 +2187,7 @@ dependencies = [ { name = "google-api-python-client" }, { name = "google-auth" }, { name = "grpcio" }, + { name = "grpcio-health-checking" }, { name = "grpcio-tools" }, { name = "httpx", extra = ["http2"] }, { name = "opentelemetry-api" }, @@ -2189,6 +2218,7 @@ agents = [ dev = [ { name = "bandit" }, { name = "cyclonedx-bom" }, + { name = "diff-cover" }, { name = "licenseheaders" }, { name = "mypy" }, { name = "pip-audit" }, @@ -2213,6 +2243,7 @@ requires-dist = [ { name = "google-generativeai", marker = "extra == 'agents'", specifier = ">=0.7.0" }, { name = "groq", marker = "extra == 'agents'", specifier = ">=0.9.0" }, { name = "grpcio", specifier = ">=1.62.0" }, + { name = "grpcio-health-checking", specifier = ">=1.62.0" }, { name = "grpcio-tools", specifier = ">=1.62.0" }, { name = "httpx", extras = ["http2"], specifier = ">=0.27.0,<0.28.0" }, { name = "mcp", marker = "extra == 'agents'", specifier = ">=1.6.0" }, @@ -2237,6 +2268,7 @@ provides-extras = ["agents"] dev = [ { name = "bandit", extras = ["toml"], specifier = ">=1.7.9" }, { name = "cyclonedx-bom", specifier = "==4.6.1" }, + { name = "diff-cover", specifier = ">=9.0.0" }, { name = "licenseheaders", specifier = ">=0.8.8" }, { name = "mypy", specifier = ">=1.9.0" }, { name = "pip-audit", specifier = ">=2.7.0" },