diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index fd6fbeb..22aa8d3 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,6 +38,11 @@ jobs: with: languages: python queries: security-and-quality + config: | + paths-ignore: + # grpcio-tools output — regenerated from connector.proto, not hand-edited + - '**/*_pb2.py' + - '**/*_pb2_grpc.py' - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@8272c299f21ca24af15dfe9ac0971ba969e5e0d5 # v3.36.2 diff --git a/playground/scenario_post_visit.py b/playground/scenario_post_visit.py index 434e671..56861cf 100644 --- a/playground/scenario_post_visit.py +++ b/playground/scenario_post_visit.py @@ -75,7 +75,7 @@ async def run_scenario(): print("\n=== STEP 1: Patient Discovery ===") patient_search_params = {"family": "Smith", "given": "Jason", "birthdate": "1985-01-01"} - logger.info(f"Searching for patient: {patient_search_params}") + logger.info("Searching for patient by fields: %s", ", ".join(sorted(patient_search_params))) try: patient_result = await connector.internal_execute( @@ -83,7 +83,7 @@ async def run_scenario(): trace_id=trace_id, ) patient_id = patient_result.resource.get("id") - logger.info(f"Found Patient ID: {patient_id}") + logger.info("Patient resolved successfully") except Exception as e: logger.error(f"Patient search failed: {e}") return @@ -91,7 +91,7 @@ async def run_scenario(): print("\n=== STEP 2: Encounter Identification ===") today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") encounter_params = {"patient": patient_id, "status": "finished", "date": today} - logger.info(f"Finding encounter for patient {patient_id} on {today}") + logger.info(f"Finding encounter for resolved patient on {today}") try: enc_result = await connector.internal_execute( @@ -116,7 +116,7 @@ async def run_scenario(): return encounter_id = enc_result.resources[0].get("id") - logger.info(f"Selected Encounter ID: {encounter_id}") + logger.info("Encounter selected") except Exception as e: logger.error(f"Encounter search failed: {e}") return @@ -142,7 +142,7 @@ async def run_scenario(): context={"encounter": [{"reference": f"Encounter/{encounter_id}"}]}, ) - logger.info(f"Uploading clinical note for Encounter {encounter_id}") + logger.info("Uploading clinical note for selected encounter") try: doc_result = await connector.internal_execute(doc_input, trace_id=trace_id) logger.info(f"SUCCESS! Created DocumentReference: {doc_result.resource_id}") diff --git a/playground/scenarios.py b/playground/scenarios.py index dd6d0b2..367e93e 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, ValidationError, model_validator from dotenv import load_dotenv import os +import sys import asyncio from node_wire_runtime.errors import ErrorMapper from node_wire_runtime.models import ErrorCategory @@ -252,7 +253,6 @@ def _safe_error_return( ) -> ScenarioResponse: from node_wire_runtime.errors import ErrorMapper from node_wire_runtime.models import ErrorCategory - import logging log = logging.getLogger("playground.scenarios") @@ -266,8 +266,8 @@ def _safe_error_return( if hasattr(e, "errors") and callable(getattr(e, "errors", None)): try: safe_msg = e.errors()[0].get("msg", "Schema validation failed") - except Exception: - pass + except Exception as detail_exc: + log.debug("Could not extract validation error detail: %s", detail_exc) steps[-1].status = "error" steps[-1].details = f"[{mapped_err.category.value}] {safe_msg}" @@ -312,6 +312,12 @@ async def execute_with_retry( else: logger.error(f"Action failed after {max_retries + 1} attempts: {e}") raise last_exception + # Only reachable when max_retries < 0 leaves the loop with zero iterations. + raise ( + last_exception + if last_exception + else RuntimeError(f"execute_with_retry made no attempts (max_retries={max_retries})") + ) # Single shared factory for playground scenarios (matches REST: enabled + exposed_via includes "rest"). @@ -408,7 +414,7 @@ def add_step( add_step("Patient Discovery", "pending", display_name="Identify Patient") try: if payload.patient_id: - logger.info(f"Performing direct Patient ID lookup: {payload.patient_id}") + logger.info("Performing direct Patient ID lookup") p_res = await execute_with_retry( connector, FhirPatientReadInput(resource_id=payload.patient_id), trace_id, steps[-1] ) @@ -423,7 +429,17 @@ def add_step( }.items() if v is not None } - logger.info(f"Searching for patient: {patient_search_params}") + # Log only literal field names (not the payload-derived dict) so no + # user-controlled data reaches the log record. + provided_fields = { + "family": payload.patient_family is not None, + "given": payload.patient_given is not None, + "birthdate": payload.patient_birthdate is not None, + } + logger.info( + "Searching for patient by fields: %s", + ", ".join(sorted(k for k, present in provided_fields.items() if present)), + ) p_res = await execute_with_retry( connector, FhirPatientReadInput(search_params=patient_search_params), @@ -455,16 +471,15 @@ def add_step( add_step("Encounter Identification", "pending", display_name="Locate Medical Visit") try: if payload.encounter_id: - logger.info( - f"Using manual Encounter ID: {payload.encounter_id}", extra={"trace_id": trace_id} - ) + logger.info("Using manually supplied Encounter ID", extra={"trace_id": trace_id}) encounter_id = payload.encounter_id enc_type = "Manual" enc_status = "verified" else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") logger.info( - f"Searching for encounter... patient={patient_id}, date={visit_date}", + "Searching for encounter for resolved patient on date=%s", + visit_date.replace("\r", "").replace("\n", ""), extra={"trace_id": trace_id}, ) enc_res = await execute_with_retry( @@ -505,7 +520,7 @@ def add_step( raise ValueError("The found Encounter resource is missing a valid FHIR ID.") logger.info( - f"Selected Encounter: ID={encounter_id}, Type={enc_type}, Status={enc_status}", + f"Selected Encounter: Type={enc_type}, Status={enc_status}", extra={"trace_id": trace_id}, ) @@ -717,8 +732,6 @@ def add_step( http_action = connector response = await execute_with_retry(http_action, request_input, trace_id, steps[-1]) - import json - resp_body = json.loads(response.body) steps[-1].status = "success" @@ -760,8 +773,6 @@ def add_step( add_step("Audit", "pending", display_name="Update Audit Log") try: # Simulate background task - import asyncio - await asyncio.sleep(0.4) steps[-1].status = "success" @@ -800,7 +811,7 @@ def add_step( add_step("Patient Discovery", "pending", display_name="Identify Patient") try: if payload.patient_id: - logger.info(f"Cerner: direct Patient ID lookup: {payload.patient_id}") + logger.info("Cerner: direct Patient ID lookup") p_res = await execute_with_retry( connector, FhirCernerPatientReadInput(resource_id=payload.patient_id), @@ -818,7 +829,17 @@ def add_step( }.items() if v } - logger.info(f"Cerner: searching for patient: {search_params}") + # Log only literal field names (not the payload-derived dict) so no + # user-controlled data reaches the log record. + provided_fields = { + "family": bool(payload.patient_family), + "given": bool(payload.patient_given), + "birthdate": bool(payload.patient_birthdate), + } + logger.info( + "Cerner: searching for patient by fields: %s", + ", ".join(sorted(k for k, present in provided_fields.items() if present)), + ) p_res = await execute_with_retry( connector, FhirCernerPatientReadInput(search_params=search_params), @@ -1009,8 +1030,8 @@ def add_step( decoded_text = base64.b64decode(content[0]["attachment"]["data"]).decode( "utf-8" ) - except Exception: - pass + except Exception as decode_exc: + logger.debug("Could not decode attachment content: %s", decode_exc) beautiful_data = { "id": doc_res.resource_id, @@ -1852,9 +1873,6 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: Accepts a user message + conversation history, runs through the ToolHiveAgent, and returns the agent's reply with any tool steps executed. """ - import os - import sys - trace_id = str(uuid.uuid4()) logger.info( "Agent Chat request | trace_id=%s | provider=%s", @@ -2009,8 +2027,6 @@ async def agent_chat_stream(payload: AgentChatInput) -> Any: async def stream_events(): try: - import sys - from agents.llm_factory import LLMProviderFactory from agents.toolhive import ( MultiMcpClient, @@ -2079,13 +2095,16 @@ async def stream_events(): yield json.dumps(event) + "\n" except Exception as exc: - logger.error("Agent Chat stream failed: %s", exc, exc_info=True) trace_id = str(uuid.uuid4()) + logger.error("Agent Chat stream failed (trace_id=%s): %s", trace_id, exc, exc_info=True) yield ( json.dumps( { "type": "final_chunk", - "content": f"Sorry, I encountered an error: {exc}. Please check the server configuration and try again.", + "content": ( + "Sorry, I encountered an internal error. " + f"Please check the server configuration and try again. trace_id={trace_id}" + ), } ) + "\n" @@ -2451,8 +2470,7 @@ def add_step( try: if payload.patient_id: logger.info( - "[ExtViewer] Direct Patient ID lookup: %s on %s", - payload.patient_id, + "[ExtViewer] Direct Patient ID lookup on %s", system_label, extra={"trace_id": trace_id}, ) @@ -2488,9 +2506,16 @@ def add_step( }.items() if v } + # Log only literal field names (not the payload-derived dict) so no + # user-controlled data reaches the log record. + provided_fields = { + "family": bool(payload.patient_family), + "given": bool(payload.patient_given), + "birthdate": bool(payload.patient_birthdate), + } logger.info( - "[ExtViewer] Identity-layer search: %s on %s", - search_params, + "[ExtViewer] Identity-layer search by fields [%s] on %s", + ", ".join(sorted(k for k, present in provided_fields.items() if present)), system_label, extra={"trace_id": trace_id}, ) diff --git a/src/agents/llm_base.py b/src/agents/llm_base.py new file mode 100644 index 0000000..9722f66 --- /dev/null +++ b/src/agents/llm_base.py @@ -0,0 +1,89 @@ +# +# SPDX-FileCopyrightText: 2026 AOT Technologies +# SPDX-License-Identifier: Apache-2.0 +# +""" +LLM Provider Base +================= +Provider-agnostic data models and the abstract provider interface. + +Kept separate from :mod:`agents.llm_factory` so provider implementations can +depend on the interface without importing the factory (which imports the +providers), avoiding a module-level import cycle. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +# --------------------------------------------------------------------------- +# Data models (provider-agnostic) +# --------------------------------------------------------------------------- + + +@dataclass +class ToolCall: + """A single tool-call request returned by the LLM.""" + + id: str + name: str + arguments: Dict[str, Any] + + +@dataclass +class LLMMessage: + """A single message in the conversation thread.""" + + role: str # "system" | "user" | "assistant" | "tool" + content: Optional[str] = None + tool_calls: List[ToolCall] = field(default_factory=list) + tool_call_id: Optional[str] = None # required for role="tool" responses + name: Optional[str] = None # tool name for role="tool" + + +@dataclass +class LLMResponse: + """Raw response from the LLM.""" + + content: Optional[str] + tool_calls: List[ToolCall] = field(default_factory=list) + stop_reason: str = "stop" # "stop" | "tool_calls" + + @property + def wants_tool_call(self) -> bool: + return bool(self.tool_calls) + + +# --------------------------------------------------------------------------- +# Abstract base +# --------------------------------------------------------------------------- + + +class BaseLLMProvider(ABC): + """Common interface for all LLM providers.""" + + @abstractmethod + def chat_with_tools( + self, + messages: List[LLMMessage], + tools: List[Dict[str, Any]], + ) -> LLMResponse: + """ + Send a conversation to the LLM, optionally with a set of tools. + + Parameters + ---------- + messages: + Full conversation history in provider-agnostic format. + tools: + List of MCP-style tool objects with ``name``, ``description``, + and ``input_schema`` keys. + + Returns + ------- + LLMResponse + The model's response, which may include tool_calls. + """ diff --git a/src/agents/llm_factory.py b/src/agents/llm_factory.py index 0d293b2..d4a4ff1 100644 --- a/src/agents/llm_factory.py +++ b/src/agents/llm_factory.py @@ -24,80 +24,23 @@ from __future__ import annotations import os -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Type - - -# --------------------------------------------------------------------------- -# Data models (provider-agnostic) -# --------------------------------------------------------------------------- - - -@dataclass -class ToolCall: - """A single tool-call request returned by the LLM.""" - - id: str - name: str - arguments: Dict[str, Any] - - -@dataclass -class LLMMessage: - """A single message in the conversation thread.""" - - role: str # "system" | "user" | "assistant" | "tool" - content: Optional[str] = None - tool_calls: List[ToolCall] = field(default_factory=list) - tool_call_id: Optional[str] = None # required for role="tool" responses - name: Optional[str] = None # tool name for role="tool" - - -@dataclass -class LLMResponse: - """Raw response from the LLM.""" - - content: Optional[str] - tool_calls: List[ToolCall] = field(default_factory=list) - stop_reason: str = "stop" # "stop" | "tool_calls" - - @property - def wants_tool_call(self) -> bool: - return bool(self.tool_calls) - - -# --------------------------------------------------------------------------- -# Abstract base -# --------------------------------------------------------------------------- - - -class BaseLLMProvider(ABC): - """Common interface for all LLM providers.""" - - @abstractmethod - def chat_with_tools( - self, - messages: List[LLMMessage], - tools: List[Dict[str, Any]], - ) -> LLMResponse: - """ - Send a conversation to the LLM, optionally with a set of tools. - - Parameters - ---------- - messages: - Full conversation history in provider-agnostic format. - tools: - List of MCP-style tool objects with ``name``, ``description``, - and ``input_schema`` keys. - - Returns - ------- - LLMResponse - The model's response, which may include tool_calls. - """ - +from typing import Any, Dict, Optional, Type + +# Re-exported for backward compatibility; the canonical home is agents.llm_base. +from agents.llm_base import ( # noqa: F401 + BaseLLMProvider, + LLMMessage, + LLMResponse, + ToolCall, +) + +__all__ = [ + "BaseLLMProvider", + "LLMMessage", + "LLMProviderFactory", + "LLMResponse", + "ToolCall", +] # --------------------------------------------------------------------------- # Factory diff --git a/src/agents/providers/anthropic_provider.py b/src/agents/providers/anthropic_provider.py index 8515a53..78f5273 100644 --- a/src/agents/providers/anthropic_provider.py +++ b/src/agents/providers/anthropic_provider.py @@ -17,7 +17,7 @@ import logging from typing import Any, Dict, List, Optional -from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall +from agents.llm_base import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall logger = logging.getLogger("agents.providers.anthropic") @@ -75,13 +75,13 @@ def _messages_to_claude( return result, system_prompt -anthropic: Any = None +anthropic: Any try: import anthropic as _anthropic anthropic = _anthropic except ImportError: - pass + anthropic = None class AnthropicProvider(BaseLLMProvider): diff --git a/src/agents/providers/gemini_provider.py b/src/agents/providers/gemini_provider.py index 379dd8d..7fcaec6 100644 --- a/src/agents/providers/gemini_provider.py +++ b/src/agents/providers/gemini_provider.py @@ -19,7 +19,7 @@ import uuid from typing import Any, Dict, List, Optional -from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall +from agents.llm_base import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall logger = logging.getLogger("agents.providers.gemini") @@ -36,13 +36,13 @@ def _mcp_schema_to_gemini(schema: Dict[str, Any]) -> Dict[str, Any]: return cleaned -genai: Any = None +genai: Any try: import google.generativeai as _genai genai = _genai except ImportError: - pass + genai = None class GeminiProvider(BaseLLMProvider): diff --git a/src/agents/providers/groq_provider.py b/src/agents/providers/groq_provider.py index 8105efb..ee84820 100644 --- a/src/agents/providers/groq_provider.py +++ b/src/agents/providers/groq_provider.py @@ -18,7 +18,7 @@ import logging from typing import Any, Dict, List, cast -from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall +from agents.llm_base import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall logger = logging.getLogger("agents.providers.groq") @@ -68,13 +68,13 @@ def _messages_to_groq(messages: List[LLMMessage]) -> List[Dict[str, Any]]: return result -Groq: Any = None +Groq: Any try: from groq import Groq as _Groq Groq = _Groq except ImportError: - pass + Groq = None class GroqProvider(BaseLLMProvider): diff --git a/src/agents/providers/openai_provider.py b/src/agents/providers/openai_provider.py index 29d1b50..344e9d8 100644 --- a/src/agents/providers/openai_provider.py +++ b/src/agents/providers/openai_provider.py @@ -17,7 +17,7 @@ import logging from typing import Any, Dict, List, cast -from agents.llm_factory import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall +from agents.llm_base import BaseLLMProvider, LLMMessage, LLMResponse, ToolCall logger = logging.getLogger("agents.providers.openai") @@ -66,13 +66,13 @@ def _messages_to_openai(messages: List[LLMMessage]) -> List[Dict[str, Any]]: return result -OpenAI: Any = None +OpenAI: Any try: from openai import OpenAI as _OpenAI OpenAI = _OpenAI except ImportError: - pass + OpenAI = None class OpenAIProvider(BaseLLMProvider): diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index de8cb03..07e9ae6 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -149,7 +149,7 @@ def _is_tool_failure(tool_result: str) -> bool: if isinstance(data, dict) and data.get("success") is False: return True except json.JSONDecodeError: - pass + return False return False @@ -221,9 +221,11 @@ class AgentRunResult: @runtime_checkable class McpClient(Protocol): - async def list_tools(self) -> List[Dict[str, Any]]: ... + async def list_tools(self) -> List[Dict[str, Any]]: + """Return the tool definitions exposed by the MCP server.""" - async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: ... + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: + """Invoke the named tool with the given arguments and return its result.""" class ToolHiveMcpClient: @@ -549,7 +551,7 @@ def __init__( async def run(self, task: str) -> AgentRunResult: trace_id = str(uuid.uuid4()) logger.info("Agent run started | trace_id=%s", trace_id) - logger.info("Task: %s", task) + logger.info("Task received (%d chars); content withheld from logs", len(task)) # Import here to avoid circular dependency in tests from agents.llm_factory import LLMMessage @@ -730,7 +732,7 @@ async def _run_events_inner(self, task: str, trace_id: str) -> AsyncIterator[Dic are emitted immediately after each MCP tool call completes. """ logger.info("Streaming agent run started | trace_id=%s", trace_id) - logger.info("Task: %s", task) + logger.info("Task received (%d chars); content withheld from logs", len(task)) from agents.llm_factory import LLMMessage @@ -909,7 +911,11 @@ async def _execute_task( print(f"Provider : {provider_name}") print(f"MCP info : {mcp_info}") print("=" * 60) - print(f"Task:\n{task}\n") + print( + "Task: fetch patient details from FHIR, create a summary file in Google Drive, " + "and email it to the recipient. (Patient identifiers and recipient address are " + "withheld from logs.)\n" + ) run_result = await agent.run(task) diff --git a/src/bindings/mcp_server/auth.py b/src/bindings/mcp_server/auth.py index be4d05e..6123452 100644 --- a/src/bindings/mcp_server/auth.py +++ b/src/bindings/mcp_server/auth.py @@ -97,18 +97,19 @@ def __init__(self) -> None: ) -_mcp_auth_env_bootstrapped = False +# Mutable holder so _bootstrap_mcp_auth_env can flip the flag without +# rebinding a module global. +_BOOTSTRAP_STATE: dict[str, bool] = {"done": False} def _bootstrap_mcp_auth_env() -> None: - global _mcp_auth_env_bootstrapped - if _mcp_auth_env_bootstrapped: + if _BOOTSTRAP_STATE["done"]: return # Some launch paths on Windows can miss .env loading for the MCP worker. # If MCP auth vars are missing/empty, try loading project .env once. if os.environ.get("NW_MCP_API_KEY") or os.environ.get("NW_MCP_JWT_SECRET"): - _mcp_auth_env_bootstrapped = True + _BOOTSTRAP_STATE["done"] = True return # Align with REST/bindings: when dotenv merge is disabled (pytest, CI, prod), @@ -122,7 +123,7 @@ def _bootstrap_mcp_auth_env() -> None: repo_root_env = Path(__file__).resolve().parents[3] / ".env" load_dotenv(override=False) load_dotenv(repo_root_env, override=False) - _mcp_auth_env_bootstrapped = True + _BOOTSTRAP_STATE["done"] = True def mcp_auth_disabled() -> bool: diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index adff955..14c5578 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -351,7 +351,11 @@ async def invoke_tool( run_args[page_param] = min(val, max_items) clamped_params[page_param] = run_args[page_param] except (ValueError, TypeError): - pass + logger.debug( + "Ignoring non-numeric pagination parameter %s=%r", + page_param, + current_val, + ) try: response = await connector.run( diff --git a/src/bindings/rest_api/auth.py b/src/bindings/rest_api/auth.py index 2e05cc4..3560f1f 100644 --- a/src/bindings/rest_api/auth.py +++ b/src/bindings/rest_api/auth.py @@ -20,6 +20,8 @@ import hashlib import os +import secrets +from functools import lru_cache from typing import Callable from starlette.middleware.base import BaseHTTPMiddleware @@ -34,6 +36,34 @@ REST_CALLER_STATE_KEY = "nw_rest_caller_identity" +# Per-process secret salt for fingerprinting credentials into rate-limit keys. +# Per-process suffices because the rate limiter is in-memory and process-local. +_IDENTITY_FINGERPRINT_SALT = secrets.token_bytes(32) + +# OWASP-recommended work factor for PBKDF2-HMAC-SHA256. Paid once per unique +# token per process (see the lru_cache below), not per request. +_IDENTITY_FINGERPRINT_ITERATIONS = 600_000 + + +@lru_cache(maxsize=1024) +def _fingerprint_token(token: str) -> str: + """ + Derive a non-reversible rate-limit key from a verified credential. + + Uses PBKDF2 with a per-process secret salt so a leaked key (logs, metrics) + cannot be brute-forced offline back to the token. Auth middleware runs + before any caller, so only verified credentials reach this cache — the + input space is the small set of valid tokens, and the KDF cost is + amortized to once per token per process. + """ + digest = hashlib.pbkdf2_hmac( + "sha256", + token.encode("utf-8"), + _IDENTITY_FINGERPRINT_SALT, + _IDENTITY_FINGERPRINT_ITERATIONS, + ).hex()[:16] + return f"token:{digest}" + def get_rest_caller_identity(request: Request) -> CallerIdentity | None: """Return JWT/API-key caller identity attached by middleware, if any.""" @@ -102,13 +132,14 @@ def get_request_identity_key(request: Request) -> str: Return a stable, non-sensitive identity key for request-level controls. Preference order: - 1) Auth token/API key (fingerprinted, never returned raw) + 1) Auth token/API key (PBKDF2-fingerprinted with a per-process salt, never returned raw) 2) Trusted client IP (``NW_REST_TRUSTED_PROXY_HOPS`` controls X-Forwarded-For use) + + Keys are stable within a process, which matches the process-local rate limiter. """ token = _extract_bearer_or_api_key(request) if token: - digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16] - return f"token:{digest}" + return _fingerprint_token(token) return f"ip:{_client_ip_for_rate_limit(request)}" diff --git a/src/bindings/rest_api/body_limit.py b/src/bindings/rest_api/body_limit.py index 13fd9a2..66fbb6b 100644 --- a/src/bindings/rest_api/body_limit.py +++ b/src/bindings/rest_api/body_limit.py @@ -73,11 +73,13 @@ async def __call__( content_length = _header_value(headers, "content-length") if content_length is not None: try: - if int(content_length) > max_bytes: - await _send_payload_too_large(send) - return + declared_length = int(content_length) except ValueError: - pass + # Malformed header; fall back to streaming enforcement below + declared_length = None + if declared_length is not None and declared_length > max_bytes: + await _send_payload_too_large(send) + return bytes_read = 0 diff --git a/src/node_wire_fhir_cerner/logic.py b/src/node_wire_fhir_cerner/logic.py index 5216b7c..e9df118 100644 --- a/src/node_wire_fhir_cerner/logic.py +++ b/src/node_wire_fhir_cerner/logic.py @@ -11,6 +11,7 @@ import os from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import httpx @@ -520,7 +521,8 @@ async def _create_document_reference( codings = doc_ref["type"].get("coding", []) for coding in codings: # Validate CodeSet 72 vs LOINC - if "loinc.org" in coding.get("system", ""): + system_host = urlparse(coding.get("system", "")).hostname or "" + if system_host == "loinc.org" or system_host.endswith(".loinc.org"): raise ValueError( "Cerner requires the proprietary CodeSet 72 system for DocumentReference 'type', " "not a LOINC system URL. " @@ -671,8 +673,8 @@ async def _create_document_reference( try: body = response.json() resource_id = body.get("id") - except Exception: - pass + except Exception as exc: + logger.debug("Could not parse create response body as JSON: %s", exc) if not resource_id: raise ValueError( diff --git a/src/node_wire_fhir_epic/logic.py b/src/node_wire_fhir_epic/logic.py index ee11934..7f6c017 100644 --- a/src/node_wire_fhir_epic/logic.py +++ b/src/node_wire_fhir_epic/logic.py @@ -523,8 +523,8 @@ async def _create_document_reference( try: body = response.json() resource_id = body.get("id") - except Exception: - pass + except Exception as exc: + logger.debug("Could not parse create response body as JSON: %s", exc) if not resource_id: raise ValueError( diff --git a/src/node_wire_runtime/auth/oauth2.py b/src/node_wire_runtime/auth/oauth2.py index 54aeab4..955ae2e 100644 --- a/src/node_wire_runtime/auth/oauth2.py +++ b/src/node_wire_runtime/auth/oauth2.py @@ -223,8 +223,8 @@ def _resolve_scopes(self) -> Optional[str]: val = self._sp.get_secret(self._scopes_secret) if val and val.strip(): return val.strip() - except Exception: - pass + except Exception as exc: + logger.debug("Scope secret lookup failed, falling back to static scopes: %s", exc) if self._static_scopes: return " ".join(self._static_scopes) return None diff --git a/src/node_wire_runtime/auth/service_account.py b/src/node_wire_runtime/auth/service_account.py index 5b0ea64..e1fe550 100644 --- a/src/node_wire_runtime/auth/service_account.py +++ b/src/node_wire_runtime/auth/service_account.py @@ -115,10 +115,7 @@ def _build_credentials(self) -> Any: logger.debug( "ServiceAccountAuthProvider: credentials built", - extra={ - "sa_json_secret": self._sa_json_secret, - "scopes": self._scopes, - }, + extra={"scopes": self._scopes}, ) return creds @@ -149,8 +146,5 @@ async def refresh(self) -> None: Forces :meth:`get_client_credentials` to rebuild from the secret on the next call, picking up any rotated service-account JSON. """ - logger.debug( - "ServiceAccountAuthProvider: credentials cache invalidated", - extra={"sa_json_secret": self._sa_json_secret}, - ) + logger.debug("ServiceAccountAuthProvider: credentials cache invalidated") self._credentials = None diff --git a/src/node_wire_runtime/auth/static_token.py b/src/node_wire_runtime/auth/static_token.py index 87ae1fd..55f4568 100644 --- a/src/node_wire_runtime/auth/static_token.py +++ b/src/node_wire_runtime/auth/static_token.py @@ -86,15 +86,12 @@ async def get_headers(self) -> Dict[str, str]: if self._cached_header is None: logger.debug( "StaticTokenAuthProvider: resolving secret", - extra={"secret_key": self._secret_key, "header": self._header_name}, + extra={"header": self._header_name}, ) self._cached_header = self._build_header() return dict(self._cached_header) async def refresh(self) -> None: """Invalidate the cached header so the secret is re-read on the next call.""" - logger.debug( - "StaticTokenAuthProvider: cache invalidated", - extra={"secret_key": self._secret_key}, - ) + logger.debug("StaticTokenAuthProvider: cache invalidated") self._cached_header = None diff --git a/src/node_wire_runtime/caller_identity.py b/src/node_wire_runtime/caller_identity.py index 2c38a45..d88b1db 100644 --- a/src/node_wire_runtime/caller_identity.py +++ b/src/node_wire_runtime/caller_identity.py @@ -12,6 +12,7 @@ import os import re from dataclasses import dataclass +from functools import lru_cache from typing import Any, Mapping import jwt @@ -21,8 +22,6 @@ JWT_AUDIENCE_ENV = "NW_JWT_AUDIENCE" JWT_ISSUER_ENV = "NW_JWT_ISSUER" -_jwt_aud_iss_warned = False - @dataclass(frozen=True) class CallerIdentity: @@ -101,11 +100,9 @@ def load_jwt_audience_issuer_from_env() -> tuple[str, str] | None: return str(audience).strip(), str(issuer).strip() +@lru_cache(maxsize=None) def warn_jwt_audience_issuer_not_configured() -> None: - global _jwt_aud_iss_warned - if _jwt_aud_iss_warned: - return - _jwt_aud_iss_warned = True + """Warn once per process; ``lru_cache`` suppresses repeats.""" logger.warning( "JWT secret is configured but %s and %s are not set; JWT verification will fail", JWT_AUDIENCE_ENV, diff --git a/src/node_wire_runtime/log_sanitization.py b/src/node_wire_runtime/log_sanitization.py index 60f0258..24fef27 100644 --- a/src/node_wire_runtime/log_sanitization.py +++ b/src/node_wire_runtime/log_sanitization.py @@ -67,8 +67,6 @@ } ) -_SANITIZING_FILTER_INSTALLED = False - def _normalize_key(key: str) -> str: return key.lower().replace("_", "").replace("-", "").replace(" ", "") @@ -137,16 +135,10 @@ def filter(self, record: logging.LogRecord) -> bool: # noqa: A003 def install_sanitizing_log_filter() -> None: """Attach :class:`SanitizingLogFilter` to the root logger once.""" - global _SANITIZING_FILTER_INSTALLED - if _SANITIZING_FILTER_INSTALLED: - return root = logging.getLogger() - for flt in root.filters: - if isinstance(flt, SanitizingLogFilter): - _SANITIZING_FILTER_INSTALLED = True - return + if any(isinstance(flt, SanitizingLogFilter) for flt in root.filters): + return root.addFilter(SanitizingLogFilter()) - _SANITIZING_FILTER_INSTALLED = True def fhir_log_extra(trace_id: str, *, mode: str) -> dict[str, str]: diff --git a/src/node_wire_runtime/mcp_client/client.py b/src/node_wire_runtime/mcp_client/client.py index 44a88aa..504562c 100644 --- a/src/node_wire_runtime/mcp_client/client.py +++ b/src/node_wire_runtime/mcp_client/client.py @@ -164,8 +164,8 @@ async def _initialize(self) -> None: notif = {"jsonrpc": "2.0", "method": "notifications/initialized"} try: await self._request("POST", json_body=notif) - except Exception: - pass + except Exception as exc: + logger.debug("MCP initialized notification failed (non-fatal): %s", exc) self._initialized = True async def _rpc(self, method: str, params: Dict[str, Any]) -> Any: diff --git a/src/node_wire_runtime/mcp_client/redirect_listener.py b/src/node_wire_runtime/mcp_client/redirect_listener.py index ad9dfbd..bfe4a0a 100644 --- a/src/node_wire_runtime/mcp_client/redirect_listener.py +++ b/src/node_wire_runtime/mcp_client/redirect_listener.py @@ -157,8 +157,8 @@ async def _handle( writer.close() try: await writer.wait_closed() - except Exception: - pass + except Exception as exc: + logger.debug("Redirect listener connection close failed: %s", exc) if self._server: self._server.close() diff --git a/src/node_wire_runtime/mcp_client/token_storage.py b/src/node_wire_runtime/mcp_client/token_storage.py index 91b51d7..6451268 100644 --- a/src/node_wire_runtime/mcp_client/token_storage.py +++ b/src/node_wire_runtime/mcp_client/token_storage.py @@ -170,8 +170,8 @@ def delete(self, partition_key: str) -> None: if self._keyring is not None: try: self._keyring.delete_password(_KEYRING_SERVICE, partition_key) - except Exception: - pass + except Exception as exc: + logger.debug("Keychain delete failed (entry may not exist): %s", exc) self._fallback.delete(partition_key) diff --git a/src/node_wire_runtime/observability.py b/src/node_wire_runtime/observability.py index 2f70440..abe1d79 100644 --- a/src/node_wire_runtime/observability.py +++ b/src/node_wire_runtime/observability.py @@ -26,7 +26,9 @@ logger = logging.getLogger("runtime.observability") -_INITIALIZED: bool = False +# Mutable holder so init_observability can flip the flag without rebinding a +# module global (tests reset it via _STATE["initialized"] = False). +_STATE: dict[str, bool] = {"initialized": False} class _OtelContextFilter(logging.Filter): @@ -92,8 +94,7 @@ def init_observability(app_name: str = "node_wire") -> None: This is intended to be called once at process startup (e.g. from the bindings_entrypoint main()) and is safe to call multiple times. """ - global _INITIALIZED - if _INITIALIZED: + if _STATE["initialized"]: return install_sanitizing_log_filter() @@ -166,5 +167,5 @@ def init_observability(app_name: str = "node_wire") -> None: except Exception as exc: # pragma: no cover - defensive; should not fail app startup logger.warning("Failed to initialize Traceloop/OpenLLMetry: %s", exc) - _INITIALIZED = True + _STATE["initialized"] = True logger.info("Observability initialized for app %s", app_name) diff --git a/src/node_wire_runtime/policies/mcp_scope_policy.py b/src/node_wire_runtime/policies/mcp_scope_policy.py index 5016ed9..563e4f5 100644 --- a/src/node_wire_runtime/policies/mcp_scope_policy.py +++ b/src/node_wire_runtime/policies/mcp_scope_policy.py @@ -7,6 +7,7 @@ import json import logging import os +from functools import lru_cache from pathlib import Path from typing import Mapping, Optional @@ -20,7 +21,15 @@ DEFAULT_SCOPE_MODE_ALLOW = "allow" DEFAULT_SCOPE_MODE_DENY = "deny" -_warned_implicit_scope_default = False + +@lru_cache(maxsize=None) +def _warn_implicit_scope_default() -> None: + """Warn once per process; ``lru_cache`` suppresses repeats.""" + logger.warning( + "NW_MCP_SCOPE_POLICY_DEFAULT is unset; using code default 'deny'. " + "Set NW_MCP_SCOPE_POLICY_DEFAULT explicitly and configure " + "NW_*_API_KEY_SCOPES (or JWT scopes) for each transport." + ) def _truthy_default_mode(val: str) -> str: @@ -32,16 +41,9 @@ def _truthy_default_mode(val: str) -> str: def load_scope_policy_default_from_env() -> str: """Return ``allow`` or ``deny`` from ``NW_MCP_SCOPE_POLICY_DEFAULT`` (default: deny).""" - global _warned_implicit_scope_default raw = os.environ.get("NW_MCP_SCOPE_POLICY_DEFAULT") if not raw or not str(raw).strip(): - if not _warned_implicit_scope_default: - logger.warning( - "NW_MCP_SCOPE_POLICY_DEFAULT is unset; using code default 'deny'. " - "Set NW_MCP_SCOPE_POLICY_DEFAULT explicitly and configure " - "NW_*_API_KEY_SCOPES (or JWT scopes) for each transport." - ) - _warned_implicit_scope_default = True + _warn_implicit_scope_default() return DEFAULT_SCOPE_MODE_DENY return _truthy_default_mode(str(raw)) diff --git a/src/node_wire_slack/schema.py b/src/node_wire_slack/schema.py index a544483..f9e7aea 100644 --- a/src/node_wire_slack/schema.py +++ b/src/node_wire_slack/schema.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -108,20 +108,6 @@ class SlackUploadFileInput(_BaseSlackInput): ) -# --------------------------------------------------------------------------- -# Discriminated union — used by BaseConnector internally -# --------------------------------------------------------------------------- - -_SlackOperationUnion = Annotated[ - Union[ - SlackPostMessageInput, - SlackSendDirectMessageInput, - SlackUploadFileInput, - ], - Field(discriminator="action"), -] - - # --------------------------------------------------------------------------- # Output # --------------------------------------------------------------------------- diff --git a/tests/playground/slack/conftest.py b/tests/playground/slack/conftest.py index 9c3a180..df52872 100644 --- a/tests/playground/slack/conftest.py +++ b/tests/playground/slack/conftest.py @@ -52,14 +52,15 @@ def slack_upload_channel() -> str: Falls back to SLACK_TEST_CHANNEL, but skips if that is still a name — the Slack external-upload API requires an ID, not a name. """ - if _DEFAULT_CHANNEL_ID: - return _DEFAULT_CHANNEL_ID - if _DEFAULT_CHANNEL and _DEFAULT_CHANNEL[0].upper() in ("C", "G", "D"): - return _DEFAULT_CHANNEL - pytest.skip( - "upload_file tests require a channel ID. " - "Set SLACK_TEST_CHANNEL_ID (e.g. C0ANP6RADHU) in .env." - ) + channel = _DEFAULT_CHANNEL_ID + if not channel and _DEFAULT_CHANNEL and _DEFAULT_CHANNEL[0].upper() in ("C", "G", "D"): + channel = _DEFAULT_CHANNEL + if not channel: + pytest.skip( + "upload_file tests require a channel ID. " + "Set SLACK_TEST_CHANNEL_ID (e.g. C0ANP6RADHU) in .env." + ) + return channel @pytest.fixture(scope="session") diff --git a/tests/test_grpc_async_runner.py b/tests/test_grpc_async_runner.py index 4d8d3ce..9d5f5ea 100644 --- a/tests/test_grpc_async_runner.py +++ b/tests/test_grpc_async_runner.py @@ -26,7 +26,7 @@ async def record_loop() -> int: def test_background_runner_handles_concurrent_calls() -> None: runner = BackgroundAsyncRunner() runner.start() - errors: list[BaseException] = [] + errors: list[Exception] = [] async def echo(value: int) -> int: await asyncio.sleep(0.01) @@ -35,7 +35,7 @@ async def echo(value: int) -> int: def worker(value: int) -> None: try: assert runner.run(echo(value)) == value - except BaseException as exc: + except Exception as exc: errors.append(exc) threads = [threading.Thread(target=worker, args=(i,)) for i in range(8)] diff --git a/tests/test_observability.py b/tests/test_observability.py index f74feea..b2e7483 100644 --- a/tests/test_observability.py +++ b/tests/test_observability.py @@ -49,9 +49,9 @@ def _ensure_traceloop_stub_modules() -> Iterator[None]: @pytest.fixture(autouse=True) def reset_observability_initialized() -> None: - obs._INITIALIZED = False + obs._STATE["initialized"] = False yield - obs._INITIALIZED = False + obs._STATE["initialized"] = False @contextmanager diff --git a/tests/test_rest_connector_dispatch.py b/tests/test_rest_connector_dispatch.py index 31ceb97..7db207b 100644 --- a/tests/test_rest_connector_dispatch.py +++ b/tests/test_rest_connector_dispatch.py @@ -90,6 +90,11 @@ ] +def test_smoke_cases_cover_all_eight_connectors() -> None: + covered = {case.values[0] for case in _REST_SMOKE_CASES} + assert covered == set(_ALL_EIGHT_CONNECTOR_IDS) + + def _stub_connector(response: ConnectorResponse) -> MagicMock: connector = MagicMock() connector.run = AsyncMock(return_value=response) diff --git a/tests/test_rest_identity_key.py b/tests/test_rest_identity_key.py index eace743..ca50a2d 100644 --- a/tests/test_rest_identity_key.py +++ b/tests/test_rest_identity_key.py @@ -4,8 +4,6 @@ # from __future__ import annotations -import hashlib - import pytest from starlette.requests import Request @@ -28,14 +26,21 @@ def _request( def test_identity_key_uses_token_hash_regardless_of_xff(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("NW_REST_TRUSTED_PROXY_HOPS", "0") token = "secret-api-key" - digest = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16] request = _request( headers=[ (b"x-api-key", token.encode("utf-8")), (b"x-forwarded-for", b"203.0.113.99"), ], ) - assert get_request_identity_key(request) == f"token:{digest}" + key = get_request_identity_key(request) + assert key.startswith("token:") + # Raw token never appears in the key. + assert token not in key + # Stable across requests with the same token. + assert get_request_identity_key(request) == key + # Distinct tokens map to distinct keys. + other = _request(headers=[(b"x-api-key", b"another-key")]) + assert get_request_identity_key(other) != key def test_identity_key_ignores_spoofed_xff_when_proxy_hops_zero( diff --git a/tests/test_slack_connector.py b/tests/test_slack_connector.py index a60dade..733714e 100644 --- a/tests/test_slack_connector.py +++ b/tests/test_slack_connector.py @@ -28,6 +28,8 @@ from __future__ import annotations import base64 +import importlib +import json import logging from typing import Any from unittest.mock import AsyncMock, patch @@ -51,7 +53,10 @@ _resolve_blocks, _upload_bytes, ) -import node_wire_slack.registration # noqa: F401 + +# Load the ErrorMapper registrations the same way the production +# connector_registry does — the error_code assertions below depend on them. +importlib.import_module("node_wire_slack.registration") # --------------------------------------------------------------------------- # Helpers @@ -143,8 +148,6 @@ async def test_post_message_with_blocks_json_string() -> None: """Blocks provided as a JSON string are parsed before being sent.""" connector = _make_connector() blocks = [{"type": "section", "text": {"type": "mrkdwn", "text": "hello"}}] - import json - blocks_str = json.dumps(blocks) captured: dict[str, Any] = {} @@ -452,8 +455,6 @@ def test_resolve_blocks_list_passthrough() -> None: def test_resolve_blocks_valid_json_string() -> None: - import json - blocks = [{"type": "section"}] assert _resolve_blocks(json.dumps(blocks)) == blocks @@ -464,8 +465,6 @@ def test_resolve_blocks_invalid_json_raises() -> None: def test_resolve_blocks_non_array_json_raises() -> None: - import json - with pytest.raises(SlackMessageError, match="must be a JSON array"): _resolve_blocks(json.dumps({"type": "section"})) @@ -559,7 +558,8 @@ async def test_resolve_channel_id_network_error_falls_back_and_logs( monkeypatch.delenv("NW_SLACK_SKIP_RESOLVE", raising=False) class _RaisingClient: - def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass async def __aenter__(self) -> "_RaisingClient": return self @@ -684,9 +684,7 @@ async def post(self, *args: object, **kwargs: object) -> FakeResponse: def test_default_timeout_honors_env(monkeypatch: pytest.MonkeyPatch) -> None: """Q-7: _DEFAULT_TIMEOUT is configurable via NW_SLACK_TIMEOUT / NW_TIMEOUT.""" - import importlib - - import node_wire_slack.logic as slack_logic + slack_logic = importlib.import_module("node_wire_slack.logic") monkeypatch.setenv("NW_SLACK_TIMEOUT", "12.5") try: